## 1. 安装依赖包

In [None]:
!pip install umap-learn -q

## 2. 导入必要的库

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.decomposition import PCA
from sklearn.metrics import pairwise_distances
import umap
import os
import time
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import urllib.request
import tarfile

# 设置设备
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# 参数设置
SIFT_DIM = 128
REDUCED_DIM = 64

## 3. 下载和解压 SIFT1M 数据集

In [None]:
def download_sift1m():
    """
    下载SIFT1M数据集
    """
    # 创建数据目录
    os.makedirs('sift1m', exist_ok=True)
    
    files = [
        "sift_base.fvecs",
        "sift_query.fvecs",
        "sift_groundtruth.ivecs"
    ]
    
    # 检查文件是否已存在
    all_exist = True
    for filename in files:
        if not os.path.exists(os.path.join('sift1m', filename)):
            all_exist = False
            break
            
    if all_exist:
        print("所有文件已存在，跳过下载")
        return True
        
    print("正在下载 SIFT1M 数据集...")
    
    # 尝试下载 tar.gz 文件
    tar_url = "ftp://ftp.irisa.fr/local/texmex/corpus/sift.tar.gz"
    tar_path = "sift1m/sift.tar.gz"
    
    try:
        print(f"尝试从 {tar_url} 下载...")
        # 增加 timeout
        import socket
        socket.setdefaulttimeout(30)
        urllib.request.urlretrieve(tar_url, tar_path)
        print("下载完成，正在解压...")
        
        with tarfile.open(tar_path, "r:gz") as tar:
            tar.extractall(path="sift1m")
            
        # 移动文件到 sift1m 根目录 (解压后会在 sift/ 目录下)
        extracted_dir = os.path.join("sift1m", "sift")
        if os.path.exists(extracted_dir):
            for filename in files:
                src = os.path.join(extracted_dir, filename)
                dst = os.path.join("sift1m", filename)
                if os.path.exists(src):
                    if os.path.exists(dst):
                        os.remove(dst)
                    os.rename(src, dst)
            # 清理
            try:
                import shutil
                shutil.rmtree(extracted_dir)
            except:
                pass
                
        # 删除 tar 文件
        if os.path.exists(tar_path):
            os.remove(tar_path)
        print("数据集准备完成")
        return True
        
    except Exception as e:
        print(f"FTP下载失败: {e}")
        print("尝试使用 wget 下载...")
        
        try:
            # 尝试使用 wget
            res = os.system(f"wget {tar_url} -O {tar_path}")
            if res == 0 and os.path.exists(tar_path):
                print("wget 下载成功，正在解压...")
                with tarfile.open(tar_path, "r:gz") as tar:
                    tar.extractall(path="sift1m")
                
                # 移动文件
                extracted_dir = os.path.join("sift1m", "sift")
                if os.path.exists(extracted_dir):
                    for filename in files:
                        src = os.path.join(extracted_dir, filename)
                        dst = os.path.join("sift1m", filename)
                        if os.path.exists(src):
                            if os.path.exists(dst):
                                os.remove(dst)
                            os.rename(src, dst)
                    try:
                        import shutil
                        shutil.rmtree(extracted_dir)
                    except:
                        pass
                if os.path.exists(tar_path):
                    os.remove(tar_path)
                return True
        except Exception as e2:
            print(f"wget 失败: {e2}")
            
        print("无法自动下载数据集。")
        print("请手动下载 sift.tar.gz 从 http://corpus-texmex.irisa.fr/")
        print("并解压到 sift1m/ 目录下。")
        return False

# 下载数据集
print("开始下载SIFT1M数据集...")
download_sift1m()

## 4. INT3 量化函数

In [None]:
def quantize_to_int3(data):
    """
    將數據量化到INT3 (-4 到 3，共8個離散值)
    改進：使用百分位數 (1% - 99%) 進行截斷，減少極端值對量化的影響
    """
    is_torch = isinstance(data, torch.Tensor)
    
    if is_torch:
        arr = data.cpu().numpy()
    else:
        arr = data
    
    # 使用百分位數來決定範圍，避免 outlier 影響
    # SIFT 特徵通常會有少數極大值
    min_val = np.percentile(arr, 1)
    max_val = np.percentile(arr, 99)
    
    # 將數據縮放到 [-4, 3] 範圍
    scaled = (arr - min_val) / (max_val - min_val + 1e-8) * 7 - 4
    
    # 四捨五入並裁剪到 INT3 範圍
    quantized = np.clip(np.round(scaled), -4, 3).astype(np.int8)
    
    if is_torch:
        return torch.from_numpy(quantized).to(data.device)
    else:
        return quantized

## 5. AutoEncoder 模型定义

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self, input_dim=128, latent_dim=64):
        super(AutoEncoder, self).__init__()
        
        # Encoder (加深並加入 BatchNorm)
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128, 96),
            nn.BatchNorm1d(96),
            nn.ReLU(),
            nn.Linear(96, latent_dim)
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 96),
            nn.BatchNorm1d(96),
            nn.ReLU(),
            nn.Linear(96, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128, input_dim)
        )
    
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded
    
    def encode(self, x):
        return self.encoder(x)

def train_autoencoder(data, input_dim=128, latent_dim=64, epochs=100, batch_size=256, patience=5, min_delta=1e-4):
    """
    訓練自編碼器 (加入 Early Stopping)
    """
    model = AutoEncoder(input_dim, latent_dim).to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.MSELoss()
    
    data_tensor = torch.FloatTensor(data).to(DEVICE)
    
    # 創建 DataLoader 以便於 batch 處理
    dataset = torch.utils.data.TensorDataset(data_tensor)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    model.train()
    
    best_loss = float('inf')
    patience_counter = 0
    
    pbar = tqdm(range(epochs), desc="训练AutoEncoder")
    for epoch in pbar:
        total_loss = 0
        num_batches = 0
        
        for batch in dataloader:
            batch_data = batch[0]
            
            optimizer.zero_grad()
            _, reconstructed = model(batch_data)
            loss = criterion(reconstructed, batch_data)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            num_batches += 1
        
        avg_loss = total_loss / num_batches
        pbar.set_postfix({'loss': f'{avg_loss:.6f}'})
        
        # Early Stopping Check
        if avg_loss < best_loss - min_delta:
            best_loss = avg_loss
            patience_counter = 0
        else:
            patience_counter += 1
            
        if patience_counter >= patience:
            print(f"Early stopping triggered at epoch {epoch+1} (Best Loss: {best_loss:.6f})")
            break
    
    return model

## 6. 五种数据处理方法

In [None]:
def method1_direct_int3(data):
    """方法1: 直接將128維向量量化為INT3"""
    return quantize_to_int3(data)

def method2_average_pooling_int3(data):
    """方法2: 將相鄰兩維做平均，從128維降到64維，然後量化為INT3"""
    reshaped = data.reshape(data.shape[0], 64, 2)
    averaged = np.mean(reshaped, axis=2)
    return quantize_to_int3(averaged)

def method3_pca_int3(data, query_data=None):
    """方法3: 使用PCA降維到64維，然後量化為INT3"""
    print("  - 训练PCA...")
    # 使用 whiten=True 來標準化分量，通常能改善後續量化效果
    pca = PCA(n_components=REDUCED_DIM, whiten=True)
    pca.fit(data)
    
    data_reduced = pca.transform(data)
    data_quantized = quantize_to_int3(data_reduced)
    
    if query_data is not None:
        query_reduced = pca.transform(query_data)
        query_quantized = quantize_to_int3(query_reduced)
        return pca, data_quantized, query_quantized
    
    return pca, data_quantized

def method4_max_pooling_int3(data, query_data=None):
    """方法4: 將相鄰兩維做最大絕對值池化 (Max Magnitude Pooling)，從128維降到64維，然後量化為INT3"""
    
    def max_magnitude_pool(arr):
        # Reshape to (N, 64, 2)
        reshaped = arr.reshape(arr.shape[0], 64, 2)
        a = reshaped[:, :, 0]
        b = reshaped[:, :, 1]
        # 比較絕對值大小
        mask = np.abs(a) >= np.abs(b)
        # 選擇絕對值較大的那個原始值
        return np.where(mask, a, b)

    max_pooled = max_magnitude_pool(data)
    data_quantized = quantize_to_int3(max_pooled)
    
    if query_data is not None:
        q_max_pooled = max_magnitude_pool(query_data)
        query_quantized = quantize_to_int3(q_max_pooled)
        return None, data_quantized, query_quantized
    
    return None, data_quantized

def method5_autoencoder_int3(data, query_data=None, epochs=100):
    """方法5: 使用AutoEncoder降維到64維，然後量化為INT3"""
    print("  - 训练AutoEncoder...")
    # 這裡 epochs 預設為 100，配合 Early Stopping
    ae_model = train_autoencoder(data, SIFT_DIM, REDUCED_DIM, epochs=epochs)
    ae_model.eval()
    
    with torch.no_grad():
        data_tensor = torch.FloatTensor(data).to(DEVICE)
        data_reduced = ae_model.encode(data_tensor).cpu().numpy()
        data_quantized = quantize_to_int3(data_reduced)
        
        if query_data is not None:
            query_tensor = torch.FloatTensor(query_data).to(DEVICE)
            query_reduced = ae_model.encode(query_tensor).cpu().numpy()
            query_quantized = quantize_to_int3(query_reduced)
            return ae_model, data_quantized, query_quantized
    
    return ae_model, data_quantized

## 7. 距离计算和评估函数

In [None]:
def calculate_l2_distances(queries, database):
    """計算L2距離"""
    return pairwise_distances(queries, database, metric='euclidean')

def evaluate_recall_at_k(distances, ground_truth, k_values=[1, 10, 100]):
    """計算Recall@K指標"""
    sorted_indices = np.argsort(distances, axis=1)
    
    recalls = {}
    for k in k_values:
        top_k_predictions = sorted_indices[:, :k]
        
        query_recalls = []
        for i in range(len(ground_truth)):
            gt_i = ground_truth[i]
            # Handle jagged arrays: check length
            limit = min(k, len(gt_i))
            
            if limit == 0:
                query_recalls.append(0.0)
                continue
                
            true_neighbors = set(gt_i[:limit])
            pred_neighbors = set(top_k_predictions[i])
            intersection = true_neighbors & pred_neighbors
            recall = len(intersection) / len(true_neighbors)
            query_recalls.append(recall)
        
        recalls[f'recall@{k}'] = np.mean(query_recalls)
    
    return recalls

## 8. 加载 SIFT1M 数据集

In [None]:
def read_fvecs(filename):
    """讀取.fvecs格式文件"""
    with open(filename, 'rb') as f:
        d = np.fromfile(f, dtype=np.int32, count=1)[0]
        f.seek(0)
        # .fvecs 格式: 4 bytes int32 (dimension) + d * 4 bytes float32 (data)
        # 读取为 float32，header 的 int32 会被读成一个 float32，reshape 后切片去掉即可
        data = np.fromfile(f, dtype=np.float32)
        data = data.reshape(-1, d + 1)
        return data[:, 1:].copy()

def read_ivecs(filename):
    """讀取.ivecs格式文件"""
    with open(filename, 'rb') as f:
        d = np.fromfile(f, dtype=np.int32, count=1)[0]
        f.seek(0)
        data = np.fromfile(f, dtype=np.int32)
        data = data.reshape(-1, d + 1)
        return data[:, 1:].copy()

print("载入SIFT1M数据集...")
base_vectors = read_fvecs('sift1m/sift_base.fvecs')
query_vectors = read_fvecs('sift1m/sift_query.fvecs')
ground_truth = read_ivecs('sift1m/sift_groundtruth.ivecs')

print(f"Base vectors shape: {base_vectors.shape}")
print(f"Query vectors shape: {query_vectors.shape}")
print(f"Ground truth shape: {ground_truth.shape}")

## 9. 运行实验

In [None]:
def calculate_top_k_ground_truth(queries, database, k=100):
    """
    計算每個查詢的前 K 個真實最近鄰居 (Ground Truth)
    使用 GPU 加速計算，並分批處理 Query 以節省記憶體
    """
    print(f"正在計算 Top-{k} Ground Truth (Query: {len(queries)}, DB: {len(database)})...")
    
    num_queries = len(queries)
    query_batch_size = 100  # 每次處理 100 個 Query
    
    all_indices = []
    
    for i in tqdm(range(0, num_queries, query_batch_size), desc="Computing GT"):
        q_end = min(i + query_batch_size, num_queries)
        q_batch = queries[i:q_end]
        q_tensor = torch.FloatTensor(q_batch).to(DEVICE)
        q_sq = torch.sum(q_tensor**2, dim=1, keepdim=True)
        
        dists_batch = []
        
        # 分批處理 DB
        db_batch_size = 50000
        for j in range(0, len(database), db_batch_size):
            db_end = min(j + db_batch_size, len(database))
            db_chunk = torch.FloatTensor(database[j:db_end]).to(DEVICE)
            
            db_sq_chunk = torch.sum(db_chunk**2, dim=1)
            term2 = -2 * torch.matmul(q_tensor, db_chunk.t())
            
            dists_chunk = q_sq + db_sq_chunk + term2
            dists_batch.append(dists_chunk.cpu())
            
            del db_chunk, db_sq_chunk, term2, dists_chunk
            torch.cuda.empty_cache()
            
        full_dists_batch = torch.cat(dists_batch, dim=1)
        
        # 取 Top K
        _, indices = torch.topk(full_dists_batch, k=k, dim=1, largest=False)
        all_indices.append(indices.numpy())
        
        del full_dists_batch, indices, q_tensor, q_sq
        torch.cuda.empty_cache()
        
    return np.vstack(all_indices)

def evaluate_recall_batched(query_vectors, db_vectors, gt_top_k, retrieval_depths, k_true_neighbors):
    """
    分批計算距離並評估 Recall
    Recall = (Retrieved & Top-K_GT) / K_GT
    """
    max_depth = max(retrieval_depths)
    num_queries = len(query_vectors)
    query_batch_size = 100
    
    total_hits = {r: 0 for r in retrieval_depths}
    
    if query_vectors.dtype != np.float32:
        query_vectors = query_vectors.astype(np.float32)
    if db_vectors.dtype != np.float32:
        db_vectors = db_vectors.astype(np.float32)
        
    for i in tqdm(range(0, num_queries, query_batch_size), desc="Evaluating Batches", leave=False):
        q_end = min(i + query_batch_size, num_queries)
        q_batch = query_vectors[i:q_end]
        gt_batch = gt_top_k[i:q_end]
        
        dists = pairwise_distances(q_batch, db_vectors, metric='euclidean')
        sorted_indices = np.argsort(dists, axis=1)[:, :max_depth]
        
        for r in retrieval_depths:
            retrieved_indices = sorted_indices[:, :r]
            for j in range(len(gt_batch)):
                gt_set = set(gt_batch[j]) # 這是 Top K 真實鄰居
                retrieved_set = set(retrieved_indices[j])
                total_hits[r] += len(gt_set & retrieved_set)
                
    recalls = {}
    for r in retrieval_depths:
        # Normalize by k_true_neighbors (e.g., 100)
        recalls[f'recall@{r}'] = total_hits[r] / (num_queries * float(k_true_neighbors))
        
    return recalls

def run_experiment_part(part_name, base_vectors, query_vectors, gt_top_k, retrieval_depths, k_true_neighbors):
    """
    執行單個部分的實驗
    """
    print(f"\n{'='*80}")
    print(f"實驗部分: {part_name}")
    print(f"Database Size: {len(base_vectors)}")
    print(f"Retrieval Depths: {retrieval_depths}")
    print(f"Target GT Size (K): {k_true_neighbors}")
    print(f"{'='*80}\n")
    
    results = []

    # 方法1
    print("\n[方法1] 直接INT3量化 (128維)...")
    t0 = time.time()
    db_m1 = method1_direct_int3(base_vectors)
    q_m1 = method1_direct_int3(query_vectors)
    recalls_m1 = evaluate_recall_batched(q_m1, db_m1, gt_top_k, retrieval_depths, k_true_neighbors)
    results.append({'method': 'Method 1: Direct INT3', 'time': time.time()-t0, **recalls_m1})
    print(recalls_m1)
    
    # 方法2
    print("\n[方法2] 平均池化降維 + INT3 (64維)...")
    t0 = time.time()
    db_m2 = method2_average_pooling_int3(base_vectors)
    q_m2 = method2_average_pooling_int3(query_vectors)
    recalls_m2 = evaluate_recall_batched(q_m2, db_m2, gt_top_k, retrieval_depths, k_true_neighbors)
    results.append({'method': 'Method 2: AvgPooling + INT3', 'time': time.time()-t0, **recalls_m2})
    print(recalls_m2)
    
    # 方法3
    print("\n[方法3] PCA降維 + INT3 (64維)...")
    t0 = time.time()
    _, db_m3, q_m3 = method3_pca_int3(base_vectors, query_vectors)
    recalls_m3 = evaluate_recall_batched(q_m3, db_m3, gt_top_k, retrieval_depths, k_true_neighbors)
    results.append({'method': 'Method 3: PCA + INT3', 'time': time.time()-t0, **recalls_m3})
    print(recalls_m3)
    
    # 方法4
    print(f"\n[方法4] Max Magnitude Pooling降維 + INT3 (64維)...")
    t0 = time.time()
    _, db_m4, q_m4 = method4_max_pooling_int3(base_vectors, query_vectors)
    recalls_m4 = evaluate_recall_batched(q_m4, db_m4, gt_top_k, retrieval_depths, k_true_neighbors)
    results.append({'method': 'Method 4: Max Mag Pooling + INT3', 'time': time.time()-t0, **recalls_m4})
    print(recalls_m4)
    
    # 方法5
    print("\n[方法5] AutoEncoder降維 + INT3 (64維)...")
    t0 = time.time()
    # Epochs 設為 100，配合 Early Stopping
    _, db_m5, q_m5 = method5_autoencoder_int3(base_vectors, query_vectors, epochs=100)
    recalls_m5 = evaluate_recall_batched(q_m5, db_m5, gt_top_k, retrieval_depths, k_true_neighbors)
    results.append({'method': 'Method 5: AutoEncoder + INT3', 'time': time.time()-t0, **recalls_m5})
    print(recalls_m5)
    
    return pd.DataFrame(results)

# --- 主執行流程 ---

# 準備數據
NUM_QUERIES = 1000
base_vectors_full = base_vectors.astype(np.float32)
query_vectors_subset = query_vectors[:NUM_QUERIES].astype(np.float32)

all_experiment_results = []

# --- 第一部分: 10k 資料 ---
db_size_1 = 10000
db_subset_1 = base_vectors_full[:db_size_1]
gt_k_1 = 100
print(f"\n正在計算 Part 1 Ground Truth (Top-{gt_k_1})...")
gt_1 = calculate_top_k_ground_truth(query_vectors_subset, db_subset_1, k=gt_k_1)

df_1 = run_experiment_part(
    part_name="Part 1 (10k DB)",
    base_vectors=db_subset_1,
    query_vectors=query_vectors_subset,
    gt_top_k=gt_1,
    retrieval_depths=[1000],
    k_true_neighbors=gt_k_1
)
df_1['Experiment'] = '10k DB'
all_experiment_results.append(df_1)
display(df_1)

# --- 第二部分: 100k 資料 ---
db_size_2 = 100000
db_subset_2 = base_vectors_full[:db_size_2]
gt_k_2 = 100
print(f"\n正在計算 Part 2 Ground Truth (Top-{gt_k_2})...")
gt_2 = calculate_top_k_ground_truth(query_vectors_subset, db_subset_2, k=gt_k_2)

df_2 = run_experiment_part(
    part_name="Part 2 (100k DB)",
    base_vectors=db_subset_2,
    query_vectors=query_vectors_subset,
    gt_top_k=gt_2,
    retrieval_depths=[1000, 10000],
    k_true_neighbors=gt_k_2
)
df_2['Experiment'] = '100k DB'
all_experiment_results.append(df_2)
display(df_2)

# --- 第三部分: 1M 資料 ---
db_size_3 = 1000000
db_subset_3 = base_vectors_full # Full
gt_k_3 = 100
print(f"\n正在計算 Part 3 Ground Truth (Top-{gt_k_3})...")
gt_3 = calculate_top_k_ground_truth(query_vectors_subset, db_subset_3, k=gt_k_3)

df_3 = run_experiment_part(
    part_name="Part 3 (1M DB)",
    base_vectors=db_subset_3,
    query_vectors=query_vectors_subset,
    gt_top_k=gt_3,
    retrieval_depths=[1000, 10000],
    k_true_neighbors=gt_k_3
)
df_3['Experiment'] = '1M DB'
all_experiment_results.append(df_3)
display(df_3)

# 保存所有結果
final_df = pd.concat(all_experiment_results, ignore_index=True)
final_df.to_csv("sift1m_3parts_experiment_results.csv", index=False)
print("\n所有實驗完成，結果已保存。")

## 10. 可视化结果

In [None]:
# 可視化結果
# 我們需要畫 3 張圖 (或 3 組圖)，分別對應 Part 1, Part 2, Part 3

fig, axes = plt.subplots(1, 3, figsize=(20, 6))
fig.suptitle('Recall of Top-100 True Neighbors @ Different Retrieval Depths', fontsize=16, fontweight='bold')

# Part 1: 10k DB, Recall@1000
ax1 = axes[0]
df1 = all_experiment_results[0]
methods = df1['method'].str.replace('Method ', 'M').str.replace(': ', '\n', 1)
bars1 = ax1.bar(range(len(df1)), df1['recall@1000'], color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd'])
ax1.set_xticks(range(len(df1)))
ax1.set_xticklabels(methods, rotation=45, ha='right', fontsize=9)
ax1.set_ylabel('Recall@1000', fontsize=12, fontweight='bold')
ax1.set_title('Part 1: 10k DB (100@1000)', fontsize=14, fontweight='bold')
ax1.set_ylim([0, 1.05])
ax1.grid(axis='y', linestyle='--', alpha=0.7)
for bar in bars1:
    height = bar.get_height()
    ax1.text(bar.get_x() + bar.get_width()/2., height + 0.01, f'{height:.3f}', ha='center', va='bottom', fontsize=9)

# Part 2: 100k DB, Recall@1000 & Recall@10000
# 這裡我們畫 Recall@10000 作為代表，或者畫兩個柱狀圖
ax2 = axes[1]
df2 = all_experiment_results[1]
# 為了同時顯示 @1000 和 @10000，我們使用分組柱狀圖
x = np.arange(len(df2))
width = 0.35
bars2_1 = ax2.bar(x - width/2, df2['recall@1000'], width, label='@1000', color='#1f77b4')
bars2_2 = ax2.bar(x + width/2, df2['recall@10000'], width, label='@10000', color='#ff7f0e')

ax2.set_xticks(x)
ax2.set_xticklabels(methods, rotation=45, ha='right', fontsize=9)
ax2.set_ylabel('Recall', fontsize=12, fontweight='bold')
ax2.set_title('Part 2: 100k DB', fontsize=14, fontweight='bold')
ax2.set_ylim([0, 1.05])
ax2.legend()
ax2.grid(axis='y', linestyle='--', alpha=0.7)

# Part 3: 1M DB, Recall@1000 & Recall@10000
ax3 = axes[2]
df3 = all_experiment_results[2]
x = np.arange(len(df3))
bars3_1 = ax3.bar(x - width/2, df3['recall@1000'], width, label='@1000', color='#1f77b4')
bars3_2 = ax3.bar(x + width/2, df3['recall@10000'], width, label='@10000', color='#ff7f0e')

ax3.set_xticks(x)
ax3.set_xticklabels(methods, rotation=45, ha='right', fontsize=9)
ax3.set_ylabel('Recall', fontsize=12, fontweight='bold')
ax3.set_title('Part 3: 1M DB', fontsize=14, fontweight='bold')
ax3.set_ylim([0, 1.05])
ax3.legend()
ax3.grid(axis='y', linestyle='--', alpha=0.7)

plt.tight_layout()
plt.savefig('sift1m_3parts_results.png', dpi=300, bbox_inches='tight')
plt.show()

## 11. 实验总结

In [None]:
print("\n" + "="*80)
print("實驗總結")
print("="*80)

for config, results_df in zip(EXPERIMENTS, all_results):
    print(f"\n配置: {config['query_size']}@{config['database_size']}")
    display(results_df)

print("\n所有實驗完成！")