In [1]:
# @title 1. Installation & Imports
!pip install -q sentence-transformers datasets scipy networkx

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModel
import numpy as np
import networkx as nx
from scipy.stats import norm
from tqdm.auto import tqdm
from dataclasses import dataclass
from typing import List, Dict, Tuple
import random

# Set random seed for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on {device}")

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/75.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m75.1/75.1 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/488.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m488.0/488.0 kB[0m [31m18.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m511.6/511.6 kB[0m [31m37.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m119.7/119.7 kB[0m [31m11.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m150.3/150.3 kB[0m [31m15.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.9/193.9 kB[0m [31m20.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━



Running on cpu


Below is the mathematical core of ANMI:SparseELOEstimator: Estimates global ELO scores from sparse pairwise comparisons ($O(n)$ complexity).ELOGapSelector: Determines if a negative is "Safe", "Hard", or "Dangerous" based on the ELO gap.

In [6]:
# @title 2. Core ANMI Engine (ELO & Selection) - FIXED

class SparseELOEstimator:
    """
    Estimates ELO scores.
    OPTIMIZED: Uses Pointwise scoring -> Pairwise difference.
    This is O(N) instead of O(K*N), making it much faster for the CrossEncoder.
    """
    def __init__(self, k=4, max_iter=50, tol=1e-4):
        self.k = k
        self.max_iter = max_iter
        self.tol = tol

    def estimate(self, documents: List[str], query: str, pairwise_model) -> Dict[int, float]:
        n = len(documents)
        if n < 2: return {0: 1200.0}

        # 1. Score all documents directly (Pointwise)
        # This fixes the "3 arguments" issue by only sending [Query, Doc]
        pairs_to_score = [[query, doc] for doc in documents]

        # Returns raw logits/scores (unbounded)
        scores = pairwise_model.predict(pairs_to_score, batch_size=32, show_progress_bar=False)

        # 2. Build k-regular graph for ELO calculation
        # Even though we have exact scores, we simulate the graph structure
        # to maintain the ANMI theoretical sparse constraint (or we can just use scores directly)

        # Since we have pointwise scores, we can actually skip the iterative Thurstone
        # solver and map scores directly to ELO if we assume transitivity.
        # However, to keep the graph-theory component:

        edges = set()
        indices = list(range(n))
        for _ in range(max(1, self.k // 2)):
            np.random.shuffle(indices)
            for i in range(n):
                u, v = indices[i], indices[(i + 1) % n]
                edges.add(tuple(sorted((u, v))))
        edge_list = list(edges)

        # 3. Fit ELO
        elo_scores = np.zeros(n)
        lr = 1.0

        for it in range(self.max_iter):
            grad = np.zeros(n)
            for u, v in edge_list:
                # Derive pairwise preference from pointwise score difference
                # P(u > v) = sigmoid(s_u - s_v)
                diff = scores[u] - scores[v]
                w_uv = 1 / (1 + np.exp(-diff))

                delta = elo_scores[u] - elo_scores[v]
                sigma = 1 / (1 + np.exp(-delta))

                g = w_uv - sigma
                grad[u] += g
                grad[v] -= g

            grad -= np.mean(grad)
            if np.max(np.abs(grad)) < self.tol: break
            elo_scores += lr * grad
            lr *= 0.9

        elo_scores = (elo_scores - np.mean(elo_scores)) * 400 + 1200
        return {i: s for i, s in enumerate(elo_scores)}


class ELOGapSelector:
    """
    Selects negatives based on ELO gap from positive.
    Gap = Positive_ELO - Negative_ELO
    """
    def __init__(self):
        self.danger_zone = 50    # Gap < 50: Too risky
        self.hard_zone = 200     # 50-200: Hard negative
        self.easy_zone = 400     # > 400: Easy negative

    def weight(self, pos_elo, neg_elo):
        gap = pos_elo - neg_elo
        if gap < self.danger_zone: return 0.0
        if gap < self.hard_zone: return 1.0
        if gap < self.easy_zone: return 0.5
        return 0.1

This implements the Fixed Hybrid Loss discussed in the critique. It includes a learnable elo_head to map dot-products to ELO space, preventing gradient scale mismatch.

In [7]:
# @title 3. ANMI Hybrid Loss Function

class ANMIHybridLoss(nn.Module):
    def __init__(self, alpha=0.6, tau=0.07):
        super().__init__()
        self.alpha = alpha  # Mixing coef (0.6 means 60% InfoNCE, 40% MSE)
        self.tau = tau
        # Learnable projection: DotProduct -> ELO Space
        self.elo_head = nn.Linear(1, 1)

    def forward(self, q_emb, p_emb, n_embs, n_weights, n_elos, p_elo):
        """
        q_emb: [B, Dim]
        p_emb: [B, Dim]
        n_embs: [B, K, Dim]
        n_weights: [B, K] (ELO-derived weights)
        n_elos: [B, K] (Target ELOs)
        p_elo: [B] (Target Positive ELO)
        """
        B, K, D = n_embs.shape

        # 1. Compute Dot Products (Sim Scores)
        p_sim = torch.sum(q_emb * p_emb, dim=-1, keepdim=True) # [B, 1]
        n_sim = torch.bmm(n_embs, q_emb.unsqueeze(-1)).squeeze(-1) # [B, K]

        # === COMPONENT A: Weighted InfoNCE ===
        # Denominator: exp(p/tau) + sum(w_i * exp(n_i/tau))
        # Note: We detach weights to prevent backprop through the ELO engine
        p_exp = torch.exp(p_sim / self.tau)
        n_exp = torch.exp(n_sim / self.tau) * n_weights.to(q_emb.device)

        loss_nce = -torch.log(p_exp / (p_exp + n_exp.sum(dim=-1, keepdim=True)))

        # === COMPONENT B: ELO MSE (Regularization) ===
        # Project all scores to ELO space
        all_sims = torch.cat([p_sim, n_sim], dim=1) # [B, K+1]
        all_elos = torch.cat([p_elo.unsqueeze(1), n_elos], dim=1).to(q_emb.device)

        # Learnable mapping applied to similarity scores
        pred_elos = self.elo_head(all_sims.unsqueeze(-1)).squeeze(-1)

        # Normalize ELO targets for numerical stability (Standard Scaler approx)
        target_elos = (all_elos - 1200) / 400
        pred_elos_norm = (pred_elos - 1200) / 400

        loss_mse = F.mse_loss(pred_elos_norm, target_elos)

        # Combine
        return self.alpha * loss_nce.mean() + (1 - self.alpha) * loss_mse

To make this run instantly without massive downloads, we generate a synthetic "Technical Support" dataset.

Concepts: Python, SQL, Java, AWS.

Logic: A query about "Python" should match a doc about "Python".

Hard Negative: A doc about "Java" (both are code, but wrong language).

In [8]:
# @title 4. Generate Synthetic Dataset (Micro-Corpus)

def generate_micro_corpus():
    topics = ["Python", "Java", "SQL", "AWS", "Docker"]
    actions = ["install", "debug", "configure", "deploy", "optimize"]
    corpus = []
    queries = []
    ground_truth = {} # query_idx -> pos_doc_idx

    print("Generating synthetic data...")

    # Create Documents
    doc_id = 0
    for t in topics:
        for a in actions:
            # Create 5 variations per topic-action pair
            for i in range(5):
                text = f"Guide to {a} {t} in production environment version {i}."
                corpus.append(text)
                doc_id += 1

    # Create Queries
    q_id = 0
    for t in topics:
        for a in actions:
            text = f"How do I {a} {t}?"
            queries.append(text)
            # Assign a random positive from the correct topic/action group
            # Simple heuristic: find docs containing both keywords
            candidates = [i for i, d in enumerate(corpus) if t in d and a in d]
            ground_truth[q_id] = candidates[0] # Pick first as "True Positive"
            q_id += 1

    print(f"Generated {len(corpus)} docs and {len(queries)} queries.")
    return corpus, queries, ground_truth

corpus, queries, ground_truth = generate_micro_corpus()

Generating synthetic data...
Generated 125 docs and 25 queries.


This is the heavy lifting. We use a Cross-Encoder to find hard negatives and assign ELO scores.

Miner: Simple Dense Retrieval (using a tiny model for speed).

Oracle: cross-encoder/ms-marco-TinyBERT-L-2 (Small but effective).

In [9]:
# @title 5. Phase 1: Offline Mining & ELO Estimation - FIXED

from sentence_transformers import CrossEncoder # <--- THIS WAS MISSING

# 1. Setup Models
print("Loading models...")
# Fast bi-encoder for candidate generation
retriever = SentenceTransformer('all-MiniLM-L6-v2', device=device)

# Oracle: Use CrossEncoder class, NOT SentenceTransformer
oracle = CrossEncoder('cross-encoder/ms-marco-TinyBERT-L-2', device=device)

# 2. Encode Corpus (Offline Indexing)
print("Encoding corpus...")
corpus_embs = retriever.encode(corpus, convert_to_tensor=True, show_progress_bar=False)

# 3. Mining Loop
elo_engine = SparseELOEstimator(k=6)
selector = ELOGapSelector()

training_data = []

print("Starting Mining Phase...")
for q_idx, query_text in tqdm(enumerate(queries), total=len(queries)):
    pos_doc_idx = ground_truth[q_idx]
    pos_doc_text = corpus[pos_doc_idx]

    # A. Retrieve Candidates (Top-20)
    q_emb = retriever.encode(query_text, convert_to_tensor=True)
    hits = util.semantic_search(q_emb, corpus_embs, top_k=20)[0]
    candidate_indices = [h['corpus_id'] for h in hits]

    if pos_doc_idx not in candidate_indices:
        candidate_indices.append(pos_doc_idx)

    # Get texts
    candidate_texts = [corpus[i] for i in candidate_indices]

    # B. Estimate ELO Scores (The Oracle Step)
    elo_map = elo_engine.estimate(candidate_texts, query_text, oracle)

    # Map back to global indices
    doc_elos = {candidate_indices[i]: score for i, score in elo_map.items()}
    pos_elo = doc_elos[pos_doc_idx]

    # C. Select & Weight Negatives
    negatives = []
    neg_weights = []
    neg_elos = []

    for idx in candidate_indices:
        if idx == pos_doc_idx: continue

        elo = doc_elos[idx]
        weight = selector.weight(pos_elo, elo)

        if weight > 0:
            negatives.append(corpus[idx])
            neg_weights.append(weight)
            neg_elos.append(elo)

    if len(negatives) > 0:
        training_data.append({
            "query": query_text,
            "positive": pos_doc_text,
            "negatives": negatives[:8],
            "neg_weights": neg_weights[:8],
            "neg_elos": neg_elos[:8],
            "pos_elo": pos_elo
        })

print(f"Mining complete. Created {len(training_data)} curated training examples.")

Loading models...


README.md: 0.00B [00:00, ?B/s]

Encoding corpus...
Starting Mining Phase...


  0%|          | 0/25 [00:00<?, ?it/s]

Mining complete. Created 25 curated training examples.


Now we train a new model from scratch using the mined data. Notice how simple this loop is because all the heavy calculation (validation, weighting) happened in Phase 1

In [12]:
# @title 6. Phase 2: Online Training (ANMI)

# Initialize a fresh Student Model
student_model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
# Add a linear projection head for ELO MSE
criterion = ANMIHybridLoss(alpha=0.5).to(device)
optimizer = torch.optim.AdamW(list(student_model.parameters()) + list(criterion.parameters()), lr=2e-5)

# Simple Collate Function for padding
def collate_fn(batch):
    queries = [b['query'] for b in batch]
    positives = [b['positive'] for b in batch]
    # Flatten negatives for encoding, reshape later
    negatives = [n for b in batch for n in b['negatives']]

    # Prepare Tensors
    weights = torch.tensor([b['neg_weights'] for b in batch], dtype=torch.float)
    neg_elos = torch.tensor([b['neg_elos'] for b in batch], dtype=torch.float)
    pos_elos = torch.tensor([b['pos_elo'] for b in batch], dtype=torch.float)

    return queries, positives, negatives, weights, neg_elos, pos_elos

# Dataloader
loader = DataLoader(training_data, batch_size=8, shuffle=True, collate_fn=collate_fn)

print("Starting ANMI Training...")
student_model.train()

for epoch in range(20): # Fast training
    total_loss = 0
    for batch in loader:
        q_txt, p_txt, n_txt, weights, n_elos, p_elos = batch

        # 1. Forward Pass (Student)
        q_emb = student_model.encode(q_txt, convert_to_tensor=True)
        p_emb = student_model.encode(p_txt, convert_to_tensor=True)
        n_embs_flat = student_model.encode(n_txt, convert_to_tensor=True)

        # Reshape negatives [B * K, D] -> [B, K, D]
        # Note: In real production, handle variable num negatives via masking
        # Here we assume fixed size for demo simplicity or truncate
        k = len(n_txt) // len(q_txt)
        n_embs = n_embs_flat.view(len(q_txt), k, -1)

        # 2. Compute Hybrid Loss
        loss = criterion(q_emb, p_emb, n_embs, weights, n_elos, p_elos)

        # 3. Update
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}: Loss = {total_loss/len(loader):.4f}")

print("Training finished!")

Starting ANMI Training...
Epoch 1: Loss = 4.5594
Epoch 2: Loss = 4.5362
Epoch 3: Loss = 4.5363
Epoch 4: Loss = 4.5896
Epoch 5: Loss = 4.4987
Epoch 6: Loss = 4.5148
Epoch 7: Loss = 4.5272
Epoch 8: Loss = 4.5272
Epoch 9: Loss = 4.5363
Epoch 10: Loss = 4.4603
Epoch 11: Loss = 4.5896
Epoch 12: Loss = 4.5099
Epoch 13: Loss = 4.5362
Epoch 14: Loss = 4.4614
Epoch 15: Loss = 4.5261
Epoch 16: Loss = 4.4578
Epoch 17: Loss = 4.5261
Epoch 18: Loss = 4.4744
Epoch 19: Loss = 4.4554
Epoch 20: Loss = 4.4578
Training finished!


Let's verify that the model actually learned to rank relevant documents higher.

In [13]:
# @title 7. Validation / Inference

student_model.eval()

# Test Query
test_query = "How do I configure Docker?"
print(f"Query: {test_query}\n")

# Encode
q_emb = student_model.encode(test_query, convert_to_tensor=True)
doc_embs = student_model.encode(corpus, convert_to_tensor=True)

# Search
hits = util.semantic_search(q_emb, doc_embs, top_k=3)[0]

print("Top Results:")
for i, hit in enumerate(hits):
    print(f"{i+1}. {corpus[hit['corpus_id']]} (Score: {hit['score']:.4f})")

Query: How do I configure Docker?

Top Results:
1. Guide to configure Docker in production environment version 3. (Score: 0.7963)
2. Guide to configure Docker in production environment version 1. (Score: 0.7940)
3. Guide to configure Docker in production environment version 4. (Score: 0.7888)
