## 2. 导入必要的库

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.decomposition import PCA
from sklearn.metrics import pairwise_distances
import umap
import os
import time
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import urllib.request
import tarfile

# 设置设备
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# 参数设置
SIFT_DIM = 128
REDUCED_DIM = 64

## 3. 下载和解压 SIFT1M 数据集

In [None]:
def download_sift1m():
    """
    下载SIFT1M数据集
    """
    # 创建数据目录
    os.makedirs('sift1m', exist_ok=True)
    
    files = [
        "sift_base.fvecs",
        "sift_query.fvecs",
        "sift_groundtruth.ivecs"
    ]
    
    # 检查文件是否已存在
    all_exist = True
    for filename in files:
        if not os.path.exists(os.path.join('sift1m', filename)):
            all_exist = False
            break
            
    if all_exist:
        print("所有文件已存在，跳过下载")
        return True
        
    print("正在下载 SIFT1M 数据集...")
    
    # 尝试下载 tar.gz 文件
    tar_url = "ftp://ftp.irisa.fr/local/texmex/corpus/sift.tar.gz"
    tar_path = "sift1m/sift.tar.gz"
    
    try:
        print(f"尝试从 {tar_url} 下载...")
        # 增加 timeout
        import socket
        socket.setdefaulttimeout(30)
        urllib.request.urlretrieve(tar_url, tar_path)
        print("下载完成，正在解压...")
        
        with tarfile.open(tar_path, "r:gz") as tar:
            tar.extractall(path="sift1m")
            
        # 移动文件到 sift1m 根目录 (解压后会在 sift/ 目录下)
        extracted_dir = os.path.join("sift1m", "sift")
        if os.path.exists(extracted_dir):
            for filename in files:
                src = os.path.join(extracted_dir, filename)
                dst = os.path.join("sift1m", filename)
                if os.path.exists(src):
                    if os.path.exists(dst):
                        os.remove(dst)
                    os.rename(src, dst)
            # 清理
            try:
                import shutil
                shutil.rmtree(extracted_dir)
            except:
                pass
                
        # 删除 tar 文件
        if os.path.exists(tar_path):
            os.remove(tar_path)
        print("数据集准备完成")
        return True
        
    except Exception as e:
        print(f"FTP下载失败: {e}")
        print("尝试使用 wget 下载...")
        
        try:
            # 尝试使用 wget
            res = os.system(f"wget {tar_url} -O {tar_path}")
            if res == 0 and os.path.exists(tar_path):
                print("wget 下载成功，正在解压...")
                with tarfile.open(tar_path, "r:gz") as tar:
                    tar.extractall(path="sift1m")
                
                # 移动文件
                extracted_dir = os.path.join("sift1m", "sift")
                if os.path.exists(extracted_dir):
                    for filename in files:
                        src = os.path.join(extracted_dir, filename)
                        dst = os.path.join("sift1m", filename)
                        if os.path.exists(src):
                            if os.path.exists(dst):
                                os.remove(dst)
                            os.rename(src, dst)
                    try:
                        import shutil
                        shutil.rmtree(extracted_dir)
                    except:
                        pass
                if os.path.exists(tar_path):
                    os.remove(tar_path)
                return True
        except Exception as e2:
            print(f"wget 失败: {e2}")
            
        print("无法自动下载数据集。")
        print("请手动下载 sift.tar.gz 从 http://corpus-texmex.irisa.fr/")
        print("并解压到 sift1m/ 目录下。")
        return False

# 下载数据集
print("开始下载SIFT1M数据集...")
download_sift1m()

## 4. INT3 量化函数

In [None]:
def quantize_to_int3(data):
    """
    將數據量化到INT3 (-4 到 3，共8個離散值)
    改進：使用百分位數 (1% - 99%) 進行截斷，減少極端值對量化的影響
    支援 PyTorch GPU 運算
    """
    if isinstance(data, torch.Tensor):
        # PyTorch implementation
        # Ensure data is on the correct device
        if data.device != DEVICE:
            data = data.to(DEVICE)
            
        arr = data.float() # Ensure float for quantile
        
        # Calculate percentiles on GPU
        # Note: quantile requires float
        # Optimization: For large tensors, estimate quantiles from a subset
        # torch.quantile can fail on very large tensors (RuntimeError: quantile() input tensor is too large)
        if arr.numel() > 100000:
            # Sample ~100k elements uniformly
            step = max(1, arr.numel() // 100000)
            # Use view(-1) to flatten and slice
            sample = arr.view(-1)[::step]
            min_val = torch.quantile(sample, 0.01)
            max_val = torch.quantile(sample, 0.99)
        else:
            min_val = torch.quantile(arr, 0.01)
            max_val = torch.quantile(arr, 0.99)
        
        # Scale to [-4, 3]
        scaled = (arr - min_val) / (max_val - min_val + 1e-8) * 7 - 4
        
        # Round and clip
        quantized = torch.clamp(torch.round(scaled), -4, 3).to(torch.int8)
        return quantized
    else:
        # Numpy implementation
        arr = data
        # Similar optimization for numpy to speed up
        if arr.size > 100000:
             step = max(1, arr.size // 100000)
             sample = arr.ravel()[::step]
             min_val = np.percentile(sample, 1)
             max_val = np.percentile(sample, 99)
        else:
             min_val = np.percentile(arr, 1)
             max_val = np.percentile(arr, 99)
             
        scaled = (arr - min_val) / (max_val - min_val + 1e-8) * 7 - 4
        quantized = np.clip(np.round(scaled), -4, 3).astype(np.int8)
        return quantized

class PerDimensionQuantileQuantizer:
    """
    Quantize to INT3 using per-dimension quantiles.
    For each dimension, it calculates 7 thresholds (1/8, 2/8, ..., 7/8 quantiles)
    from the training data, and maps values to bins -4 to 3.
    This ensures that for each dimension, the distribution of quantized values
    is approximately uniform across the 8 bins.
    """
    def __init__(self):
        self.thresholds = None
        self.device = DEVICE

    def fit(self, data):
        if isinstance(data, np.ndarray):
            data = torch.from_numpy(data)
        if data.device != self.device:
            data = data.to(self.device)
        
        data = data.float()
        N, D = data.shape
        
        # Calculate quantiles for each dimension
        # quantiles: (7, D)
        q_vals = torch.tensor([0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875], device=self.device)
        
        # Use torch.quantile with dim=0 (requires reasonably new pytorch)
        # If data is very large, we might want to sample, but for 1M x 128 it fits in GPU memory usually.
        # 1M * 128 * 4 bytes = 512MB.
        try:
            self.thresholds = torch.quantile(data, q_vals, dim=0) # (7, D)
        except RuntimeError:
            # Fallback if memory issue or old pytorch
            # Process per dimension or sample
            if N > 100000:
                step = N // 100000
                sample = data[::step]
                self.thresholds = torch.quantile(sample, q_vals, dim=0)
            else:
                raise

    def transform(self, data):
        if self.thresholds is None:
            raise ValueError("Quantizer not fitted")
            
        if isinstance(data, np.ndarray):
            data = torch.from_numpy(data)
        if data.device != self.device:
            data = data.to(self.device)
            
        data = data.float()
        
        # Broadcast comparison
        # data: (N, D) -> (N, 1, D)
        # thresholds: (7, D) -> (1, 7, D)
        # comparison: (N, 7, D)
        
        # We want to count how many thresholds are smaller than data
        # If x < t1 (smallest), count is 0 -> map to -4
        # If x >= t7 (largest), count is 7 -> map to 3
        
        comparison = data.unsqueeze(1) >= self.thresholds.unsqueeze(0)
        rank = torch.sum(comparison, dim=1).to(torch.int8) # (N, D), values 0-7
        
        return rank - 4
        
    def fit_transform(self, data):
        self.fit(data)
        return self.transform(data)

## 5. AutoEncoder 模型定义

In [None]:
# 5. 定义AutoEncoder模型和训练函数

class AutoEncoder(nn.Module):
    def __init__(self, input_dim, encoding_dim):
        super(AutoEncoder, self).__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 96),
            nn.ReLU(True),
            nn.Linear(96, encoding_dim),
            nn.Tanh() # 輸出範圍 [-1, 1]，適合量化
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(encoding_dim, 96),
            nn.ReLU(True),
            nn.Linear(96, input_dim),
            # nn.Sigmoid() # SIFT 數據大概是 0-255 或 0-1? 
            # SIFT特徵通常是已標準化的，包含負數嗎？SIFT通常非負。
            # 這裡不加 Sigmoid，直接輸出線性值，配合 MSE
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded
    
    def encode(self, x):
        return self.encoder(x)

def train_autoencoder(data, input_dim, encoding_dim, epochs=50, batch_size=256, 
                      distance_loss_weight=0.1, use_grouped_latent_dist=False, threshold_config=None):
    """
    训练AutoEncoder
    threshold_config: dict, optional. 
        {'target_threshold': float, 'input_limit': float}
        If provided, adds a penalty for group distances > target_threshold for neighbors (input_dist < input_limit).
        Penalty based on user idea: d_g * exp(ReLU(d_g - T))
    """
    input_dim = data.shape[1]
    
    model = AutoEncoder(input_dim, encoding_dim).to(DEVICE)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    # Learn Rate Scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
    
    # Early Stopping params
    patience = 10
    min_delta = 1e-4
    
    # Prepare Data
    if isinstance(data, np.ndarray):
        data_tensor = torch.FloatTensor(data).to(DEVICE)
    else:
        # If already tensor (e.g. on GPU)
        data_tensor = data.to(DEVICE)
        
    dataset = torch.utils.data.TensorDataset(data_tensor)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    model.train()
    
    best_loss = float('inf')
    patience_counter = 0
    
    history = {'loss': [], 'recon_loss': [], 'dist_loss': []}
    
    pbar = tqdm(range(epochs), desc="训练AutoEncoder")
    for epoch in pbar:
        total_loss = 0
        total_recon_loss = 0
        total_dist_loss = 0
        num_batches = 0
        
        for batch in dataloader:
            batch_data = batch[0]
            
            optimizer.zero_grad()
            encoded, reconstructed = model(batch_data)
            
            # 1. Reconstruction Loss
            recon_loss = criterion(reconstructed, batch_data)
            
            # 2. Distance Preservation Loss
            # Input space distance: Standard L2
            dist_input = torch.cdist(batch_data, batch_data, p=2)
            
            threshold_term = 0
            
            # Latent space distance
            if use_grouped_latent_dist:
                # Calculate Grouped L2 Norm for Latent Space
                # Group Size = 6
                group_dists = []
                D = encoded.shape[1]
                for start_idx in range(0, D, 6):
                    end_idx = min(start_idx + 6, D)
                    sub = encoded[:, start_idx:end_idx]
                    d_g = torch.cdist(sub, sub, p=2)
                    group_dists.append(d_g)
                
                dist_latent = sum(group_dists)
                
                # --- Threshold Awareness Penalty ---
                if threshold_config and isinstance(threshold_config, dict):
                    T = threshold_config.get('target_threshold', 0.5)
                    limit = threshold_config.get('input_limit', 150.0)
                    
                    # Identify neighbors in input space
                    neighbor_mask = (dist_input <= limit)
                    
                    if neighbor_mask.sum() > 0:
                        penalty_accum = 0.0
                        for d_g in group_dists:
                            d_val = d_g[neighbor_mask]
                            
                            # User formula: multiply by exp(ReLU(d - T))
                            # We formulate penalty as adding to loss:
                            # Loss += mean( d * (exp(ReLU(d-T)) - 1) )
                            # If d <= T, exp(0)=1, penalty=0.
                            # If d > T, penalty grows exponentially.
                            excess = torch.relu(d_val - T)
                            term = d_val * (torch.exp(excess) - 1.0)
                            penalty_accum += term.mean()
                            
                        threshold_term = penalty_accum * 0.1 # Weight it? Assume included in logic.
            else:
                # Standard L2
                dist_latent = torch.cdist(encoded, encoded, p=2)
            
            # MSE between Latent and Input Distance
            dist_loss = nn.functional.mse_loss(dist_latent, dist_input)
            
            # Total Loss
            loss = (1-distance_loss_weight) * recon_loss + distance_loss_weight * dist_loss + threshold_term
            
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            total_recon_loss += recon_loss.item()
            total_dist_loss += dist_loss.item()
            num_batches += 1
        
        avg_loss = total_loss / num_batches
        avg_recon = total_recon_loss / num_batches
        avg_dist = total_dist_loss / num_batches
        
        history['loss'].append(avg_loss)
        history['recon_loss'].append(avg_recon)
        history['dist_loss'].append(avg_dist)
        
        pbar.set_postfix({
            'loss': f'{avg_loss:.4f}', 
            'recon': f'{avg_recon:.4f}', 
            'dist': f'{avg_dist:.4f}'
        })
        
        scheduler.step(avg_loss)
        
        # Early Stopping Check (based on total loss)
        if avg_loss < best_loss - min_delta:
            best_loss = avg_loss
            patience_counter = 0
        else:
            patience_counter += 1
            
        if patience_counter >= patience:
            print(f"Early stopping triggered at epoch {epoch+1} (Best Loss: {best_loss:.6f})")
            break
    
    return model, history

## 6. 五种数据处理方法

In [None]:
def apply_quantization(train_data, query_data, quantizer):
    """Helper to apply quantizer (function or class) to train and query data"""
    # Check if quantizer is a class instance (has fit/transform)
    if hasattr(quantizer, 'fit') and hasattr(quantizer, 'transform'):
        # Stateful quantizer: Fit on Train, Transform Train & Query
        # Note: We must fit on the specific data passed here (which might be reduced dim)
        train_q = quantizer.fit_transform(train_data)
        query_q = quantizer.transform(query_data) if query_data is not None else None
    else:
        # Functional quantizer: Stateless
        train_q = quantizer(train_data)
        query_q = quantizer(query_data) if query_data is not None else None
    return train_q, query_q

def method1_direct_int3(data, query_data=None, quantizer=quantize_to_int3):
    """方法1: 直接將128維向量量化為INT3"""
    # Convert to Tensor if numpy
    if isinstance(data, np.ndarray):
        data = torch.from_numpy(data).to(DEVICE)
    if query_data is not None and isinstance(query_data, np.ndarray):
        query_data = torch.from_numpy(query_data).to(DEVICE)
        
    db_q, q_q = apply_quantization(data, query_data, quantizer)
    
    if query_data is not None:
        return None, db_q, q_q
    return db_q

def method2_average_pooling_int3(data, query_data=None, quantizer=quantize_to_int3):
    """方法2: 將相鄰兩維做平均，從128維降到64維，然後量化為INT3"""
    def process(arr):
        if isinstance(arr, np.ndarray):
            arr = torch.from_numpy(arr).to(DEVICE)
        # Reshape (N, 64, 2)
        reshaped = arr.reshape(arr.shape[0], 64, 2)
        # PyTorch mean
        return torch.mean(reshaped.float(), dim=2)

    data_reduced = process(data)
    query_reduced = process(query_data) if query_data is not None else None
    
    db_q, q_q = apply_quantization(data_reduced, query_reduced, quantizer)
    
    if query_data is not None:
        return None, db_q, q_q
    return db_q

def method3_pca_int3(data, query_data=None, quantizer=quantize_to_int3):
    """方法3: 使用PCA降維到64維，然後量化為INT3"""
    print("  - 训练PCA (CPU)...")
    
    # Ensure data is numpy for sklearn fitting
    if isinstance(data, torch.Tensor):
        data_np = data.cpu().numpy()
    else:
        data_np = data
        
    # 使用 whiten=True 來標準化分量
    pca = PCA(n_components=REDUCED_DIM, whiten=True)
    pca.fit(data_np)
    
    # Prepare for GPU transform
    mean = torch.from_numpy(pca.mean_).float().to(DEVICE)
    components = torch.from_numpy(pca.components_).float().to(DEVICE)
    explained_variance = torch.from_numpy(pca.explained_variance_).float().to(DEVICE)
    
    def transform_gpu(X):
        if isinstance(X, np.ndarray):
            X = torch.from_numpy(X).to(DEVICE)
        X = X.float()
        
        # Center
        X_centered = X - mean
        # Project
        X_transformed = torch.matmul(X_centered, components.T)
        
        # Whiten
        if pca.whiten:
            scale = torch.sqrt(explained_variance)
            X_transformed = X_transformed / scale
            
        return X_transformed

    data_reduced = transform_gpu(data_np)
    query_reduced = transform_gpu(query_data) if query_data is not None else None
    
    db_q, q_q = apply_quantization(data_reduced, query_reduced, quantizer)
    
    if query_data is not None:
        return pca, db_q, q_q
    
    return pca, db_q

def method4_max_pooling_int3(data, query_data=None, quantizer=quantize_to_int3):
    """方法4: 將相鄰兩維做最大絕對值池化 (Max Magnitude Pooling)，從128維降到64維，然後量化為INT3"""
    
    def max_magnitude_pool_torch(arr):
        if isinstance(arr, np.ndarray):
            arr = torch.from_numpy(arr).to(DEVICE)
            
        # Reshape to (N, 64, 2)
        reshaped = arr.reshape(arr.shape[0], 64, 2)
        a = reshaped[:, :, 0]
        b = reshaped[:, :, 1]
        # 比較絕對值大小
        mask = torch.abs(a) >= torch.abs(b)
        # 選擇絕對值較大的那個原始值
        return torch.where(mask, a, b)

    data_reduced = max_magnitude_pool_torch(data)
    query_reduced = max_magnitude_pool_torch(query_data) if query_data is not None else None
    
    db_q, q_q = apply_quantization(data_reduced, query_reduced, quantizer)
    
    if query_data is not None:
        return None, db_q, q_q
    
    return None, db_q

def method5_autoencoder_int3(data, query_data=None, epochs=100, quantizer=quantize_to_int3, distance_loss_weight=0.1, use_grouped_latent_dist=False, threshold_config=None):
    """方法5: 使用AutoEncoder降維到64維，然後量化為INT3"""
    print(f"  - 训练AutoEncoder (Dist Weight: {distance_loss_weight}, Grouped: {use_grouped_latent_dist}, Threshold: {threshold_config is not None})...")
    # Ensure numpy for training (dataloader handles conversion)
    if isinstance(data, torch.Tensor):
        data_np = data.cpu().numpy()
    else:
        data_np = data
        
    # 這裡 epochs 預設為 100，配合 Early Stopping
    ae_model, history = train_autoencoder(data_np, SIFT_DIM, REDUCED_DIM, epochs=epochs, 
                                          distance_loss_weight=distance_loss_weight, 
                                          use_grouped_latent_dist=use_grouped_latent_dist,
                                          threshold_config=threshold_config)
    ae_model.eval()
    
    with torch.no_grad():
        # Encode on GPU
        if isinstance(data, np.ndarray):
            data_tensor = torch.FloatTensor(data).to(DEVICE)
        else:
            data_tensor = data.float().to(DEVICE)
            
        data_reduced = ae_model.encode(data_tensor) # Returns tensor on GPU
        
        if query_data is not None:
            if isinstance(query_data, np.ndarray):
                query_tensor = torch.FloatTensor(query_data).to(DEVICE)
            else:
                query_tensor = query_data.float().to(DEVICE)
            query_reduced = ae_model.encode(query_tensor)
        else:
            query_reduced = None
            
        db_q, q_q = apply_quantization(data_reduced, query_reduced, quantizer)
        
        if query_data is not None:
            return ae_model, db_q, q_q, history
    
    return ae_model, db_q, history

## 7. 距离计算和评估函数

In [None]:
def calculate_l2_distances(queries, database):
    """計算L2距離"""
    return pairwise_distances(queries, database, metric='euclidean')

def evaluate_recall_at_k(distances, ground_truth, k_values=[1, 10, 100]):
    """計算Recall@K指標"""
    sorted_indices = np.argsort(distances, axis=1)
    
    recalls = {}
    for k in k_values:
        top_k_predictions = sorted_indices[:, :k]
        
        query_recalls = []
        for i in range(len(ground_truth)):
            gt_i = ground_truth[i]
            # Handle jagged arrays: check length
            limit = min(k, len(gt_i))
            
            if limit == 0:
                query_recalls.append(0.0)
                continue
                
            true_neighbors = set(gt_i[:limit])
            pred_neighbors = set(top_k_predictions[i])
            intersection = true_neighbors & pred_neighbors
            recall = len(intersection) / len(true_neighbors)
            query_recalls.append(recall)
        
        recalls[f'recall@{k}'] = np.mean(query_recalls)
    
    return recalls

## 8. 加载 SIFT1M 数据集

In [None]:
def read_fvecs(filename):
    """讀取.fvecs格式文件"""
    with open(filename, 'rb') as f:
        d = np.fromfile(f, dtype=np.int32, count=1)[0]
        f.seek(0)
        # .fvecs 格式: 4 bytes int32 (dimension) + d * 4 bytes float32 (data)
        # 读取为 float32，header 的 int32 会被读成一个 float32，reshape 后切片去掉即可
        data = np.fromfile(f, dtype=np.float32)
        data = data.reshape(-1, d + 1)
        return data[:, 1:].copy()

def read_ivecs(filename):
    """讀取.ivecs格式文件"""
    with open(filename, 'rb') as f:
        d = np.fromfile(f, dtype=np.int32, count=1)[0]
        f.seek(0)
        data = np.fromfile(f, dtype=np.int32)
        data = data.reshape(-1, d + 1)
        return data[:, 1:].copy()

print("载入SIFT1M数据集...")
base_vectors = read_fvecs('sift1m/sift_base.fvecs')
query_vectors = read_fvecs('sift1m/sift_query.fvecs')
ground_truth = read_ivecs('sift1m/sift_groundtruth.ivecs')

print(f"Base vectors shape: {base_vectors.shape}")
print(f"Query vectors shape: {query_vectors.shape}")
print(f"Ground truth shape: {ground_truth.shape}")

## 9. 运行实验

In [None]:
def calculate_top_k_ground_truth(queries, database, k=100):
    """
    計算每個查詢的前 K 個真實最近鄰居 (Ground Truth)
    使用 GPU 加速計算，並分批處理 Query 以節省記憶體
    """
    print(f"正在計算 Top-{k} Ground Truth (Query: {len(queries)}, DB: {len(database)})...")
    
    num_queries = len(queries)
    query_batch_size = 100  # 每次處理 100 個 Query
    
    all_indices = []
    
    for i in tqdm(range(0, num_queries, query_batch_size), desc="Computing GT"):
        q_end = min(i + query_batch_size, num_queries)
        q_batch = queries[i:q_end]
        q_tensor = torch.FloatTensor(q_batch).to(DEVICE)
        q_sq = torch.sum(q_tensor**2, dim=1, keepdim=True)
        
        dists_batch = []
        
        # 分批處理 DB
        db_batch_size = 50000
        for j in range(0, len(database), db_batch_size):
            db_end = min(j + db_batch_size, len(database))
            db_chunk = torch.FloatTensor(database[j:db_end]).to(DEVICE)
            
            db_sq_chunk = torch.sum(db_chunk**2, dim=1)
            term2 = -2 * torch.matmul(q_tensor, db_chunk.t())
            
            dists_chunk = q_sq + db_sq_chunk + term2
            dists_batch.append(dists_chunk.cpu())
            
            del db_chunk, db_sq_chunk, term2, dists_chunk
            torch.cuda.empty_cache()
            
        full_dists_batch = torch.cat(dists_batch, dim=1)
        
        # 取 Top K
        _, indices = torch.topk(full_dists_batch, k=k, dim=1, largest=False)
        all_indices.append(indices.numpy())
        
        del full_dists_batch, indices, q_tensor, q_sq
        torch.cuda.empty_cache()
        
    return np.vstack(all_indices)

def evaluate_recall_batched(query_vectors, db_vectors, gt_top_k, retrieval_depths, k_true_neighbors):
    """
    分批計算距離並評估 Recall
    使用 GPU 加速距離計算
    Recall = (Retrieved & Top-K_GT) / K_GT
    """
    max_depth = max(retrieval_depths)
    num_queries = len(query_vectors)
    query_batch_size = 100
    
    total_hits = {r: 0 for r in retrieval_depths}
    
    # Ensure inputs are tensors on GPU
    if isinstance(query_vectors, np.ndarray):
        query_vectors = torch.from_numpy(query_vectors).to(DEVICE)
    elif query_vectors.device != DEVICE:
        query_vectors = query_vectors.to(DEVICE)
        
    if isinstance(db_vectors, np.ndarray):
        db_vectors = torch.from_numpy(db_vectors).to(DEVICE)
    elif db_vectors.device != DEVICE:
        db_vectors = db_vectors.to(DEVICE)
    
    # Ensure float for distance calc (INT3 needs to be float for calc)
    if query_vectors.dtype != torch.float32:
        query_vectors = query_vectors.float()
    if db_vectors.dtype != torch.float32:
        db_vectors = db_vectors.float()
        
    start_time = time.time()
    
    # Pre-calculate DB squared norms if memory allows, or do it in chunks
    # For 1M vectors, pre-calculating norms is fast and takes little memory (1M floats = 4MB)
    db_sq = torch.sum(db_vectors**2, dim=1)
    
    for i in tqdm(range(0, num_queries, query_batch_size), desc="Evaluating Batches", leave=False):
        q_end = min(i + query_batch_size, num_queries)
        q_batch = query_vectors[i:q_end] # (Batch, Dim)
        gt_batch = gt_top_k[i:q_end]
        
        # Calculate L2 Distance on GPU: |x-y|^2 = |x|^2 + |y|^2 - 2xy
        q_sq = torch.sum(q_batch**2, dim=1, keepdim=True) # (Batch, 1)
        
        # Matrix multiplication: (Batch, Dim) @ (Dim, DB_Size) -> (Batch, DB_Size)
        # Note: db_vectors is (DB_Size, Dim), so we transpose it
        # For very large DB, we might need to chunk this too, but 1M fits in GPU memory for matmul usually
        # 100 * 1M * 4 bytes = 400MB result matrix. This is fine.
        
        term2 = -2 * torch.matmul(q_batch, db_vectors.t())
        
        # Broadcasting: (Batch, 1) + (DB_Size,) + (Batch, DB_Size)
        dists = q_sq + db_sq + term2
        
        # Find Top-K on GPU
        # We need the smallest distances
        _, sorted_indices_tensor = torch.topk(dists, k=max_depth, dim=1, largest=False)
        
        # Move indices to CPU for set intersection (faster on CPU for small sets logic)
        sorted_indices = sorted_indices_tensor.cpu().numpy()
        
        for r in retrieval_depths:
            retrieved_indices = sorted_indices[:, :r]
            for j in range(len(gt_batch)):
                gt_set = set(gt_batch[j]) # 這是 Top K 真實鄰居
                retrieved_set = set(retrieved_indices[j])
                total_hits[r] += len(gt_set & retrieved_set)
                
        # Clean up intermediate tensors
        del dists, term2, q_sq, sorted_indices_tensor
        # torch.cuda.empty_cache() # Optional: can slow down loop if called too often
    
    end_time = time.time()
    search_time = end_time - start_time
                
    recalls = {}
    for r in retrieval_depths:
        # Normalize by k_true_neighbors (e.g., 100)
        recalls[f'recall@{r}'] = total_hits[r] / (num_queries * float(k_true_neighbors))
        
    recalls['search_time'] = search_time
    recalls['qps'] = num_queries / search_time if search_time > 0 else 0
        
    return recalls

def method5_autoencoder_int3_with_timing(data, query_data=None, epochs=100):
    """
    方法5: 使用AutoEncoder降維到64維，然後量化為INT3
    返回額外的訓練時間和推理時間
    """
    print("  - 训练AutoEncoder...")
    
    # 訓練時間
    train_start = time.time()
    ae_model = train_autoencoder(data, SIFT_DIM, REDUCED_DIM, epochs=epochs)
    train_time = time.time() - train_start
    
    ae_model.eval()
    
    # 推理時間
    inference_start = time.time()
    with torch.no_grad():
        data_tensor = torch.FloatTensor(data).to(DEVICE)
        data_reduced = ae_model.encode(data_tensor).cpu().numpy()
        data_quantized = quantize_to_int3(data_reduced)
        
        if query_data is not None:
            query_tensor = torch.FloatTensor(query_data).to(DEVICE)
            query_reduced = ae_model.encode(query_tensor).cpu().numpy()
            query_quantized = quantize_to_int3(query_reduced)
            inference_time = time.time() - inference_start
            return ae_model, data_quantized, query_quantized, train_time, inference_time
    
    inference_time = time.time() - inference_start
    return ae_model, data_quantized, train_time, inference_time

def run_experiment_part(part_name, base_vectors, query_vectors, gt_top_k, retrieval_depths, k_true_neighbors, quantizer=quantize_to_int3, quantizer_name="Standard"):
    """
    執行單個部分的實驗
    返回 recalls DataFrame 和 times DataFrame (不含 AE 訓練時間)
    """
    print(f"\n{'='*80}")
    print(f"實驗部分: {part_name} (Quantizer: {quantizer_name})")
    print(f"Database Size: {len(base_vectors)}")
    print(f"Retrieval Depths: {retrieval_depths}")
    print(f"Target GT Size (K): {k_true_neighbors}")
    print(f"{'='*80}\n")
    
    results = []

    # 方法1
    print(f"\n[方法1] 直接INT3量化 (128維) - {quantizer_name}...")
    t0 = time.time()
    db_m1 = method1_direct_int3(base_vectors)
    q_m1 = method1_direct_int3(query_vectors)
    recalls_m1 = evaluate_recall_batched(q_m1, db_m1, gt_top_k, retrieval_depths, k_true_neighbors)
    results.append({'method': 'Method 1: Direct INT3', 'time': time.time()-t0, **recalls_m1})
    print(recalls_m1)
    
    # 方法2
    print(f"\n[方法2] 平均池化降維 + INT3 (64維) - {quantizer_name}...")
    t0 = time.time()
    db_m2 = method2_average_pooling_int3(base_vectors)
    q_m2 = method2_average_pooling_int3(query_vectors)
    recalls_m2 = evaluate_recall_batched(q_m2, db_m2, gt_top_k, retrieval_depths, k_true_neighbors)
    results.append({'method': 'Method 2: AvgPooling + INT3', 'time': time.time()-t0, **recalls_m2})
    print(recalls_m2)
    
    # 方法3
    print(f"\n[方法3] PCA降維 + INT3 (64維) - {quantizer_name}...")
    t0 = time.time()
    _, db_m3, q_m3 = method3_pca_int3(base_vectors, query_vectors)
    recalls_m3 = evaluate_recall_batched(q_m3, db_m3, gt_top_k, retrieval_depths, k_true_neighbors)
    results.append({'method': 'Method 3: PCA + INT3', 'time': time.time()-t0, **recalls_m3})
    print(recalls_m3)
    
    # 方法4
    print(f"\n[方法4] Max Magnitude Pooling降維 + INT3 (64維) - {quantizer_name}...")
    t0 = time.time()
    _, db_m4, q_m4 = method4_max_pooling_int3(base_vectors, query_vectors)
    recalls_m4 = evaluate_recall_batched(q_m4, db_m4, gt_top_k, retrieval_depths, k_true_neighbors)
    results.append({'method': 'Method 4: Max Mag Pooling + INT3', 'time': time.time()-t0, **recalls_m4})
    print(recalls_m4)
    
    # 方法5
    print("\n[方法5] AutoEncoder降維 + INT3 (64維)...")
    t0 = time.time()
    # Epochs 設為 100，配合 Early Stopping
    _, db_m5, q_m5 = method5_autoencoder_int3(base_vectors, query_vectors, epochs=100)
    recalls_m5 = evaluate_recall_batched(q_m5, db_m5, gt_top_k, retrieval_depths, k_true_neighbors)
    results.append({'method': 'Method 5: AutoEncoder + INT3', 'time': time.time()-t0, **recalls_m5})
    print(recalls_m5)
    
    return pd.DataFrame(results), pd.DataFrame(time_results)

In [None]:
# 執行實驗
# 實驗配置：
# Part 1: 100K DB, Recall 100@1000, 100@5000, 100@10000
# Part 2: 1M DB, Recall 100@1000, 100@5000, 100@10000

all_experiment_results = []  # 存儲 recall 結果
all_time_results = []  # 存儲時間結果

# --- Part 1: 100K Database ---
print("\n" + "="*80)
print("Part 1: First 100K Database")
print("="*80)

db_100k = base_vectors[:100000].astype(np.float32)
query_subset = query_vectors[:1000].astype(np.float32)

# 計算 Ground Truth for 100K DB
print("Computing Ground Truth for 100K DB...")
gt_100k = calculate_top_k_ground_truth(query_subset, db_100k, k=100)

# 執行實驗 Part 1
retrieval_depths_1 = [1000, 5000, 10000]
k_true_neighbors = 100

results_df_1, time_df_1 = run_experiment_part(
    "Part 1: 100K DB",
    db_100k,
    query_subset,
    gt_100k,
    retrieval_depths_1,
    k_true_neighbors
)

all_experiment_results.append(results_df_1)
all_time_results.append(time_df_1)

print("\nPart 1 Recall Results:")
display(results_df_1)
print("\nPart 1 Time Results (不含 AE 訓練時間):")
display(time_df_1)

# --- Part 2: 1M Database ---
print("\n" + "="*80)
print("Part 2: Full 1M Database")
print("="*80)

db_1m = base_vectors.astype(np.float32)

# 計算 Ground Truth for 1M DB
print("Computing Ground Truth for 1M DB...")
gt_1m = calculate_top_k_ground_truth(query_subset, db_1m, k=100)

# 執行實驗 Part 2
retrieval_depths_2 = [1000, 5000, 10000]

results_df_2, time_df_2 = run_experiment_part(
    "Part 2: 1M DB",
    db_1m,
    query_subset,
    gt_1m,
    retrieval_depths_2,
    k_true_neighbors
)

all_experiment_results.append(results_df_2)
all_time_results.append(time_df_2)

print("\nPart 2 Recall Results:")
display(results_df_2)
print("\nPart 2 Time Results (不含 AE 訓練時間):")
display(time_df_2)

print("\n" + "="*80)
print("所有實驗完成！")
print("="*80)

In [None]:
# 執行實驗
# 實驗配置：
# Part 1: 100K DB, Recall 100@1000, 100@5000, 100@10000
# Part 2: 1M DB, Recall 100@1000, 100@5000, 100@10000
# 比較兩種 Quantization 方法：Standard (Percentile) vs Quantile (Per-Dimension Rank-based)

all_experiment_results = []  # 存儲結果 (包含 Recall, Time, QPS)

# 定義要比較的 Quantizers
# 注意：對於 Class 類型的 Quantizer，我們傳遞實例，run_experiment_part 會重複使用它 (每次 fit 會覆蓋)
quantizers = [
    ("Standard", quantize_to_int3),
    ("Quantile", PerDimensionQuantileQuantizer())
]

# --- Part 1: 100K Database ---
print("\n" + "="*80)
print("Part 1: First 100K Database")
print("="*80)

db_100k = base_vectors[:100000].astype(np.float32)
query_subset = query_vectors[:1000].astype(np.float32)

# 計算 Ground Truth for 100K DB
print("Computing Ground Truth for 100K DB...")
gt_100k = calculate_top_k_ground_truth(query_subset, db_100k, k=100)

# 執行實驗 Part 1 (兩種 Quantization)
retrieval_depths_1 = [1000, 5000, 10000]
k_true_neighbors = 100

for q_name, q_obj in quantizers:
    results_df = run_experiment_part(
        f"Part 1: 100K DB ({q_name})",
        db_100k,
        query_subset,
        gt_100k,
        retrieval_depths_1,
        k_true_neighbors,
        quantizer=q_obj,
        quantizer_name=q_name
    )
    all_experiment_results.append(results_df)
    print(f"\nPart 1 ({q_name}) Results:")
    display(results_df)


# --- Part 2: 1M Database ---
print("\n" + "="*80)
print("Part 2: Full 1M Database")
print("="*80)

db_1m = base_vectors.astype(np.float32)

# 計算 Ground Truth for 1M DB
print("Computing Ground Truth for 1M DB...")
gt_1m = calculate_top_k_ground_truth(query_subset, db_1m, k=100)

# 執行實驗 Part 2 (兩種 Quantization)
retrieval_depths_2 = [1000, 5000, 10000]

for q_name, q_obj in quantizers:
    results_df = run_experiment_part(
        f"Part 2: 1M DB ({q_name})",
        db_1m,
        query_subset,
        gt_1m,
        retrieval_depths_2,
        k_true_neighbors,
        quantizer=q_obj,
        quantizer_name=q_name
    )
    all_experiment_results.append(results_df)
    print(f"\nPart 2 ({q_name}) Results:")
    display(results_df)

print("\n" + "="*80)
print("所有實驗完成！")
print("="*80)

## 10. 可视化结果

In [None]:
# 可視化結果
# 我們需要畫 3 張圖 (或 3 組圖)，分別對應 Part 1, Part 2, Part 3

fig, axes = plt.subplots(1, 3, figsize=(20, 6))
fig.suptitle('Recall of Top-100 True Neighbors @ Different Retrieval Depths', fontsize=16, fontweight='bold')

# Part 1: 10k DB, Recall@1000
ax1 = axes[0]
df1 = all_experiment_results[0]
methods = df1['method'].str.replace('Method ', 'M').str.replace(': ', '\n', 1)
bars1 = ax1.bar(range(len(df1)), df1['recall@1000'], color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd'])
ax1.set_xticks(range(len(df1)))
ax1.set_xticklabels(methods, rotation=45, ha='right', fontsize=9)
ax1.set_ylabel('Recall@1000', fontsize=12, fontweight='bold')
ax1.set_title('Part 1: 10k DB (100@1000)', fontsize=14, fontweight='bold')
ax1.set_ylim([0, 1.05])
ax1.grid(axis='y', linestyle='--', alpha=0.7)
for bar in bars1:
    height = bar.get_height()
    ax1.text(bar.get_x() + bar.get_width()/2., height + 0.01, f'{height:.3f}', ha='center', va='bottom', fontsize=9)

# Part 2: 100k DB, Recall@1000 & Recall@10000
# 這裡我們畫 Recall@10000 作為代表，或者畫兩個柱狀圖
ax2 = axes[1]
df2 = all_experiment_results[1]
# 為了同時顯示 @1000 和 @10000，我們使用分組柱狀圖
x = np.arange(len(df2))
width = 0.35
bars2_1 = ax2.bar(x - width/2, df2['recall@1000'], width, label='@1000', color='#1f77b4')
bars2_2 = ax2.bar(x + width/2, df2['recall@10000'], width, label='@10000', color='#ff7f0e')

ax2.set_xticks(x)
ax2.set_xticklabels(methods, rotation=45, ha='right', fontsize=9)
ax2.set_ylabel('Recall', fontsize=12, fontweight='bold')
ax2.set_title('Part 2: 100k DB', fontsize=14, fontweight='bold')
ax2.set_ylim([0, 1.05])
ax2.legend()
ax2.grid(axis='y', linestyle='--', alpha=0.7)

# Part 3: 1M DB, Recall@1000 & Recall@10000
ax3 = axes[2]
df3 = all_experiment_results[2]
x = np.arange(len(df3))
bars3_1 = ax3.bar(x - width/2, df3['recall@1000'], width, label='@1000', color='#1f77b4')
bars3_2 = ax3.bar(x + width/2, df3['recall@10000'], width, label='@10000', color='#ff7f0e')

ax3.set_xticks(x)
ax3.set_xticklabels(methods, rotation=45, ha='right', fontsize=9)
ax3.set_ylabel('Recall', fontsize=12, fontweight='bold')
ax3.set_title('Part 3: 1M DB', fontsize=14, fontweight='bold')
ax3.set_ylim([0, 1.05])
ax3.legend()
ax3.grid(axis='y', linestyle='--', alpha=0.7)

plt.tight_layout()
plt.savefig('sift1m_3parts_results.png', dpi=300, bbox_inches='tight')
plt.show()

## 11. 实验总结

In [None]:
print("\n" + "="*80)
print("實驗總結")
print("="*80)

# 確保變數名稱兼容性
if 'all_experiment_results' not in locals() and 'all_results' in locals():
    all_experiment_results = all_results

if 'all_experiment_results' in locals():
    for i, results_df in enumerate(all_experiment_results):
        print(f"\nExperiment Part {i+1}")
        display(results_df)
else:
    print("No results found.")

print("\n所有實驗完成！")

In [None]:
# 12. Dimension Reduction Analysis (128 -> 16)
# Analyzing Recall Rate vs Dimension for AvgPooling and PCA
# Configuration: 1M Database, Recall of Top-1000 GT @ Depth 10000

def adaptive_avg_pool_int3(data, target_dim):
    """
    Adaptive Average Pooling to reduce to specific target dimension
    """
    if isinstance(data, np.ndarray):
        t = torch.from_numpy(data).float()
    else:
        t = data.float()
    
    t = t.to(DEVICE)
    # Input: (N, 128) -> (N, 1, 128) for pooling
    t = t.unsqueeze(1)
    # Adaptive Pool to (N, 1, target_dim)
    pooled = nn.functional.adaptive_avg_pool1d(t, target_dim)
    # (N, target_dim)
    pooled = pooled.squeeze(1)
    
    return quantize_to_int3(pooled)

def run_dimension_sweep(base_vectors, query_vectors, gt_top_k, k_true_neighbors=1000, retrieval_depth=10000):
    # Dimensions: 128, 112, 96, ..., 16
    target_dims = list(range(128, 15, -16)) 
    results = []
    
    print(f"Starting Dimension Sweep: {target_dims}")
    print(f"Database Size: {len(base_vectors)}")
    print(f"Target GT Size: {k_true_neighbors}")
    print(f"Retrieval Depth: {retrieval_depth}")
    
    for dim in target_dims:
        print(f"\nTesting Target Dimension: {dim}")
        
        # --- Method: Adaptive Avg Pooling ---
        print("  Running Avg Pooling...")
        t0 = time.time()
        db_avg = adaptive_avg_pool_int3(base_vectors, dim)
        q_avg = adaptive_avg_pool_int3(query_vectors, dim)
        
        # Evaluate Recall
        recalls_avg = evaluate_recall_batched(q_avg, db_avg, gt_top_k, [retrieval_depth], k_true_neighbors)
        time_avg = time.time() - t0
        
        metric_key = f'recall@{retrieval_depth}'
        results.append({
            'Dimension': dim,
            'Method': 'AvgPooling',
            'Recall': recalls_avg[metric_key],
            'Time': time_avg
        })
        print(f"    AvgPool {metric_key}: {recalls_avg[metric_key]:.4f}")
        
        # --- Method: PCA ---
        print("  Running PCA...")
        t0 = time.time()
        # Fit PCA on a subset for speed (max 50k)
        fit_size = min(len(base_vectors), 50000)
        pca = PCA(n_components=dim, whiten=True)
        pca.fit(base_vectors[:fit_size])
        
        db_pca = pca.transform(base_vectors)
        q_pca = pca.transform(query_vectors)
        
        db_pca_q = quantize_to_int3(db_pca)
        q_pca_q = quantize_to_int3(q_pca)
        
        recalls_pca = evaluate_recall_batched(q_pca_q, db_pca_q, gt_top_k, [retrieval_depth], k_true_neighbors)
        time_pca = time.time() - t0
        
        results.append({
            'Dimension': dim,
            'Method': 'PCA',
            'Recall': recalls_pca[metric_key],
            'Time': time_pca
        })
        print(f"    PCA {metric_key}: {recalls_pca[metric_key]:.4f}")

    return pd.DataFrame(results)

# Setup Data for Sweep (Using 1M Full DB)
# Ensure we have the data
if 'base_vectors_full' not in locals():
    base_vectors_full = base_vectors.astype(np.float32)
    query_vectors_subset = query_vectors[:1000].astype(np.float32)

# Use Full 1M Database
db_subset_sweep = base_vectors_full 
k_gt = 1000
depth = 10000

print(f"Computing GT for sweep (1M DB, Top-{k_gt})...")
# Always recalculate to be safe or check if existing gt matches requirements
# Since previous cells might have calculated k=100, we likely need to recalc for k=1000
gt_sweep = calculate_top_k_ground_truth(query_vectors_subset, db_subset_sweep, k=k_gt)

# Run Sweep
sweep_df = run_dimension_sweep(db_subset_sweep, query_vectors_subset, gt_sweep, k_true_neighbors=k_gt, retrieval_depth=depth)

# Visualization
plt.figure(figsize=(10, 6))
avg_df = sweep_df[sweep_df['Method'] == 'AvgPooling']
pca_df = sweep_df[sweep_df['Method'] == 'PCA']

plt.plot(avg_df['Dimension'], avg_df['Recall'], marker='o', linewidth=2, label='Avg Pooling + INT3')
plt.plot(pca_df['Dimension'], pca_df['Recall'], marker='s', linewidth=2, label='PCA + INT3')

plt.xlabel('Dimension', fontsize=12)
plt.ylabel(f'Recall@{depth} (Top-{k_gt} GT)', fontsize=12)
plt.title(f'Recall vs Dimension (1M DB, {k_gt}@{depth})', fontsize=14)
plt.legend(fontsize=11)
plt.grid(True, linestyle='--', alpha=0.7)
plt.gca().invert_xaxis() # Display 128 on left, 16 on right
plt.xticks(list(range(128, 15, -16)))
plt.tight_layout()
plt.show()

# Save Results
csv_filename = 'sift1m_1M_1000at10000_sweep_results.csv'
sweep_df.to_csv(csv_filename, index=False)
print(f"Sweep results saved to {csv_filename}")
sweep_df

In [None]:
# 13. PCA Dimension Difference Analysis
# Analyze absolute difference between Query and Top-100 NN in PCA space per dimension

def analyze_pca_diff(base_vectors, query_vectors, gt_indices, n_components=128):
    print("Fitting PCA...")
    # Fit PCA on a subset for speed
    fit_size = min(len(base_vectors), 50000)
    pca = PCA(n_components=n_components, whiten=True)
    pca.fit(base_vectors[:fit_size])
    
    print("Transforming data...")
    # Transform all needed data
    # We only need the specific base vectors that are in the GT of the queries
    # But transforming all might be easier if memory allows (1M * 128 * 4 bytes ~ 512MB)
    q_pca = pca.transform(query_vectors)
    db_pca = pca.transform(base_vectors)
    
    n_queries = len(query_vectors)
    k_neighbors = gt_indices.shape[1] # Should be 100
    n_dims = n_components
    
    # Array to store sum of absolute differences per dimension
    total_abs_diff = np.zeros(n_dims)
    count = 0
    
    print("Calculating differences...")
    for i in tqdm(range(n_queries), desc="Analyzing Queries"):
        q_vec = q_pca[i] # (128,)
        neighbor_indices = gt_indices[i] # (100,)
        
        # Get neighbor vectors
        neighbor_vecs = db_pca[neighbor_indices] # (100, 128)
        
        # Calculate absolute difference
        # |q - n|
        abs_diff = np.abs(neighbor_vecs - q_vec) # (100, 128)
        
        # Sum over neighbors
        total_abs_diff += np.sum(abs_diff, axis=0)
        count += k_neighbors
        
    # Mean absolute difference per dimension
    mean_abs_diff = total_abs_diff / count
    
    return mean_abs_diff, pca.explained_variance_ratio_

# Setup Data
if 'base_vectors_full' not in locals():
    base_vectors_full = base_vectors.astype(np.float32)
    query_vectors_subset = query_vectors[:1000].astype(np.float32)

# Ensure GT for Top-100 exists
k_target = 100
if 'gt_sweep' in locals() and gt_sweep.shape[1] >= k_target:
    gt_100 = gt_sweep[:, :k_target]
else:
    print("Computing GT for Top-100...")
    gt_100 = calculate_top_k_ground_truth(query_vectors_subset, base_vectors_full, k=k_target)

# Run Analysis
mean_diffs, explained_var = analyze_pca_diff(base_vectors_full, query_vectors_subset, gt_100)

# Visualization
plt.figure(figsize=(15, 6))

# Plot 1: Mean Absolute Difference
plt.subplot(1, 2, 1)
plt.bar(range(1, 129), mean_diffs, color='skyblue')
plt.xlabel('PCA Component (Dimension)')
plt.ylabel('Mean Absolute Difference')
plt.title('Mean Abs Diff between Query and Top-100 NN per PCA Dimension')
plt.grid(axis='y', linestyle='--', alpha=0.7)

# Plot 2: Explained Variance (for context)
plt.subplot(1, 2, 2)
plt.plot(range(1, 129), explained_var, marker='.', linestyle='-', color='orange')
plt.xlabel('PCA Component')
plt.ylabel('Explained Variance Ratio')
plt.title('PCA Explained Variance Ratio')
plt.grid(True, linestyle='--', alpha=0.7)

plt.tight_layout()
plt.show()

# Save to CSV
diff_df = pd.DataFrame({
    'Dimension': range(1, 129),
    'Mean_Abs_Diff': mean_diffs,
    'Explained_Variance': explained_var
})
diff_df.to_csv('sift1m_pca_128_diff_analysis.csv', index=False)
print("Analysis saved to sift1m_pca_128_diff_analysis.csv")
diff_df.head()

## 14. 降維後資料分佈分析 (尚未量化)
分析各種降維方法處理後，但尚未進行 INT3 量化前的資料分佈情形。
這有助於了解不同方法產生的數值範圍與分佈特性，進而評估量化策略的適用性。
- **Overall Histogram**: 所有維度數值的整體分佈。
- **Per-Dimension Heatmap**: 每一維度的數值分佈 (X軸為數值，Y軸為維度索引，顏色為頻率)。

In [None]:
# 14. Data Distribution Analysis (Before Quantization)

def analyze_distribution(data, method_name, save_prefix):
    """
    Analyzes and visualizes the distribution of data.
    1. Overall histogram of all values.
    2. Heatmap of distributions per dimension (Dimension vs Value).
    """
    if isinstance(data, torch.Tensor):
        data = data.cpu().numpy()
        
    N, D = data.shape
    print(f"Analyzing {method_name}: Shape {data.shape}")
    
    # 1. Overall Histogram
    plt.figure(figsize=(10, 6))
    # 使用 100 個 bins
    plt.hist(data.flatten(), bins=100, color='skyblue', edgecolor='black', alpha=0.7)
    plt.title(f'{method_name} - Overall Value Distribution')
    plt.xlabel('Value')
    plt.ylabel('Frequency')
    plt.grid(True, linestyle='--', alpha=0.5)
    
    filename_hist = f'{save_prefix}_overall_hist.png'
    plt.savefig(filename_hist, dpi=300, bbox_inches='tight')
    plt.show()
    print(f"Saved {filename_hist}")
    
    # 2. Per-Dimension Distribution (Heatmap)
    # 計算每個維度的 Histogram 並堆疊成 Heatmap
    
    # Determine global min/max for binning
    v_min, v_max = np.min(data), np.max(data)
    bins = np.linspace(v_min, v_max, 101) # 100 bins
    
    hist_matrix = np.zeros((D, 100))
    
    for d in range(D):
        hist, _ = np.histogram(data[:, d], bins=bins)
        hist_matrix[d, :] = hist
        
    # Normalize for better visualization (顯示相對分佈形狀)
    hist_matrix_norm = hist_matrix / (hist_matrix.max(axis=1, keepdims=True) + 1e-9)
    
    plt.figure(figsize=(12, max(6, D/4))) # 動態調整高度
    # Extent: [left, right, bottom, top]
    plt.imshow(hist_matrix_norm, aspect='auto', origin='lower', 
               extent=[v_min, v_max, 0, D], cmap='viridis')
    plt.colorbar(label='Normalized Frequency')
    plt.title(f'{method_name} - Per-Dimension Distribution Heatmap')
    plt.xlabel('Value')
    plt.ylabel('Dimension Index')
    
    filename_heatmap = f'{save_prefix}_dim_heatmap.png'
    plt.savefig(filename_heatmap, dpi=300, bbox_inches='tight')
    plt.show()
    print(f"Saved {filename_heatmap}")

# 準備數據 (使用 10000 筆子集進行快速分析)
subset_size = 10000
# 確保 base_vectors 存在
if 'base_vectors' not in locals():
    print("Error: base_vectors not found. Please run previous cells to load data.")
else:
    data_subset = base_vectors[:subset_size].astype(np.float32)
    data_tensor = torch.from_numpy(data_subset).to(DEVICE)

    print(f"Using subset of {subset_size} vectors for distribution analysis.")

    # --- Method 1: Original (128 dims) ---
    print("\nAnalyzing Method 1 (Original)...")
    analyze_distribution(data_subset, "Method 1 (Original 128D)", "dist_m1")

    # --- Method 2: Avg Pooling (64 dims) ---
    print("\nAnalyzing Method 2 (Avg Pooling)...")
    reshaped = data_tensor.reshape(data_tensor.shape[0], 64, 2)
    m2_data = torch.mean(reshaped, dim=2)
    analyze_distribution(m2_data, "Method 2 (Avg Pooling 64D)", "dist_m2")

    # --- Method 3: PCA (64 dims) ---
    print("\nAnalyzing Method 3 (PCA)...")
    pca = PCA(n_components=64, whiten=True)
    m3_data = pca.fit_transform(data_subset)
    analyze_distribution(m3_data, "Method 3 (PCA 64D)", "dist_m3")

    # --- Method 4: Max Mag Pooling (64 dims) ---
    print("\nAnalyzing Method 4 (Max Mag Pooling)...")
    reshaped = data_tensor.reshape(data_tensor.shape[0], 64, 2)
    a = reshaped[:, :, 0]
    b = reshaped[:, :, 1]
    mask = torch.abs(a) >= torch.abs(b)
    m4_data = torch.where(mask, a, b)
    analyze_distribution(m4_data, "Method 4 (Max Mag Pooling 64D)", "dist_m4")

    # --- Method 5: AutoEncoder (64 dims) ---
    print("\nAnalyzing Method 5 (AutoEncoder)...")
    # 訓練一個臨時的 AE 用於分析
    ae_model = train_autoencoder(data_subset, 128, 64, epochs=50) 
    ae_model.eval()
    with torch.no_grad():
        m5_data = ae_model.encode(data_tensor)
    analyze_distribution(m5_data, "Method 5 (AutoEncoder 64D)", "dist_m5")

## 15. AutoEncoder Loss Comparison Experiment
比較有無加入距離保留損失 (Distance Preservation Loss) 的 AutoEncoder 訓練過程與最終效果。
- **Baseline**: Distance Loss Weight = 0.0
- **Proposed**: Distance Loss Weight = 0.1 (or other value)
使用 100K Database 進行快速驗證。

In [None]:
# 15. AutoEncoder Loss Comparison Experiment

def run_ae_loss_comparison(base_vectors, query_vectors, gt_top_k, retrieval_depths, k_true_neighbors, quantizer=quantize_to_int3):
    print("Running AutoEncoder Loss Comparison Experiment...")
    
    # Settings
    weights = [0.0, 0.1] # Compare 0 (No Dist Loss) vs 0.1 (With Dist Loss)
    results = []
    histories = {}
    
    # Use 100K subset for training/testing to be faster
    db_subset = base_vectors[:100000].astype(np.float32)
    query_subset = query_vectors[:1000].astype(np.float32)
    
    # Ensure GT is for this subset
    # Assuming gt_top_k passed in is already for 100K DB (from Part 1 of main experiment)
    # If not, we should recalculate. Let's assume the user runs this after Part 1.
    
    for w in weights:
        label = f"AE (w={w})"
        print(f"\nTraining {label}...")
        
        # Train and Transform
        # Note: method5_autoencoder_int3 now returns history as 4th element
        _, db_q, q_q, history = method5_autoencoder_int3(
            db_subset, 
            query_subset, 
            epochs=50, # 50 epochs for comparison
            quantizer=quantizer,
            distance_loss_weight=w
        )
        
        histories[label] = history
        
        # Evaluate Recall
        print(f"Evaluating {label}...")
        recalls = evaluate_recall_batched(q_q, db_q, gt_top_k, retrieval_depths, k_true_neighbors)
        
        res_entry = {'Method': label, **recalls}
        results.append(res_entry)
        print(f"Result: {recalls}")

    # --- Visualization of Loss Curves ---
    plt.figure(figsize=(18, 5))
    
    # Plot 1: Total Loss
    plt.subplot(1, 3, 1)
    for label, hist in histories.items():
        plt.plot(hist['loss'], label=label)
    plt.title('Total Loss vs Epoch')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.5)
    
    # Plot 2: Reconstruction Loss
    plt.subplot(1, 3, 2)
    for label, hist in histories.items():
        plt.plot(hist['recon_loss'], label=label)
    plt.title('Reconstruction Loss vs Epoch')
    plt.xlabel('Epoch')
    plt.ylabel('MSE Loss')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.5)
    
    # Plot 3: Distance Loss
    plt.subplot(1, 3, 3)
    for label, hist in histories.items():
        plt.plot(hist['dist_loss'], label=label)
    plt.title('Distance Preservation Loss vs Epoch')
    plt.xlabel('Epoch')
    plt.ylabel('MSE Loss')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.5)
    
    plt.tight_layout()
    plt.savefig('ae_loss_comparison.png')
    plt.show()
    
    # --- Display Results Table ---
    df_res = pd.DataFrame(results)
    print("\nComparison Results (Recall):")
    display(df_res)
    return df_res

# Run the comparison
# We need GT for 100K DB. If 'gt_100k' exists from previous cells, use it.
if 'gt_100k' in locals() and 'db_100k' in locals():
    print("Using existing 100K GT and DB...")
    run_ae_loss_comparison(base_vectors, query_vectors, gt_100k, [1000, 5000, 10000], 100)
else:
    print("Calculating GT for 100K DB...")
    db_100k = base_vectors[:100000].astype(np.float32)
    query_subset = query_vectors[:1000].astype(np.float32)
    gt_100k = calculate_top_k_ground_truth(query_subset, db_100k, k=100)
    run_ae_loss_comparison(base_vectors, query_vectors, gt_100k, [1000, 5000, 10000], 100)

## 16. PCA Whitening vs No Whitening Comparison
比較 PCA 降維時是否使用 Whitening 對最終檢索召回率的影響。
- **配置**: PCA (64 dims) + INT3 Quantile Quantization (Per-Dimension) + L2 Distance
- **比較**: `whiten=True` vs `whiten=False`
- **指標**: Recall@100, @500, @1000 (Top-100 GT)
- **數據集**: 100K Database Subset

In [None]:
# 16. PCA Whitening Comparison

def method3_pca_int3_configurable(data, query_data=None, quantizer=quantize_to_int3, whiten=True):
    """
    Configurable PCA method (allows toggling whitening)
    """
    label = "PCA+Whitening" if whiten else "PCA (No Whitening)"
    print(f"  - Training {label}...")
    
    # Ensure data is numpy for sklearn fitting
    if isinstance(data, torch.Tensor):
        data_np = data.cpu().numpy()
    else:
        data_np = data
        
    # Fit PCA
    pca = PCA(n_components=REDUCED_DIM, whiten=whiten)
    pca.fit(data_np)
    
    # Prepare for GPU transform
    mean = torch.from_numpy(pca.mean_).float().to(DEVICE)
    components = torch.from_numpy(pca.components_).float().to(DEVICE)
    explained_variance = torch.from_numpy(pca.explained_variance_).float().to(DEVICE)
    
    def transform_gpu(X):
        if isinstance(X, np.ndarray):
            X = torch.from_numpy(X).to(DEVICE)
        X = X.float()
        
        # Center
        X_centered = X - mean
        # Project
        X_transformed = torch.matmul(X_centered, components.T)
        
        # Whiten
        if pca.whiten:
            scale = torch.sqrt(explained_variance)
            X_transformed = X_transformed / scale
            
        return X_transformed

    data_reduced = transform_gpu(data_np)
    query_reduced = transform_gpu(query_data) if query_data is not None else None
    
    # Apply Quantization
    db_q, q_q = apply_quantization(data_reduced, query_reduced, quantizer)
    
    return db_q, q_q

def run_pca_whitening_comparison(base_vectors, query_vectors, gt_top_k, retrieval_depths, k_true_neighbors):
    print("Running PCA Whitening vs No Whitening Comparison...")
    
    # Settings
    configs = [False, True]
    results = []
    
    # Use 100K subset
    db_subset = base_vectors[:100000].astype(np.float32)
    query_subset = query_vectors[:1000].astype(np.float32)
    
    for use_whiten in configs:
        label = "PCA + Whitening" if use_whiten else "PCA (No Whitening)"
        
        # Switch to Min-Max Quantization (Standard Percentile) instead of Quartile
        # quantize_to_int3 is the function defined earlier for 1%-99% min-max clipping
        quantizer = quantize_to_int3 
        
        t0 = time.time()
        db_q, q_q = method3_pca_int3_configurable(
            db_subset, 
            query_subset, 
            quantizer=quantizer, 
            whiten=use_whiten
        )
        t_enc = time.time() - t0
        
        # Evaluate
        print(f"Evaluating {label}...")
        recalls = evaluate_recall_batched(q_q, db_q, gt_top_k, retrieval_depths, k_true_neighbors)
        
        results.append({
            'Method': label,
            'Quantizer': 'Min-Max (Standard)',
            'Encoding Time': t_enc,
            **recalls
        })
        print(f"Result: {recalls}")
        
    # Create DataFrame
    df_res = pd.DataFrame(results)
    
    # Visualization
    plt.figure(figsize=(10, 6))
    
    x = np.arange(len(retrieval_depths))
    width = 0.35
    
    row_no_white = df_res[df_res['Method'] == 'PCA (No Whitening)'].iloc[0]
    row_white = df_res[df_res['Method'] == 'PCA + Whitening'].iloc[0]
    
    vals_no_white = [row_no_white[f'recall@{k}'] for k in retrieval_depths]
    vals_white = [row_white[f'recall@{k}'] for k in retrieval_depths]
    
    plt.bar(x - width/2, vals_no_white, width, label='PCA (No Whitening)', color='#1f77b4')
    plt.bar(x + width/2, vals_white, width, label='PCA + Whitening', color='#ff7f0e')
    
    plt.xlabel('Retrieval Depth')
    plt.ylabel(f'Recall (Top-{k_true_neighbors} GT)')
    plt.title('PCA Whitening Comparison (100K DB, Min-Max Quantization)')
    plt.xticks(x, [f'@{k}' for k in retrieval_depths])
    plt.legend()
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Add values on bars
    for i, v in enumerate(vals_no_white):
        plt.text(i - width/2, v + 0.01, f'{v:.3f}', ha='center', va='bottom', fontsize=9)
    for i, v in enumerate(vals_white):
        plt.text(i + width/2, v + 0.01, f'{v:.3f}', ha='center', va='bottom', fontsize=9)
        
    plt.tight_layout()
    plt.savefig('pca_whitening_comparison.png')
    plt.show()
    
    print("\nDetailed Results:")
    display(df_res)
    return df_res

# Run the experiment
# Depths: 100, 500, 1000
# GT: Top 100
depths = [100, 500, 1000]
k_gt = 100

if 'gt_100k' in locals() and 'db_100k' in locals():
    print("Using existing 100K GT and DB...")
    # Check if gt_100k has k=100
    if gt_100k.shape[1] < k_gt:
         print(f"Existing GT has k={gt_100k.shape[1]}, need {k_gt}. Recalculating...")
         gt_100k = calculate_top_k_ground_truth(query_subset, db_100k, k=k_gt)
    run_pca_whitening_comparison(base_vectors, query_vectors, gt_100k, depths, k_gt)
else:
    print("Calculating GT for 100K DB...")
    db_100k = base_vectors[:100000].astype(np.float32)
    query_subset = query_vectors[:1000].astype(np.float32)
    gt_100k = calculate_top_k_ground_truth(query_subset, db_100k, k=k_gt)
    run_pca_whitening_comparison(base_vectors, query_vectors, gt_100k, depths, k_gt)

## 17. Nearest Neighbor Distribution Analysis (Single Query)
分析特定查詢向量 (Query Vector) 的 Top-100 Nearest Neighbors 在整體數據空間中的分佈情形。
1. **Distance Distribution**: 該查詢向量到資料庫所有點的距離分佈，並標示出 Top-100 NN 的位置。
2. **2D PCA Projection**: 將 **(隨機採樣背景點 + Top-100 NN + Query)** 組成的數據集，使用 **PCA** 降維到 2 維平面。
   - **降維方法**: 使用 `sklearn.decomposition.PCA(n_components=2)`。
   - **數據來源**: 用於擬合 PCA 的數據包含了 10000 個隨機背景點、100 個目標近鄰點和 1 個查詢點。這確保了投影平面能夠捕捉到該局部查詢區域和整體分佈的主要變異方向。
   - **目的**: 觀察在高維空間中與 Query 最近的那些點，在降維後的 2D 平面上是否依然聚集在 Query 附近，或者因為維度壓縮而與背景點混雜在一起。

In [None]:
# 17. Top-100 NN Distribution Analysis

def analyze_nn_distribution(query_idx, db_vectors, query_vectors, gt_indices, sample_size=10000):
    """
    Analyzes the distribution of Top-100 NN for a specific query.
    """
    print(f"\nAnalyzing Query Index: {query_idx}")
    
    q_vec = query_vectors[query_idx].reshape(1, -1) # (1, 128)
    nn_indices = gt_indices[query_idx] # (100,)
    
    # 1. Distance Distribution Analysis
    # Calculate distances to ALL DB vectors (or a large subset if DB is huge)
    # Using CPU for simplicity in plotting logic, or GPU if available
    dists = pairwise_distances(q_vec, db_vectors, metric='euclidean').flatten() # (N,)
    
    nn_dists = dists[nn_indices]
    
    plt.figure(figsize=(15, 6))
    
    # Plot 1: Histogram of Distances
    plt.subplot(1, 2, 1)
    
    # Background: All distances
    plt.hist(dists, bins=100, color='lightgray', label='All DB Vectors', density=True)
    
    # Foreground: NN distances (Use a different scale or just rug plot/distinct hist)
    # Since NN counts are small compared to N, density=True helps, but they might still be invisible.
    # Let's use a vertical line for the max NN distance (search radius)
    radius = np.max(nn_dists)
    plt.axvline(x=radius, color='r', linestyle='--', linewidth=2, label=f'radius (Top-100) = {radius:.2f}')
    
    # Zoom in on the histogram for the small distances?
    # Or just overlay a histogram of the NNs?
    plt.hist(nn_dists, bins=20, color='red', alpha=0.5, label='Top-100 NN', density=True)
    
    plt.title(f'Distance Distribution (Query {query_idx})')
    plt.xlabel('L2 Distance')
    plt.ylabel('Density')
    plt.legend()
    plt.grid(True, linestyle=':', alpha=0.6)
    
    # Plot 2: 2D PCA Projection
    # Project specific points: Random Subset + Top 100 NN + Query
    
    # Create mask for NNs to exclude them from random sample to avoid duplicates drawing
    mask = np.ones(len(db_vectors), dtype=bool)
    mask[nn_indices] = False
    background_indices = np.where(mask)[0]
    
    # Sample background
    if len(background_indices) > sample_size:
        bg_sample_indices = np.random.choice(background_indices, sample_size, replace=False)
    else:
        bg_sample_indices = background_indices
        
    bg_vectors = db_vectors[bg_sample_indices]
    nn_vectors = db_vectors[nn_indices]
    
    # Combine for PCA fitting
    combined = np.vstack([bg_vectors, nn_vectors, q_vec])
    
    pca = PCA(n_components=2)
    combined_2d = pca.fit_transform(combined)
    
    # Split back
    bg_2d = combined_2d[:len(bg_vectors)]
    nn_2d = combined_2d[len(bg_vectors):-1]
    q_2d = combined_2d[-1]
    
    plt.subplot(1, 2, 2)
    plt.scatter(bg_2d[:, 0], bg_2d[:, 1], c='lightgray', alpha=0.5, s=10, label='Background (Random Subset)')
    plt.scatter(nn_2d[:, 0], nn_2d[:, 1], c='red', alpha=0.8, s=20, label='Top-100 NN')
    plt.scatter(q_2d[0], q_2d[1], c='blue', marker='*', s=200, label='Query')
    
    plt.title(f'PCA 2D Projection (Query {query_idx})')
    plt.xlabel('PC 1')
    plt.ylabel('PC 2')
    plt.legend()
    plt.grid(True, linestyle=':', alpha=0.6)
    
    plt.tight_layout()
    filename = f'query_{query_idx}_distribution.png'
    plt.savefig(filename)
    plt.show()
    print(f"Saved visualization to {filename}")

# Run Analysis on a few random queries
# Using 100K subset for speed
if 'db_100k' not in locals():
    db_100k = base_vectors[:100000].astype(np.float32)
    query_subset = query_vectors[:1000].astype(np.float32)
    gt_100k = calculate_top_k_ground_truth(query_subset, db_100k, k=100)

# Pick 3 random queries
import random
random.seed(42)
test_queries = [0, 10, 50] # Check first few or random

for q_idx in test_queries:
    analyze_nn_distribution(q_idx, db_100k, query_subset, gt_100k)

## 18. Aggregate Dimension Analysis (NN vs Avg vs Farthest)
統計分析所有 Query 的 Top-100 Nearest Neighbors、整體分佈 (Average)、以及最遠點 (Farthest) 在每個維度上的距離差異。
- **Metric**: Absolute Difference in each dimension ($|q_d - x_d|$).
- **Scope**: Average over all queries.
- **Targets**:
    1.  **Top-100 NN**: Average distance of top-100 neighbors.
    2.  **Average Data**: Average distance of randomly sampled background data.
    3.  **Farthest Data**: Distance of the single farthest data point (L2-based).


In [None]:
# 18. Aggregate Dimension Analysis

def analyze_dimension_stats(db_vectors, query_vectors, gt_indices, sample_size_for_avg=5000, batch_size=100):
    """
    Calculates average absolute difference per dimension for:
    1. Top-100 NNs
    2. Random Sample (representing Average Data)
    3. Farthest Data (in terms of L2 distance)
    
    Returns metrics averaged over all queries.
    """
    print("Running Aggregate Dimension Analysis...")
    
    num_queries = len(query_vectors)
    dim = query_vectors.shape[1]
    
    # Initialize accumulators
    total_nn_diff = torch.zeros(dim).to(DEVICE)
    total_avg_diff = torch.zeros(dim).to(DEVICE)
    total_far_diff = torch.zeros(dim).to(DEVICE)
    
    # Prepare DB tensor
    if isinstance(db_vectors, np.ndarray):
        db_tensor = torch.from_numpy(db_vectors).to(DEVICE)
    else:
        db_tensor = db_vectors.to(DEVICE)
        
    # Sample DB for Average Stats (to speed up)
    # We use a fixed sample for 'Average Data' metric estimation
    perm = torch.randperm(len(db_tensor))[:sample_size_for_avg]
    db_sample = db_tensor[perm]
    
    # Pre-calc DB norms for Farthest search
    db_sq = torch.sum(db_tensor**2, dim=1)
    
    # Process queries in batches
    for i in tqdm(range(0, num_queries, batch_size), desc="Analyzing Queries"):
        q_end = min(i + batch_size, num_queries)
        q_batch_np = query_vectors[i:q_end]
        gt_batch = gt_indices[i:q_end] # (Batch, 100)
        
        q_batch = torch.from_numpy(q_batch_np).to(DEVICE) # (Batch, D)
        
        # --- 1. Top-100 NN Diff ---
        # Gather NN vectors: (Batch, 100, D)
        # We need to flatten indices to gather efficiently or loop
        # For simplicity in logic:
        batch_nn_diff = torch.zeros(q_batch.shape[0], dim).to(DEVICE)
        
        for b in range(len(q_batch)):
            nn_idxs = gt_batch[b] # (100,)
            nn_vecs = db_tensor[nn_idxs] # (100, D)
            # Abs Diff per dim: (100, D)
            diffs = torch.abs(nn_vecs - q_batch[b].unsqueeze(0))
            # Mean over 100 neighbors
            batch_nn_diff[b] = torch.mean(diffs, dim=0)
            
        total_nn_diff += torch.sum(batch_nn_diff, dim=0)
        
        # --- 2. Average Data Diff ---
        # Compare q_batch (Batch, D) vs db_sample (Sample, D)
        # Result: (Batch, D) -> mean abs diff with all sample points
        # efficiently: expand dims
        # (Batch, 1, D) - (1, Sample, D) -> (Batch, Sample, D)
        # Be careful with memory. 100 * 5000 * 128 * 4 bytes ~ 250MB. OK.
        avg_diffs = torch.abs(q_batch.unsqueeze(1) - db_sample.unsqueeze(0))
        batch_avg_diff = torch.mean(avg_diffs, dim=1) # (Batch, D)
        total_avg_diff += torch.sum(batch_avg_diff, dim=0)
        
        # --- 3. Farthest Data Diff ---
        # Find farthest L2 point
        q_sq = torch.sum(q_batch**2, dim=1, keepdim=True)
        # Dist Matrix: (Batch, DB_Full)
        # If DB is huge (1M), we need to chunk this calculation or use a larger sample
        # Assuming we can iterate chunks of DB to find max
        
        # To strictly find the farthest in 1M DB, we need to scan all.
        # Let's do a chunked scan for max distance.
        max_dists = torch.full((len(q_batch),), -1.0).to(DEVICE)
        max_indices = torch.full((len(q_batch),), -1, dtype=torch.long).to(DEVICE)
        
        chunk_size = 50000
        for j in range(0, len(db_tensor), chunk_size):
            db_end = min(j + chunk_size, len(db_tensor))
            db_chunk = db_tensor[j:db_end]
            db_sq_chunk = db_sq[j:db_end]
            
            # L2^2 = q^2 + db^2 - 2q.db
            term2 = -2 * torch.matmul(q_batch, db_chunk.t())
            dists = q_sq + db_sq_chunk + term2
            
            # Find max in chunk
            curr_max, curr_idx = torch.max(dists, dim=1)
            
            # Update global max
            mask = curr_max > max_dists
            max_dists[mask] = curr_max[mask]
            max_indices[mask] = curr_idx[mask] + j
            
        # Retrieve farthest vectors
        farthest_vecs = db_tensor[max_indices] # (Batch, D)
        batch_far_diff = torch.abs(farthest_vecs - q_batch) # (Batch, D)
        total_far_diff += torch.sum(batch_far_diff, dim=0)
        
    # Average over all queries
    avg_nn_per_dim = total_nn_diff / num_queries
    avg_data_per_dim = total_avg_diff / num_queries
    avg_far_per_dim = total_far_diff / num_queries
    
    return avg_nn_per_dim.cpu().numpy(), avg_data_per_dim.cpu().numpy(), avg_far_per_dim.cpu().numpy()

# Run it
# Use 100K DB for faster execution, or 1M if required. 
# Prompt implies "doing the final experiment", likely referring to the dataset used recently.
# Let's use 100K subset defined previously as 'db_100k' and 'query_subset' (1000 queries).
if 'db_100k' not in locals():
    # Fallback setup
    db_100k = base_vectors[:100000].astype(np.float32)
    query_subset = query_vectors[:1000].astype(np.float32)
    # Recalc GT
    gt_100k = calculate_top_k_ground_truth(query_subset, db_100k, k=100)

nn_stats, avg_stats, far_stats = analyze_dimension_stats(db_100k, query_subset, gt_100k)

# Statistics Summary
print("\n" + "="*50)
print("Aggregate Dimension Statistics (Averaged over Queries)")
print("="*50)
print(f"Mean Abs Diff (Top-100 NN): {np.mean(nn_stats):.4f}")
print(f"Mean Abs Diff (Average Data): {np.mean(avg_stats):.4f}")
print(f"Mean Abs Diff (Farthest Data): {np.mean(far_stats):.4f}")
print("-" * 50)

# Visualization
plt.figure(figsize=(12, 6))
dims = range(len(nn_stats))

plt.plot(dims, nn_stats, label='Top-100 NNs', color='red', linewidth=1.5)
plt.plot(dims, avg_stats, label='Average Data (Random Sample)', color='blue', alpha=0.6, linewidth=1)
plt.plot(dims, far_stats, label='Farthest Data', color='green', alpha=0.6, linewidth=1)

plt.xlabel('Dimension Index')
plt.ylabel('Mean Absolute Difference')
plt.title('Mean Diff per Dimension: NN vs Avg vs Far')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.5)
plt.tight_layout()
plt.savefig('aggregate_dim_stats.png')
plt.show()

# Visualize Summary Bars
plt.figure(figsize=(8, 5))
labels = ['Top-100 NN', 'Average Data', 'Farthest Data']
means = [np.mean(nn_stats), np.mean(avg_stats), np.mean(far_stats)]
colors = ['red', 'blue', 'green']

plt.bar(labels, means, color=colors, alpha=0.7)
plt.ylabel('Mean Absolute Diff (Averaged across Dims)')
plt.title('Overall Comparison of Distances')
for i, v in enumerate(means):
    plt.text(i, v, f'{v:.3f}', ha='center', va='bottom')

plt.tight_layout()
plt.show()


## 19. Grouped Distance Experiment (Full SIFT1M Dataset)
在 **完整 SIFT1M (1M Vectors)** 數據集上比較四種方法。
針對 64維的方法，使用 **Grouped L2 Distance** 作為檢索度量。

**Grouped L2 Distance 實現方式**:
將 64 維向量切分為多個小組 (Group)。
- 設定 `group_size = 6`。
- 分組方式：Dimension 0-5, 6-11, ..., 最後一組為 60-63 (4維)。
- 計算公式：先計算每個 Group 的 L2 Distance，然後將所有 Groups 的距離相加。
  $D_{total}(x, y) = \sum_{g} ||x_g - y_g||_2$

**Methods**:
1. **Direct INT3 (128D)**: Standard L2.
2. **AvgPool (64D)**: Min-Max Quantization + Grouped L2.
3. **PCA (64D)**: Percentile Quantization + Grouped L2.
4. **AutoEncoder (64D)**: Recon+Dist Loss + Percentile Quantization + Grouped L2.

In [None]:
# 20. Threshold Experiment (Grouped L2 with Threshold Filtering)

def estimate_safe_threshold(q_vecs, db_vecs, gt, group_size=6, sample_size=1000, percentile=98):
    """
    Estimates:
    1. Grouped L2 Threshold (safe_threshold): Target for Latent Space
    2. Input Space L2 Limit (input_limit): Standard L2 distance for neighbors in original space
    """
    print(f"Estimating thresholds from {len(q_vecs)} queries (target sample {sample_size})...")
    
    n_samples = min(len(q_vecs), sample_size)
    indices = np.random.choice(len(q_vecs), n_samples, replace=False)
    
    if isinstance(q_vecs, np.ndarray):
        q_sample = torch.from_numpy(q_vecs[indices]).to(DEVICE).float()
    else:
        q_sample = q_vecs[indices].to(DEVICE).float()
        
    gt_sample = gt[indices]
    
    all_group_dists = []
    max_input_dists = [] # Standard L2 distance for True Neighbors
    
    D = q_sample.shape[1]
    
    for i in range(len(q_sample)):
        q = q_sample[i:i+1] # (1, D)
        neighbor_indices = gt_sample[i] # Indices
        
        # Access neighbors from DB
        # Warning: db_vecs might be huge, use careful indexing
        if isinstance(db_vecs, torch.Tensor):
             neighbors = db_vecs[neighbor_indices].float()
        else:
             neighbors = torch.from_numpy(db_vecs[neighbor_indices]).to(DEVICE).float()
             
        # 1. Calc Grouped distances
        for start_idx in range(0, D, group_size):
            end_idx = min(start_idx + group_size, D)
            current_group_size = end_idx - start_idx
            
            q_sub = q[:, start_idx:end_idx]
            db_sub = neighbors[:, start_idx:end_idx]
            
            dists = torch.cdist(q_sub, db_sub, p=2) # (1, 100)
            vals = dists.flatten()
            
            # Normalize small groups
            if current_group_size < group_size:
                vals = vals * (group_size / current_group_size)
            
            all_group_dists.append(vals)
            
        # 2. Calc Standard Input Distance (Input Limit estimation)
        # Note: If input is 64D (compressed), this estimates limit for compressed space
        # If input is 128D (original), estimates for original space
        # Here we perform it on whatever vectors are passed.
        # If we want 128D limit, we must pass 128D vectors.
        # In the context of AE training, we need limit for 128D Input.
        
        input_dists = torch.cdist(q, neighbors, p=2).flatten()
        max_input_dists.append(input_dists)
            
    all_dists = torch.cat(all_group_dists)
    all_input_dists = torch.cat(max_input_dists)
    
    threshold = torch.quantile(all_dists, percentile / 100.0).item()
    input_limit = torch.quantile(all_input_dists, percentile / 100.0).item()
    
    print(f"Estimated Base Threshold (P{percentile}): {threshold:.4f}")
    print(f"Estimated Input Neighbor Limit (P{percentile}): {input_limit:.4f}")
    return threshold, input_limit

def evaluate_recall_grouped_threshold_batched(query_vectors, db_vectors, gt_top_k, retrieval_depths, k_true_neighbors, group_size=6, threshold=None):
    """
    Evaluate Recall using Grouped L2 distance with Hard Threshold Filtering.
    """
    max_depth = max(retrieval_depths)
    num_queries = len(query_vectors)
    query_batch_size = 100
    
    total_hits = {r: 0 for r in retrieval_depths}
    
    if isinstance(query_vectors, np.ndarray):
        q_vectors = torch.from_numpy(query_vectors).to(DEVICE).float()
    else:
        q_vectors = query_vectors.to(DEVICE).float()
        
    if isinstance(db_vectors, np.ndarray):
        d_vectors = torch.from_numpy(db_vectors).to(DEVICE).float()
    else:
        d_vectors = db_vectors.to(DEVICE).float()
    
    start_time = time.time()
    
    for i in tqdm(range(0, num_queries, query_batch_size), desc=f"Eval Threshold={threshold if threshold else 'None'}", leave=False):
        q_end = min(i + query_batch_size, num_queries)
        q_batch = q_vectors[i:q_end] # (B, D)
        gt_batch = gt_top_k[i:q_end]
        
        B = q_batch.shape[0]
        N = d_vectors.shape[0]
        D = q_batch.shape[1]
        
        # Initialize Total Distance and Valid Mask
        total_dist = torch.zeros((B, N), device=DEVICE)
        valid_mask = torch.ones((B, N), dtype=torch.bool, device=DEVICE)
        
        if threshold is not None:
             for start_idx in range(0, D, group_size):
                end_idx = min(start_idx + group_size, D)
                current_group_size = end_idx - start_idx
                
                # Determine local threshold
                current_threshold = threshold
                if current_group_size < group_size:
                    # Scale threshold for smaller groups: T * (current / base)
                    current_threshold = threshold * (current_group_size / group_size)
                
                q_sub = q_batch[:, start_idx:end_idx]
                db_sub = d_vectors[:, start_idx:end_idx]
                
                dist_sub = torch.cdist(q_sub, db_sub, p=2)
                
                # Filter
                valid_mask &= (dist_sub <= current_threshold)
                total_dist += dist_sub
        else:
             for start_idx in range(0, D, group_size):
                end_idx = min(start_idx + group_size, D)
                q_sub = q_batch[:, start_idx:end_idx]
                db_sub = d_vectors[:, start_idx:end_idx]
                total_dist += torch.cdist(q_sub, db_sub, p=2)
                
        # Apply Infinity to invalid
        if threshold is not None:
            total_dist.masked_fill_(~valid_mask, float('inf'))
            
        # Top K
        _, sorted_indices = torch.topk(total_dist, k=max_depth, dim=1, largest=False)
        sorted_indices = sorted_indices.cpu().numpy()
        
        for r in retrieval_depths:
            retrieved = sorted_indices[:, :r]
            for j in range(len(gt_batch)):
                hits = len(set(gt_batch[j]) & set(retrieved[j]))
                total_hits[r] += hits
                
    time_taken = time.time() - start_time
    
    recalls = {}
    for r in retrieval_depths:
        recalls[f'recall@{r}'] = total_hits[r] / (num_queries * float(k_true_neighbors))
        
    recalls['time'] = time_taken
    return recalls

def run_threshold_experiment(base_vectors, query_vectors, gt_top_k, retrieval_depths, k_true_neighbors):
    torch.cuda.empty_cache()
    print("Running Threshold Experiment (Full SIFT1M Dataset)...")
    
    db_full = base_vectors
    query_full = query_vectors
    
    results = []
    
    # Setup: Estimating Threshold
    # We need 2 estimates:
    # 1. 'safe_threshold' (Group L2) - calculated using quant/reduced data
    # 2. 'input_limit' (Standard L2) - calculated using ORIGINAL 128D data
    
    # A. Estimate Input Limit using Original Vectors
    print("\n[Setup] Estimating Input Space Limit (128D)...")
    q_sub = query_full[:1000]
    gt_sub = gt_top_k[:1000]
    
    # Fake call to estimate_safe_threshold but passing 128D data and group_size=128 (so global L2)
    # Actually, estimate_safe_threshold returns both.
    # But for 'Group L2' threshold we need reduced data.
    # For 'Input Limit' we need original data.
    
    # Let's perform estimate on Original Data first to get input_limit
    # Using group_size=128 means "Single Group" = Standard L2
    _, input_limit_est = estimate_safe_threshold(q_sub, db_full, gt_sub, group_size=128, percentile=98)
    
    # B. Estimate Group Threshold using Quantized (AvgPool) Data
    print("\n[Setup] Estimating Group Threshold using AvgPool (64D) Quantized Data...")
    quantizer = quantize_to_int3
    _, db_est, q_est = method2_average_pooling_int3(db_full, q_sub, quantizer=quantizer)
    
    # Note: Pass Float for DB if using method2 output which might be quantized
    # method2 returns quantized. Cast to float handled inside estimator.
    safe_threshold_est, _ = estimate_safe_threshold(q_est, db_est, gt_sub, group_size=6, percentile=95) # P95 tighter for thresholding? Or P98.
    
    print(f"Set Global Threshold T={safe_threshold_est:.4f}")
    print(f"Set Input Neighbor Limit L={input_limit_est:.4f}")
    
    # Config for AE training
    threshold_config = {
        'target_threshold': safe_threshold_est,
        'input_limit': input_limit_est
    }
    
    # --- 1. Baseline INT3 (128D) ---
    print("\n[Baseline] Direct INT3 (128D)...")
    quantizer = quantize_to_int3
    _, db_q, q_q = method1_direct_int3(db_full, query_full, quantizer=quantizer)
    res = evaluate_recall_grouped_threshold_batched(q_q, db_q, gt_top_k, retrieval_depths, k_true_neighbors, threshold=None)
    results.append({'Method': 'Baseline (128D)', 'Config': 'Standard', **res})
    
    # Define Methods
    methods_config = [
        ('AvgPool (64D)', method2_average_pooling_int3, quantize_to_int3, {}),
        ('PCA (64D)', method3_pca_int3, PerDimensionQuantileQuantizer(), {}),
        ('AutoEncoder (64D, No Penalty)', method5_autoencoder_int3, PerDimensionQuantileQuantizer(), 
             {'train_subset': True, 'distance_loss_weight': 0.1, 'use_grouped_latent_dist': True, 'threshold_config': None}),
        ('AutoEncoder (64D, Threshold Penalty)', method5_autoencoder_int3, PerDimensionQuantileQuantizer(), 
             {'train_subset': True, 'distance_loss_weight': 0.1, 'use_grouped_latent_dist': True, 'threshold_config': threshold_config})
    ]
    
    for name, method_func, quantizer_inst, kwargs in methods_config:
        print(f"\n[{name}] Processing...")
        torch.cuda.empty_cache()
        
        if 'train_subset' in kwargs:
            train_sub = db_full[:100000]
            if isinstance(train_sub, torch.Tensor): train_sub = train_sub.cpu().numpy()
            
            # Call method
            ae_model, _, _ = method_func(train_sub, None, epochs=50, quantizer=quantizer_inst, 
                                         distance_loss_weight=kwargs.get('distance_loss_weight', 0.1),
                                         use_grouped_latent_dist=kwargs.get('use_grouped_latent_dist', False),
                                         threshold_config=kwargs.get('threshold_config', None))
            
            # Apply
            print("  Encoding Full DB...")
            ae_model.eval()
            with torch.no_grad():
                if isinstance(db_full, np.ndarray): db_tensor = torch.from_numpy(db_full).to(DEVICE)
                else: db_tensor = db_full.to(DEVICE)
                full_db_reduced = ae_model.encode(db_tensor.float())
                
                if isinstance(query_full, np.ndarray): q_tensor = torch.from_numpy(query_full).to(DEVICE)
                else: q_tensor = query_full.to(DEVICE)
                full_q_reduced = ae_model.encode(q_tensor.float())
                
            db_q = quantizer_inst.fit_transform(full_db_reduced)
            q_q = quantizer_inst.transform(full_q_reduced)
            
        else:
            _, db_q, q_q = method_func(db_full, query_full, quantizer=quantizer_inst)
            
        # 1. Grouped No Threshold
        print(f"  Eval: Grouped No Threshold...")
        res_no = evaluate_recall_grouped_threshold_batched(q_q, db_q, gt_top_k, retrieval_depths, k_true_neighbors, group_size=6, threshold=None)
        results.append({'Method': name, 'Config': 'Grouped No Threshold', **res_no})
        print(f"  -> {res_no['recall@1000']:.4f}")
        
        # 2. Grouped With Threshold
        print(f"  Eval: Grouped With Threshold (T={safe_threshold_est:.4f})...")
        res_th = evaluate_recall_grouped_threshold_batched(q_q, db_q, gt_top_k, retrieval_depths, k_true_neighbors, group_size=6, threshold=safe_threshold_est)
        results.append({'Method': name, 'Config': 'Grouped + Threshold', **res_th})
        print(f"  -> {res_th['recall@1000']:.4f}")

    # Summary
    df_res = pd.DataFrame(results)
    csv_filename = "result_threshold_experiment.csv"
    df_res.to_csv(csv_filename, index=False)
    display(df_res)
    plot_threshold_results(df_res)
    
if 'base_vectors' in locals() and 'query_vectors' in locals():
    if 'gt_1m' not in locals():
         gt_1m = calculate_top_k_ground_truth(query_vectors, base_vectors, k=100)
    run_threshold_experiment(base_vectors, query_vectors, gt_1m, [100, 500, 1000], 100)