# üöÄ ANMI 2.0: Adaptive Negative Mining Intelligence

## A Complete Implementation for Training Dense Retrievers with ELO-Calibrated Hard Negatives

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/)

---

### What This Notebook Covers

This notebook implements the **ANMI 2.0 (Adaptive Negative Mining Intelligence)** framework, which synthesizes:

1. **Contrastive Learning Theory** - InfoNCE loss with temperature scaling
2. **Hard Negative Mining** - Finding challenging negatives for better training
3. **ELO-Based Calibration** - Using pairwise comparisons to estimate document quality
4. **Hybrid Loss** - Combining contrastive and regression objectives

### The Core Problem: The Laffer Curve of Negative Mining

```
Performance
    ‚îÇ
    ‚îÇ           ‚îå‚îÄ‚îÄ‚îÄ‚îÄ Sweet Spot
    ‚îÇ          /‚îÇ\
    ‚îÇ         / ‚îÇ \        ‚Üê Laffer Curve
    ‚îÇ        /  ‚îÇ  \
    ‚îÇ       /   ‚îÇ   \
    ‚îÇ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ/‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ\‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
    ‚îÇ  Random   ‚îÇ  Too Hard
    ‚îÇ  (boring) ‚îÇ  (false negatives!)
    ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¥‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚Üí Mining Difficulty
```

**Key Insight**: The hardest negatives provide the most gradient signal, BUT they're also most likely to be **false negatives** (actually relevant documents mislabeled as negative). ANMI 2.0 uses ELO scores to find the "Goldilocks zone" of difficulty.

---

**Author**: ANMI Research Team  
**Last Updated**: December 2024  
**Runtime**: ~20-30 minutes on Colab T4 GPU


## 1. Setup & Installation

First, let's install the required packages and check our GPU.


In [None]:
# ============================================================
# üì¶ INSTALLATION
# ============================================================
# Install required packages (takes ~2 minutes)

%pip install -q sentence-transformers datasets transformers torch numpy scipy tqdm
%pip install -q rank_bm25  # For BM25 baseline

# Check GPU availability
import torch
print(f"üî• PyTorch version: {torch.__version__}")
print(f"üéÆ CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("‚ö†Ô∏è No GPU detected! This notebook will be slow on CPU.")
    print("   Go to Runtime > Change runtime type > GPU")


In [None]:
# ============================================================
# üìö IMPORTS
# ============================================================

import numpy as np
from scipy.stats import norm
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass
from collections import defaultdict
import random
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset as TorchDataset

from sentence_transformers import SentenceTransformer, InputExample, losses
from datasets import load_dataset
from rank_bm25 import BM25Okapi

# Set seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# Device configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"‚úÖ Using device: {DEVICE}")


## 2. Load Dataset: SciFact

We'll use the **SciFact** dataset from BEIR - it's small enough to run on Colab's free tier while still being challenging (scientific claims with high false negative rates ~20%).

| Dataset | Queries | Corpus | Domain | Why We Use It |
|---------|---------|--------|--------|---------------|
| SciFact | 300 (test) | 5,183 docs | Scientific | Small, challenging, high FN rate |


In [None]:
# ============================================================
# üìä LOAD SCIFACT DATASET
# ============================================================

print("üì• Loading SciFact dataset from HuggingFace...")

# Load corpus (documents)
corpus_dataset = load_dataset("BeIR/scifact", "corpus", split="corpus")
print(f"   Corpus: {len(corpus_dataset)} documents")

# Load queries
queries_dataset = load_dataset("BeIR/scifact", "queries", split="queries")
print(f"   Queries: {len(queries_dataset)} queries")

# Load relevance judgments (qrels)
# Note: SciFact has train/test splits - we'll use a portion for training
qrels_dataset = load_dataset("BeIR/scifact-qrels", split="train")
print(f"   Relevance judgments: {len(qrels_dataset)} pairs")

# Build lookup dictionaries
corpus = {str(doc["_id"]): doc["text"] for doc in corpus_dataset}
queries = {str(q["_id"]): q["text"] for q in queries_dataset}

# Build positive mapping: query_id -> list of relevant doc_ids
positives = defaultdict(list)
for qrel in qrels_dataset:
    if qrel["score"] > 0:  # Only positive relevance
        positives[str(qrel["query-id"])].append(str(qrel["corpus-id"]))

# Filter to queries that have positives
query_ids = [qid for qid in positives.keys() if qid in queries]
print(f"\n‚úÖ Loaded {len(query_ids)} queries with positive labels")
print(f"   Average positives per query: {np.mean([len(positives[q]) for q in query_ids]):.1f}")

# Show example
example_qid = query_ids[0]
print(f"\nüìù Example Query (ID: {example_qid}):")
print(f"   Query: {queries[example_qid][:100]}...")
print(f"   Positives: {len(positives[example_qid])} documents")


## 3. Core ANMI Components

### 3.1 The ELO Engine

The heart of ANMI is estimating **absolute quality scores** from **pairwise comparisons**.

**Why ELO?**
- Pairwise judgments ("Is doc A better than doc B for this query?") are more reliable than absolute judgments ("Is doc A relevant?")
- ELO/Thurstone models convert O(n) pairwise comparisons into absolute scores
- This lets us find the "Goldilocks zone" of difficulty

**The Math (Thurstone Model):**
```
P(doc_i > doc_j | query) = Œ¶((e_i - e_j) / œÉ‚àö2)
```
Where:
- `e_i` = ELO score of document i
- `Œ¶` = Standard normal CDF
- Higher ELO = more relevant to query


In [None]:
# ============================================================
# üéØ ELO ENGINE: Sparse ELO Estimation
# ============================================================

class SparseELOEstimator:
    """
    Estimates ELO scores from sparse pairwise comparisons using Thurstone model.
    
    Key Innovation: Instead of comparing all O(n¬≤) pairs, we only compare O(n*k) pairs
    using a k-regular graph structure. This is mathematically sound because:
    - k-regular graphs are connected (ELO differences are well-defined)
    - Diameter is O(log n) so estimation error propagation is bounded
    
    Args:
        comparison_degree: Number of comparisons per document (k in k-regular graph)
        max_iterations: Maximum gradient ascent iterations for MLE
        tolerance: Convergence threshold
    """
    
    def __init__(
        self,
        comparison_degree: int = 4,  # Lower for Colab speed
        max_iterations: int = 50,
        tolerance: float = 1e-3,
    ):
        self.k = comparison_degree
        self.max_iter = max_iterations
        self.tol = tolerance
    
    def _build_k_regular_graph(self, n: int) -> List[Tuple[int, int]]:
        """
        Build k-regular graph by unioning k/2 random Hamiltonian cycles.
        
        A Hamiltonian cycle visits every node exactly once, forming a closed loop.
        Unioning multiple such cycles creates a k-regular graph where every node
        has exactly k neighbors.
        
        Example for n=5, k=4 (2 cycles):
            Cycle 1: 0 ‚Üí 3 ‚Üí 1 ‚Üí 4 ‚Üí 2 ‚Üí 0
            Cycle 2: 0 ‚Üí 2 ‚Üí 4 ‚Üí 1 ‚Üí 3 ‚Üí 0
            Combined: Each node has 4 edges
        """
        edges = set()
        num_cycles = max(1, self.k // 2)
        
        for _ in range(num_cycles):
            # Generate random permutation (Hamiltonian cycle)
            perm = np.random.permutation(n).tolist()
            
            # Add cycle edges: perm[0]‚Üíperm[1]‚Üí...‚Üíperm[n-1]‚Üíperm[0]
            for i in range(n):
                edge = tuple(sorted([perm[i], perm[(i + 1) % n]]))
                edges.add(edge)
        
        return list(edges)
    
    def _fit_thurstone(
        self,
        preferences: Dict[Tuple[int, int], float],
        edges: List[Tuple[int, int]],
        n: int,
    ) -> np.ndarray:
        """
        Fit Thurstone model via gradient ascent on log-likelihood.
        
        The Thurstone model assumes observed preferences come from comparing
        noisy estimates of true quality. If doc i has quality e_i and noise Œµ_i ~ N(0,œÉ¬≤):
            P(i > j) = P(e_i + Œµ_i > e_j + Œµ_j) = Œ¶((e_i - e_j) / (œÉ‚àö2))
        
        We maximize log-likelihood:
            ‚Ñì(e) = Œ£ [w_ij * log(Œ¶(e_i - e_j)) + (1-w_ij) * log(Œ¶(e_j - e_i))]
        """
        e = np.zeros(n)  # Initialize ELO scores to zero
        
        for iteration in range(self.max_iter):
            grad = np.zeros(n)
            
            for (i, j) in edges:
                w_ij = preferences.get((i, j), 0.5)
                delta = e[i] - e[j]
                
                # Compute gradient using inverse Mills ratio
                # The inverse Mills ratio Œª(x) = œÜ(x)/Œ¶(x) appears naturally
                # in the gradient of the Thurstone log-likelihood
                phi_delta = norm.pdf(delta)
                Phi_delta = norm.cdf(delta)
                Phi_neg_delta = 1 - Phi_delta
                
                # Avoid division by zero
                lambda_pos = phi_delta / max(Phi_delta, 1e-10)
                lambda_neg = phi_delta / max(Phi_neg_delta, 1e-10)
                
                grad[i] += w_ij * lambda_pos - (1 - w_ij) * lambda_neg
                grad[j] += -w_ij * lambda_pos + (1 - w_ij) * lambda_neg
            
            # Project onto constraint manifold (mean = 0 for identifiability)
            grad = grad - grad.mean()
            
            # Check convergence
            if np.abs(grad).max() < self.tol:
                break
            
            # Decaying step size for stability
            eta = 1.0 / (1 + 0.1 * iteration)
            
            # Update and re-center
            e = e + eta * grad
            e = e - e.mean()
        
        # Scale to interpretable ELO range (centered at 1000, ~200 points = significant diff)
        e = e * 200 + 1000
        
        return e
    
    def estimate(
        self,
        doc_scores: np.ndarray,  # Scores from a ranker (proxy for pairwise prefs)
    ) -> np.ndarray:
        """
        Estimate ELO scores for documents.
        
        In a full implementation, we'd use a pairwise model (cross-encoder).
        For this demo, we convert point-wise scores to pairwise preferences.
        
        Args:
            doc_scores: Array of scores for each document (higher = more relevant)
            
        Returns:
            Array of ELO scores
        """
        n = len(doc_scores)
        
        # Build comparison graph
        edges = self._build_k_regular_graph(n)
        
        # Convert point-wise scores to pairwise preferences using Bradley-Terry
        # P(i > j) = œÉ(s_i - s_j) where œÉ is sigmoid
        preferences = {}
        for (i, j) in edges:
            score_diff = doc_scores[i] - doc_scores[j]
            # Sigmoid with temperature for smoother preferences
            preferences[(i, j)] = 1 / (1 + np.exp(-score_diff * 5))
        
        # Fit Thurstone model
        elos = self._fit_thurstone(preferences, edges, n)
        
        return elos

# Test the ELO estimator
print("üß™ Testing ELO Estimator...")
test_scores = np.array([0.9, 0.7, 0.5, 0.3, 0.1])  # 5 docs with decreasing scores
estimator = SparseELOEstimator(comparison_degree=4)
test_elos = estimator.estimate(test_scores)
print(f"   Input scores:  {test_scores}")
print(f"   Output ELOs:   {test_elos.round(0)}")
print(f"   ‚úÖ ELO order matches score order: {np.all(np.diff(test_elos) < 0)}")


### 3.2 ELO-Gap Based Selection

Once we have ELO scores, we select negatives based on the **gap** from the positive document.

**The Goldilocks Principle**: Negatives should be hard enough to provide learning signal, but not so hard they're likely false negatives.


In [None]:
# ============================================================
# üéöÔ∏è ELO-GAP BASED SELECTOR
# ============================================================

class ELOGapSelector:
    """
    Selects and weights negatives based on ELO gap from positive.
    
    The key insight is that "difficulty" varies per query. Rank 20 for one
    query might be harder than rank 5 for another. ELO gaps normalize this:
    - Gap < 100: Too close to positive ‚Üí likely false negative ‚Üí REJECT
    - Gap 100-200: Borderline ‚Üí include with reduced weight
    - Gap 200-400: Goldilocks zone ‚Üí optimal learning signal
    - Gap 400-600: Medium ‚Üí good for early curriculum stages
    - Gap > 600: Easy ‚Üí low learning signal
    """
    
    def __init__(
        self,
        danger_zone: float = 100,
        goldilocks_zone: Tuple[float, float] = (200, 400),
    ):
        self.danger_zone = danger_zone
        self.goldilocks = goldilocks_zone
    
    def select_and_weight(
        self,
        positive_elo: float,
        candidate_elos: List[Tuple[int, float]],  # [(idx, elo), ...]
        num_negatives: int = 10,
        curriculum_tier: int = 4,  # 1=easy only, 4=all difficulty levels
    ) -> List[Tuple[int, float]]:
        """
        Select negatives and assign weights based on ELO gap.
        
        Args:
            positive_elo: ELO score of the positive document
            candidate_elos: List of (doc_idx, elo_score) for candidates
            num_negatives: Number of negatives to select
            curriculum_tier: Current training phase (1-4)
            
        Returns:
            List of (doc_idx, weight) tuples
        """
        weighted_candidates = []
        
        for idx, elo in candidate_elos:
            gap = positive_elo - elo  # Higher positive ELO = larger gap
            
            # Assign weight based on gap category
            if gap < self.danger_zone:
                # Danger zone: too likely to be false negative
                weight = 0.0
                tier_required = 99  # Never include
            elif gap < 200:
                # Soft negative zone
                weight = 0.5
                tier_required = 4  # Only in final curriculum stage
            elif gap < self.goldilocks[1]:
                # Goldilocks zone: optimal difficulty
                weight = 1.0
                tier_required = 3
            elif gap < 600:
                # Medium difficulty
                weight = 0.7
                tier_required = 2
            else:
                # Easy negatives
                weight = 0.3
                tier_required = 1  # Always include
            
            # Apply curriculum filter
            if curriculum_tier >= tier_required and weight > 0:
                weighted_candidates.append((idx, weight, gap))
        
        # Sort by preference: Goldilocks zone first, then by gap
        weighted_candidates.sort(key=lambda x: (
            0 if self.goldilocks[0] <= x[2] < self.goldilocks[1] else 1,
            x[2]
        ))
        
        # Return top-k
        return [(idx, weight) for idx, weight, _ in weighted_candidates[:num_negatives]]

# Test the selector
print("üß™ Testing ELO Gap Selector...")
selector = ELOGapSelector()

# Simulate: positive has ELO 1200, candidates have various ELOs
test_positive_elo = 1200
test_candidates = [
    (0, 1150),  # Gap 50 ‚Üí Danger zone
    (1, 1050),  # Gap 150 ‚Üí Soft
    (2, 900),   # Gap 300 ‚Üí Goldilocks ‚úì
    (3, 850),   # Gap 350 ‚Üí Goldilocks ‚úì
    (4, 700),   # Gap 500 ‚Üí Medium
    (5, 400),   # Gap 800 ‚Üí Easy
]

selected = selector.select_and_weight(test_positive_elo, test_candidates, num_negatives=4)
print(f"   Positive ELO: {test_positive_elo}")
print(f"   Selected negatives:")
for idx, weight in selected:
    gap = test_positive_elo - test_candidates[idx][1]
    print(f"      Doc {idx}: ELO={test_candidates[idx][1]}, Gap={gap}, Weight={weight}")


### 3.3 Hybrid Loss Function

The ANMI 2.0 loss combines two objectives:

1. **Weighted InfoNCE** - Contrastive loss with soft negative weights
2. **ELO MSE** - Regression to match predicted scores with ELO targets

**Why Hybrid?**
- InfoNCE shapes the embedding geometry (alignment + uniformity)
- MSE provides calibration and reduces false negative damage
- When a false negative has high ELO, MSE gradient opposes InfoNCE gradient ‚Üí damage mitigation!


In [None]:
# ============================================================
# üî• HYBRID LOSS FUNCTION
# ============================================================

class HybridLoss(nn.Module):
    """
    Hybrid loss combining Weighted InfoNCE and ELO MSE.
    
    L = Œ± * L_InfoNCE_weighted + (1-Œ±) * L_MSE
    
    CRITICAL IMPLEMENTATION NOTES:
    1. We use a learnable projection head (elo_head) to map dot products to ELO space.
       This avoids gradient scale mismatch between log-scale InfoNCE and squared MSE.
    
    2. Negative weights are applied INSIDE the softmax denominator, allowing
       "soft" exclusion of uncertain negatives rather than hard binary decisions.
    
    Args:
        alpha: Mixing coefficient (0=pure MSE, 1=pure InfoNCE)
        temperature: Softmax temperature for InfoNCE
    """
    
    def __init__(
        self,
        alpha: float = 0.6,
        temperature: float = 0.07,
    ):
        super().__init__()
        self.alpha = alpha
        self.tau = temperature
        
        # Learnable projection from dot products to ELO space
        # This is CRITICAL for stable training - avoids gradient mismatch
        self.elo_head = nn.Sequential(
            nn.Linear(1, 16),
            nn.ReLU(),
            nn.Linear(16, 1),
        )
    
    def forward(
        self,
        query_emb: torch.Tensor,       # [batch_size, hidden_dim]
        positive_emb: torch.Tensor,    # [batch_size, hidden_dim]
        negative_embs: torch.Tensor,   # [batch_size, num_neg, hidden_dim]
        negative_weights: torch.Tensor, # [batch_size, num_neg]
        elo_targets: torch.Tensor,     # [batch_size, 1+num_neg] - target ELO scores
    ) -> Tuple[torch.Tensor, Dict]:
        """
        Compute hybrid loss.
        
        Returns:
            total_loss: Scalar loss tensor
            metrics: Dict with component losses for logging
        """
        batch_size = query_emb.size(0)
        
        # === Compute Similarities ===
        # Positive: [batch_size]
        pos_sim = torch.sum(query_emb * positive_emb, dim=-1)
        
        # Negative: [batch_size, num_neg]
        neg_sim = torch.bmm(
            negative_embs,
            query_emb.unsqueeze(-1)
        ).squeeze(-1)
        
        # === Weighted InfoNCE Loss ===
        # Scale by temperature
        pos_sim_scaled = pos_sim / self.tau
        neg_sim_scaled = neg_sim / self.tau
        
        # Apply weights to negatives in the denominator
        # Weight=0 effectively removes that negative from the softmax
        weighted_neg_exp = negative_weights * torch.exp(neg_sim_scaled)
        
        # InfoNCE: -log(exp(pos) / (exp(pos) + Œ£ w_i * exp(neg_i)))
        denominator = torch.exp(pos_sim_scaled) + weighted_neg_exp.sum(dim=-1)
        loss_nce = -pos_sim_scaled + torch.log(denominator + 1e-10)
        loss_nce = loss_nce.mean()
        
        # === MSE Loss with Learnable Projection ===
        # Concatenate all similarities: [batch_size, 1+num_neg]
        all_sims = torch.cat([pos_sim.unsqueeze(-1), neg_sim], dim=-1)
        
        # Project to ELO space via learnable head
        pred_elo = self.elo_head(all_sims.unsqueeze(-1)).squeeze(-1)
        
        # MSE loss
        loss_mse = F.mse_loss(pred_elo, elo_targets)
        
        # === Combine ===
        total_loss = self.alpha * loss_nce + (1 - self.alpha) * loss_mse
        
        metrics = {
            "loss": total_loss.item(),
            "nce": loss_nce.item(),
            "mse": loss_mse.item(),
        }
        
        return total_loss, metrics

# Test the hybrid loss
print("üß™ Testing Hybrid Loss...")
loss_fn = HybridLoss(alpha=0.6, temperature=0.07)

# Dummy inputs
batch_size, hidden_dim, num_neg = 4, 64, 5
dummy_query = torch.randn(batch_size, hidden_dim)
dummy_pos = torch.randn(batch_size, hidden_dim)
dummy_neg = torch.randn(batch_size, num_neg, hidden_dim)
dummy_weights = torch.tensor([[1.0, 0.7, 0.5, 0.3, 0.0]] * batch_size)
dummy_elos = torch.randn(batch_size, 1 + num_neg) * 200 + 1000

loss, metrics = loss_fn(dummy_query, dummy_pos, dummy_neg, dummy_weights, dummy_elos)
print(f"   Total loss: {metrics['loss']:.4f}")
print(f"   InfoNCE:    {metrics['nce']:.4f}")
print(f"   MSE:        {metrics['mse']:.4f}")
print("   ‚úÖ Loss computation successful!")


## 4. Mining Pipeline

Now we'll put it all together: retrieve candidates, estimate ELOs, select negatives.


In [None]:
# ============================================================
# ‚õèÔ∏è COMPLETE MINING PIPELINE
# ============================================================

class ANMIMiner:
    """
    Complete ANMI mining pipeline:
    1. BM25 retrieval for initial candidates
    2. Dense scoring for ELO estimation  
    3. ELO-gap based selection with soft weights
    
    This runs OFFLINE before training (as per the production fix).
    """
    
    def __init__(
        self,
        corpus: Dict[str, str],
        encoder_model: str = "sentence-transformers/all-MiniLM-L6-v2",
        num_candidates: int = 50,  # BM25 top-k
        num_negatives: int = 7,    # Final negatives per query
    ):
        self.corpus = corpus
        self.corpus_ids = list(corpus.keys())
        self.corpus_texts = list(corpus.values())
        self.num_candidates = num_candidates
        self.num_negatives = num_negatives
        
        print("üîß Initializing ANMI Miner...")
        
        # Initialize BM25 index
        print("   Building BM25 index...")
        tokenized_corpus = [doc.lower().split() for doc in self.corpus_texts]
        self.bm25 = BM25Okapi(tokenized_corpus)
        
        # Initialize encoder for scoring
        print(f"   Loading encoder: {encoder_model}")
        self.encoder = SentenceTransformer(encoder_model, device=DEVICE)
        
        # Pre-encode corpus (this takes a minute)
        print("   Encoding corpus (this may take a minute)...")
        self.corpus_embeddings = self.encoder.encode(
            self.corpus_texts,
            convert_to_tensor=True,
            show_progress_bar=True,
            batch_size=64,
        )
        
        # Initialize components
        self.elo_estimator = SparseELOEstimator(comparison_degree=4)
        self.selector = ELOGapSelector()
        
        print("   ‚úÖ Miner ready!")
    
    def mine_for_query(
        self,
        query: str,
        positive_ids: List[str],
        curriculum_tier: int = 4,
    ) -> Dict:
        """
        Mine negatives for a single query.
        
        Returns dict with:
            - query: str
            - positive: str (first positive doc)
            - positive_elo: float
            - negatives: List[str]
            - negative_weights: List[float]
            - negative_elos: List[float]
        """
        # Step 1: BM25 retrieval for candidates
        query_tokens = query.lower().split()
        bm25_scores = self.bm25.get_scores(query_tokens)
        top_indices = np.argsort(bm25_scores)[::-1][:self.num_candidates]
        
        candidate_ids = [self.corpus_ids[i] for i in top_indices]
        candidate_texts = [self.corpus_texts[i] for i in top_indices]
        
        # Remove positives from candidates
        positive_set = set(positive_ids)
        filtered = [(cid, ctxt, i) for i, (cid, ctxt) in enumerate(zip(candidate_ids, candidate_texts)) 
                    if cid not in positive_set]
        
        if len(filtered) < 3:
            # Not enough negatives, return None
            return None
        
        candidate_ids = [x[0] for x in filtered]
        candidate_texts = [x[1] for x in filtered]
        original_indices = [x[2] for x in filtered]
        
        # Step 2: Dense scoring for ELO estimation
        query_emb = self.encoder.encode(query, convert_to_tensor=True)
        candidate_embs = self.corpus_embeddings[[self.corpus_ids.index(cid) for cid in candidate_ids]]
        
        # Compute similarities
        sims = torch.nn.functional.cosine_similarity(
            query_emb.unsqueeze(0), candidate_embs
        ).cpu().numpy()
        
        # Get positive embedding and similarity
        pos_id = positive_ids[0]
        pos_text = self.corpus[pos_id]
        pos_idx = self.corpus_ids.index(pos_id)
        pos_sim = torch.nn.functional.cosine_similarity(
            query_emb.unsqueeze(0), 
            self.corpus_embeddings[pos_idx].unsqueeze(0)
        ).item()
        
        # Step 3: ELO estimation
        all_sims = np.concatenate([[pos_sim], sims])
        elos = self.elo_estimator.estimate(all_sims)
        
        pos_elo = elos[0]
        candidate_elos = [(i, elos[i+1]) for i in range(len(candidate_ids))]
        
        # Step 4: ELO-gap selection
        selected = self.selector.select_and_weight(
            positive_elo=pos_elo,
            candidate_elos=candidate_elos,
            num_negatives=self.num_negatives,
            curriculum_tier=curriculum_tier,
        )
        
        if len(selected) == 0:
            return None
        
        # Build result
        neg_texts = [candidate_texts[idx] for idx, _ in selected]
        neg_weights = [weight for _, weight in selected]
        neg_elos = [elos[idx + 1] for idx, _ in selected]
        
        return {
            "query": query,
            "positive": pos_text,
            "positive_elo": pos_elo,
            "negatives": neg_texts,
            "negative_weights": neg_weights,
            "negative_elos": neg_elos,
        }
    
    def mine_dataset(
        self,
        queries: Dict[str, str],
        positives: Dict[str, List[str]],
        query_ids: List[str],
        curriculum_tier: int = 4,
    ) -> List[Dict]:
        """
        Mine negatives for all queries.
        """
        examples = []
        
        for qid in tqdm(query_ids, desc=f"Mining (tier={curriculum_tier})"):
            if qid not in queries or qid not in positives:
                continue
            
            result = self.mine_for_query(
                query=queries[qid],
                positive_ids=positives[qid],
                curriculum_tier=curriculum_tier,
            )
            
            if result is not None:
                examples.append(result)
        
        return examples

# Initialize the miner
print("üöÄ Initializing ANMI Miner...")
miner = ANMIMiner(
    corpus=corpus,
    encoder_model="sentence-transformers/all-MiniLM-L6-v2",
    num_candidates=50,
    num_negatives=7,
)


In [None]:
# ============================================================
# üì¶ OFFLINE MINING (Run ONCE before training)
# ============================================================
# This is the expensive step - we pre-compute everything so training is fast

print("‚õèÔ∏è Running OFFLINE mining...")
print("   This pre-computes negatives, ELOs, and weights.")
print("   Training will then be fast (just forward/backward passes).\n")

# Mine for all curriculum tiers
mined_data = miner.mine_dataset(
    queries=queries,
    positives=positives,
    query_ids=query_ids[:100],  # Use subset for Colab speed
    curriculum_tier=4,  # All difficulty levels
)

print(f"\n‚úÖ Mined {len(mined_data)} training examples")

# Show statistics
avg_weight = np.mean([np.mean(ex["negative_weights"]) for ex in mined_data])
avg_neg = np.mean([len(ex["negatives"]) for ex in mined_data])
print(f"   Average negatives per query: {avg_neg:.1f}")
print(f"   Average negative weight: {avg_weight:.2f}")

# Show one example
print(f"\nüìù Example mined data:")
ex = mined_data[0]
print(f"   Query: {ex['query'][:80]}...")
print(f"   Positive ELO: {ex['positive_elo']:.0f}")
print(f"   Negatives: {len(ex['negatives'])}")
for i, (w, e) in enumerate(zip(ex["negative_weights"][:3], ex["negative_elos"][:3])):
    gap = ex["positive_elo"] - e
    print(f"      Neg {i}: ELO={e:.0f}, Gap={gap:.0f}, Weight={w:.1f}")


## 5. Training

Now we train the model using our pre-mined data with the hybrid loss.


In [None]:
# ============================================================
# üìä DATASET CLASS
# ============================================================

class ANMIDataset(TorchDataset):
    """
    PyTorch Dataset for ANMI training.
    
    Handles padding negatives to fixed size and converting to tensors.
    """
    
    def __init__(self, mined_data: List[Dict], encoder: SentenceTransformer, num_negatives: int = 7):
        self.data = mined_data
        self.encoder = encoder
        self.num_negatives = num_negatives
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        ex = self.data[idx]
        
        # Pad negatives to fixed size if needed
        negatives = ex["negatives"][:self.num_negatives]
        weights = ex["negative_weights"][:self.num_negatives]
        elos = ex["negative_elos"][:self.num_negatives]
        
        # Pad if not enough negatives
        while len(negatives) < self.num_negatives:
            negatives.append("")  # Empty string for padding
            weights.append(0.0)   # Zero weight (ignored in loss)
            elos.append(500.0)    # Placeholder ELO
        
        return {
            "query": ex["query"],
            "positive": ex["positive"],
            "positive_elo": ex["positive_elo"],
            "negatives": negatives,
            "negative_weights": torch.tensor(weights, dtype=torch.float32),
            "negative_elos": torch.tensor(elos, dtype=torch.float32),
        }


def collate_fn(batch):
    """Custom collate function for batching."""
    return {
        "queries": [ex["query"] for ex in batch],
        "positives": [ex["positive"] for ex in batch],
        "positive_elos": torch.stack([torch.tensor(ex["positive_elo"]) for ex in batch]),
        "negatives": [ex["negatives"] for ex in batch],  # List of lists
        "negative_weights": torch.stack([ex["negative_weights"] for ex in batch]),
        "negative_elos": torch.stack([ex["negative_elos"] for ex in batch]),
    }

# Create dataset
dataset = ANMIDataset(mined_data, miner.encoder, num_negatives=7)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

print(f"‚úÖ Created dataset with {len(dataset)} examples")
print(f"   Batch size: 8")
print(f"   Batches per epoch: {len(dataloader)}")


In [None]:
# ============================================================
# üèãÔ∏è TRAINING LOOP
# ============================================================

def train_anmi(
    encoder: SentenceTransformer,
    dataloader: DataLoader,
    num_epochs: int = 3,
    learning_rate: float = 2e-5,
    alpha: float = 0.6,
):
    """
    Train encoder with ANMI hybrid loss.
    
    Key features:
    - Pre-mined negatives with ELO-calibrated weights (OFFLINE)
    - Hybrid loss: Œ± * InfoNCE + (1-Œ±) * MSE
    - Curriculum: could progressively increase alpha
    """
    # Initialize loss and optimizer
    loss_fn = HybridLoss(alpha=alpha, temperature=0.07).to(DEVICE)
    
    # Combine encoder and loss head parameters
    all_params = list(encoder.parameters()) + list(loss_fn.parameters())
    optimizer = torch.optim.AdamW(all_params, lr=learning_rate)
    
    encoder.train()
    history = []
    
    print(f"üèãÔ∏è Training for {num_epochs} epochs...")
    print(f"   Learning rate: {learning_rate}")
    print(f"   Alpha (NCE weight): {alpha}")
    print()
    
    for epoch in range(num_epochs):
        epoch_losses = {"loss": [], "nce": [], "mse": []}
        
        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
        
        for batch in pbar:
            optimizer.zero_grad()
            
            # Encode queries
            query_embs = encoder.encode(
                batch["queries"],
                convert_to_tensor=True,
                show_progress_bar=False,
            ).to(DEVICE)
            
            # Encode positives
            pos_embs = encoder.encode(
                batch["positives"],
                convert_to_tensor=True,
                show_progress_bar=False,
            ).to(DEVICE)
            
            # Encode negatives (flatten, encode, reshape)
            batch_size = len(batch["queries"])
            num_neg = len(batch["negatives"][0])
            
            flat_negatives = [neg for negs in batch["negatives"] for neg in negs]
            neg_embs = encoder.encode(
                flat_negatives,
                convert_to_tensor=True,
                show_progress_bar=False,
            ).to(DEVICE)
            neg_embs = neg_embs.view(batch_size, num_neg, -1)
            
            # Get weights and ELOs
            weights = batch["negative_weights"].to(DEVICE)
            pos_elos = batch["positive_elos"].to(DEVICE)
            neg_elos = batch["negative_elos"].to(DEVICE)
            
            # Combine ELOs: [batch_size, 1+num_neg]
            all_elos = torch.cat([pos_elos.unsqueeze(-1), neg_elos], dim=-1)
            
            # Compute loss
            loss, metrics = loss_fn(
                query_emb=query_embs,
                positive_emb=pos_embs,
                negative_embs=neg_embs,
                negative_weights=weights,
                elo_targets=all_elos,
            )
            
            # Backward
            loss.backward()
            torch.nn.utils.clip_grad_norm_(all_params, 1.0)
            optimizer.step()
            
            # Track
            epoch_losses["loss"].append(metrics["loss"])
            epoch_losses["nce"].append(metrics["nce"])
            epoch_losses["mse"].append(metrics["mse"])
            
            pbar.set_postfix({
                "loss": f"{metrics['loss']:.3f}",
                "nce": f"{metrics['nce']:.3f}",
                "mse": f"{metrics['mse']:.3f}",
            })
        
        # Epoch summary
        avg_loss = np.mean(epoch_losses["loss"])
        avg_nce = np.mean(epoch_losses["nce"])
        avg_mse = np.mean(epoch_losses["mse"])
        
        history.append({
            "epoch": epoch + 1,
            "loss": avg_loss,
            "nce": avg_nce,
            "mse": avg_mse,
        })
        
        print(f"   Epoch {epoch+1}: Loss={avg_loss:.4f}, NCE={avg_nce:.4f}, MSE={avg_mse:.4f}")
    
    return encoder, history

# Train the model
print("üöÄ Starting ANMI Training...")
print("=" * 60)

trained_encoder, training_history = train_anmi(
    encoder=miner.encoder,
    dataloader=dataloader,
    num_epochs=3,
    learning_rate=2e-5,
    alpha=0.6,
)

print("\n‚úÖ Training complete!")


## 6. Evaluation

Let's evaluate the trained model and compare with baseline.


In [None]:
# ============================================================
# üìà EVALUATION
# ============================================================

def evaluate_retrieval(
    encoder: SentenceTransformer,
    queries: Dict[str, str],
    corpus: Dict[str, str],
    positives: Dict[str, List[str]],
    query_ids: List[str],
    top_k: int = 10,
) -> Dict[str, float]:
    """
    Evaluate retrieval performance.
    
    Metrics:
    - MRR@k: Mean Reciprocal Rank
    - Recall@k: Fraction of positives in top-k
    """
    encoder.eval()
    
    corpus_ids = list(corpus.keys())
    corpus_texts = list(corpus.values())
    
    # Encode corpus
    print("   Encoding corpus for evaluation...")
    corpus_embs = encoder.encode(
        corpus_texts,
        convert_to_tensor=True,
        show_progress_bar=True,
        batch_size=64,
    )
    
    mrr_scores = []
    recall_scores = []
    
    print("   Computing metrics...")
    for qid in tqdm(query_ids, desc="Evaluating"):
        if qid not in queries or qid not in positives:
            continue
        
        query_text = queries[qid]
        pos_ids = set(positives[qid])
        
        # Encode query
        query_emb = encoder.encode(query_text, convert_to_tensor=True)
        
        # Compute similarities
        sims = torch.nn.functional.cosine_similarity(
            query_emb.unsqueeze(0), corpus_embs
        )
        
        # Get top-k
        top_indices = torch.argsort(sims, descending=True)[:top_k].cpu().numpy()
        top_ids = [corpus_ids[i] for i in top_indices]
        
        # MRR: Reciprocal of first positive's rank
        mrr = 0.0
        for rank, doc_id in enumerate(top_ids, 1):
            if doc_id in pos_ids:
                mrr = 1.0 / rank
                break
        mrr_scores.append(mrr)
        
        # Recall: Fraction of positives in top-k
        hits = len(pos_ids.intersection(set(top_ids)))
        recall = hits / len(pos_ids) if pos_ids else 0.0
        recall_scores.append(recall)
    
    return {
        f"MRR@{top_k}": np.mean(mrr_scores),
        f"Recall@{top_k}": np.mean(recall_scores),
    }

# Evaluate on test queries (different from training)
test_query_ids = query_ids[100:150]  # Use queries we didn't train on

print("üìä Evaluating ANMI-trained model...")
anmi_metrics = evaluate_retrieval(
    encoder=trained_encoder,
    queries=queries,
    corpus=corpus,
    positives=positives,
    query_ids=test_query_ids,
    top_k=10,
)

print(f"\n‚úÖ ANMI Results:")
for metric, value in anmi_metrics.items():
    print(f"   {metric}: {value:.4f}")


## 7. Summary

### What We Implemented

1. **Sparse ELO Estimator** - Converts O(n) pairwise comparisons to absolute quality scores using Thurstone MLE on k-regular graphs

2. **ELO-Gap Selector** - Selects negatives based on quality gap, not rank, with soft continuous weights

3. **Hybrid Loss** - Combines InfoNCE (geometry) + MSE (calibration) with learnable projection head

4. **Offline Mining Pipeline** - Pre-computes everything expensive so training is fast

### Key Innovations

| Component | Traditional | ANMI 2.0 |
|-----------|-------------|----------|
| Negative Selection | Rank-based | ELO gap-based |
| Negative Weighting | Binary (include/exclude) | Soft continuous |
| Loss Function | Pure InfoNCE | Hybrid (InfoNCE + MSE) |
| False Negative Handling | Ad-hoc threshold | ELO danger zone |
| Computation | Online (slow) | Offline mining (fast training) |

### Production Considerations Addressed

1. **Learnable ELO head** - Fixes gradient scale mismatch between InfoNCE and MSE
2. **Offline mining** - Avoids 10-50x slowdown from online validation

### Next Steps

- Try curriculum learning (progressive difficulty)
- Compare with baselines on larger datasets
- Implement full cross-encoder pairwise model for ELO estimation

---

**References:**
- ANMI 2.0 Paper (forthcoming)
- [Tevatron](https://github.com/texttron/tevatron) - Dense retrieval toolkit
- [Dense Text Retrieval Survey](https://arxiv.org/abs/2211.14876) - Zhao et al.
