# 以图搜图

本 Notebook 将引导您完成设置和使用基于 Amazon OpenSearch 的图像搜索系统的全过程。

## 1. 环境准备

首先，我们需要创建一个python环境并安装所需的依赖项。

In [None]:
# 安装依赖项
!pip install boto3 opensearch-py requests_aws4auth numpy tqdm opencv-python pillow

## 2. 设置环境变量

接下来，我们需要设置一些环境变量，这些变量将在后续步骤中使用。

**【注意】**：'ImgSearch' role 需要预先创建好，建议给予SageMakerFullAccess 和 aoss:*，BedrockFullAccess权限。另外需要在bedrock的model access中打开各类embedding模型的权限

In [None]:
import os

# 设置环境变量
OPENSEARCH_INDEX_NAME = 'image-index'
OPENSEARCH_COLLECTION_NAME = 'image-search-collection'
REGION = 'us-west-2'

# The S3 bucket for the coresponding SageMaker 
BUCKET_NAME = 'sagemaker-us-west-2-687752207838'

EMBEDDING_LENGTH = 256
EMBEDDING_MODEL_ID = 'amazon.titan-embed-image-v1'

# The SageMaker Execution Role Name
ROLE_NAME = 'ImgSearch'

# The policies on aoss side
encryption_policy_name = 'image-search-encryption-policy'
network_policy_name = 'image-search-network-policy'
access_policy_name = 'image-search-access-policy'

## 3. 创建 OpenSearch Serverless Collection

现在，我们将创建一个 OpenSearch Serverless Collection，用于存储图像嵌入向量。

In [None]:
import boto3
from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth
from requests_aws4auth import AWS4Auth
import base64
import json
import time
import os
import logging

boto3.set_stream_logger('boto3.resources', logging.DEBUG)
# AWS 配置
region = REGION  # 例如 'us-west-2'
service = 'aoss'
credentials = boto3.Session().get_credentials()

awsauth = AWSV4SignerAuth(credentials, region, "aoss")

# OpenSearch Serverless 客户端
aoss_client = boto3.client(service_name="opensearchserverless", region_name=REGION)

role_arn = f"arn:aws:iam::{boto3.client('sts').get_caller_identity()['Account']}:role/{ROLE_NAME}"

In [None]:
role_arn

In [None]:
# 创建加密策略
try:
    security_policy = aoss_client.create_security_policy(
        name = encryption_policy_name,
        policy = json.dumps(
            {
                'Rules': [{'Resource': ['collection/' + OPENSEARCH_COLLECTION_NAME],
                'ResourceType': 'collection'}],
                'AWSOwnedKey': True
            }),
        type = 'encryption'
    )
    print(f"创建加密策略: {encryption_policy_name}")
except aoss_client.exceptions.ConflictException:
    print(f"加密策略 {encryption_policy_name} 已存在")

In [None]:
# 创建网络策略
try:
    network_policy = aoss_client.create_security_policy(
        name = network_policy_name,
        policy = json.dumps(
            [
                {'Rules': [{'Resource': ['collection/' + OPENSEARCH_COLLECTION_NAME],
                'ResourceType': 'collection'}],
                'AllowFromPublic': True}
            ]),
        type = 'network'
    )
    print(f"创建网络策略: {network_policy_name}")
except aoss_client.exceptions.ConflictException:
    print(f"网络策略 {network_policy_name} 已存在")

In [None]:
# 创建访问策略
try:
    access_policy = aoss_client.create_access_policy(
        name = access_policy_name,
        policy = json.dumps(
        [
            {
                'Rules': [
                    {
                        'Resource': ['collection/' + OPENSEARCH_COLLECTION_NAME],
                        'Permission': [
                            'aoss:CreateCollectionItems',
                            'aoss:DeleteCollectionItems',
                            'aoss:UpdateCollectionItems',
                            'aoss:DescribeCollectionItems',
                        ],
                        'ResourceType': 'collection'
                    },
                    {
                        'Resource': ['index/' + '*' + '/*'],
                        'Permission': [
                            'aoss:CreateIndex',
                            'aoss:DeleteIndex',
                            'aoss:UpdateIndex',
                            'aoss:DescribeIndex',
                            'aoss:ReadDocument',
                            'aoss:WriteDocument',
                        ],
                        'ResourceType': 'index'
                    }
                ],
                'Principal': [role_arn],
                'Description': 'Complete data access policy'
            }
        ]),
        type = 'data'
    )

    print(f"创建访问策略: {access_policy_name}")
except aoss_client.exceptions.ConflictException:
    print(f"访问策略 {access_policy_name} 已存在")

In [None]:
# 等待策略生效
print("等待策略生效...")
time.sleep(10)
print("继续执行...")

In [None]:
# 创建集合
collection_name = OPENSEARCH_COLLECTION_NAME
try:
    response = aoss_client.create_collection(
        name=collection_name,
        type='VECTORSEARCH'
    )
    print(f"集合已创建: {response['createCollectionDetail']['name']}")
except aoss_client.exceptions.ConflictException:
    print(f"集合 {collection_name} 已存在")

In [None]:
# 等待集合变为活动状态
print("等待集合变为活动状态...")
while True:
    status = aoss_client.list_collections(collectionFilters={'name':OPENSEARCH_COLLECTION_NAME})['collectionSummaries'][0]['status']
    print(f"当前状态: {status}")
    if status in ('ACTIVE', 'FAILED'):
        break
    time.sleep(10)

print(f"集合 {collection_name} 已激活")

In [None]:
# 获取集合端点
collection = aoss_client.list_collections(collectionFilters={'name':OPENSEARCH_COLLECTION_NAME})['collectionSummaries'][0]

collection_arn = collection['arn']
collection_id = collection['id']

host = collection_id + '.' + region + '.aoss.amazonaws.com'
print(f"OpenSearch 端点: {host}")

# 创建 OpenSearch 客户端
os_client = OpenSearch(hosts=[{'host': host, 'port': 443}], http_auth=awsauth, use_ssl=True, verify_certs=True, connection_class=RequestsHttpConnection)

In [None]:
collection_arn

In [None]:
# 创建索引

index_body = {
    "settings": {
        "index.knn": True
    },
    "mappings": {
        "properties": {
            "pic_emb": {
                "type": "knn_vector",
                "dimension": EMBEDDING_LENGTH,
                "similarity": "cosine",
                "method": {
                    "name": "hnsw",
                    "engine": "faiss"
                }
            },
            "s3_uri": {
                "type": "keyword"
            },
            "pic_name":  { 
                "type" : "keyword" 
            }
        }
    }
}

if not os_client.indices.exists(index=OPENSEARCH_INDEX_NAME):
    try:
        os_client.indices.create(index=OPENSEARCH_INDEX_NAME, body=index_body)
        print(f"索引 {OPENSEARCH_INDEX_NAME} 已创建")
    except Exception as e:
        print(f"异常：{e}， 请注意当前role是否能操作OpenSearch Serverless")
else:
    print(f"索引 {OPENSEARCH_INDEX_NAME} 已存在")

## 4. 导入数据到 OpenSearch

现在，我们将导入图像数据到 OpenSearch。首先，我们需要定义一些辅助函数。

In [None]:
# 导入所需的库
import boto3
from opensearchpy import OpenSearch, RequestsHttpConnection, helpers
from requests_aws4auth import AWS4Auth
import cv2
import os
import base64
from PIL import Image
from datetime import datetime
from tqdm import tqdm
from io import BytesIO
from PIL import Image
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
import numpy as np
import time
from botocore.exceptions import ClientError

# 定义常量
MAX_IMAGE_HEIGHT: int = 2048
MAX_IMAGE_WIDTH: int = 2048

# 创建 Bedrock 客户端
bedrock = boto3.client('bedrock-runtime')

In [None]:
# 定义获取嵌入向量的函数
def getEmbeddings(inputImageB64, max_retries=10, initial_delay=2, text=None, output_embedding_length=1024):
    def exponential_delay(attempt):
        return initial_delay * (2 ** attempt)

    for attempt in range(max_retries):
        try:
            request_body = {
                "inputText": text,
                "inputImage": inputImageB64,
                "embeddingConfig": {
                    "outputEmbeddingLength": output_embedding_length
                }
            }
            
            body = json.dumps(request_body)
            response = bedrock.invoke_model(
                body=body,
                modelId=EMBEDDING_MODEL_ID,
                accept="application/json",
                contentType="application/json")
            response_body = json.loads(response.get("body").read())
            return np.array([response_body.get("embedding")]).astype(np.float32)
        except ClientError as e:
            if attempt == max_retries - 1:
                raise  # If this was the last attempt, re-raise the exception

            delay = exponential_delay(attempt)
            print(f"{e}")
            print(f"请求失败。{delay} 秒后重试...")
            time.sleep(delay)
        
    # If we've exhausted all retries
    raise Exception("达到最大重试次数。无法获取嵌入向量。")

In [None]:
# 定义图像处理函数
def resizeandGetByteData(imageFile):
    image = Image.open(imageFile)
    if (image.size[0] * image.size[1]) > (MAX_IMAGE_HEIGHT * MAX_IMAGE_WIDTH):
        image = image.resize((MAX_IMAGE_HEIGHT, MAX_IMAGE_WIDTH))
    with BytesIO() as output:
        image.save(output, 'png')
        bytes_data = output.getvalue()
    return bytes_data

def embed_img(img_patch_pair, preprocessing_img_dir, text=None, output_embedding_length=1024):
    """处理单个图像并返回嵌入向量和图像名"""
    try:
        image_name, patch_name = img_patch_pair[0], img_patch_pair[1]

        patch_path = os.path.join(preprocessing_img_dir, patch_name)
        with open(patch_path, 'rb') as f:
            bytes_data = resizeandGetByteData(f)
            input_image_base64 = base64.b64encode(bytes_data).decode('utf-8')
            embedding = getEmbeddings(input_image_base64, text=text, output_embedding_length=output_embedding_length)[0]
            return (image_name, patch_name, embedding, None)
    except Exception as e:
        return (image_name, patch_name, None, str(e))

def process_embeddings_unordered(img_patch_pairs, preprocessing_img_dir, output_embedding_length=1024):
    """乱序处理，按名称重排序对齐结果"""
    # 存储所有结果和失败记录
    failed_entries = []
    
    # 按原始输入顺序重建结果列表
    embeddings = []
    pairs = []
    
    documents = []

    # 动态线程数（I/O密集型任务可增加）
    max_workers = min(32, (os.cpu_count() or 1) * 4)
    
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # 提交所有任务
        futures = {
            executor.submit(embed_img, pair, preprocessing_img_dir, None, output_embedding_length): pair
            for pair in img_patch_pairs
        }
        
        # 使用tqdm跟踪进度
        for future in tqdm(as_completed(futures), total=len(futures), desc="处理中"):
            image_id, patch_name, embedding, error = future.result()
            if error:
                print(f'错误: {error}')
                failed_entries.append((image_id, patch_name, error))
            else:
                if embedding is not None:
                    embeddings.append(embedding)
                    pairs.append((image_id, patch_name))
                    image_name = image_id.split('/')[-1]
                    documents.append({
                        "s3_uri": f"s3://{BUCKET_NAME}/{image_name}",
                        "pic_emb": embedding,
                        "pic_name": image_name
                    })

    return documents

In [None]:
def split_connected_components(image_path, output_dir="object_patches", min_size=(20, 20)):
    """
    将图像中的非透明区域分割为独立的透明背景 Patch
    """
    img_patch_pairs = []
    
    img_name_without_ext = Path(image_path).stem
    
    # 创建输出目录
    os.makedirs(output_dir, exist_ok=True)
    
    # 读取图像并保留透明通道（BGRA 格式）
    img_bgra = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
    if img_bgra is None:
        raise FileNotFoundError(f"无法加载图像: {image_path}")
    
    # 提取 Alpha 通道并二值化（非透明区域为白色）
    alpha_channel = img_bgra[:, :, 3]
    _, binary_mask = cv2.threshold(alpha_channel, 1, 255, cv2.THRESH_BINARY)
    
    # 检测连通区域
    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask)
    
    # 遍历每个连通区域（跳过背景标签 0）
    for idx, label in enumerate(range(1, num_labels)):
        # 获取当前区域的统计信息
        x, y, w, h = stats[label][cv2.CC_STAT_LEFT], \
                     stats[label][cv2.CC_STAT_TOP], \
                     stats[label][cv2.CC_STAT_WIDTH], \
                     stats[label][cv2.CC_STAT_HEIGHT]
        
        if w < min_size[0] and h < min_size[1]:
            continue
        # 裁剪当前区域（包括透明背景）
        patch_bgra = img_bgra[y:y+h, x:x+w]
        
        # 创建透明背景的 Patch（确保 Alpha 通道正确）
        patch_rgba = cv2.cvtColor(patch_bgra, cv2.COLOR_BGRA2RGBA)
        
        patch_pil = Image.fromarray(patch_rgba)
        
        patch_img_name = f'{img_name_without_ext}_{idx}.png'
        img_patch_pairs.append((image_path, patch_img_name))
        # 保存为 PNG
        patch_pil.save(os.path.join(output_dir, patch_img_name))
    
    return img_patch_pairs

In [None]:
# 定义导入数据的函数
def import_data_to_opensearch(image_dir):
    # 创建 S3 客户端
    s3 = boto3.client('s3')
    bucket = BUCKET_NAME
    index_name = OPENSEARCH_INDEX_NAME
    embedding_length = EMBEDDING_LENGTH
    
    # 创建临时目录存储处理后的图像
    preprocessing_img_dir = "object_patches"
    os.makedirs(preprocessing_img_dir, exist_ok=True)

    img_patch_pairs = []
    
    # 获取所有图片
    image_list = os.listdir(image_dir)
    image_list = [Path(img).name for img in image_list if img.endswith('.jpg') or img.endswith('.webp') or img.endswith('.png') or 
                  img.endswith('.jpeg') or img.endswith('.JPG') or img.endswith('.PNG') or img.endswith('.JPEG')]
    image_list = image_list[:20]  # 限制处理的图像数量
    
    print(f"找到 {len(image_list)} 张图像")
    
    # 处理图像，分割连通区域
    with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
        futures = {
            executor.submit(split_connected_components, os.path.join(image_dir, k), preprocessing_img_dir): k
            for k in image_list
        }
        
        for future in tqdm(as_completed(futures), total=len(futures), desc="裁剪连通区域"):
            try:
                img_patch_pairs += future.result()            
            except Exception as e:
                print(f"处理图像时出错: {e}")
    
    print(f"共生成 {len(img_patch_pairs)} 个图像块")
    
    # 为所有图像生成嵌入向量
    documents = process_embeddings_unordered(
        img_patch_pairs, 
        preprocessing_img_dir,
        output_embedding_length=embedding_length
    )
    
    print(f"生成了 {len(documents)} 个文档")
    
    # 上传图像到 S3
    for image_name in tqdm(image_list, desc="上传图像到 S3"):
        s3.upload_file(os.path.join(image_dir, image_name), bucket, image_name)
    
    # 批量导入到 OpenSearch
    actions = [
        {
            "_index": index_name,
            "_source": document
        }
        for document in documents
    ]
    
    print("导入数据到 OpenSearch...")
    helpers.bulk(os_client, actions, request_timeout=300)
    print("数据导入完成")

#### 注入本地路径的图片

In [None]:
import_data_to_opensearch('./imgs')

## 5. 测试效果

In [None]:
def search_by_aos_knn(os_client, q_embedding, index_name, size=10):
    #Note: 查询时无需指定排序方式，最临近的向量分数越高，做过归一化(0.0~1.0)
    #精准Knn的查询语法参考 https://opensearch.org/docs/latest/search-plugins/knn/knn-score-script/
    #模糊Knn的查询语法参考 https://opensearch.org/docs/latest/search-plugins/knn/approximate-knn/
    #这里采用的是模糊查询
    query = {
        "size": size,
        "query": {
            "knn": {
                "pic_emb": {
                    "vector": q_embedding,
                    "k": size
                }
            }
        }
    }

    opensearch_knn_respose = []
    query_response = os_client.search(
        body=query,
        index=index_name
    )
    opensearch_knn_respose = [{'score':item['_score'],'s3_uri':item['_source']['s3_uri'], 'pic_name':item['_source']['pic_name']}  for item in query_response["hits"]["hits"]]
    return opensearch_knn_respose

In [None]:
query_image_path='./object_patches/Picture1_3.png'
with open(query_image_path, 'rb') as f:
    bytes_data = f.read()
    input_image_base64 = base64.b64encode(bytes_data).decode('utf-8')
    embedding = getEmbeddings(input_image_base64, text=None, output_embedding_length=EMBEDDING_LENGTH)[0]

opensearch_knn_respose = search_by_aos_knn(os_client=os_client, q_embedding=embedding, index_name=OPENSEARCH_INDEX_NAME, size=10)
for result in opensearch_knn_respose:
    print(result)