## 2. 导入必要的库

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.decomposition import PCA
from sklearn.metrics import pairwise_distances
import umap
import os
import time
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import urllib.request
import tarfile

# 设置设备
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# 参数设置
SIFT_DIM = 128
REDUCED_DIM = 64

## 3. 下载和解压 SIFT1M 数据集

In [None]:
def download_sift1m():
    """
    下载SIFT1M数据集
    """
    # 创建数据目录
    os.makedirs('sift1m', exist_ok=True)
    
    files = [
        "sift_base.fvecs",
        "sift_query.fvecs",
        "sift_groundtruth.ivecs"
    ]
    
    # 检查文件是否已存在
    all_exist = True
    for filename in files:
        if not os.path.exists(os.path.join('sift1m', filename)):
            all_exist = False
            break
            
    if all_exist:
        print("所有文件已存在，跳过下载")
        return True
        
    print("正在下载 SIFT1M 数据集...")
    
    # 尝试下载 tar.gz 文件
    tar_url = "ftp://ftp.irisa.fr/local/texmex/corpus/sift.tar.gz"
    tar_path = "sift1m/sift.tar.gz"
    
    try:
        print(f"尝试从 {tar_url} 下载...")
        # 增加 timeout
        import socket
        socket.setdefaulttimeout(30)
        urllib.request.urlretrieve(tar_url, tar_path)
        print("下载完成，正在解压...")
        
        with tarfile.open(tar_path, "r:gz") as tar:
            tar.extractall(path="sift1m")
            
        # 移动文件到 sift1m 根目录 (解压后会在 sift/ 目录下)
        extracted_dir = os.path.join("sift1m", "sift")
        if os.path.exists(extracted_dir):
            for filename in files:
                src = os.path.join(extracted_dir, filename)
                dst = os.path.join("sift1m", filename)
                if os.path.exists(src):
                    if os.path.exists(dst):
                        os.remove(dst)
                    os.rename(src, dst)
            # 清理
            try:
                import shutil
                shutil.rmtree(extracted_dir)
            except:
                pass
                
        # 删除 tar 文件
        if os.path.exists(tar_path):
            os.remove(tar_path)
        print("数据集准备完成")
        return True
        
    except Exception as e:
        print(f"FTP下载失败: {e}")
        print("尝试使用 wget 下载...")
        
        try:
            # 尝试使用 wget
            res = os.system(f"wget {tar_url} -O {tar_path}")
            if res == 0 and os.path.exists(tar_path):
                print("wget 下载成功，正在解压...")
                with tarfile.open(tar_path, "r:gz") as tar:
                    tar.extractall(path="sift1m")
                
                # 移动文件
                extracted_dir = os.path.join("sift1m", "sift")
                if os.path.exists(extracted_dir):
                    for filename in files:
                        src = os.path.join(extracted_dir, filename)
                        dst = os.path.join("sift1m", filename)
                        if os.path.exists(src):
                            if os.path.exists(dst):
                                os.remove(dst)
                            os.rename(src, dst)
                    try:
                        import shutil
                        shutil.rmtree(extracted_dir)
                    except:
                        pass
                if os.path.exists(tar_path):
                    os.remove(tar_path)
                return True
        except Exception as e2:
            print(f"wget 失败: {e2}")
            
        print("无法自动下载数据集。")
        print("请手动下载 sift.tar.gz 从 http://corpus-texmex.irisa.fr/")
        print("并解压到 sift1m/ 目录下。")
        return False

# 下载数据集
print("开始下载SIFT1M数据集...")
download_sift1m()

## 4. INT3 量化函数

In [None]:
def quantize_to_int3(data):
    """
    將數據量化到INT3 (-4 到 3，共8個離散值)
    改進：使用百分位數 (1% - 99%) 進行截斷，減少極端值對量化的影響
    支援 PyTorch GPU 運算
    """
    if isinstance(data, torch.Tensor):
        # PyTorch implementation
        # Ensure data is on the correct device
        if data.device != DEVICE:
            data = data.to(DEVICE)
            
        arr = data.float() # Ensure float for quantile
        
        # Calculate percentiles on GPU
        # Note: quantile requires float
        # Optimization: For large tensors, estimate quantiles from a subset
        # torch.quantile can fail on very large tensors (RuntimeError: quantile() input tensor is too large)
        if arr.numel() > 100000:
            # Sample ~100k elements uniformly
            step = max(1, arr.numel() // 100000)
            # Use view(-1) to flatten and slice
            sample = arr.view(-1)[::step]
            min_val = torch.quantile(sample, 0.01)
            max_val = torch.quantile(sample, 0.99)
        else:
            min_val = torch.quantile(arr, 0.01)
            max_val = torch.quantile(arr, 0.99)
        
        # Scale to [-4, 3]
        scaled = (arr - min_val) / (max_val - min_val + 1e-8) * 7 - 4
        
        # Round and clip
        quantized = torch.clamp(torch.round(scaled), -4, 3).to(torch.int8)
        return quantized
    else:
        # Numpy implementation
        arr = data
        # Similar optimization for numpy to speed up
        if arr.size > 100000:
             step = max(1, arr.size // 100000)
             sample = arr.ravel()[::step]
             min_val = np.percentile(sample, 1)
             max_val = np.percentile(sample, 99)
        else:
             min_val = np.percentile(arr, 1)
             max_val = np.percentile(arr, 99)
             
        scaled = (arr - min_val) / (max_val - min_val + 1e-8) * 7 - 4
        quantized = np.clip(np.round(scaled), -4, 3).astype(np.int8)
        return quantized

## 5. AutoEncoder 模型定义

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self, input_dim=128, latent_dim=64):
        super(AutoEncoder, self).__init__()
        
        # Encoder (加深並加入 BatchNorm)
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128, 96),
            nn.BatchNorm1d(96),
            nn.ReLU(),
            nn.Linear(96, latent_dim)
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 96),
            nn.BatchNorm1d(96),
            nn.ReLU(),
            nn.Linear(96, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128, input_dim)
        )
    
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded
    
    def encode(self, x):
        return self.encoder(x)

def train_autoencoder(data, input_dim=128, latent_dim=64, epochs=100, batch_size=256, patience=5, min_delta=1e-4):
    """
    訓練自編碼器 (加入 Early Stopping)
    """
    model = AutoEncoder(input_dim, latent_dim).to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.MSELoss()
    
    data_tensor = torch.FloatTensor(data).to(DEVICE)
    
    # 創建 DataLoader 以便於 batch 處理
    dataset = torch.utils.data.TensorDataset(data_tensor)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    model.train()
    
    best_loss = float('inf')
    patience_counter = 0
    
    pbar = tqdm(range(epochs), desc="训练AutoEncoder")
    for epoch in pbar:
        total_loss = 0
        num_batches = 0
        
        for batch in dataloader:
            batch_data = batch[0]
            
            optimizer.zero_grad()
            _, reconstructed = model(batch_data)
            loss = criterion(reconstructed, batch_data)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            num_batches += 1
        
        avg_loss = total_loss / num_batches
        pbar.set_postfix({'loss': f'{avg_loss:.6f}'})
        
        # Early Stopping Check
        if avg_loss < best_loss - min_delta:
            best_loss = avg_loss
            patience_counter = 0
        else:
            patience_counter += 1
            
        if patience_counter >= patience:
            print(f"Early stopping triggered at epoch {epoch+1} (Best Loss: {best_loss:.6f})")
            break
    
    return model

## 6. 五种数据处理方法

In [None]:
def method1_direct_int3(data):
    """方法1: 直接將128維向量量化為INT3"""
    # Convert to Tensor if numpy
    if isinstance(data, np.ndarray):
        data = torch.from_numpy(data).to(DEVICE)
    return quantize_to_int3(data)

def method2_average_pooling_int3(data):
    """方法2: 將相鄰兩維做平均，從128維降到64維，然後量化為INT3"""
    if isinstance(data, np.ndarray):
        data = torch.from_numpy(data).to(DEVICE)
        
    # Reshape (N, 64, 2)
    reshaped = data.reshape(data.shape[0], 64, 2)
    # PyTorch mean
    averaged = torch.mean(reshaped.float(), dim=2)
    return quantize_to_int3(averaged)

def method3_pca_int3(data, query_data=None):
    """方法3: 使用PCA降維到64維，然後量化為INT3"""
    print("  - 训练PCA (CPU)...")
    # PCA fitting is still better on CPU with sklearn for stability/ease
    # But we can transform on GPU
    
    # Ensure data is numpy for sklearn fitting
    if isinstance(data, torch.Tensor):
        data_np = data.cpu().numpy()
    else:
        data_np = data
        
    # 使用 whiten=True 來標準化分量
    pca = PCA(n_components=REDUCED_DIM, whiten=True)
    pca.fit(data_np)
    
    # Prepare for GPU transform
    mean = torch.from_numpy(pca.mean_).float().to(DEVICE)
    components = torch.from_numpy(pca.components_).float().to(DEVICE)
    explained_variance = torch.from_numpy(pca.explained_variance_).float().to(DEVICE)
    
    def transform_gpu(X):
        if isinstance(X, np.ndarray):
            X = torch.from_numpy(X).to(DEVICE)
        X = X.float()
        
        # Center
        X_centered = X - mean
        # Project
        X_transformed = torch.matmul(X_centered, components.T)
        
        # Whiten
        if pca.whiten:
            scale = torch.sqrt(explained_variance)
            X_transformed = X_transformed / scale
            
        return X_transformed

    data_reduced = transform_gpu(data_np)
    data_quantized = quantize_to_int3(data_reduced)
    
    if query_data is not None:
        query_reduced = transform_gpu(query_data)
        query_quantized = quantize_to_int3(query_reduced)
        return pca, data_quantized, query_quantized
    
    return pca, data_quantized

def method4_max_pooling_int3(data, query_data=None):
    """方法4: 將相鄰兩維做最大絕對值池化 (Max Magnitude Pooling)，從128維降到64維，然後量化為INT3"""
    
    def max_magnitude_pool_torch(arr):
        if isinstance(arr, np.ndarray):
            arr = torch.from_numpy(arr).to(DEVICE)
            
        # Reshape to (N, 64, 2)
        reshaped = arr.reshape(arr.shape[0], 64, 2)
        a = reshaped[:, :, 0]
        b = reshaped[:, :, 1]
        # 比較絕對值大小
        mask = torch.abs(a) >= torch.abs(b)
        # 選擇絕對值較大的那個原始值
        return torch.where(mask, a, b)

    max_pooled = max_magnitude_pool_torch(data)
    data_quantized = quantize_to_int3(max_pooled)
    
    if query_data is not None:
        q_max_pooled = max_magnitude_pool_torch(query_data)
        query_quantized = quantize_to_int3(q_max_pooled)
        return None, data_quantized, query_quantized
    
    return None, data_quantized

def method5_autoencoder_int3(data, query_data=None, epochs=100):
    """方法5: 使用AutoEncoder降維到64維，然後量化為INT3"""
    print("  - 训练AutoEncoder...")
    # Ensure numpy for training (dataloader handles conversion)
    if isinstance(data, torch.Tensor):
        data_np = data.cpu().numpy()
    else:
        data_np = data
        
    # 這裡 epochs 預設為 100，配合 Early Stopping
    ae_model = train_autoencoder(data_np, SIFT_DIM, REDUCED_DIM, epochs=epochs)
    ae_model.eval()
    
    with torch.no_grad():
        # Encode on GPU
        if isinstance(data, np.ndarray):
            data_tensor = torch.FloatTensor(data).to(DEVICE)
        else:
            data_tensor = data.float().to(DEVICE)
            
        data_reduced = ae_model.encode(data_tensor) # Returns tensor on GPU
        data_quantized = quantize_to_int3(data_reduced)
        
        if query_data is not None:
            if isinstance(query_data, np.ndarray):
                query_tensor = torch.FloatTensor(query_data).to(DEVICE)
            else:
                query_tensor = query_data.float().to(DEVICE)
                
            query_reduced = ae_model.encode(query_tensor)
            query_quantized = quantize_to_int3(query_reduced)
            return ae_model, data_quantized, query_quantized
    
    return ae_model, data_quantized

## 7. 距离计算和评估函数

In [None]:
def calculate_l2_distances(queries, database):
    """計算L2距離"""
    return pairwise_distances(queries, database, metric='euclidean')

def evaluate_recall_at_k(distances, ground_truth, k_values=[1, 10, 100]):
    """計算Recall@K指標"""
    sorted_indices = np.argsort(distances, axis=1)
    
    recalls = {}
    for k in k_values:
        top_k_predictions = sorted_indices[:, :k]
        
        query_recalls = []
        for i in range(len(ground_truth)):
            gt_i = ground_truth[i]
            # Handle jagged arrays: check length
            limit = min(k, len(gt_i))
            
            if limit == 0:
                query_recalls.append(0.0)
                continue
                
            true_neighbors = set(gt_i[:limit])
            pred_neighbors = set(top_k_predictions[i])
            intersection = true_neighbors & pred_neighbors
            recall = len(intersection) / len(true_neighbors)
            query_recalls.append(recall)
        
        recalls[f'recall@{k}'] = np.mean(query_recalls)
    
    return recalls

## 8. 加载 SIFT1M 数据集

In [None]:
def read_fvecs(filename):
    """讀取.fvecs格式文件"""
    with open(filename, 'rb') as f:
        d = np.fromfile(f, dtype=np.int32, count=1)[0]
        f.seek(0)
        # .fvecs 格式: 4 bytes int32 (dimension) + d * 4 bytes float32 (data)
        # 读取为 float32，header 的 int32 会被读成一个 float32，reshape 后切片去掉即可
        data = np.fromfile(f, dtype=np.float32)
        data = data.reshape(-1, d + 1)
        return data[:, 1:].copy()

def read_ivecs(filename):
    """讀取.ivecs格式文件"""
    with open(filename, 'rb') as f:
        d = np.fromfile(f, dtype=np.int32, count=1)[0]
        f.seek(0)
        data = np.fromfile(f, dtype=np.int32)
        data = data.reshape(-1, d + 1)
        return data[:, 1:].copy()

print("载入SIFT1M数据集...")
base_vectors = read_fvecs('sift1m/sift_base.fvecs')
query_vectors = read_fvecs('sift1m/sift_query.fvecs')
ground_truth = read_ivecs('sift1m/sift_groundtruth.ivecs')

print(f"Base vectors shape: {base_vectors.shape}")
print(f"Query vectors shape: {query_vectors.shape}")
print(f"Ground truth shape: {ground_truth.shape}")

## 9. 运行实验

In [None]:
def calculate_top_k_ground_truth(queries, database, k=100):
    """
    計算每個查詢的前 K 個真實最近鄰居 (Ground Truth)
    使用 GPU 加速計算，並分批處理 Query 以節省記憶體
    """
    print(f"正在計算 Top-{k} Ground Truth (Query: {len(queries)}, DB: {len(database)})...")
    
    num_queries = len(queries)
    query_batch_size = 100  # 每次處理 100 個 Query
    
    all_indices = []
    
    for i in tqdm(range(0, num_queries, query_batch_size), desc="Computing GT"):
        q_end = min(i + query_batch_size, num_queries)
        q_batch = queries[i:q_end]
        q_tensor = torch.FloatTensor(q_batch).to(DEVICE)
        q_sq = torch.sum(q_tensor**2, dim=1, keepdim=True)
        
        dists_batch = []
        
        # 分批處理 DB
        db_batch_size = 50000
        for j in range(0, len(database), db_batch_size):
            db_end = min(j + db_batch_size, len(database))
            db_chunk = torch.FloatTensor(database[j:db_end]).to(DEVICE)
            
            db_sq_chunk = torch.sum(db_chunk**2, dim=1)
            term2 = -2 * torch.matmul(q_tensor, db_chunk.t())
            
            dists_chunk = q_sq + db_sq_chunk + term2
            dists_batch.append(dists_chunk.cpu())
            
            del db_chunk, db_sq_chunk, term2, dists_chunk
            torch.cuda.empty_cache()
            
        full_dists_batch = torch.cat(dists_batch, dim=1)
        
        # 取 Top K
        _, indices = torch.topk(full_dists_batch, k=k, dim=1, largest=False)
        all_indices.append(indices.numpy())
        
        del full_dists_batch, indices, q_tensor, q_sq
        torch.cuda.empty_cache()
        
    return np.vstack(all_indices)

def evaluate_recall_batched(query_vectors, db_vectors, gt_top_k, retrieval_depths, k_true_neighbors):
    """
    分批計算距離並評估 Recall
    使用 GPU 加速距離計算
    Recall = (Retrieved & Top-K_GT) / K_GT
    """
    max_depth = max(retrieval_depths)
    num_queries = len(query_vectors)
    query_batch_size = 100
    
    total_hits = {r: 0 for r in retrieval_depths}
    
    # Ensure inputs are tensors on GPU
    if isinstance(query_vectors, np.ndarray):
        query_vectors = torch.from_numpy(query_vectors).to(DEVICE)
    elif query_vectors.device != DEVICE:
        query_vectors = query_vectors.to(DEVICE)
        
    if isinstance(db_vectors, np.ndarray):
        db_vectors = torch.from_numpy(db_vectors).to(DEVICE)
    elif db_vectors.device != DEVICE:
        db_vectors = db_vectors.to(DEVICE)
    
    # Ensure float for distance calc (INT3 needs to be float for calc)
    if query_vectors.dtype != torch.float32:
        query_vectors = query_vectors.float()
    if db_vectors.dtype != torch.float32:
        db_vectors = db_vectors.float()
        
    start_time = time.time()
    
    # Pre-calculate DB squared norms if memory allows, or do it in chunks
    # For 1M vectors, pre-calculating norms is fast and takes little memory (1M floats = 4MB)
    db_sq = torch.sum(db_vectors**2, dim=1)
    
    for i in tqdm(range(0, num_queries, query_batch_size), desc="Evaluating Batches", leave=False):
        q_end = min(i + query_batch_size, num_queries)
        q_batch = query_vectors[i:q_end] # (Batch, Dim)
        gt_batch = gt_top_k[i:q_end]
        
        # Calculate L2 Distance on GPU: |x-y|^2 = |x|^2 + |y|^2 - 2xy
        q_sq = torch.sum(q_batch**2, dim=1, keepdim=True) # (Batch, 1)
        
        # Matrix multiplication: (Batch, Dim) @ (Dim, DB_Size) -> (Batch, DB_Size)
        # Note: db_vectors is (DB_Size, Dim), so we transpose it
        # For very large DB, we might need to chunk this too, but 1M fits in GPU memory for matmul usually
        # 100 * 1M * 4 bytes = 400MB result matrix. This is fine.
        
        term2 = -2 * torch.matmul(q_batch, db_vectors.t())
        
        # Broadcasting: (Batch, 1) + (DB_Size,) + (Batch, DB_Size)
        dists = q_sq + db_sq + term2
        
        # Find Top-K on GPU
        # We need the smallest distances
        _, sorted_indices_tensor = torch.topk(dists, k=max_depth, dim=1, largest=False)
        
        # Move indices to CPU for set intersection (faster on CPU for small sets logic)
        sorted_indices = sorted_indices_tensor.cpu().numpy()
        
        for r in retrieval_depths:
            retrieved_indices = sorted_indices[:, :r]
            for j in range(len(gt_batch)):
                gt_set = set(gt_batch[j]) # 這是 Top K 真實鄰居
                retrieved_set = set(retrieved_indices[j])
                total_hits[r] += len(gt_set & retrieved_set)
                
        # Clean up intermediate tensors
        del dists, term2, q_sq, sorted_indices_tensor
        # torch.cuda.empty_cache() # Optional: can slow down loop if called too often
    
    end_time = time.time()
    search_time = end_time - start_time
                
    recalls = {}
    for r in retrieval_depths:
        # Normalize by k_true_neighbors (e.g., 100)
        recalls[f'recall@{r}'] = total_hits[r] / (num_queries * float(k_true_neighbors))
        
    recalls['search_time'] = search_time
    recalls['qps'] = num_queries / search_time if search_time > 0 else 0
        
    return recalls

def run_experiment_part(part_name, base_vectors, query_vectors, gt_top_k, retrieval_depths, k_true_neighbors):
    """
    執行單個部分的實驗
    """
    print(f"\n{'='*80}")
    print(f"實驗部分: {part_name}")
    print(f"Database Size: {len(base_vectors)}")
    print(f"Retrieval Depths: {retrieval_depths}")
    print(f"Target GT Size (K): {k_true_neighbors}")
    print(f"{'='*80}\n")
    
    results = []

    # 方法1
    print("\n[方法1] 直接INT3量化 (128維)...")
    t0 = time.time()
    db_m1 = method1_direct_int3(base_vectors)
    q_m1 = method1_direct_int3(query_vectors)
    recalls_m1 = evaluate_recall_batched(q_m1, db_m1, gt_top_k, retrieval_depths, k_true_neighbors)
    results.append({'method': 'Method 1: Direct INT3', 'time': time.time()-t0, **recalls_m1})
    print(f"Recall: {recalls_m1}, QPS: {recalls_m1.get('qps', 0):.2f}")
    
    # 方法2
    print("\n[方法2] 平均池化降維 + INT3 (64維)...")
    t0 = time.time()
    db_m2 = method2_average_pooling_int3(base_vectors)
    q_m2 = method2_average_pooling_int3(query_vectors)
    recalls_m2 = evaluate_recall_batched(q_m2, db_m2, gt_top_k, retrieval_depths, k_true_neighbors)
    results.append({'method': 'Method 2: AvgPooling + INT3', 'time': time.time()-t0, **recalls_m2})
    print(f"Recall: {recalls_m2}, QPS: {recalls_m2.get('qps', 0):.2f}")
    
    # 方法3
    print("\n[方法3] PCA降維 + INT3 (64維)...")
    t0 = time.time()
    _, db_m3, q_m3 = method3_pca_int3(base_vectors, query_vectors)
    recalls_m3 = evaluate_recall_batched(q_m3, db_m3, gt_top_k, retrieval_depths, k_true_neighbors)
    results.append({'method': 'Method 3: PCA + INT3', 'time': time.time()-t0, **recalls_m3})
    print(f"Recall: {recalls_m3}, QPS: {recalls_m3.get('qps', 0):.2f}")
    
    # 方法4
    print(f"\n[方法4] Max Magnitude Pooling降維 + INT3 (64維)...")
    t0 = time.time()
    _, db_m4, q_m4 = method4_max_pooling_int3(base_vectors, query_vectors)
    recalls_m4 = evaluate_recall_batched(q_m4, db_m4, gt_top_k, retrieval_depths, k_true_neighbors)
    results.append({'method': 'Method 4: Max Mag Pooling + INT3', 'time': time.time()-t0, **recalls_m4})
    print(f"Recall: {recalls_m4}, QPS: {recalls_m4.get('qps', 0):.2f}")
    
    # 方法5
    print("\n[方法5] AutoEncoder降維 + INT3 (64維)...")
    t0 = time.time()
    # Epochs 設為 100，配合 Early Stopping
    _, db_m5, q_m5 = method5_autoencoder_int3(base_vectors, query_vectors, epochs=100)
    recalls_m5 = evaluate_recall_batched(q_m5, db_m5, gt_top_k, retrieval_depths, k_true_neighbors)
    results.append({'method': 'Method 5: AutoEncoder + INT3', 'time': time.time()-t0, **recalls_m5})
    print(f"Recall: {recalls_m5}, QPS: {recalls_m5.get('qps', 0):.2f}")
    
    return pd.DataFrame(results)

In [None]:
# 執行實驗
# 實驗配置：
# Part 1: 100K DB, Recall 100@1000, 100@5000, 100@10000
# Part 2: 1M DB, Recall 100@1000, 100@5000, 100@10000

all_experiment_results = []  # 存儲結果 (包含 Recall, Time, QPS)

# --- Part 1: 100K Database ---
print("\n" + "="*80)
print("Part 1: First 100K Database")
print("="*80)

db_100k = base_vectors[:100000].astype(np.float32)
query_subset = query_vectors[:1000].astype(np.float32)

# 計算 Ground Truth for 100K DB
print("Computing Ground Truth for 100K DB...")
gt_100k = calculate_top_k_ground_truth(query_subset, db_100k, k=100)

# 執行實驗 Part 1
retrieval_depths_1 = [1000, 5000, 10000]
k_true_neighbors = 100

results_df_1 = run_experiment_part(
    "Part 1: 100K DB",
    db_100k,
    query_subset,
    gt_100k,
    retrieval_depths_1,
    k_true_neighbors
)

all_experiment_results.append(results_df_1)

print("\nPart 1 Results:")
display(results_df_1)

# --- Part 2: 1M Database ---
print("\n" + "="*80)
print("Part 2: Full 1M Database")
print("="*80)

db_1m = base_vectors.astype(np.float32)

# 計算 Ground Truth for 1M DB
print("Computing Ground Truth for 1M DB...")
gt_1m = calculate_top_k_ground_truth(query_subset, db_1m, k=100)

# 執行實驗 Part 2
retrieval_depths_2 = [1000, 5000, 10000]

results_df_2 = run_experiment_part(
    "Part 2: 1M DB",
    db_1m,
    query_subset,
    gt_1m,
    retrieval_depths_2,
    k_true_neighbors
)

all_experiment_results.append(results_df_2)

print("\nPart 2 Results:")
display(results_df_2)

print("\n" + "="*80)
print("所有實驗完成！")
print("="*80)

## 10. 可视化结果

In [None]:
# 可視化結果
# 根據實驗結果的數量動態調整圖表

# 確保變數名稱兼容性
if 'all_experiment_results' not in locals() and 'all_results' in locals():
    all_experiment_results = all_results

num_experiments = len(all_experiment_results)
print(f"Detected {num_experiments} experiment results.")

if num_experiments > 0:
    # 設定圖表佈局: 2列 (Recall, QPS) x N行 (實驗數量) 或者 N列 x 2行
    # 這裡維持原本的邏輯: Row 1 = Recall, Row 2 = QPS
    
    # 計算需要的列數 (columns)
    ncols = num_experiments
    
    fig, axes = plt.subplots(2, ncols, figsize=(6 * ncols, 12))
    
    # 如果只有一列，axes 是 1D array，需要轉為 2D 以便統一處理
    if ncols == 1:
        axes = np.array([[axes[0]], [axes[1]]]) 
        
    fig.suptitle('Recall & QPS Analysis', fontsize=16, fontweight='bold')
    
    # 實驗標題映射 (根據 Cell 9 的邏輯)
    titles = []
    if num_experiments == 2:
        titles = ['100K DB', '1M DB']
    elif num_experiments == 3:
        titles = ['10K DB', '100K DB', '1M DB']
    else:
        titles = [f'Experiment {i+1}' for i in range(num_experiments)]

    for i in range(num_experiments):
        df = all_experiment_results[i]
        title = titles[i] if i < len(titles) else f'Exp {i+1}'
        
        # --- Row 1: Recall ---
        ax_recall = axes[0, i]
        
        # 準備數據
        methods = df['method'].str.replace('Method ', 'M').str.replace(': ', '\n', 1)
        x = np.arange(len(df))
        
        # 檢查有哪些 Recall 指標
        recall_cols = [c for c in df.columns if 'recall@' in c]
        # 嘗試按 K 值排序
        try:
            recall_cols.sort(key=lambda x: int(x.split('@')[1])) 
        except:
            pass # 如果格式不符就不排序
        
        # 繪製 Recall
        if len(recall_cols) == 1:
            # 單一 Recall
            col = recall_cols[0]
            bars = ax_recall.bar(x, df[col], color='#1f77b4')
            for bar in bars:
                height = bar.get_height()
                ax_recall.text(bar.get_x() + bar.get_width()/2., height + 0.01, f'{height:.3f}', ha='center', va='bottom', fontsize=9)
            ax_recall.set_ylabel(col, fontsize=12, fontweight='bold')
        else:
            # 多個 Recall (Grouped Bar)
            width = 0.8 / len(recall_cols)
            for j, col in enumerate(recall_cols):
                offset = (j - len(recall_cols)/2 + 0.5) * width
                ax_recall.bar(x + offset, df[col], width, label=f'@{col.split("@")[1]}')
            ax_recall.legend()
            ax_recall.set_ylabel('Recall', fontsize=12, fontweight='bold')

        ax_recall.set_xticks(x)
        ax_recall.set_xticklabels(methods, rotation=45, ha='right', fontsize=9)
        ax_recall.set_title(f'{title} (Recall)', fontsize=14, fontweight='bold')
        ax_recall.set_ylim([0, 1.1])
        ax_recall.grid(axis='y', linestyle='--', alpha=0.7)
        
        # --- Row 2: QPS ---
        ax_qps = axes[1, i]
        bars_qps = ax_qps.bar(x, df['qps'], color='teal')
        
        ax_qps.set_xticks(x)
        ax_qps.set_xticklabels(methods, rotation=45, ha='right', fontsize=9)
        ax_qps.set_ylabel('QPS (Queries/sec)', fontsize=12, fontweight='bold')
        ax_qps.set_title(f'{title} (QPS)', fontsize=14, fontweight='bold')
        ax_qps.grid(axis='y', linestyle='--', alpha=0.7)
        
        for bar in bars_qps:
            height = bar.get_height()
            ax_qps.text(bar.get_x() + bar.get_width()/2., height, f'{int(height)}', ha='center', va='bottom', fontsize=9)

    plt.tight_layout()
    plt.subplots_adjust(top=0.92, hspace=0.4)
    plt.savefig('sift1m_results_recall_qps.png', dpi=300, bbox_inches='tight')
    plt.show()

    # --- Save Data ---
    print("\nSaving results to CSV...")
    for i, df in enumerate(all_experiment_results):
        safe_title = titles[i].replace(" ", "_") if i < len(titles) else f"part{i+1}"
        fname = f'sift1m_part{i+1}_{safe_title}.csv'
        df.to_csv(fname, index=False)
        print(f"Saved {fname}")
else:
    print("No results to visualize.")

## 11. 实验总结

In [None]:
print("\n" + "="*80)
print("實驗總結")
print("="*80)

# 確保變數名稱兼容性
if 'all_experiment_results' not in locals() and 'all_results' in locals():
    all_experiment_results = all_results

if 'all_experiment_results' in locals():
    for i, results_df in enumerate(all_experiment_results):
        print(f"\nExperiment Part {i+1}")
        display(results_df)
else:
    print("No results found.")

print("\n所有實驗完成！")

In [None]:
# 12. Dimension Reduction Analysis (128 -> 16)
# Analyzing Recall Rate vs Dimension for AvgPooling and PCA
# Configuration: 1M Database, Recall of Top-1000 GT @ Depth 10000

def adaptive_avg_pool_int3(data, target_dim):
    """
    Adaptive Average Pooling to reduce to specific target dimension
    """
    if isinstance(data, np.ndarray):
        t = torch.from_numpy(data).float()
    else:
        t = data.float()
    
    t = t.to(DEVICE)
    # Input: (N, 128) -> (N, 1, 128) for pooling
    t = t.unsqueeze(1)
    # Adaptive Pool to (N, 1, target_dim)
    pooled = nn.functional.adaptive_avg_pool1d(t, target_dim)
    # (N, target_dim)
    pooled = pooled.squeeze(1)
    
    return quantize_to_int3(pooled)

def run_dimension_sweep(base_vectors, query_vectors, gt_top_k, k_true_neighbors=1000, retrieval_depth=10000):
    # Dimensions: 128, 112, 96, ..., 16
    target_dims = list(range(128, 15, -16)) 
    results = []
    
    print(f"Starting Dimension Sweep: {target_dims}")
    print(f"Database Size: {len(base_vectors)}")
    print(f"Target GT Size: {k_true_neighbors}")
    print(f"Retrieval Depth: {retrieval_depth}")
    
    for dim in target_dims:
        print(f"\nTesting Target Dimension: {dim}")
        
        # --- Method: Adaptive Avg Pooling ---
        print("  Running Avg Pooling...")
        t0 = time.time()
        db_avg = adaptive_avg_pool_int3(base_vectors, dim)
        q_avg = adaptive_avg_pool_int3(query_vectors, dim)
        
        # Evaluate Recall
        recalls_avg = evaluate_recall_batched(q_avg, db_avg, gt_top_k, [retrieval_depth], k_true_neighbors)
        time_avg = time.time() - t0
        
        metric_key = f'recall@{retrieval_depth}'
        results.append({
            'Dimension': dim,
            'Method': 'AvgPooling',
            'Recall': recalls_avg[metric_key],
            'Time': time_avg
        })
        print(f"    AvgPool {metric_key}: {recalls_avg[metric_key]:.4f}")
        
        # --- Method: PCA ---
        print("  Running PCA...")
        t0 = time.time()
        # Fit PCA on a subset for speed (max 50k)
        fit_size = min(len(base_vectors), 50000)
        pca = PCA(n_components=dim, whiten=True)
        pca.fit(base_vectors[:fit_size])
        
        db_pca = pca.transform(base_vectors)
        q_pca = pca.transform(query_vectors)
        
        db_pca_q = quantize_to_int3(db_pca)
        q_pca_q = quantize_to_int3(q_pca)
        
        recalls_pca = evaluate_recall_batched(q_pca_q, db_pca_q, gt_top_k, [retrieval_depth], k_true_neighbors)
        time_pca = time.time() - t0
        
        results.append({
            'Dimension': dim,
            'Method': 'PCA',
            'Recall': recalls_pca[metric_key],
            'Time': time_pca
        })
        print(f"    PCA {metric_key}: {recalls_pca[metric_key]:.4f}")

    return pd.DataFrame(results)

# Setup Data for Sweep (Using 1M Full DB)
# Ensure we have the data
if 'base_vectors_full' not in locals():
    base_vectors_full = base_vectors.astype(np.float32)
    query_vectors_subset = query_vectors[:1000].astype(np.float32)

# Use Full 1M Database
db_subset_sweep = base_vectors_full 
k_gt = 1000
depth = 10000

print(f"Computing GT for sweep (1M DB, Top-{k_gt})...")
# Always recalculate to be safe or check if existing gt matches requirements
# Since previous cells might have calculated k=100, we likely need to recalc for k=1000
gt_sweep = calculate_top_k_ground_truth(query_vectors_subset, db_subset_sweep, k=k_gt)

# Run Sweep
sweep_df = run_dimension_sweep(db_subset_sweep, query_vectors_subset, gt_sweep, k_true_neighbors=k_gt, retrieval_depth=depth)

# Visualization
plt.figure(figsize=(10, 6))
avg_df = sweep_df[sweep_df['Method'] == 'AvgPooling']
pca_df = sweep_df[sweep_df['Method'] == 'PCA']

plt.plot(avg_df['Dimension'], avg_df['Recall'], marker='o', linewidth=2, label='Avg Pooling + INT3')
plt.plot(pca_df['Dimension'], pca_df['Recall'], marker='s', linewidth=2, label='PCA + INT3')

plt.xlabel('Dimension', fontsize=12)
plt.ylabel(f'Recall@{depth} (Top-{k_gt} GT)', fontsize=12)
plt.title(f'Recall vs Dimension (1M DB, {k_gt}@{depth})', fontsize=14)
plt.legend(fontsize=11)
plt.grid(True, linestyle='--', alpha=0.7)
plt.gca().invert_xaxis() # Display 128 on left, 16 on right
plt.xticks(list(range(128, 15, -16)))
plt.tight_layout()
plt.show()

# Save Results
csv_filename = 'sift1m_1M_1000at10000_sweep_results.csv'
sweep_df.to_csv(csv_filename, index=False)
print(f"Sweep results saved to {csv_filename}")
sweep_df

In [None]:
# 13. PCA Dimension Difference Analysis
# Analyze absolute difference between Query and Top-100 NN in PCA space per dimension

def analyze_pca_diff(base_vectors, query_vectors, gt_indices, n_components=128):
    print("Fitting PCA...")
    # Fit PCA on a subset for speed
    fit_size = min(len(base_vectors), 50000)
    pca = PCA(n_components=n_components, whiten=True)
    pca.fit(base_vectors[:fit_size])
    
    print("Transforming data...")
    # Transform all needed data
    # We only need the specific base vectors that are in the GT of the queries
    # But transforming all might be easier if memory allows (1M * 128 * 4 bytes ~ 512MB)
    q_pca = pca.transform(query_vectors)
    db_pca = pca.transform(base_vectors)
    
    n_queries = len(query_vectors)
    k_neighbors = gt_indices.shape[1] # Should be 100
    n_dims = n_components
    
    # Array to store sum of absolute differences per dimension
    total_abs_diff = np.zeros(n_dims)
    count = 0
    
    print("Calculating differences...")
    for i in tqdm(range(n_queries), desc="Analyzing Queries"):
        q_vec = q_pca[i] # (128,)
        neighbor_indices = gt_indices[i] # (100,)
        
        # Get neighbor vectors
        neighbor_vecs = db_pca[neighbor_indices] # (100, 128)
        
        # Calculate absolute difference
        # |q - n|
        abs_diff = np.abs(neighbor_vecs - q_vec) # (100, 128)
        
        # Sum over neighbors
        total_abs_diff += np.sum(abs_diff, axis=0)
        count += k_neighbors
        
    # Mean absolute difference per dimension
    mean_abs_diff = total_abs_diff / count
    
    return mean_abs_diff, pca.explained_variance_ratio_

# Setup Data
if 'base_vectors_full' not in locals():
    base_vectors_full = base_vectors.astype(np.float32)
    query_vectors_subset = query_vectors[:1000].astype(np.float32)

# Ensure GT for Top-100 exists
k_target = 100
if 'gt_sweep' in locals() and gt_sweep.shape[1] >= k_target:
    gt_100 = gt_sweep[:, :k_target]
else:
    print("Computing GT for Top-100...")
    gt_100 = calculate_top_k_ground_truth(query_vectors_subset, base_vectors_full, k=k_target)

# Run Analysis
mean_diffs, explained_var = analyze_pca_diff(base_vectors_full, query_vectors_subset, gt_100)

# Visualization
plt.figure(figsize=(15, 6))

# Plot 1: Mean Absolute Difference
plt.subplot(1, 2, 1)
plt.bar(range(1, 129), mean_diffs, color='skyblue')
plt.xlabel('PCA Component (Dimension)')
plt.ylabel('Mean Absolute Difference')
plt.title('Mean Abs Diff between Query and Top-100 NN per PCA Dimension')
plt.grid(axis='y', linestyle='--', alpha=0.7)

# Plot 2: Explained Variance (for context)
plt.subplot(1, 2, 2)
plt.plot(range(1, 129), explained_var, marker='.', linestyle='-', color='orange')
plt.xlabel('PCA Component')
plt.ylabel('Explained Variance Ratio')
plt.title('PCA Explained Variance Ratio')
plt.grid(True, linestyle='--', alpha=0.7)

plt.tight_layout()
plt.show()

# Save to CSV
diff_df = pd.DataFrame({
    'Dimension': range(1, 129),
    'Mean_Abs_Diff': mean_diffs,
    'Explained_Variance': explained_var
})
diff_df.to_csv('sift1m_pca_128_diff_analysis.csv', index=False)
print("Analysis saved to sift1m_pca_128_diff_analysis.csv")
diff_df.head()