# 以图搜图

本 Notebook 将引导您完成设置和使用基于 Amazon OpenSearch 的图像搜索系统的全过程。

## 1. 环境准备

首先，我们需要创建一个python环境并安装所需的依赖项。

In [None]:
# 安装依赖项
!pip install boto3 opensearch-py numpy tqdm opencv-python pillow

In [None]:
!unzip imgs.zip

## 2. 设置环境变量

接下来，我们需要设置一些环境变量，这些变量将在后续步骤中使用。

**【注意】**：'ImgSearch' role 需要预先创建好，建议给予SageMakerFullAccess 和 aoss:*，BedrockFullAccess权限。另外需要在bedrock的model access中打开各类embedding模型的权限

In [None]:
# 修改这两个变量
REGION = 'us-west-2'
# The S3 bucket for the coresponding SageMaker
BUCKET_NAME = 'sagemaker-us-west-2-687752207838'

In [None]:
import os
import glob

# 设置环境变量
OPENSEARCH_INDEX_NAME = 'image-index'
OPENSEARCH_COLLECTION_NAME = 'image-search-collection'

Prefix="imgs"

EMBEDDING_LENGTH = 256
EMBEDDING_MODEL_ID = 'amazon.titan-embed-image-v1'

# The SageMaker Execution Role Name
ROLE_NAME = 'ImgSearch'
# The Ec2 of Dify Role Name
ROLE_NAME_EC2 = 'DifyEc2Role'

# The policies on aoss side
encryption_policy_name = 'image-search-encryption-policy'
network_policy_name = 'image-search-network-policy'
access_policy_name = 'image-search-access-policy'

## 拷贝数据到S3路径

In [None]:
import boto3
import os
import glob

s3_client = boto3.client('s3')
local_directory = './imgs'  # 本地图片目录

# 获取所有 .png 文件
png_files = glob.glob(os.path.join(local_directory, '*.png'))

# 遍历并上传所有 .png 文件
for local_file in png_files:
    # 获取文件名
    file_name = os.path.basename(local_file)
    # 构建 S3 中的对象路径
    s3_object_key = f"{Prefix}/{file_name}"
    
    print(f"上传文件: {local_file} 到 {BUCKET_NAME}/{s3_object_key}")
    
    # 上传文件到 S3
    s3_client.upload_file(
        local_file,
        BUCKET_NAME,
        s3_object_key
    )

print(f"成功上传了 {len(png_files)} 个 PNG 文件到 S3")

## 3. 创建 OpenSearch Serverless Collection

现在，我们将创建一个 OpenSearch Serverless Collection，用于存储图像嵌入向量。

In [None]:
import boto3
from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth
import base64
import json
import time
import os
import logging

boto3.set_stream_logger('boto3.resources', logging.DEBUG)
# AWS 配置
region = REGION  # 例如 'us-west-2'
service = 'aoss'
credentials = boto3.Session().get_credentials()

awsauth = AWSV4SignerAuth(credentials, region, "aoss")

# OpenSearch Serverless 客户端
aoss_client = boto3.client(service_name="opensearchserverless", region_name=REGION)

role_arn = f"arn:aws:iam::{boto3.client('sts').get_caller_identity()['Account']}:role/{ROLE_NAME}"
role_arn_ec2 = f"arn:aws:iam::{boto3.client('sts').get_caller_identity()['Account']}:role/{ROLE_NAME_EC2}"

In [None]:
print(f"SageMaker Ingestion Role: {role_arn}")
print(f"Dify Search Role: {role_arn_ec2}")

In [None]:
# 创建加密策略
try:
    security_policy = aoss_client.create_security_policy(
        name = encryption_policy_name,
        policy = json.dumps(
            {
                'Rules': [{'Resource': ['collection/' + OPENSEARCH_COLLECTION_NAME],
                'ResourceType': 'collection'}],
                'AWSOwnedKey': True
            }),
        type = 'encryption'
    )
    print(f"创建加密策略: {encryption_policy_name}")
except aoss_client.exceptions.ConflictException:
    print(f"加密策略 {encryption_policy_name} 已存在")

In [None]:
# 创建网络策略
try:
    network_policy = aoss_client.create_security_policy(
        name = network_policy_name,
        policy = json.dumps(
            [
                {'Rules': [{'Resource': ['collection/' + OPENSEARCH_COLLECTION_NAME],
                'ResourceType': 'collection'}],
                'AllowFromPublic': True}
            ]),
        type = 'network'
    )
    print(f"创建网络策略: {network_policy_name}")
except aoss_client.exceptions.ConflictException:
    print(f"网络策略 {network_policy_name} 已存在")

In [None]:
# 创建访问策略
try:
    access_policy = aoss_client.create_access_policy(
        name = access_policy_name,
        policy = json.dumps(
        [
            {
                'Rules': [
                    {
                        'Resource': ['collection/' + OPENSEARCH_COLLECTION_NAME],
                        'Permission': [
                            'aoss:CreateCollectionItems',
                            'aoss:DeleteCollectionItems',
                            'aoss:UpdateCollectionItems',
                            'aoss:DescribeCollectionItems',
                        ],
                        'ResourceType': 'collection'
                    },
                    {
                        'Resource': ['index/' + '*' + '/*'],
                        'Permission': [
                            'aoss:CreateIndex',
                            'aoss:DeleteIndex',
                            'aoss:UpdateIndex',
                            'aoss:DescribeIndex',
                            'aoss:ReadDocument',
                            'aoss:WriteDocument',
                        ],
                        'ResourceType': 'index'
                    }
                ],
                'Principal': [role_arn, role_arn_ec2],
                'Description': 'Complete data access policy'
            }
        ]),
        type = 'data'
    )

    print(f"创建访问策略: {access_policy_name}")
except aoss_client.exceptions.ConflictException:
    print(f"访问策略 {access_policy_name} 已存在")

In [None]:
# 等待策略生效
print("等待策略生效...")
time.sleep(10)
print("继续执行...")

In [None]:
# 创建集合
collection_name = OPENSEARCH_COLLECTION_NAME
try:
    response = aoss_client.create_collection(
        name=collection_name,
        type='VECTORSEARCH'
    )
    print(f"集合已创建: {response['createCollectionDetail']['name']}")
except aoss_client.exceptions.ConflictException:
    print(f"集合 {collection_name} 已存在")

In [None]:
# 等待集合变为活动状态
print("等待集合变为活动状态...")
while True:
    status = aoss_client.list_collections(collectionFilters={'name':OPENSEARCH_COLLECTION_NAME})['collectionSummaries'][0]['status']
    print(f"当前状态: {status}")
    if status in ('ACTIVE', 'FAILED'):
        break
    time.sleep(10)

print(f"集合 {collection_name} 已激活")

In [None]:
# 获取集合端点
collection = aoss_client.list_collections(collectionFilters={'name':OPENSEARCH_COLLECTION_NAME})['collectionSummaries'][0]

collection_arn = collection['arn']
collection_id = collection['id']

host = collection_id + '.' + region + '.aoss.amazonaws.com'
print(f"OpenSearch 端点: {host}")

# 创建 OpenSearch 客户端
os_client = OpenSearch(hosts=[{'host': host, 'port': 443}], http_auth=awsauth, use_ssl=True, verify_certs=True, connection_class=RequestsHttpConnection)

In [None]:
collection_endpoint=f"https://{collection_id}.{region}.aoss.amazonaws.com"
print(collection_endpoint)

In [None]:
# 创建索引

index_body = {
    "settings": {
        "index.knn": True
    },
    "mappings": {
        "properties": {
            "pic_emb": {
                "type": "knn_vector",
                "dimension": EMBEDDING_LENGTH,
                "similarity": "cosine",
                "method": {
                    "name": "hnsw",
                    "engine": "faiss"
                }
            },
            "s3_uri": {
                "type": "keyword"
            },
            "pic_name":  {
                "type": "keyword"
            },
            "pic_hash": {
                "type": "keyword"
            }
        }
    }
}

if not os_client.indices.exists(index=OPENSEARCH_INDEX_NAME):
    try:
        os_client.indices.create(index=OPENSEARCH_INDEX_NAME, body=index_body)
        print(f"索引 {OPENSEARCH_INDEX_NAME} 已创建")
    except Exception as e:
        print(f"异常：{e}， 请注意当前role是否能操作OpenSearch Serverless")
else:
    print(f"索引 {OPENSEARCH_INDEX_NAME} 已存在")

## 4. 导入数据到 OpenSearch

现在，我们将导入图像数据到 OpenSearch。首先，我们需要定义一些辅助函数。

In [None]:
# 导入所需的库
import boto3
from opensearchpy import helpers
import cv2
import os
import base64
from PIL import Image
from datetime import datetime
from tqdm import tqdm
from io import BytesIO
from PIL import Image
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
import numpy as np
import time
import hashlib
from botocore.exceptions import ClientError

# 定义常量
MAX_IMAGE_HEIGHT: int = 2048
MAX_IMAGE_WIDTH: int = 2048

# 创建 Bedrock 客户端
bedrock = boto3.client('bedrock-runtime')

In [None]:
# 定义获取嵌入向量的函数
def getEmbeddings(inputImageB64, max_retries=10, initial_delay=2, text=None, output_embedding_length=1024):
    def exponential_delay(attempt):
        return initial_delay * (2 ** attempt)

    for attempt in range(max_retries):
        try:
            request_body = {
                "inputText": text,
                "inputImage": inputImageB64,
                "embeddingConfig": {
                    "outputEmbeddingLength": output_embedding_length
                }
            }

            body = json.dumps(request_body)
            response = bedrock.invoke_model(
                body=body,
                modelId=EMBEDDING_MODEL_ID,
                accept="application/json",
                contentType="application/json")
            response_body = json.loads(response.get("body").read())
            return np.array([response_body.get("embedding")]).astype(np.float32)
        except ClientError as e:
            if attempt == max_retries - 1:
                raise  # If this was the last attempt, re-raise the exception

            delay = exponential_delay(attempt)
            print(f"{e}")
            print(f"请求失败。{delay} 秒后重试...")
            time.sleep(delay)

    # If we've exhausted all retries
    raise Exception("达到最大重试次数。无法获取嵌入向量。")

def list_s3_images(s3_client, bucket, prefix):
    """List image files in S3 bucket with given prefix"""
    image_extensions = ['.jpg', '.jpeg', '.png', '.webp', '.gif', '.bmp', '.tiff']

    paginator = s3_client.get_paginator('list_objects_v2')
    pages = paginator.paginate(Bucket=bucket, Prefix=prefix)

    image_keys = []

    for page in pages:
        if 'Contents' not in page:
            continue

        for obj in page['Contents']:
            key = obj['Key']
            if any(key.lower().endswith(ext) for ext in image_extensions):
                image_keys.append(key)

    return image_keys

In [None]:
def import_data_to_opensearch(s3_image_path):
    s3_client = boto3.client('s3')
    print(s3_image_path)
    image_s3_path_list = list_s3_images(s3_client, BUCKET_NAME, Prefix)

    # 处理每个图像
    successful_imports = 0
    failed_imports = 0
    
    actions = []
    for image_key in tqdm(image_s3_path_list, desc="处理图像"):
        try:
            s3_uri = f"s3://{BUCKET_NAME}/{image_key}"
            print(f"Processing image: {s3_uri}")
            response = s3_client.get_object(Bucket=BUCKET_NAME, Key=image_key)
            image_content = response['Body'].read()
            image_base64 = base64.b64encode(image_content).decode('utf-8')
            embedding = getEmbeddings(
                image_base64,
                text=None,
                output_embedding_length=EMBEDDING_LENGTH
            )[0].tolist()

            # 生成唯一的图像哈希（这里简单使用文件名和时间戳）
            pic_hash = hashlib.md5(str(embedding).encode('utf-8')).hexdigest()
            # 准备文档
            doc = {
                '_index': OPENSEARCH_INDEX_NAME,
                '_source': {
                    'pic_emb': embedding,
                    's3_uri': s3_uri,  # 这里使用本地路径，也可以上传到 S3 并使用 S3 URI
                    'pic_name': os.path.basename(image_key),
                    'pic_hash': pic_hash
                }
            }

            actions.append(doc)
            # 每 100 个文档批量导入一次
            if len(actions) >= 100:
                success, failed = helpers.bulk(os_client, actions, stats_only=True)
                successful_imports += success
                failed_imports += failed
                actions = []
        except Exception as e:
            print(f"处理图像 {img_path} 时出错: {e}")
            failed_imports += 1

    # 导入剩余的文档
    if actions:
        try:
            success, failed = helpers.bulk(os_client, actions, stats_only=True)
            successful_imports += success
            failed_imports += failed
        except Exception as e:
            print(f"批量导入时出错: {e}")
            failed_imports += len(actions)
            
    print(f"导入完成: {successful_imports} 成功, {failed_imports} 失败")

#### 注入S3路径的图片

In [None]:
s3_image_path=f"s3://{BUCKET_NAME}/{Prefix}/"
import_data_to_opensearch(s3_image_path)

## 5. 测试效果

In [None]:
def search_by_aos_knn(os_client, q_embedding, index_name, size=10):
    #Note: 查询时无需指定排序方式，最临近的向量分数越高，做过归一化(0.0~1.0)
    #精准Knn的查询语法参考 https://opensearch.org/docs/latest/search-plugins/knn/knn-score-script/
    #模糊Knn的查询语法参考 https://opensearch.org/docs/latest/search-plugins/knn/approximate-knn/
    #这里采用的是模糊查询
    query = {
        "size": size,
        "query": {
            "knn": {
                "pic_emb": {
                    "vector": q_embedding,
                    "k": size
                }
            }
        }
    }

    opensearch_knn_respose = []
    query_response = os_client.search(
        body=query,
        index=index_name
    )
    opensearch_knn_respose = [{'score':item['_score'],'s3_uri':item['_source']['s3_uri'], 'pic_name':item['_source']['pic_name'], "id": item["_id"]}  for item in query_response["hits"]["hits"]]
    return opensearch_knn_respose

In [None]:
query_image_path='./imgs/car_0.png'
with open(query_image_path, 'rb') as f:
    bytes_data = f.read()
    input_image_base64 = base64.b64encode(bytes_data).decode('utf-8')
    embedding = getEmbeddings(input_image_base64, text=None, output_embedding_length=EMBEDDING_LENGTH)[0]

opensearch_knn_respose = search_by_aos_knn(os_client=os_client, q_embedding=embedding, index_name=OPENSEARCH_INDEX_NAME, size=10)
for result in opensearch_knn_respose:
    print(result)

## 6. 清理Index

In [None]:
def delete_aoss_index(os_client, index_name):
    try:
        # 检查索引是否存在
        if os_client.indices.exists(index=index_name):
            # 删除索引
            response = os_client.indices.delete(index=index_name)
            print(f"索引 '{index_name}' 已成功删除")
            return response
        else:
            print(f"索引 '{index_name}' 不存在，无需删除")
            return {"acknowledged": True, "message": f"索引 '{index_name}' 不存在"}
            
    except Exception as e:
        print(f"删除索引 '{index_name}' 时发生错误: {str(e)}")
        # 可以选择重新抛出异常或返回错误信息
        raise e

delete_aoss_index(os_client, OPENSEARCH_INDEX_NAME)