<a href="https://colab.research.google.com/github/Rongxuan-Zhou/CS6120_project/blob/main/notebooks/3_index_construction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [8]:
# 1. Environment Setup
!pip install -q faiss-cpu sentence-transformers nltk
from google.colab import drive
drive.mount('/content/drive')

import os
PROJECT_PATH = "/content/drive/MyDrive/CS6120_project"
os.chdir(PROJECT_PATH)

# GPU detection
import torch
print(f"Available GPU: {torch.cuda.is_available()}")
print("Note: Using CPU version of FAISS for compatibility")

# Create necessary directories
os.makedirs("models/indexes", exist_ok=True)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Available GPU: True
Note: Using CPU version of FAISS for compatibility


In [9]:
# 2. Load fine-tuned SBERT model
from sentence_transformers import SentenceTransformer

# 加载微调过的模型
model = SentenceTransformer("models/sbert_model")
model.to('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Model loaded successfully: {model}")

# 显示模型架构信息
print(f"Model architecture: {model.get_sentence_embedding_dimension()}d embeddings")

Model loaded successfully: SentenceTransformer(
  (0): Transformer({'max_seq_length': 384, 'do_lower_case': False}) with Transformer model: MPNetModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): Normalize()
)
Model architecture: 768d embeddings


In [10]:
# 3. Build FAISS index (based on src/index_builder.py)
import faiss
import numpy as np
import json
from tqdm import tqdm
import time

# 清理 GPU 缓存
if torch.cuda.is_available():
    torch.cuda.empty_cache()

# Load data
print("Loading data...")
with open("data/processed/combined.json") as f:
    data = json.load(f)
    # 合并所有数据以创建更全面的索引
    corpus = data["train"] + data["val"] + data["test"]

print(f"Loaded {len(corpus)} documents")

# Batch encoding with timing
print("Generating embeddings...")
start_time = time.time()
batch_size = 128
embeddings = []
for i in tqdm(range(0, len(corpus), batch_size)):
    batch = corpus[i:i+batch_size]
    emb = model.encode(batch, show_progress_bar=False)
    embeddings.append(emb)

embeddings = np.vstack(embeddings)
encoding_time = time.time() - start_time
dimension = embeddings.shape[1]
print(f"Generated {len(embeddings)} embeddings of dimension {dimension}")
print(f"Encoding completed in {encoding_time:.2f} seconds ({len(corpus)/encoding_time:.2f} docs/sec)")

# 归一化向量以便使用内积计算余弦相似度
print("Normalizing vectors...")
faiss.normalize_L2(embeddings)

# Create flat FAISS index (精确搜索)
print("Building flat index...")
index_flat = faiss.IndexFlatIP(dimension)
index_flat.add(embeddings)
print(f"Flat index built with {index_flat.ntotal} vectors")

# 创建 HNSW 索引（更快的检索）
print("Building HNSW index...")
M = 16  # 每个节点的连接数
ef_construction = 200  # 构建时的搜索宽度
index_hnsw = faiss.IndexHNSWFlat(dimension, M)
index_hnsw.hnsw.efConstruction = ef_construction
index_hnsw.add(embeddings)
print(f"HNSW index built with {index_hnsw.ntotal} vectors")

# 创建 IVF-PQ 索引（更小的内存占用）
print("Building IVF-PQ index...")
nlist = min(100, len(corpus) // 50)  # 聚类中心数
m = 8  # 子向量数
bits = 8  # 每个子向量的位数
quantizer = faiss.IndexFlatL2(dimension)
index_ivfpq = faiss.IndexIVFPQ(quantizer, dimension, nlist, m, bits)
index_ivfpq.train(embeddings)
index_ivfpq.add(embeddings)
print(f"IVF-PQ index built with {index_ivfpq.ntotal} vectors")

# 保存向量维度信息，便于后续加载
embedding_info = {
    "dimension": dimension,
    "count": len(embeddings),
    "corpus_size": len(corpus)
}

with open(os.path.join("models/indexes", "embedding_info.json"), 'w') as f:
    json.dump(embedding_info, f)

Loading data...
Loaded 11000 documents
Generating embeddings...


100%|██████████| 86/86 [00:19<00:00,  4.37it/s]


Generated 11000 embeddings of dimension 768
Encoding completed in 19.68 seconds (559.02 docs/sec)
Normalizing vectors...
Building flat index...
Flat index built with 11000 vectors
Building HNSW index...
HNSW index built with 11000 vectors
Building IVF-PQ index...
IVF-PQ index built with 11000 vectors


In [11]:
# 4. Test indexes with sample queries and measure performance
import time

test_queries = [
    "How does social media affect mental health?",
    "Best programming languages to learn",
    "Artificial intelligence applications",
    "Climate change solutions and mitigation strategies",
    "Nutrition advice for athletes performance"
]

print("Testing indexes with sample queries...")
# 对测试查询进行编码
query_embeddings = model.encode(test_queries)

# 归一化查询向量
faiss.normalize_L2(query_embeddings)

# 设置返回结果数量
k = 5

# 性能测试变量
num_runs = 5
flat_times = []
hnsw_times = []
ivfpq_times = []

# Flat 索引搜索（最精确但最慢）
print("\nFlat index search results:")
for _ in range(num_runs):
    start = time.time()
    D_flat, I_flat = index_flat.search(query_embeddings, k)
    flat_times.append(time.time() - start)

avg_flat_time = sum(flat_times) / len(flat_times)
print(f"Flat index average search time: {avg_flat_time*1000:.2f} ms")

for i, query in enumerate(test_queries):
    print(f"\nQuery: {query}")
    for j in range(min(3, k)):  # 只显示前3个结果
        print(f"  Match {j+1}: (Score: {D_flat[i][j]:.4f})")
        print(f"  {corpus[I_flat[i][j]][:100]}...")

# HNSW 索引搜索（快速近似）
print("\nHNSW index search results:")
for _ in range(num_runs):
    start = time.time()
    D_hnsw, I_hnsw = index_hnsw.search(query_embeddings, k)
    hnsw_times.append(time.time() - start)

avg_hnsw_time = sum(hnsw_times) / len(hnsw_times)
print(f"HNSW index average search time: {avg_hnsw_time*1000:.2f} ms")
print(f"Speedup vs flat index: {avg_flat_time/avg_hnsw_time:.2f}x")

# 计算与精确搜索的重合度
hnsw_overlap = 0
for i in range(len(test_queries)):
    overlap = len(set(I_flat[i][:k]) & set(I_hnsw[i][:k]))
    hnsw_overlap += overlap / k
hnsw_overlap /= len(test_queries)

print(f"HNSW average overlap with flat search: {hnsw_overlap:.2%}")

# IVF-PQ 索引搜索（紧凑型）
print("\nIVF-PQ index search results:")
index_ivfpq.nprobe = 10  # 搜索时检查的聚类数量
for _ in range(num_runs):
    start = time.time()
    D_ivfpq, I_ivfpq = index_ivfpq.search(query_embeddings, k)
    ivfpq_times.append(time.time() - start)

avg_ivfpq_time = sum(ivfpq_times) / len(ivfpq_times)
print(f"IVF-PQ index average search time: {avg_ivfpq_time*1000:.2f} ms")
print(f"Speedup vs flat index: {avg_flat_time/avg_ivfpq_time:.2f}x")
print(f"Memory usage vs flat index: ~{8/m/bits:.2f}x reduction")

# 计算与精确搜索的重合度
ivfpq_overlap = 0
for i in range(len(test_queries)):
    overlap = len(set(I_flat[i][:k]) & set(I_ivfpq[i][:k]))
    ivfpq_overlap += overlap / k
ivfpq_overlap /= len(test_queries)

print(f"IVF-PQ average overlap with flat search: {ivfpq_overlap:.2%}")

# 尝试调整 IVF-PQ 参数提高准确性
print("\nTrying to improve IVF-PQ accuracy by increasing nprobe...")
index_ivfpq.nprobe = 30  # 增加检查的聚类数量
D_ivfpq_improved, I_ivfpq_improved = index_ivfpq.search(query_embeddings, k)

# 计算改进后的重合度
ivfpq_improved_overlap = 0
for i in range(len(test_queries)):
    overlap = len(set(I_flat[i][:k]) & set(I_ivfpq_improved[i][:k]))
    ivfpq_improved_overlap += overlap / k
ivfpq_improved_overlap /= len(test_queries)

print(f"IVF-PQ with nprobe=30 overlap with flat search: {ivfpq_improved_overlap:.2%}")

Testing indexes with sample queries...

Flat index search results:
Flat index average search time: 2.20 ms

Query: How does social media affect mental health?
  Match 1: (Score: 0.3095)
  The Social Cognitive Theory is relevant to health communication. First, the theory deals with cognit...
  Match 2: (Score: 0.3061)
  Practitioners of magnetic field therapy believe that interactions between the body, the earth, and o...
  Match 3: (Score: 0.2739)
  Psychiatrists need to be able to take in complex information and synthesize it to reach a conclusion...

Query: Best programming languages to learn
  Match 1: (Score: 0.3534)
  R is a programming language: you do data analysis in R by writing scripts and functions in the R pro...
  Match 2: (Score: 0.3201)
  An integrated development environment (IDE) is a programming environment that has been packaged as a...
  Match 3: (Score: 0.3157)
  Furthermore, there is no loss of language ability or language learning ability over time. Age is not...

In [13]:
# 5. Save indexes and corpus information
print("Saving indexes...")
index_dir = os.path.join(PROJECT_PATH, "models/indexes")
os.makedirs(index_dir, exist_ok=True)

# 保存所有索引类型
print("Saving flat index...")
faiss.write_index(index_flat, os.path.join(index_dir, "flat_index.faiss"))

print("Saving HNSW index...")
faiss.write_index(index_hnsw, os.path.join(index_dir, "hnsw_index.faiss"))

print("Saving IVF-PQ index...")
faiss.write_index(index_ivfpq, os.path.join(index_dir, "ivfpq_index.faiss"))

# 保存文档数据，用于后续检索时显示结果
print("Saving corpus texts...")
with open(os.path.join(index_dir, "corpus_texts.json"), 'w') as f:
    json.dump(corpus, f)

# 保存索引配置信息
index_config = {
    "flat_index": {"type": "IndexFlatIP", "dimension": dimension},
    "hnsw_index": {"type": "IndexHNSWFlat", "dimension": dimension, "M": M, "efConstruction": ef_construction},
    "ivfpq_index": {"type": "IndexIVFPQ", "dimension": dimension, "nlist": nlist, "m": m, "bits": bits, "recommended_nprobe": 30}
}

with open(os.path.join(index_dir, "index_config.json"), 'w') as f:
    json.dump(index_config, f)

print("\nAll indexes saved successfully to:", index_dir)

Saving indexes...
Saving flat index...
Saving HNSW index...
Saving IVF-PQ index...
Saving corpus texts...

All indexes saved successfully to: /content/drive/MyDrive/CS6120_project/models/indexes


In [14]:
# 6. Create a simple retrieval function for future use
def retrieve_documents(query, index_type="hnsw", top_k=5):
    """
    简单的文档检索函数，可以在其他笔记本中重用

    参数:
    - query: 查询字符串
    - index_type: 使用的索引类型 ("flat", "hnsw", "ivfpq")
    - top_k: 返回的结果数量

    返回:
    - 包含文档和相似度分数的列表
    """
    # 编码查询
    query_embedding = model.encode([query])
    faiss.normalize_L2(query_embedding)

    # 选择索引
    if index_type == "flat":
        index = index_flat
    elif index_type == "hnsw":
        index = index_hnsw
    elif index_type == "ivfpq":
        index = index_ivfpq
        index.nprobe = 30  # 设置合适的nprobe值
    else:
        raise ValueError(f"Unknown index type: {index_type}")

    # 执行搜索
    D, I = index.search(query_embedding, top_k)

    # 构建结果
    results = []
    for i in range(top_k):
        if i < len(I[0]) and I[0][i] >= 0:  # 确保索引有效
            results.append({
                "score": float(D[0][i]),
                "text": corpus[I[0][i]],
                "index": int(I[0][i])
            })

    return results

# 测试检索函数
demo_query = "How to improve productivity while working from home?"
print("\nTesting retrieval function with query:", demo_query)
results = retrieve_documents(demo_query, index_type="hnsw", top_k=3)

for i, result in enumerate(results):
    print(f"Result {i+1} (Score: {result['score']:.4f}):")
    print(f"{result['text'][:150]}...")
    print()

print("Retrieval function is ready for use in other notebooks")


Testing retrieval function with query: How to improve productivity while working from home?
Result 1 (Score: 1.4247):
7. Buy energy efficient devices: Energy efficient devices cost more up front but over years of use, they’re going to save you money. This hold true fo...

Result 2 (Score: 1.4763):
Energy efficiency – doing more with less energy – benefits you, your country, and the world. The benefits of energy efficiency are numerous. But the t...

Result 3 (Score: 1.4916):
Returning to work ■ If you have been off work for a long time, an informal visit during lunchtime or coffee breaks can help you catch up. ■ Your emplo...

Retrieval function is ready for use in other notebooks
