# Hybrid Retrieval SystemImplemented based on src/dynamic_weighting.py

In [None]:
# 1. Environment Setup
!pip install -q rank_bm25 faiss-gpu
from google.colab import drive
drive.mount('/content/drive')

import os
PROJECT_PATH = "/content/drive/MyDrive/CS6120_project"
os.chdir(PROJECT_PATH)

# Memory monitoring
import psutil
print(f"Available memory: {psutil.virtual_memory().available/1024**3:.2f} GB")

In [None]:
# 2. Load components
from src.bm25_retriever import BM25Retriever
from src.dynamic_weighting import DynamicWeighting
import faiss
import json

# Load data
with open("data/processed/combined.json") as f:
    corpus = json.load(f)["train"]
    texts = [doc["text"] for doc in corpus]

# Initialize BM25
bm25 = BM25Retriever(texts)

# Load FAISS index
index = faiss.read_index("indices/sbert_faiss.index")
if torch.cuda.is_available():
    res = faiss.StandardGpuResources()
    index = faiss.index_cpu_to_gpu(res, 0, index)

# Initialize dynamic weighting module
weighting = DynamicWeighting(bm25, index)

In [None]:
# 3. Hybrid retrieval demo
query = "机器学习在自然语言处理中的应用"

# Get component scores
bm25_scores = bm25.get_scores(query)
vector_scores = weighting.get_vector_scores(query)

# Calculate dynamic weights
final_scores = weighting.calculate_hybrid_scores(query)

# Get Top-K results
top_k = 5
indices = np.argsort(final_scores)[-top_k:][::-1]
for i, idx in enumerate(indices):
    print(f"Rank {i+1}: {texts[idx][:200]}...")
    print(f"BM25 Score: {bm25_scores[idx]:.4f}, Vector Score: {vector_scores[idx]:.4f}, Final Score: {final_scores[idx]:.4f}")
    print("-"*50)

In [None]:
# 4. Weight analysis
import matplotlib.pyplot as plt

# Simulate different query lengths
query_lengths = range(5, 50, 5)
bm25_weights = []
vector_weights = []

for length in query_lengths:
    test_query = " ".join(["word"] * length)
    w_bm25, w_vector = weighting.calculate_weights(test_query)
    bm25_weights.append(w_bm25)
    vector_weights.append(w_vector)

# Visualize weight dynamics
plt.figure(figsize=(10,6))
plt.plot(query_lengths, bm25_weights, label='BM25 Weight')
plt.plot(query_lengths, vector_weights, label='Vector Weight')
plt.xlabel('Query Length')
plt.ylabel('Weight')
plt.title('Dynamic Weighting by Query Length')
plt.legend()
plt.show()