In [None]:
!pip install -q datasets
!pip3 install boto3 requests requests_aws4auth argparse opensearch-py

### 1. Download Experiment Data - squad_v2 (下载实验数据squad_v2）

In [None]:
from datasets import load_dataset

dataset_name = "squad_v2"
dataset = load_dataset(dataset_name)

print(dataset)
sample = dataset["train"][0]
print(sample)

### 2. Setup OpenSearch Index & Model For Experiment(创建用于测试的AOS索引 & 模型) 

- Setup Sparse vector model(创建Sparse Vector模型)
  + 进入OpenSearch的Integration页面, 由于咱们OpenSearch集群是部署在VPC中的，所以选择“Configure VPC Domain”，会弹出一个Cloudformation模版填写。
    ![integration_1.png](./integration_1.png)<br>
    + vpc请选择OpenSearch所在的vpc，security group选择为OpenSearch同一个
    + 子网subnet请选择对应的Private subnet 
    <br>
  + 验证部署的nerual-sparse模型
    进入cloudformation对应stack，切换到output, 获取modelId, ConnecterId 以及Sagemaker endpoint
    ![nerual-sparse.png](./nerual-sparse.png)

- Setup Cohere Multilingual Model & ingestion pipeline(创建cohere模型以及ingest的pipeline)

In [None]:
aos_endpoint='vpc-domain66ac69e0-2m4jji7cweof-4fefsofiqdzu3hxammxwq5hth4.us-west-2.es.amazonaws.com'
# sparse_model_id=<sparse_model_id> # look for sparse_model_id in the output of Cloudformation
index_name="aos-retrieval"

In [None]:
!python3 setup_model_and_pipeline.py --aos_endpoint {aos_endpoint} --sparse_model_id {sparse_model_id} --index_name {index_name}

- Extarct dense_model_id for query embedding from Output of previous cell(根据上个Cell的输出提取query向量化的模型ID)

In [None]:
dense_model_id=<dense_model_id>

### 3. Ingest Data（执行数据摄入）

In [None]:
import json
from tqdm import tqdm
from setup_model_and_pipeline import get_aos_client

def deduplicate_dataset(dataset):
    context_list = [row["context"] for row in dataset]
    context_set = set(context_list)
    return list(context_set)

def build_bulk_body(index_name,sources_list):
    bulk_body = []
    for source in sources_list:
        bulk_body.append({ "index" : { "_index" : index_name} })
        bulk_body.append(source)
    return bulk_body

def ingest_dataset(dataset,aos_client,index_name, bulk_size=50):
    print("Deduplicating dataset...")
    context_list = deduplicate_dataset(dataset)
    # 19029 for train, 1204 for validation
    print(f"Finished deduplication. Total number of passages: {len(context_list)}")
    
    for start_idx in tqdm(range(0,len(context_list),bulk_size)):
        contexts = context_list[start_idx:start_idx+bulk_size]
        response = aos_client.bulk(
            build_bulk_body(index_name, [{"content":context} for context in contexts]),
            # set a large timeout because a new sparse encoding endpoint need warm up
            request_timeout=100
        )
        assert response["errors"]==False
    
    aos_client.indices.refresh(index=index_name,request_timeout=100)

aos_client = get_aos_client(aos_endpoint)
ingest_dataset(dataset=dataset["train"],aos_client=aos_client,index_name=index_name)
ingest_dataset(dataset=dataset["validation"],aos_client=aos_client,index_name=index_name)

### 4. Search benchmark （查询性能测试）

In [None]:
from search_func import search_by_bm25, search_by_dense, search_by_sparse, search_by_dense_sparse, search_by_dense_bm25

In [None]:
QUERY_DATASET_TYPE = "validation"
QUERY_DATASET_SIZE = 1000
QUERY_DATASET_SIZE = min(QUERY_DATASET_SIZE,len(dataset[QUERY_DATASET_TYPE]))
RECALL_K = 4

In [None]:
def calculate_recall_rate(dataset, index_name, aos_client, data_size, query_body_lambda, recall_k=4):
    hit_cnt = 0
    miss_cnt = 0
    for idx, item in tqdm(enumerate(dataset.select(range(data_size)))):
        query = item['question']
        content = item['context']
        response = aos_client.search(index=index_name,size=recall_k, body=query_body_lambda(query))
        docs = [hit["_source"]['content'] for hit in response["hits"]["hits"]]
        if content in docs:
            hit_cnt += 1
        else:
            miss_cnt += 1
    print(f"hit:{hit_cnt}, miss:{miss_cnt}, recall@{recall_k}:{hit_cnt/data_size}")

In [None]:
# bm25

calculate_recall_rate(
    dataset = dataset[QUERY_DATASET_TYPE],
    index_name = index_name,
    aos_client = aos_client,
    data_size = QUERY_DATASET_SIZE,
    query_body_lambda = lambda query_text: {
        "query": {
            "match": {
                "content" : query_text
            }
        }
    },
    recall_k=RECALL_K
)

In [None]:
# dense

calculate_recall_rate(
    dataset = dataset[QUERY_DATASET_TYPE],
    index_name = index_name,
    aos_client = aos_client,
    data_size = QUERY_DATASET_SIZE,
    query_body_lambda = lambda query_text: {
        "query": {
            "neural": {
                "dense_embedding": {
                  "query_text": query_text,
                  "model_id": dense_model_id,
                  "k": recall_k
                }
            }
        }
    },
    recall_k=RECALL_K
)

In [None]:
# sparse

calculate_recall_rate(
    dataset = dataset[QUERY_DATASET_TYPE],
    index_name = index_name,
    aos_client = aos_client,
    data_size = QUERY_DATASET_SIZE,
    query_body_lambda = lambda query_text: {
        "query": {
            "neural_sparse": {
              "sparse_embedding": {
                "query_text": query_text,
                "model_id": sparse_model_id,
                "max_token_score": 3.5
              }
          }
        }
    },
    recall_k=RECALL_K
)

In [None]:
# dense+sparse

calculate_recall_rate(
    dataset = dataset[QUERY_DATASET_TYPE],
    index_name = index_name,
    aos_client = aos_client,
    data_size = QUERY_DATASET_SIZE,
    query_body_lambda = lambda query_text: {
        "query": {
            "hybrid": {
                "queries": [
                    {
                        "neural_sparse": {
                            "sparse_embedding": {
                                "query_text": query_text,
                                "model_id": sparse_model_id,
                                "max_token_score": 3.5
                            }
                        }
                    },
                    {
                        "neural": {
                            "dense_embedding": {
                                "query_text": query_text,
                                "model_id": dense_model_id,
                                "k": 10
                            }
                        }
                    }
                ]
            }
        }
    },
    recall_k=RECALL_K
)

In [None]:
# dense+bm25

calculate_recall_rate(
    dataset = dataset[QUERY_DATASET_TYPE],
    index_name = index_name,
    aos_client = aos_client,
    data_size = QUERY_DATASET_SIZE,
    query_body_lambda = lambda query_text: {
        "query": {
            "hybrid": {
                "queries": [
                    {
                        "match": {
                            "content" : query_text
                        }
                    },
                    {
                        "neural": {
                            "dense_embedding": {
                                "query_text": query_text,
                                "model_id": dense_model_id,
                                "k": 10
                            }
                        }
                    }
                ]
            }
        }
    },
    recall_k=RECALL_K
)