# Locality-Sensitive Hashing (LSH) using FAISS (Optimized)

This notebook evaluates approximate nearest neighbor search using Facebook's FAISS library with LSH indexing.

## Method: FAISS IndexLSH
FAISS (Facebook AI Similarity Search) provides highly optimized C++ implementations of various ANN methods, including LSH.

**Key advantages:**
- ✅ Production-grade C++ implementation (10-100x faster than Python)
- ✅ Optimized for modern CPUs with SIMD instructions
- ✅ Used by Facebook, Google, and industry leaders
- ✅ Well-tested on billion-scale datasets

## Installation
```bash
# CPU version (recommended for this project)
pip install faiss-cpu

# OR GPU version (if you have CUDA)
pip install faiss-gpu
```

## Metrics
- **Recall@10**: Proportion of true nearest neighbors retrieved
- **Query time**: Milliseconds per query
- **Build time**: Index construction time
- **Memory**: Total index size in MB

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import time
import sys
import json
import os

# Import FAISS
try:
    import faiss
    print(f"✓ FAISS version: {faiss.__version__}")
except ImportError:
    print("ERROR: FAISS not installed!")
    print("Install with: pip install faiss-cpu")
    raise

# Import our custom modules
sys.path.append('..')  
from evaluator import ANNEvaluator
from datasets import DatasetLoader, print_dataset_info

# Plotting style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (14, 10)

print("✓ Imports successful!")

## 1. Configuration

In [None]:
# Configuration
K = 10  # Number of nearest neighbors
DATA_DIR = "data"

# Datasets
DATASETS = [
    'sift',
    'gist',
    'deep1b'
]

# LSH parameters for FAISS
# nbits: number of hash bits (higher = finer buckets, lower recall but faster)
HASH_BITS_OPTIONS = [64, 128, 256, 512, 1024]  # FAISS uses more bits than custom implementation

# Subset sizes
N_TRAIN = None
N_TEST = 1000

# Results directory
RESULTS_DIR = "../results"
os.makedirs(RESULTS_DIR, exist_ok=True)

print(f"Configuration:")
print(f"  k (neighbors) = {K}")
print(f"  Datasets: {DATASETS}")
print(f"  LSH hash bits: {HASH_BITS_OPTIONS}")
print(f"\nNote: FAISS uses binary codes, so nbits is total bits (not tables × bits)")

## 2. FAISS LSH Implementation

FAISS `IndexLSH` uses binary hash codes:
- Projects vectors to binary codes of specified length
- Uses Hamming distance between codes for fast approximate search
- Optimized C++ implementation with SIMD

In [None]:
class FAISSLSHIndex:
    """
    Wrapper for FAISS IndexLSH with our evaluator interface.
    """
    def __init__(self, dim, nbits):
        """
        Args:
            dim: Dimension of input vectors
            nbits: Number of hash bits (length of binary code)
        """
        self.dim = dim
        self.nbits = nbits
        
        # Create FAISS LSH index
        self.index = faiss.IndexLSH(dim, nbits)
        
    def fit(self, X_train):
        """Build LSH index"""
        print(f"  Building FAISS LSH index: {X_train.shape[0]:,} points, {self.nbits} bits")
        
        # Ensure data is float32 (FAISS requirement)
        if X_train.dtype != np.float32:
            X_train = X_train.astype(np.float32)
        
        # Ensure C-contiguous (FAISS requirement)
        if not X_train.flags['C_CONTIGUOUS']:
            X_train = np.ascontiguousarray(X_train)
        
        # Train (not needed for LSH but FAISS requires it)
        start = time.time()
        self.index.train(X_train)
        train_time = time.time() - start
        
        # Add vectors
        start = time.time()
        self.index.add(X_train)
        add_time = time.time() - start
        
        print(f"    Train time: {train_time:.2f}s")
        print(f"    Add time: {add_time:.2f}s")
        print(f"    Total vectors: {self.index.ntotal:,}")
        
        return self
    
    def query(self, X_test, k):
        """
        Query for k nearest neighbors.
        
        Returns:
            indices: (n_queries, k) array
            distances: (n_queries, k) array
        """
        # Ensure data is float32 and C-contiguous
        if X_test.dtype != np.float32:
            X_test = X_test.astype(np.float32)
        if not X_test.flags['C_CONTIGUOUS']:
            X_test = np.ascontiguousarray(X_test)
        
        # Search
        distances, indices = self.index.search(X_test, k)
        
        return indices, distances
    
    def get_memory_usage(self):
        """Estimate memory usage in MB"""
        # FAISS binary codes: nbits per vector
        n_vectors = self.index.ntotal
        bits_per_vector = self.nbits
        bytes_total = (n_vectors * bits_per_vector) / 8
        
        # Add overhead for hash table structure (rough estimate)
        overhead = bytes_total * 0.1
        
        return (bytes_total + overhead) / (1024**2)


def build_faiss_lsh_index(X_train, nbits):
    """Wrapper for evaluator"""
    dim = X_train.shape[1]
    index = FAISSLSHIndex(dim, nbits)
    index.fit(X_train)
    return index


def query_faiss_lsh_index(index, X_test, k):
    """Wrapper for evaluator"""
    return index.query(X_test, k)


print("✓ FAISS LSH implementation ready!")

## 3. Run LSH Experiments

In [None]:
# Store all results
all_results = {}

# Initialize dataset loader
loader = DatasetLoader(data_dir=DATA_DIR)

for dataset_name in DATASETS:
    print(f"\n{'='*70}")
    print(f"DATASET: {dataset_name.upper()}")
    print(f"{'='*70}")
    
    # Print dataset info
    print_dataset_info(dataset_name)
    
    # Load dataset
    try:
        X_train, X_test = loader.load_dataset(
            dataset_name,
            n_train=N_TRAIN,
            n_test=N_TEST
        )
    except FileNotFoundError as e:
        print(f"ERROR: Could not load {dataset_name} dataset.")
        print(f"  {e}")
        print(f"  Skipping this dataset...\n")
        continue
    
    d = X_train.shape[1]
    
    # Initialize evaluator
    evaluator = ANNEvaluator(X_train, X_test, k=K)
    
    # Compute ground truth
    evaluator.compute_ground_truth()
    
    # Store results for this dataset
    dataset_results = []
    
    # Sweep over hash bits
    for nbits in HASH_BITS_OPTIONS:
        print(f"\n{'='*60}")
        print(f"Evaluating: FAISS LSH (nbits={nbits})")
        print(f"{'='*60}")
        
        # Create index builder
        def build_fn(X):
            return build_faiss_lsh_index(X, nbits)
        
        # Evaluate
        try:
            results = evaluator.evaluate(
                index_builder=build_fn,
                query_func=query_faiss_lsh_index,
                method_name=f"FAISS-LSH-{dataset_name.upper()}-{nbits}bits"
            )
            
            # Add metadata
            results['dataset'] = dataset_name
            results['d'] = d
            results['lsh_nbits'] = nbits
            results['library'] = 'FAISS'
            
            dataset_results.append(results)
            
            # Save individual result
            result_path = os.path.join(RESULTS_DIR, f"lsh_faiss_{dataset_name}_{nbits}bits.json")
            with open(result_path, 'w') as f:
                json.dump(results, f, indent=2)
            
            print(f"\n✓ Results saved to {result_path}")
            
        except Exception as e:
            print(f"ERROR during evaluation: {e}")
            import traceback
            traceback.print_exc()
            continue
    
    # Store results for this dataset
    all_results[dataset_name] = dataset_results

print(f"\n{'='*70}")
print("ALL LSH EXPERIMENTS COMPLETE")
print(f"{'='*70}\n")

## 4. Visualize Results

In [None]:
# Create per-dataset plots
for dataset_name, results_list in all_results.items():
    if not results_list:
        continue
    
    print(f"\nPlotting results for {dataset_name.upper()}...")
    
    # Extract data
    nbits_list = [r['lsh_nbits'] for r in results_list]
    recalls = [r['recall@k'] for r in results_list]
    query_times = [r['avg_query_time_ms'] for r in results_list]
    memories = [r['memory_mb'] for r in results_list]
    build_times = [r['build_time_s'] for r in results_list]
    d = results_list[0]['d']
    
    # Create subplots
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    fig.suptitle(f'FAISS LSH Results: {dataset_name.upper()} (d={d})', 
                 fontsize=16, fontweight='bold')
    
    # 1. Recall vs Hash Bits
    axes[0, 0].plot(nbits_list, recalls, marker='o', linewidth=2.5, markersize=10, color='steelblue')
    axes[0, 0].axhline(y=0.90, color='orange', linestyle='--', alpha=0.5, label='90% Target')
    axes[0, 0].set_xlabel('Number of Hash Bits', fontsize=11, fontweight='bold')
    axes[0, 0].set_ylabel(f'Recall@{K}', fontsize=11, fontweight='bold')
    axes[0, 0].set_title('Recall vs. Hash Bits', fontsize=12, fontweight='bold')
    axes[0, 0].set_xscale('log', base=2)
    axes[0, 0].grid(True, alpha=0.3)
    axes[0, 0].legend()
    
    # Annotate good configs
    for i, (nb, r) in enumerate(zip(nbits_list, recalls)):
        if r >= 0.90:
            axes[0, 0].annotate(f'{r:.2f}', (nb, r), textcoords="offset points", 
                               xytext=(0,5), ha='center', fontsize=9, color='green', fontweight='bold')
    
    # 2. Query Time vs Hash Bits
    axes[0, 1].plot(nbits_list, query_times, marker='s', linewidth=2.5, markersize=10, color='coral')
    axes[0, 1].set_xlabel('Number of Hash Bits', fontsize=11, fontweight='bold')
    axes[0, 1].set_ylabel('Query Time (ms)', fontsize=11, fontweight='bold')
    axes[0, 1].set_title('Query Time vs. Hash Bits', fontsize=12, fontweight='bold')
    axes[0, 1].set_xscale('log', base=2)
    axes[0, 1].grid(True, alpha=0.3)
    
    # 3. Memory Usage vs Hash Bits
    axes[1, 0].plot(nbits_list, memories, marker='^', linewidth=2.5, markersize=10, 
                   color='mediumseagreen')
    axes[1, 0].set_xlabel('Number of Hash Bits', fontsize=11, fontweight='bold')
    axes[1, 0].set_ylabel('Memory (MB)', fontsize=11, fontweight='bold')
    axes[1, 0].set_title('Memory Usage vs. Hash Bits', fontsize=12, fontweight='bold')
    axes[1, 0].set_xscale('log', base=2)
    axes[1, 0].grid(True, alpha=0.3)
    
    # 4. Recall-Query Time Pareto
    scatter = axes[1, 1].scatter(query_times, recalls, s=150, c=nbits_list, 
                                cmap='viridis', alpha=0.7, edgecolors='black', linewidth=2,
                                norm=plt.matplotlib.colors.LogNorm())
    axes[1, 1].plot(query_times, recalls, 'k--', alpha=0.3, linewidth=1)
    
    # Annotate
    for i, nb in enumerate(nbits_list):
        axes[1, 1].annotate(f'{nb}', (query_times[i], recalls[i]), 
                           textcoords="offset points", xytext=(5,5), ha='left', fontsize=9)
    
    axes[1, 1].set_xlabel('Query Time (ms)', fontsize=11, fontweight='bold')
    axes[1, 1].set_ylabel(f'Recall@{K}', fontsize=11, fontweight='bold')
    axes[1, 1].set_title('Pareto Frontier', fontsize=12, fontweight='bold')
    axes[1, 1].grid(True, alpha=0.3)
    cbar = plt.colorbar(scatter, ax=axes[1, 1])
    cbar.set_label('Hash Bits', fontsize=10)
    
    plt.tight_layout()
    plot_path = os.path.join(RESULTS_DIR, f'lsh_faiss_{dataset_name}_tradeoffs.png')
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"✓ Plot saved to {plot_path}")

## 5. Summary Statistics

In [None]:
# Flatten all results
all_results_flat = []
for dataset_name, results_list in all_results.items():
    all_results_flat.extend(results_list)

if all_results_flat:
    print("\n" + "="*100)
    print("SUMMARY: FAISS LSH")
    print("="*100)
    
    # Header
    print(f"{'Dataset':<10} {'Hash Bits':<12} {'Recall@10':<12} {'Query(ms)':<12} "
          f"{'Build(s)':<10} {'Memory(MB)':<12}")
    print("-" * 100)
    
    # Sort by dataset, then nbits
    sorted_results = sorted(all_results_flat, key=lambda x: (x['dataset'], x['lsh_nbits']))
    
    for r in sorted_results:
        marker = "✓" if r['recall@k'] >= 0.90 else " "
        
        print(f"{marker} {r['dataset'].upper():<9} {r['lsh_nbits']:<12} "
              f"{r['recall@k']:<12.4f} {r['avg_query_time_ms']:<12.3f} "
              f"{r['build_time_s']:<10.2f} {r['memory_mb']:<12.2f}")
    
    # Save summary
    summary = {
        'experiment': 'LSH (FAISS)',
        'k': K,
        'n_test': N_TEST,
        'library': 'faiss-cpu',
        'index_type': 'IndexLSH',
        'configurations_tested': len(all_results_flat),
        'results_by_dataset': all_results
    }
    
    summary_path = os.path.join(RESULTS_DIR, 'lsh_faiss_summary.json')
    with open(summary_path, 'w') as f:
        json.dump(summary, f, indent=2)
    
    print(f"\n✓ Summary saved to {summary_path}")

## 6. Best Configurations

In [None]:
if all_results:
    print("\n" + "="*80)
    print("BEST CONFIGURATIONS")
    print("="*80 + "\n")
    
    for dataset_name, results_list in all_results.items():
        if not results_list:
            continue
        
        print(f"\n{dataset_name.upper()}:")
        print("-" * 60)
        
        # Best recall
        best_recall = max(results_list, key=lambda x: x['recall@k'])
        print(f"\n  Best Recall:")
        print(f"    Hash bits: {best_recall['lsh_nbits']}")
        print(f"    Recall@{K}: {best_recall['recall@k']:.4f}")
        print(f"    Query time: {best_recall['avg_query_time_ms']:.3f} ms")
        
        # Fastest with decent recall
        decent_recall = [r for r in results_list if r['recall@k'] >= 0.80]
        if decent_recall:
            fastest = min(decent_recall, key=lambda x: x['avg_query_time_ms'])
            print(f"\n  Fastest (Recall ≥ 0.80):")
            print(f"    Hash bits: {fastest['lsh_nbits']}")
            print(f"    Recall@{K}: {fastest['recall@k']:.4f}")
            print(f"    Query time: {fastest['avg_query_time_ms']:.3f} ms")
        
        # Most memory efficient with decent recall
        if decent_recall:
            mem_efficient = min(decent_recall, key=lambda x: x['memory_mb'])
            print(f"\n  Most Memory Efficient (Recall ≥ 0.80):")
            print(f"    Hash bits: {mem_efficient['lsh_nbits']}")
            print(f"    Recall@{K}: {mem_efficient['recall@k']:.4f}")
            print(f"    Memory: {mem_efficient['memory_mb']:.2f} MB")
        
        print()

## 7. Key Observations

**FAISS Advantages:**
- ✅ Extremely fast C++ implementation
- ✅ Production-tested on billion-scale datasets
- ✅ Memory efficient binary codes
- ✅ SIMD-optimized distance computations

**Expected Patterns:**
- **More hash bits** → Better recall but more memory
- **Fewer hash bits** → Faster queries but lower recall
- Build time should be fast (just computing binary codes)
- Query time should be sublinear

**Comparison Points:**
- Baseline KNN: Recall=1.0, but slower queries
- JL+KNN: Good recall with dimensionality reduction
- FAISS LSH: Should be fastest with acceptable recall

**Next Steps:**
1. Compare all methods (Baseline, JL, LSH)
2. Create unified comparison plots
3. Analyze which method works best for which dataset