## 2. 导入必要的库

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.decomposition import PCA
from sklearn.metrics import pairwise_distances
import umap
import os
import time
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import urllib.request
import tarfile

# 设置设备
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# 参数设置
SIFT_DIM = 128
REDUCED_DIM = 64

## 3. 下载和解压 SIFT1M 数据集

In [None]:
def download_sift1m():
    """
    下载SIFT1M数据集
    """
    # 创建数据目录
    os.makedirs('sift1m', exist_ok=True)
    
    files = [
        "sift_base.fvecs",
        "sift_query.fvecs",
        "sift_groundtruth.ivecs"
    ]
    
    # 检查文件是否已存在
    all_exist = True
    for filename in files:
        if not os.path.exists(os.path.join('sift1m', filename)):
            all_exist = False
            break
            
    if all_exist:
        print("所有文件已存在，跳过下载")
        return True
        
    print("正在下载 SIFT1M 数据集...")
    
    # 尝试下载 tar.gz 文件
    tar_url = "ftp://ftp.irisa.fr/local/texmex/corpus/sift.tar.gz"
    tar_path = "sift1m/sift.tar.gz"
    
    try:
        print(f"尝试从 {tar_url} 下载...")
        # 增加 timeout
        import socket
        socket.setdefaulttimeout(30)
        urllib.request.urlretrieve(tar_url, tar_path)
        print("下载完成，正在解压...")
        
        with tarfile.open(tar_path, "r:gz") as tar:
            tar.extractall(path="sift1m")
            
        # 移动文件到 sift1m 根目录 (解压后会在 sift/ 目录下)
        extracted_dir = os.path.join("sift1m", "sift")
        if os.path.exists(extracted_dir):
            for filename in files:
                src = os.path.join(extracted_dir, filename)
                dst = os.path.join("sift1m", filename)
                if os.path.exists(src):
                    if os.path.exists(dst):
                        os.remove(dst)
                    os.rename(src, dst)
            # 清理
            try:
                import shutil
                shutil.rmtree(extracted_dir)
            except:
                pass
                
        # 删除 tar 文件
        if os.path.exists(tar_path):
            os.remove(tar_path)
        print("数据集准备完成")
        return True
        
    except Exception as e:
        print(f"FTP下载失败: {e}")
        print("尝试使用 wget 下载...")
        
        try:
            # 尝试使用 wget
            res = os.system(f"wget {tar_url} -O {tar_path}")
            if res == 0 and os.path.exists(tar_path):
                print("wget 下载成功，正在解压...")
                with tarfile.open(tar_path, "r:gz") as tar:
                    tar.extractall(path="sift1m")
                
                # 移动文件
                extracted_dir = os.path.join("sift1m", "sift")
                if os.path.exists(extracted_dir):
                    for filename in files:
                        src = os.path.join(extracted_dir, filename)
                        dst = os.path.join("sift1m", filename)
                        if os.path.exists(src):
                            if os.path.exists(dst):
                                os.remove(dst)
                            os.rename(src, dst)
                    try:
                        import shutil
                        shutil.rmtree(extracted_dir)
                    except:
                        pass
                if os.path.exists(tar_path):
                    os.remove(tar_path)
                return True
        except Exception as e2:
            print(f"wget 失败: {e2}")
            
        print("无法自动下载数据集。")
        print("请手动下载 sift.tar.gz 从 http://corpus-texmex.irisa.fr/")
        print("并解压到 sift1m/ 目录下。")
        return False

# 下载数据集
print("开始下载SIFT1M数据集...")
download_sift1m()

## 4. INT3 量化函数

In [None]:
def quantize_to_int3(data):
    """
    將數據量化到INT3 (-4 到 3，共8個離散值)
    改進：使用百分位數 (1% - 99%) 進行截斷，減少極端值對量化的影響
    """
    is_torch = isinstance(data, torch.Tensor)
    
    if is_torch:
        arr = data.cpu().numpy()
    else:
        arr = data
    
    # 使用百分位數來決定範圍，避免 outlier 影響
    # SIFT 特徵通常會有少數極大值
    min_val = np.percentile(arr, 1)
    max_val = np.percentile(arr, 99)
    
    # 將數據縮放到 [-4, 3] 範圍
    scaled = (arr - min_val) / (max_val - min_val + 1e-8) * 7 - 4
    
    # 四捨五入並裁剪到 INT3 範圍
    quantized = np.clip(np.round(scaled), -4, 3).astype(np.int8)
    
    if is_torch:
        return torch.from_numpy(quantized).to(data.device)
    else:
        return quantized

## 5. AutoEncoder 模型定义

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self, input_dim=128, latent_dim=64):
        super(AutoEncoder, self).__init__()
        
        # Encoder (加深並加入 BatchNorm)
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128, 96),
            nn.BatchNorm1d(96),
            nn.ReLU(),
            nn.Linear(96, latent_dim)
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 96),
            nn.BatchNorm1d(96),
            nn.ReLU(),
            nn.Linear(96, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128, input_dim)
        )
    
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded
    
    def encode(self, x):
        return self.encoder(x)

def train_autoencoder(data, input_dim=128, latent_dim=64, epochs=100, batch_size=256, patience=5, min_delta=1e-4):
    """
    訓練自編碼器 (加入 Early Stopping)
    """
    model = AutoEncoder(input_dim, latent_dim).to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.MSELoss()
    
    data_tensor = torch.FloatTensor(data).to(DEVICE)
    
    # 創建 DataLoader 以便於 batch 處理
    dataset = torch.utils.data.TensorDataset(data_tensor)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    model.train()
    
    best_loss = float('inf')
    patience_counter = 0
    
    pbar = tqdm(range(epochs), desc="训练AutoEncoder")
    for epoch in pbar:
        total_loss = 0
        num_batches = 0
        
        for batch in dataloader:
            batch_data = batch[0]
            
            optimizer.zero_grad()
            _, reconstructed = model(batch_data)
            loss = criterion(reconstructed, batch_data)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            num_batches += 1
        
        avg_loss = total_loss / num_batches
        pbar.set_postfix({'loss': f'{avg_loss:.6f}'})
        
        # Early Stopping Check
        if avg_loss < best_loss - min_delta:
            best_loss = avg_loss
            patience_counter = 0
        else:
            patience_counter += 1
            
        if patience_counter >= patience:
            print(f"Early stopping triggered at epoch {epoch+1} (Best Loss: {best_loss:.6f})")
            break
    
    return model

## 6. 五种数据处理方法

In [None]:
def method1_direct_int3(data):
    """方法1: 直接將128維向量量化為INT3"""
    return quantize_to_int3(data)

def method2_average_pooling_int3(data):
    """方法2: 將相鄰兩維做平均，從128維降到64維，然後量化為INT3"""
    reshaped = data.reshape(data.shape[0], 64, 2)
    averaged = np.mean(reshaped, axis=2)
    return quantize_to_int3(averaged)

def method3_pca_int3(data, query_data=None):
    """方法3: 使用PCA降維到64維，然後量化為INT3"""
    print("  - 训练PCA...")
    # 使用 whiten=True 來標準化分量，通常能改善後續量化效果
    pca = PCA(n_components=REDUCED_DIM, whiten=True)
    pca.fit(data)
    
    data_reduced = pca.transform(data)
    data_quantized = quantize_to_int3(data_reduced)
    
    if query_data is not None:
        query_reduced = pca.transform(query_data)
        query_quantized = quantize_to_int3(query_reduced)
        return pca, data_quantized, query_quantized
    
    return pca, data_quantized

def method4_max_pooling_int3(data, query_data=None):
    """方法4: 將相鄰兩維做最大絕對值池化 (Max Magnitude Pooling)，從128維降到64維，然後量化為INT3"""
    
    def max_magnitude_pool(arr):
        # Reshape to (N, 64, 2)
        reshaped = arr.reshape(arr.shape[0], 64, 2)
        a = reshaped[:, :, 0]
        b = reshaped[:, :, 1]
        # 比較絕對值大小
        mask = np.abs(a) >= np.abs(b)
        # 選擇絕對值較大的那個原始值
        return np.where(mask, a, b)

    max_pooled = max_magnitude_pool(data)
    data_quantized = quantize_to_int3(max_pooled)
    
    if query_data is not None:
        q_max_pooled = max_magnitude_pool(query_data)
        query_quantized = quantize_to_int3(q_max_pooled)
        return None, data_quantized, query_quantized
    
    return None, data_quantized

def method5_autoencoder_int3(data, query_data=None, epochs=100):
    """方法5: 使用AutoEncoder降維到64維，然後量化為INT3"""
    print("  - 训练AutoEncoder...")
    # 這裡 epochs 預設為 100，配合 Early Stopping
    ae_model = train_autoencoder(data, SIFT_DIM, REDUCED_DIM, epochs=epochs)
    ae_model.eval()
    
    with torch.no_grad():
        data_tensor = torch.FloatTensor(data).to(DEVICE)
        data_reduced = ae_model.encode(data_tensor).cpu().numpy()
        data_quantized = quantize_to_int3(data_reduced)
        
        if query_data is not None:
            query_tensor = torch.FloatTensor(query_data).to(DEVICE)
            query_reduced = ae_model.encode(query_tensor).cpu().numpy()
            query_quantized = quantize_to_int3(query_reduced)
            return ae_model, data_quantized, query_quantized
    
    return ae_model, data_quantized

## 7. 距离计算和评估函数

In [None]:
def calculate_l2_distances(queries, database):
    """計算L2距離"""
    return pairwise_distances(queries, database, metric='euclidean')

def evaluate_recall_at_k(distances, ground_truth, k_values=[1, 10, 100]):
    """計算Recall@K指標"""
    sorted_indices = np.argsort(distances, axis=1)
    
    recalls = {}
    for k in k_values:
        top_k_predictions = sorted_indices[:, :k]
        
        query_recalls = []
        for i in range(len(ground_truth)):
            gt_i = ground_truth[i]
            # Handle jagged arrays: check length
            limit = min(k, len(gt_i))
            
            if limit == 0:
                query_recalls.append(0.0)
                continue
                
            true_neighbors = set(gt_i[:limit])
            pred_neighbors = set(top_k_predictions[i])
            intersection = true_neighbors & pred_neighbors
            recall = len(intersection) / len(true_neighbors)
            query_recalls.append(recall)
        
        recalls[f'recall@{k}'] = np.mean(query_recalls)
    
    return recalls

## 8. 加载 SIFT1M 数据集

In [None]:
def read_fvecs(filename):
    """讀取.fvecs格式文件"""
    with open(filename, 'rb') as f:
        d = np.fromfile(f, dtype=np.int32, count=1)[0]
        f.seek(0)
        # .fvecs 格式: 4 bytes int32 (dimension) + d * 4 bytes float32 (data)
        # 读取为 float32，header 的 int32 会被读成一个 float32，reshape 后切片去掉即可
        data = np.fromfile(f, dtype=np.float32)
        data = data.reshape(-1, d + 1)
        return data[:, 1:].copy()

def read_ivecs(filename):
    """讀取.ivecs格式文件"""
    with open(filename, 'rb') as f:
        d = np.fromfile(f, dtype=np.int32, count=1)[0]
        f.seek(0)
        data = np.fromfile(f, dtype=np.int32)
        data = data.reshape(-1, d + 1)
        return data[:, 1:].copy()

print("载入SIFT1M数据集...")
base_vectors = read_fvecs('sift1m/sift_base.fvecs')
query_vectors = read_fvecs('sift1m/sift_query.fvecs')
ground_truth = read_ivecs('sift1m/sift_groundtruth.ivecs')

print(f"Base vectors shape: {base_vectors.shape}")
print(f"Query vectors shape: {query_vectors.shape}")
print(f"Ground truth shape: {ground_truth.shape}")

## 9. 运行实验

In [None]:
def calculate_top_k_ground_truth(queries, database, k=100):
    """
    計算每個查詢的前 K 個真實最近鄰居 (Ground Truth)
    使用 GPU 加速計算，並分批處理 Query 以節省記憶體
    """
    print(f"正在計算 Top-{k} Ground Truth (Query: {len(queries)}, DB: {len(database)})...")
    
    num_queries = len(queries)
    query_batch_size = 100  # 每次處理 100 個 Query
    
    all_indices = []
    
    for i in tqdm(range(0, num_queries, query_batch_size), desc="Computing GT"):
        q_end = min(i + query_batch_size, num_queries)
        q_batch = queries[i:q_end]
        q_tensor = torch.FloatTensor(q_batch).to(DEVICE)
        q_sq = torch.sum(q_tensor**2, dim=1, keepdim=True)
        
        dists_batch = []
        
        # 分批處理 DB
        db_batch_size = 50000
        for j in range(0, len(database), db_batch_size):
            db_end = min(j + db_batch_size, len(database))
            db_chunk = torch.FloatTensor(database[j:db_end]).to(DEVICE)
            
            db_sq_chunk = torch.sum(db_chunk**2, dim=1)
            term2 = -2 * torch.matmul(q_tensor, db_chunk.t())
            
            dists_chunk = q_sq + db_sq_chunk + term2
            dists_batch.append(dists_chunk.cpu())
            
            del db_chunk, db_sq_chunk, term2, dists_chunk
            torch.cuda.empty_cache()
            
        full_dists_batch = torch.cat(dists_batch, dim=1)
        
        # 取 Top K
        _, indices = torch.topk(full_dists_batch, k=k, dim=1, largest=False)
        all_indices.append(indices.numpy())
        
        del full_dists_batch, indices, q_tensor, q_sq
        torch.cuda.empty_cache()
        
    return np.vstack(all_indices)

def evaluate_recall_batched(query_vectors, db_vectors, gt_top_k, retrieval_depths, k_true_neighbors):
    """
    分批計算距離並評估 Recall
    Recall = (Retrieved & Top-K_GT) / K_GT
    """
    max_depth = max(retrieval_depths)
    num_queries = len(query_vectors)
    query_batch_size = 100
    
    total_hits = {r: 0 for r in retrieval_depths}
    
    # Handle Tensor inputs
    if isinstance(query_vectors, torch.Tensor):
        query_vectors = query_vectors.cpu().numpy()
    if isinstance(db_vectors, torch.Tensor):
        db_vectors = db_vectors.cpu().numpy()
    
    if query_vectors.dtype != np.float32:
        query_vectors = query_vectors.astype(np.float32)
    if db_vectors.dtype != np.float32:
        db_vectors = db_vectors.astype(np.float32)
        
    for i in tqdm(range(0, num_queries, query_batch_size), desc="Evaluating Batches", leave=False):
        q_end = min(i + query_batch_size, num_queries)
        q_batch = query_vectors[i:q_end]
        gt_batch = gt_top_k[i:q_end]
        
        dists = pairwise_distances(q_batch, db_vectors, metric='euclidean')
        sorted_indices = np.argsort(dists, axis=1)[:, :max_depth]
        
        for r in retrieval_depths:
            retrieved_indices = sorted_indices[:, :r]
            for j in range(len(gt_batch)):
                gt_set = set(gt_batch[j]) # 這是 Top K 真實鄰居
                retrieved_set = set(retrieved_indices[j])
                total_hits[r] += len(gt_set & retrieved_set)
                
    recalls = {}
    for r in retrieval_depths:
        # Normalize by k_true_neighbors (e.g., 100)
        recalls[f'recall@{r}'] = total_hits[r] / (num_queries * float(k_true_neighbors))
        
    return recalls

def method5_autoencoder_int3_with_timing(data, query_data=None, epochs=100):
    """
    方法5: 使用AutoEncoder降維到64維，然後量化為INT3
    返回額外的訓練時間和推理時間
    """
    print("  - 训练AutoEncoder...")
    
    # 訓練時間
    train_start = time.time()
    ae_model = train_autoencoder(data, SIFT_DIM, REDUCED_DIM, epochs=epochs)
    train_time = time.time() - train_start
    
    ae_model.eval()
    
    # 推理時間
    inference_start = time.time()
    with torch.no_grad():
        data_tensor = torch.FloatTensor(data).to(DEVICE)
        data_reduced = ae_model.encode(data_tensor).cpu().numpy()
        data_quantized = quantize_to_int3(data_reduced)
        
        if query_data is not None:
            query_tensor = torch.FloatTensor(query_data).to(DEVICE)
            query_reduced = ae_model.encode(query_tensor).cpu().numpy()
            query_quantized = quantize_to_int3(query_reduced)
            inference_time = time.time() - inference_start
            return ae_model, data_quantized, query_quantized, train_time, inference_time
    
    inference_time = time.time() - inference_start
    return ae_model, data_quantized, train_time, inference_time

def run_experiment_part(part_name, base_vectors, query_vectors, gt_top_k, retrieval_depths, k_true_neighbors):
    """
    執行單個部分的實驗
    返回 recalls DataFrame 和 times DataFrame (不含 AE 訓練時間)
    """
    print(f"\n{'='*80}")
    print(f"實驗部分: {part_name}")
    print(f"Database Size: {len(base_vectors)}")
    print(f"Retrieval Depths: {retrieval_depths}")
    print(f"Target GT Size (K): {k_true_neighbors}")
    print(f"{'='*80}\n")
    
    results = []
    time_results = []  # 存儲時間結果 (不含 AE 訓練時間)

    # 方法1
    print("\n[方法1] 直接INT3量化 (128維)...")
    t0 = time.time()
    db_m1 = method1_direct_int3(base_vectors)
    q_m1 = method1_direct_int3(query_vectors)
    process_time_m1 = time.time() - t0
    
    t0 = time.time()
    recalls_m1 = evaluate_recall_batched(q_m1, db_m1, gt_top_k, retrieval_depths, k_true_neighbors)
    eval_time_m1 = time.time() - t0
    total_time_m1 = process_time_m1 + eval_time_m1
    
    results.append({'method': 'Method 1: Direct INT3', 'time': total_time_m1, **recalls_m1})
    time_results.append({'method': 'M1: Direct INT3', 'process_time': process_time_m1, 'eval_time': eval_time_m1, 'total_time': total_time_m1})
    print(recalls_m1)
    
    # 方法2
    print("\n[方法2] 平均池化降維 + INT3 (64維)...")
    t0 = time.time()
    db_m2 = method2_average_pooling_int3(base_vectors)
    q_m2 = method2_average_pooling_int3(query_vectors)
    process_time_m2 = time.time() - t0
    
    t0 = time.time()
    recalls_m2 = evaluate_recall_batched(q_m2, db_m2, gt_top_k, retrieval_depths, k_true_neighbors)
    eval_time_m2 = time.time() - t0
    total_time_m2 = process_time_m2 + eval_time_m2
    
    results.append({'method': 'Method 2: AvgPooling + INT3', 'time': total_time_m2, **recalls_m2})
    time_results.append({'method': 'M2: AvgPooling', 'process_time': process_time_m2, 'eval_time': eval_time_m2, 'total_time': total_time_m2})
    print(recalls_m2)
    
    # 方法3
    print("\n[方法3] PCA降維 + INT3 (64維)...")
    t0 = time.time()
    _, db_m3, q_m3 = method3_pca_int3(base_vectors, query_vectors)
    process_time_m3 = time.time() - t0
    
    t0 = time.time()
    recalls_m3 = evaluate_recall_batched(q_m3, db_m3, gt_top_k, retrieval_depths, k_true_neighbors)
    eval_time_m3 = time.time() - t0
    total_time_m3 = process_time_m3 + eval_time_m3
    
    results.append({'method': 'Method 3: PCA + INT3', 'time': total_time_m3, **recalls_m3})
    time_results.append({'method': 'M3: PCA', 'process_time': process_time_m3, 'eval_time': eval_time_m3, 'total_time': total_time_m3})
    print(recalls_m3)
    
    # 方法4
    print(f"\n[方法4] Max Magnitude Pooling降維 + INT3 (64維)...")
    t0 = time.time()
    _, db_m4, q_m4 = method4_max_pooling_int3(base_vectors, query_vectors)
    process_time_m4 = time.time() - t0
    
    t0 = time.time()
    recalls_m4 = evaluate_recall_batched(q_m4, db_m4, gt_top_k, retrieval_depths, k_true_neighbors)
    eval_time_m4 = time.time() - t0
    total_time_m4 = process_time_m4 + eval_time_m4
    
    results.append({'method': 'Method 4: Max Mag Pooling + INT3', 'time': total_time_m4, **recalls_m4})
    time_results.append({'method': 'M4: MaxMagPool', 'process_time': process_time_m4, 'eval_time': eval_time_m4, 'total_time': total_time_m4})
    print(recalls_m4)
    
    # 方法5 - 使用新的 timing 版本
    print("\n[方法5] AutoEncoder降維 + INT3 (64維)...")
    _, db_m5, q_m5, ae_train_time, ae_inference_time = method5_autoencoder_int3_with_timing(base_vectors, query_vectors, epochs=100)
    
    t0 = time.time()
    recalls_m5 = evaluate_recall_batched(q_m5, db_m5, gt_top_k, retrieval_depths, k_true_neighbors)
    eval_time_m5 = time.time() - t0
    
    # 總時間不含訓練時間
    total_time_m5_no_train = ae_inference_time + eval_time_m5
    total_time_m5_with_train = ae_train_time + ae_inference_time + eval_time_m5
    
    results.append({'method': 'Method 5: AutoEncoder + INT3', 'time': total_time_m5_with_train, 
                   'time_no_train': total_time_m5_no_train, 'train_time': ae_train_time, **recalls_m5})
    time_results.append({'method': 'M5: AutoEncoder', 'process_time': ae_inference_time, 'eval_time': eval_time_m5, 
                        'total_time': total_time_m5_no_train, 'train_time': ae_train_time})
    print(recalls_m5)
    print(f"  AE 訓練時間: {ae_train_time:.2f}s, 推理時間: {ae_inference_time:.2f}s")
    
    return pd.DataFrame(results), pd.DataFrame(time_results)

In [None]:
# 執行實驗
# 實驗配置：
# Part 1: 100K DB, Recall 100@1000, 100@5000, 100@10000
# Part 2: 1M DB, Recall 100@1000, 100@5000, 100@10000

all_experiment_results = []  # 存儲 recall 結果
all_time_results = []  # 存儲時間結果

# --- Part 1: 100K Database ---
print("\n" + "="*80)
print("Part 1: First 100K Database")
print("="*80)

db_100k = base_vectors[:100000].astype(np.float32)
query_subset = query_vectors[:1000].astype(np.float32)

# 計算 Ground Truth for 100K DB
print("Computing Ground Truth for 100K DB...")
gt_100k = calculate_top_k_ground_truth(query_subset, db_100k, k=100)

# 執行實驗 Part 1
retrieval_depths_1 = [1000, 5000, 10000]
k_true_neighbors = 100

results_df_1, time_df_1 = run_experiment_part(
    "Part 1: 100K DB",
    db_100k,
    query_subset,
    gt_100k,
    retrieval_depths_1,
    k_true_neighbors
)

all_experiment_results.append(results_df_1)
all_time_results.append(time_df_1)

print("\nPart 1 Recall Results:")
display(results_df_1)
print("\nPart 1 Time Results (不含 AE 訓練時間):")
display(time_df_1)

# --- Part 2: 1M Database ---
print("\n" + "="*80)
print("Part 2: Full 1M Database")
print("="*80)

db_1m = base_vectors.astype(np.float32)

# 計算 Ground Truth for 1M DB
print("Computing Ground Truth for 1M DB...")
gt_1m = calculate_top_k_ground_truth(query_subset, db_1m, k=100)

# 執行實驗 Part 2
retrieval_depths_2 = [1000, 5000, 10000]

results_df_2, time_df_2 = run_experiment_part(
    "Part 2: 1M DB",
    db_1m,
    query_subset,
    gt_1m,
    retrieval_depths_2,
    k_true_neighbors
)

all_experiment_results.append(results_df_2)
all_time_results.append(time_df_2)

print("\nPart 2 Recall Results:")
display(results_df_2)
print("\nPart 2 Time Results (不含 AE 訓練時間):")
display(time_df_2)

print("\n" + "="*80)
print("所有實驗完成！")
print("="*80)

## 10. 可视化结果

In [None]:
# 可視化結果
# Part 1: 100K DB, Part 2: 1M DB
# Recall 100 @1000, @5000, @10000

# --- 圖一：Recall 作圖 ---
fig, axes = plt.subplots(1, 2, figsize=(16, 7))
fig.suptitle('Recall of Top-100 True Neighbors @ Different Retrieval Depths', fontsize=16, fontweight='bold')

colors = ['#1f77b4', '#ff7f0e', '#2ca02c']

# Part 1: 100K DB
ax1 = axes[0]
df1 = all_experiment_results[0]
methods = df1['method'].str.replace('Method ', 'M').str.replace(': ', '\n', 1)
x = np.arange(len(df1))
width = 0.25

bars1_1 = ax1.bar(x - width, df1['recall@1000'], width, label='@1000', color=colors[0])
bars1_2 = ax1.bar(x, df1['recall@5000'], width, label='@5000', color=colors[1])
bars1_3 = ax1.bar(x + width, df1['recall@10000'], width, label='@10000', color=colors[2])

ax1.set_xticks(x)
ax1.set_xticklabels(methods, rotation=45, ha='right', fontsize=9)
ax1.set_ylabel('Recall', fontsize=12, fontweight='bold')
ax1.set_title('Part 1: 100K DB (100@1000, 100@5000, 100@10000)', fontsize=14, fontweight='bold')
ax1.set_ylim([0, 1.05])
ax1.legend(loc='lower right')
ax1.grid(axis='y', linestyle='--', alpha=0.7)

# Part 2: 1M DB
ax2 = axes[1]
df2 = all_experiment_results[1]

bars2_1 = ax2.bar(x - width, df2['recall@1000'], width, label='@1000', color=colors[0])
bars2_2 = ax2.bar(x, df2['recall@5000'], width, label='@5000', color=colors[1])
bars2_3 = ax2.bar(x + width, df2['recall@10000'], width, label='@10000', color=colors[2])

ax2.set_xticks(x)
ax2.set_xticklabels(methods, rotation=45, ha='right', fontsize=9)
ax2.set_ylabel('Recall', fontsize=12, fontweight='bold')
ax2.set_title('Part 2: 1M DB (100@1000, 100@5000, 100@10000)', fontsize=14, fontweight='bold')
ax2.set_ylim([0, 1.05])
ax2.legend(loc='lower right')
ax2.grid(axis='y', linestyle='--', alpha=0.7)

plt.tight_layout()
plt.savefig('sift1m_2parts_recall_results.png', dpi=300, bbox_inches='tight')
plt.show()

# --- 圖二：時間作圖 (不含 AE 訓練時間) ---
fig2, axes2 = plt.subplots(1, 2, figsize=(16, 7))
fig2.suptitle('Processing Time Comparison (Excluding AE Training Time)', fontsize=16, fontweight='bold')

time_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd']

# Part 1: 100K DB 時間
ax3 = axes2[0]
time_df_1 = all_time_results[0]
time_methods = time_df_1['method']
bars3 = ax3.bar(range(len(time_df_1)), time_df_1['total_time'], color=time_colors)
ax3.set_xticks(range(len(time_df_1)))
ax3.set_xticklabels(time_methods, rotation=45, ha='right', fontsize=9)
ax3.set_ylabel('Time (seconds)', fontsize=12, fontweight='bold')
ax3.set_title('Part 1: 100K DB - Processing Time', fontsize=14, fontweight='bold')
ax3.grid(axis='y', linestyle='--', alpha=0.7)

# 在柱狀圖上顯示數值
for bar in bars3:
    height = bar.get_height()
    ax3.text(bar.get_x() + bar.get_width()/2., height + 0.5, f'{height:.2f}s', ha='center', va='bottom', fontsize=9)

# Part 2: 1M DB 時間
ax4 = axes2[1]
time_df_2 = all_time_results[1]
bars4 = ax4.bar(range(len(time_df_2)), time_df_2['total_time'], color=time_colors)
ax4.set_xticks(range(len(time_df_2)))
ax4.set_xticklabels(time_methods, rotation=45, ha='right', fontsize=9)
ax4.set_ylabel('Time (seconds)', fontsize=12, fontweight='bold')
ax4.set_title('Part 2: 1M DB - Processing Time', fontsize=14, fontweight='bold')
ax4.grid(axis='y', linestyle='--', alpha=0.7)

# 在柱狀圖上顯示數值
for bar in bars4:
    height = bar.get_height()
    ax4.text(bar.get_x() + bar.get_width()/2., height + 0.5, f'{height:.2f}s', ha='center', va='bottom', fontsize=9)

plt.tight_layout()
plt.savefig('sift1m_2parts_time_results.png', dpi=300, bbox_inches='tight')
plt.show()

# --- 圖三：詳細時間分解（包含 process_time 和 eval_time） ---
fig3, axes3 = plt.subplots(1, 2, figsize=(16, 7))
fig3.suptitle('Detailed Time Breakdown (Excluding AE Training Time)', fontsize=16, fontweight='bold')

# Part 1: 100K DB 詳細時間
ax5 = axes3[0]
x = np.arange(len(time_df_1))
width = 0.35

bars5_1 = ax5.bar(x - width/2, time_df_1['process_time'], width, label='Processing', color='#2ca02c')
bars5_2 = ax5.bar(x + width/2, time_df_1['eval_time'], width, label='Evaluation', color='#d62728')

ax5.set_xticks(x)
ax5.set_xticklabels(time_methods, rotation=45, ha='right', fontsize=9)
ax5.set_ylabel('Time (seconds)', fontsize=12, fontweight='bold')
ax5.set_title('Part 1: 100K DB - Time Breakdown', fontsize=14, fontweight='bold')
ax5.legend()
ax5.grid(axis='y', linestyle='--', alpha=0.7)

# Part 2: 1M DB 詳細時間
ax6 = axes3[1]

bars6_1 = ax6.bar(x - width/2, time_df_2['process_time'], width, label='Processing', color='#2ca02c')
bars6_2 = ax6.bar(x + width/2, time_df_2['eval_time'], width, label='Evaluation', color='#d62728')

ax6.set_xticks(x)
ax6.set_xticklabels(time_methods, rotation=45, ha='right', fontsize=9)
ax6.set_ylabel('Time (seconds)', fontsize=12, fontweight='bold')
ax6.set_title('Part 2: 1M DB - Time Breakdown', fontsize=14, fontweight='bold')
ax6.legend()
ax6.grid(axis='y', linestyle='--', alpha=0.7)

plt.tight_layout()
plt.savefig('sift1m_2parts_time_breakdown.png', dpi=300, bbox_inches='tight')
plt.show()

print("\n圖表已儲存！")

## 11. 实验总结

In [None]:
print("\n" + "="*80)
print("實驗總結")
print("="*80)

for config, results_df in zip(EXPERIMENTS, all_results):
    print(f"\n配置: {config['query_size']}@{config['database_size']}")
    display(results_df)

print("\n所有實驗完成！")

In [None]:
# 12. Dimension Reduction Analysis (128 -> 16)
# Analyzing Recall Rate vs Dimension for AvgPooling and PCA
# Configuration: 1M Database, Recall of Top-1000 GT @ Depth 10000

def adaptive_avg_pool_int3(data, target_dim):
    """
    Adaptive Average Pooling to reduce to specific target dimension
    """
    if isinstance(data, np.ndarray):
        t = torch.from_numpy(data).float()
    else:
        t = data.float()
    
    t = t.to(DEVICE)
    # Input: (N, 128) -> (N, 1, 128) for pooling
    t = t.unsqueeze(1)
    # Adaptive Pool to (N, 1, target_dim)
    pooled = nn.functional.adaptive_avg_pool1d(t, target_dim)
    # (N, target_dim)
    pooled = pooled.squeeze(1)
    
    return quantize_to_int3(pooled)

def run_dimension_sweep(base_vectors, query_vectors, gt_top_k, k_true_neighbors=1000, retrieval_depth=10000):
    # Dimensions: 128, 112, 96, ..., 16
    target_dims = list(range(128, 15, -16)) 
    results = []
    
    print(f"Starting Dimension Sweep: {target_dims}")
    print(f"Database Size: {len(base_vectors)}")
    print(f"Target GT Size: {k_true_neighbors}")
    print(f"Retrieval Depth: {retrieval_depth}")
    
    for dim in target_dims:
        print(f"\nTesting Target Dimension: {dim}")
        
        # --- Method: Adaptive Avg Pooling ---
        print("  Running Avg Pooling...")
        t0 = time.time()
        db_avg = adaptive_avg_pool_int3(base_vectors, dim)
        q_avg = adaptive_avg_pool_int3(query_vectors, dim)
        
        # Evaluate Recall
        recalls_avg = evaluate_recall_batched(q_avg, db_avg, gt_top_k, [retrieval_depth], k_true_neighbors)
        time_avg = time.time() - t0
        
        metric_key = f'recall@{retrieval_depth}'
        results.append({
            'Dimension': dim,
            'Method': 'AvgPooling',
            'Recall': recalls_avg[metric_key],
            'Time': time_avg
        })
        print(f"    AvgPool {metric_key}: {recalls_avg[metric_key]:.4f}")
        
        # --- Method: PCA ---
        print("  Running PCA...")
        t0 = time.time()
        # Fit PCA on a subset for speed (max 50k)
        fit_size = min(len(base_vectors), 50000)
        pca = PCA(n_components=dim, whiten=True)
        pca.fit(base_vectors[:fit_size])
        
        db_pca = pca.transform(base_vectors)
        q_pca = pca.transform(query_vectors)
        
        db_pca_q = quantize_to_int3(db_pca)
        q_pca_q = quantize_to_int3(q_pca)
        
        recalls_pca = evaluate_recall_batched(q_pca_q, db_pca_q, gt_top_k, [retrieval_depth], k_true_neighbors)
        time_pca = time.time() - t0
        
        results.append({
            'Dimension': dim,
            'Method': 'PCA',
            'Recall': recalls_pca[metric_key],
            'Time': time_pca
        })
        print(f"    PCA {metric_key}: {recalls_pca[metric_key]:.4f}")

    return pd.DataFrame(results)

# Setup Data for Sweep (Using 1M Full DB)
# Ensure we have the data
if 'base_vectors_full' not in locals():
    base_vectors_full = base_vectors.astype(np.float32)
    query_vectors_subset = query_vectors[:1000].astype(np.float32)

# Use Full 1M Database
db_subset_sweep = base_vectors_full 
k_gt = 1000
depth = 10000

print(f"Computing GT for sweep (1M DB, Top-{k_gt})...")
# Always recalculate to be safe or check if existing gt matches requirements
# Since previous cells might have calculated k=100, we likely need to recalc for k=1000
gt_sweep = calculate_top_k_ground_truth(query_vectors_subset, db_subset_sweep, k=k_gt)

# Run Sweep
sweep_df = run_dimension_sweep(db_subset_sweep, query_vectors_subset, gt_sweep, k_true_neighbors=k_gt, retrieval_depth=depth)

# Visualization
plt.figure(figsize=(10, 6))
avg_df = sweep_df[sweep_df['Method'] == 'AvgPooling']
pca_df = sweep_df[sweep_df['Method'] == 'PCA']

plt.plot(avg_df['Dimension'], avg_df['Recall'], marker='o', linewidth=2, label='Avg Pooling + INT3')
plt.plot(pca_df['Dimension'], pca_df['Recall'], marker='s', linewidth=2, label='PCA + INT3')

plt.xlabel('Dimension', fontsize=12)
plt.ylabel(f'Recall@{depth} (Top-{k_gt} GT)', fontsize=12)
plt.title(f'Recall vs Dimension (1M DB, {k_gt}@{depth})', fontsize=14)
plt.legend(fontsize=11)
plt.grid(True, linestyle='--', alpha=0.7)
plt.gca().invert_xaxis() # Display 128 on left, 16 on right
plt.xticks(list(range(128, 15, -16)))
plt.tight_layout()
plt.show()

# Save Results
csv_filename = 'sift1m_1M_1000at10000_sweep_results.csv'
sweep_df.to_csv(csv_filename, index=False)
print(f"Sweep results saved to {csv_filename}")
sweep_df

In [None]:
# 13. PCA Dimension Difference Analysis
# Analyze absolute difference between Query and Top-100 NN in PCA space per dimension

def analyze_pca_diff(base_vectors, query_vectors, gt_indices, n_components=128):
    print("Fitting PCA...")
    # Fit PCA on a subset for speed
    fit_size = min(len(base_vectors), 50000)
    pca = PCA(n_components=n_components, whiten=True)
    pca.fit(base_vectors[:fit_size])
    
    print("Transforming data...")
    # Transform all needed data
    # We only need the specific base vectors that are in the GT of the queries
    # But transforming all might be easier if memory allows (1M * 128 * 4 bytes ~ 512MB)
    q_pca = pca.transform(query_vectors)
    db_pca = pca.transform(base_vectors)
    
    n_queries = len(query_vectors)
    k_neighbors = gt_indices.shape[1] # Should be 100
    n_dims = n_components
    
    # Array to store sum of absolute differences per dimension
    total_abs_diff = np.zeros(n_dims)
    count = 0
    
    print("Calculating differences...")
    for i in tqdm(range(n_queries), desc="Analyzing Queries"):
        q_vec = q_pca[i] # (128,)
        neighbor_indices = gt_indices[i] # (100,)
        
        # Get neighbor vectors
        neighbor_vecs = db_pca[neighbor_indices] # (100, 128)
        
        # Calculate absolute difference
        # |q - n|
        abs_diff = np.abs(neighbor_vecs - q_vec) # (100, 128)
        
        # Sum over neighbors
        total_abs_diff += np.sum(abs_diff, axis=0)
        count += k_neighbors
        
    # Mean absolute difference per dimension
    mean_abs_diff = total_abs_diff / count
    
    return mean_abs_diff, pca.explained_variance_ratio_

# Setup Data
if 'base_vectors_full' not in locals():
    base_vectors_full = base_vectors.astype(np.float32)
    query_vectors_subset = query_vectors[:1000].astype(np.float32)

# Ensure GT for Top-100 exists
k_target = 100
if 'gt_sweep' in locals() and gt_sweep.shape[1] >= k_target:
    gt_100 = gt_sweep[:, :k_target]
else:
    print("Computing GT for Top-100...")
    gt_100 = calculate_top_k_ground_truth(query_vectors_subset, base_vectors_full, k=k_target)

# Run Analysis
mean_diffs, explained_var = analyze_pca_diff(base_vectors_full, query_vectors_subset, gt_100)

# Visualization
plt.figure(figsize=(15, 6))

# Plot 1: Mean Absolute Difference
plt.subplot(1, 2, 1)
plt.bar(range(1, 129), mean_diffs, color='skyblue')
plt.xlabel('PCA Component (Dimension)')
plt.ylabel('Mean Absolute Difference')
plt.title('Mean Abs Diff between Query and Top-100 NN per PCA Dimension')
plt.grid(axis='y', linestyle='--', alpha=0.7)

# Plot 2: Explained Variance (for context)
plt.subplot(1, 2, 2)
plt.plot(range(1, 129), explained_var, marker='.', linestyle='-', color='orange')
plt.xlabel('PCA Component')
plt.ylabel('Explained Variance Ratio')
plt.title('PCA Explained Variance Ratio')
plt.grid(True, linestyle='--', alpha=0.7)

plt.tight_layout()
plt.show()

# Save to CSV
diff_df = pd.DataFrame({
    'Dimension': range(1, 129),
    'Mean_Abs_Diff': mean_diffs,
    'Explained_Variance': explained_var
})
diff_df.to_csv('sift1m_pca_128_diff_analysis.csv', index=False)
print("Analysis saved to sift1m_pca_128_diff_analysis.csv")
diff_df.head()