# SIFT-1M Dimensionality Reduction Experiment: PCA -> Autoencoder

This notebook implements a two-stage dimensionality reduction pipeline:
1.  **PCA**: Reduce SIFT-1M data (128d) to $x$ dimensions, where $x \in \{16, 32, ..., 128\}$.
2.  **Autoencoder**: Reduce from $x$ dimensions to 64 dimensions.
    *   Loss Function: Reconstruction Loss + 3 * Distance Loss.

Finally, we evaluate Recall@100, Recall@500, and Recall@1000.

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.decomposition import PCA
from sklearn.metrics import pairwise_distances
import os
import time
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

# Set device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# Constants
SIFT_DIM = 128
FINAL_DIM = 64
BATCH_SIZE = 256
EPOCHS = 50 # Reduced epochs for faster experimentation, can be increased
PATIENCE = 5

In [None]:
import urllib.request
import tarfile
import shutil
import socket

def download_sift1m():
    """
    Download SIFT1M dataset if not exists
    """
    # Create data directory
    if not os.path.exists('sift1m'):
        os.makedirs('sift1m')
    
    files = [
        "sift_base.fvecs",
        "sift_query.fvecs",
        "sift_groundtruth.ivecs"
    ]
    
    # Check if files exist
    all_exist = True
    for filename in files:
        if not os.path.exists(os.path.join('sift1m', filename)):
            all_exist = False
            break
            
    if all_exist:
        print("All SIFT1M files exist.")
        return True
        
    print("Downloading SIFT1M dataset...")
    
    tar_url = "ftp://ftp.irisa.fr/local/texmex/corpus/sift.tar.gz"
    tar_path = "sift1m/sift.tar.gz"
    
    try:
        print(f"Downloading from {tar_url}...")
        socket.setdefaulttimeout(30)
        urllib.request.urlretrieve(tar_url, tar_path)
        print("Download complete. Extracting...")
        
        with tarfile.open(tar_path, "r:gz") as tar:
            tar.extractall(path="sift1m")
            
        # Move files from sift1m/sift/ to sift1m/
        extracted_dir = os.path.join("sift1m", "sift")
        if os.path.exists(extracted_dir):
            for filename in files:
                src = os.path.join(extracted_dir, filename)
                dst = os.path.join("sift1m", filename)
                if os.path.exists(src):
                    if os.path.exists(dst):
                        os.remove(dst)
                    os.rename(src, dst)
            # Cleanup
            shutil.rmtree(extracted_dir)
                
        # Remove tar file
        if os.path.exists(tar_path):
            os.remove(tar_path)
        print("Dataset ready.")
        return True
        
    except Exception as e:
        print(f"Download failed: {e}")
        print("Please manually download sift.tar.gz from http://corpus-texmex.irisa.fr/ and extract to sift1m/ folder.")
        return False

def read_fvecs(filename):
    """Read .fvecs file"""
    with open(filename, 'rb') as f:
        d = np.fromfile(f, dtype=np.int32, count=1)[0]
        f.seek(0)
        data = np.fromfile(f, dtype=np.float32)
        data = data.reshape(-1, d + 1)
        return data[:, 1:].copy()

def read_ivecs(filename):
    """Read .ivecs file"""
    with open(filename, 'rb') as f:
        d = np.fromfile(f, dtype=np.int32, count=1)[0]
        f.seek(0)
        data = np.fromfile(f, dtype=np.int32)
        data = data.reshape(-1, d + 1)
        return data[:, 1:].copy()

# Ensure data is present
if not download_sift1m():
    raise RuntimeError("Failed to download dataset and files are missing.")

# Load Data
data_dir = 'sift1m'
try:
    print("Loading SIFT1M dataset...")
    base_vectors = read_fvecs(os.path.join(data_dir, 'sift_base.fvecs'))
    query_vectors = read_fvecs(os.path.join(data_dir, 'sift_query.fvecs'))
    ground_truth = read_ivecs(os.path.join(data_dir, 'sift_groundtruth.ivecs'))
    print(f"Base vectors: {base_vectors.shape}")
    print(f"Query vectors: {query_vectors.shape}")
    print(f"Ground truth: {ground_truth.shape}")
except Exception as e:
    print(f"Error loading data: {e}")
    raise # Stop execution if loading fails

In [None]:
class PerDimensionQuantileQuantizer:
    """
    Quantize to INT3 using per-dimension quantiles.
    For each dimension, it calculates 7 thresholds (1/8, 2/8, ..., 7/8 quantiles)
    from the training data, and maps values to bins -4 to 3.
    """
    def __init__(self):
        self.thresholds = None
        self.device = DEVICE

    def fit(self, data):
        if isinstance(data, np.ndarray):
            data = torch.from_numpy(data)
        if data.device != self.device:
            data = data.to(self.device)
        
        data = data.float()
        N, D = data.shape
        
        # Calculate quantiles for each dimension
        q_vals = torch.tensor([0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875], device=self.device)
        
        try:
            self.thresholds = torch.quantile(data, q_vals, dim=0) # (7, D)
        except RuntimeError:
            if N > 100000:
                step = N // 100000
                sample = data[::step]
                self.thresholds = torch.quantile(sample, q_vals, dim=0)
            else:
                raise

    def transform(self, data):
        if self.thresholds is None:
            raise ValueError("Quantizer not fitted")
            
        if isinstance(data, np.ndarray):
            data = torch.from_numpy(data)
        if data.device != self.device:
            data = data.to(self.device)
            
        data = data.float()
        
        comparison = data.unsqueeze(1) >= self.thresholds.unsqueeze(0)
        rank = torch.sum(comparison, dim=1).to(torch.int8) # (N, D), values 0-7
        
        return rank - 4
        
    def fit_transform(self, data):
        self.fit(data)
        return self.transform(data)

def apply_quantization(train_data, query_data):
    quantizer = PerDimensionQuantileQuantizer()
    train_q = quantizer.fit_transform(train_data)
    query_q = quantizer.transform(query_data)
    # Return as numpy for evaluation (or keep as tensor if evaluate_recall handles it)
    # evaluate_recall expects numpy or tensor, but let's return numpy to be safe/consistent
    return train_q.cpu().numpy(), query_q.cpu().numpy()

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self, input_dim, latent_dim=64):
        super(AutoEncoder, self).__init__()
        
        # Calculate hidden dimension: (input_dim + 64) / 2
        hidden_dim = (input_dim + latent_dim) // 2
        
        # Encoder: Input -> Hidden -> Latent
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim)
        )
        
        # Decoder: Latent -> Hidden -> Input
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )
    
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded
    
    def encode(self, x):
        return self.encoder(x)

def train_autoencoder(data, input_dim, latent_dim=64, epochs=50, batch_size=256, patience=5, distance_loss_weight=3.0):
    """
    Train AutoEncoder with Reconstruction Loss + Distance Loss (1:3 ratio)
    """
    model = AutoEncoder(input_dim, latent_dim).to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion_recon = nn.MSELoss()
    criterion_dist = nn.MSELoss()
    
    # Convert data to tensor
    if isinstance(data, np.ndarray):
        data_tensor = torch.FloatTensor(data)
    else:
        data_tensor = data
        
    # Create DataLoader
    dataset = torch.utils.data.TensorDataset(data_tensor)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    model.train()
    
    best_loss = float('inf')
    patience_counter = 0
    history = {'loss': [], 'recon_loss': [], 'dist_loss': []}
    
    # Use tqdm for progress
    pbar = tqdm(range(epochs), desc=f"Training AE (Input: {input_dim})", leave=False)
    
    for epoch in pbar:
        total_loss = 0
        total_recon = 0
        total_dist = 0
        num_batches = 0
        
        for batch in dataloader:
            batch_data = batch[0].to(DEVICE)
            
            optimizer.zero_grad()
            encoded, reconstructed = model(batch_data)
            
            # 1. Reconstruction Loss
            recon_loss = criterion_recon(reconstructed, batch_data)
            
            # 2. Distance Preservation Loss
            # Calculate pairwise distances in batch
            # Input space distance
            dist_input = torch.cdist(batch_data, batch_data, p=2)
            # Latent space distance
            dist_latent = torch.cdist(encoded, encoded, p=2)
            
            # Loss: MSE between distance matrices
            dist_loss = criterion_dist(dist_latent, dist_input)
            
            # Total Loss = 1 * Recon + 3 * Distance
            loss = recon_loss + distance_loss_weight * dist_loss
            
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            total_recon += recon_loss.item()
            total_dist += dist_loss.item()
            num_batches += 1
            
        avg_loss = total_loss / num_batches
        avg_recon = total_recon / num_batches
        avg_dist = total_dist / num_batches
        
        history['loss'].append(avg_loss)
        history['recon_loss'].append(avg_recon)
        history['dist_loss'].append(avg_dist)
        
        pbar.set_postfix({'Loss': f"{avg_loss:.4f}", 'Recon': f"{avg_recon:.4f}", 'Dist': f"{avg_dist:.4f}"})
        
        # Early Stopping
        if avg_loss < best_loss - 1e-4:
            best_loss = avg_loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                break
                
    return model, history

In [None]:
def evaluate_recall(queries_encoded, database_encoded, ground_truth, k_values=[100, 500, 1000]):
    """
    Calculate Recall@K.
    Recall@K = (Number of relevant items in top K retrieved) / (Total number of relevant items)
    For SIFT1M, ground_truth contains top 100 neighbors. So Total relevant = 100.
    """
    # Calculate distances in reduced space
    # Using CPU for distance calculation to avoid OOM on GPU if dataset is large, 
    # or use GPU in batches if needed. SIFT1M (1M vectors) might fit in GPU memory for distance calc if careful.
    # Let's use sklearn pairwise_distances (CPU) or faiss if available. 
    # Since we imported pairwise_distances, we'll use that or a batched GPU approach.
    # Given 1M vectors, pairwise_distances(10k queries, 1M db) is large (10^10 floats = 40GB).
    # We must batch.
    
    print("Calculating distances and recall...")
    num_queries = queries_encoded.shape[0]
    recalls = {k: [] for k in k_values}
    
    # Ensure data is on CPU for simple processing or GPU for speed
    # Let's use a batched GPU approach for speed
    queries_t = torch.FloatTensor(queries_encoded).to(DEVICE)
    database_t = torch.FloatTensor(database_encoded).to(DEVICE)
    
    batch_size = 100
    
    for i in range(0, num_queries, batch_size):
        end = min(i + batch_size, num_queries)
        q_batch = queries_t[i:end]
        
        # Calculate distance to all database vectors
        # dists: (batch_size, 1M)
        # To save memory, we can iterate over database too, but let's try full DB if it fits
        # 1M * 64 floats is small (256MB). 
        # q_batch (100 * 64).
        # cdist output (100 * 1M) = 100M floats = 400MB. This fits easily in GPU.
        
        dists = torch.cdist(q_batch, database_t)
        
        # Get top max_k indices
        max_k = max(k_values)
        # topk returns largest, we want smallest distance. So use negative distance or sort.
        # topk is faster than sort.
        _, indices = torch.topk(dists, k=max_k, dim=1, largest=False)
        
        indices = indices.cpu().numpy()
        
        for j in range(len(indices)):
            gt = ground_truth[i + j] # Top 100 GT
            # For SIFT1M, gt has 100 neighbors.
            gt_set = set(gt)
            
            retrieved_indices = indices[j]
            
            for k in k_values:
                retrieved_k = set(retrieved_indices[:k])
                intersection = len(gt_set.intersection(retrieved_k))
                recall = intersection / len(gt_set) # usually 100
                recalls[k].append(recall)
                
    # Average recall
    avg_recalls = {k: np.mean(v) for k, v in recalls.items()}
    return avg_recalls

In [None]:
# Main Experiment Loop
x_values = range(16, 129, 16) # 16, 32, ..., 128
results = []

print("=== Baseline: Autoencoder Only (128 -> 64) ===")
# Train AE on raw data (128 dim)
ae_only_train_data = base_vectors[:100000]
ae_only_model, _ = train_autoencoder(
    ae_only_train_data, 
    input_dim=128, 
    latent_dim=64, 
    epochs=EPOCHS, 
    distance_loss_weight=3.0
)

ae_only_model.eval()
with torch.no_grad():
    def encode_in_batches(data, model, batch_size=10000):
        encoded_list = []
        for i in range(0, len(data), batch_size):
            batch = torch.FloatTensor(data[i:i+batch_size]).to(DEVICE)
            encoded = model.encode(batch)
            encoded_list.append(encoded.cpu().numpy())
        return np.vstack(encoded_list)
        
    base_ae_only = encode_in_batches(base_vectors, ae_only_model)
    query_ae_only = encode_in_batches(query_vectors, ae_only_model)

# Quantize AE Only
print("Quantizing AE Only results...")
base_ae_only_q, query_ae_only_q = apply_quantization(base_ae_only, query_ae_only)

# Evaluate AE Only
recalls_ae_only = evaluate_recall(query_ae_only_q, base_ae_only_q, ground_truth, k_values=[100, 500, 1000])
print(f"AE Only Results: {recalls_ae_only}")


print(f"\nStarting PCA experiment for x values: {list(x_values)}")

for x in x_values:
    print(f"\n=== Processing x = {x} ===")
    
    # 1. PCA Reduction
    print(f"Step 1: PCA Reduction to {x} dimensions...")
    pca = PCA(n_components=x, whiten=True)
    pca_train_data = base_vectors[:100000] 
    pca.fit(pca_train_data)
    
    base_pca = pca.transform(base_vectors)
    query_pca = pca.transform(query_vectors)
    
    # Baseline: PCA Only -> Quantize -> Recall
    print(f"Evaluating PCA Only ({x} dim)...")
    base_pca_q, query_pca_q = apply_quantization(base_pca, query_pca)
    recalls_pca_only = evaluate_recall(query_pca_q, base_pca_q, ground_truth, k_values=[100, 500, 1000])
    print(f"PCA Only ({x}) Results: {recalls_pca_only}")
    
    # 2. Autoencoder Training (PCA -> AE)
    print(f"Step 2: Training Autoencoder ({x} -> {FINAL_DIM})...")
    ae_train_data = base_pca[:100000]
    
    ae_model, history = train_autoencoder(
        ae_train_data, 
        input_dim=x, 
        latent_dim=FINAL_DIM, 
        epochs=EPOCHS, 
        distance_loss_weight=3.0
    )
    
    # 3. Encoding
    print("Step 3: Encoding data...")
    ae_model.eval()
    with torch.no_grad():
        base_encoded = encode_in_batches(base_pca, ae_model)
        query_encoded = encode_in_batches(query_pca, ae_model)
        
    # Quantize PCA -> AE
    print("Quantizing PCA -> AE results...")
    base_encoded_q, query_encoded_q = apply_quantization(base_encoded, query_encoded)
        
    # 4. Evaluation
    print("Step 4: Evaluating Recall...")
    recalls_pca_ae = evaluate_recall(query_encoded_q, base_encoded_q, ground_truth, k_values=[100, 500, 1000])
    
    print(f"PCA -> AE ({x}->64) Results: {recalls_pca_ae}")
    
    results.append({
        'x': x,
        # PCA -> AE
        'pca_ae_recall@100': recalls_pca_ae[100],
        'pca_ae_recall@500': recalls_pca_ae[500],
        'pca_ae_recall@1000': recalls_pca_ae[1000],
        # PCA Only
        'pca_only_recall@100': recalls_pca_only[100],
        'pca_only_recall@500': recalls_pca_only[500],
        'pca_only_recall@1000': recalls_pca_only[1000],
        # AE Only (Constant)
        'ae_only_recall@100': recalls_ae_only[100],
        'ae_only_recall@500': recalls_ae_only[500],
        'ae_only_recall@1000': recalls_ae_only[1000]
    })

# Create DataFrame
df_results = pd.DataFrame(results)
print("\nFinal Results:")
print(df_results)
df_results.to_csv('sift1m_pca_ae_comparison_results.csv', index=False)

In [None]:
# Visualization
# Plot Recall@100
plt.figure(figsize=(10, 6))
plt.plot(df_results['x'], df_results['pca_ae_recall@100'], marker='o', label='PCA(x) -> AE(64)')
plt.plot(df_results['x'], df_results['pca_only_recall@100'], marker='s', label='PCA(x) Only')
plt.axhline(y=df_results['ae_only_recall@100'][0], color='r', linestyle='--', label='AE(128->64) Only')

plt.title('Recall@100 Comparison (Quantized INT3)')
plt.xlabel('PCA Dimension (x)')
plt.ylabel('Recall@100')
plt.grid(True)
plt.legend()
plt.xticks(list(x_values))
plt.savefig('recall_at_100.png')
plt.show()

# Plot Recall@500
plt.figure(figsize=(10, 6))
plt.plot(df_results['x'], df_results['pca_ae_recall@500'], marker='o', label='PCA(x) -> AE(64)')
plt.plot(df_results['x'], df_results['pca_only_recall@500'], marker='s', label='PCA(x) Only')
plt.axhline(y=df_results['ae_only_recall@500'][0], color='r', linestyle='--', label='AE(128->64) Only')

plt.title('Recall@500 Comparison (Quantized INT3)')
plt.xlabel('PCA Dimension (x)')
plt.ylabel('Recall@500')
plt.grid(True)
plt.legend()
plt.xticks(list(x_values))
plt.savefig('recall_at_500.png')
plt.show()

# Plot Recall@1000
plt.figure(figsize=(10, 6))
plt.plot(df_results['x'], df_results['pca_ae_recall@1000'], marker='o', label='PCA(x) -> AE(64)')
plt.plot(df_results['x'], df_results['pca_only_recall@1000'], marker='s', label='PCA(x) Only')
plt.axhline(y=df_results['ae_only_recall@1000'][0], color='r', linestyle='--', label='AE(128->64) Only')

plt.title('Recall@1000 Comparison (Quantized INT3)')
plt.xlabel('PCA Dimension (x)')
plt.ylabel('Recall@1000')
plt.grid(True)
plt.legend()
plt.xticks(list(x_values))
plt.savefig('recall_at_1000.png')
plt.show()