# Sample Efficiency Comparison
============================
Train models with varying data sizes to compare data efficiency.

In [None]:
# ============================================================
# CELL 1: Setup
# ============================================================
import os
import json
import random
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from tqdm import tqdm
import matplotlib.pyplot as plt

DATA_SIZES = [100, 250, 500, 1000, 2500, 5000, 10500]
MODELS = ['StaticMCDA', 'NeuralRanker', 'DocMatchNet-Original', 'DocMatchNet-JEPA']
SEEDS = [42, 123, 456]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name()}')

os.makedirs('/kaggle/working/results', exist_ok=True)

In [None]:
# ============================================================
# CELL 2: Load Data
# ============================================================
DATA_DIR = '/kaggle/input/docmatchnet-jepa-data/data'

doctor_embeddings = torch.load(f'{DATA_DIR}/doctor_embeddings.pt', weights_only=False)
case_embeddings = torch.load(f'{DATA_DIR}/case_embeddings.pt', weights_only=False)
clinical_features = torch.load(f'{DATA_DIR}/clinical_features.pt', weights_only=False)
pastwork_features = torch.load(f'{DATA_DIR}/pastwork_features.pt', weights_only=False)
logistics_features = torch.load(f'{DATA_DIR}/logistics_features.pt', weights_only=False)
trust_features = torch.load(f'{DATA_DIR}/trust_features.pt', weights_only=False)
context_features = torch.load(f'{DATA_DIR}/context_features.pt', weights_only=False)
relevance_labels = torch.load(f'{DATA_DIR}/relevance_labels.pt', weights_only=False)
doctor_indices = torch.load(f'{DATA_DIR}/doctor_indices.pt', weights_only=False)
splits = torch.load(f'{DATA_DIR}/splits.pt', weights_only=False)

TRAIN_FULL = np.array(splits['train'])
VAL_IDX = np.array(splits['val'])
TEST_IDX = np.array(splits['test'])

print(f'Full train size: {len(TRAIN_FULL)}')
print(f'Val/Test size: {len(VAL_IDX)}/{len(TEST_IDX)}')

In [None]:
# ============================================================
# CELL 3: Dataset Classes
# ============================================================
class PairDataset(Dataset):
    def __init__(self, indices):
        self.indices = np.array(indices)

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        case_idx = int(self.indices[idx])
        rel = relevance_labels[case_idx]

        pos_mask = rel >= 3
        if pos_mask.sum() == 0:
            pos_mask = rel == rel.max()

        neg_mask = rel <= 1
        if neg_mask.sum() == 0:
            neg_mask = rel < rel.max()

        pos_pool = torch.where(pos_mask)[0]
        neg_pool = torch.where(neg_mask)[0]

        pos_local = pos_pool[torch.randint(len(pos_pool), (1,))].item()
        neg_local = neg_pool[torch.randint(len(neg_pool), (1,))].item()

        pos_global = doctor_indices[case_idx, pos_local]
        neg_global = doctor_indices[case_idx, neg_local]

        return {
            'case_embedding': case_embeddings[case_idx],
            'pos_doctor_embedding': doctor_embeddings[pos_global],
            'neg_doctor_embedding': doctor_embeddings[neg_global],
            'pos_clinical': clinical_features[case_idx, pos_local],
            'neg_clinical': clinical_features[case_idx, neg_local],
            'pos_pastwork': pastwork_features[case_idx, pos_local],
            'neg_pastwork': pastwork_features[case_idx, neg_local],
            'pos_logistics': logistics_features[case_idx, pos_local],
            'neg_logistics': logistics_features[case_idx, neg_local],
            'pos_trust': trust_features[case_idx, pos_local],
            'neg_trust': trust_features[case_idx, neg_local],
            'context': context_features[case_idx]
        }

class EvalDataset(Dataset):
    def __init__(self, indices):
        self.indices = np.array(indices)

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        case_idx = int(self.indices[idx])
        gidx = doctor_indices[case_idx]
        return {
            'case_embedding': case_embeddings[case_idx],
            'doctor_embeddings': doctor_embeddings[gidx],
            'clinical': clinical_features[case_idx],
            'pastwork': pastwork_features[case_idx],
            'logistics': logistics_features[case_idx],
            'trust': trust_features[case_idx],
            'context': context_features[case_idx],
            'relevance': relevance_labels[case_idx]
        }

In [None]:
# ============================================================
# CELL 4: Model Definitions
# ============================================================
class StaticMCDA:
    def __init__(self):
        self.cw = torch.tensor([0.55, 0.20, 0.15, 0.10])
        self.pw = torch.tensor([0.30, 0.25, 0.20, 0.15, 0.10])
        self.lw = torch.tensor([0.30, 0.25, 0.20, 0.15, 0.10])
        self.tw = torch.tensor([0.50, 0.30, 0.20])

    def score(self, c, p, l, t):
        cw = self.cw.to(c.device, c.dtype)
        pw = self.pw.to(c.device, c.dtype)
        lw = self.lw.to(c.device, c.dtype)
        tw = self.tw.to(c.device, c.dtype)
        cscore = (c * cw).sum(-1)
        pscore = (p * pw).sum(-1)
        lscore = (l * lw).sum(-1)
        tscore = (t * tw).sum(-1)
        return 0.40 * cscore + 0.25 * pscore + 0.25 * lscore + 0.10 * tscore

class NeuralRanker(nn.Module):
    def __init__(self, embed_dim=384, hidden_dim=256):
        super().__init__()
        self.patient_proj = nn.Linear(embed_dim, hidden_dim)
        self.doctor_proj = nn.Linear(embed_dim, hidden_dim)
        self.cross_attention = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=4, batch_first=True)
        self.layer_norm = nn.LayerNorm(hidden_dim)
        self.feature_encoder = nn.Sequential(nn.Linear(25, 64), nn.ReLU(), nn.Linear(64, 64))
        self.scorer = nn.Sequential(nn.Linear(hidden_dim + 64, 128), nn.ReLU(), nn.Linear(128, 1), nn.Sigmoid())

    def forward(self, pe, de, c, p, l, t, ctx):
        p1 = self.patient_proj(pe).unsqueeze(1)
        d1 = self.doctor_proj(de).unsqueeze(1)
        inter, _ = self.cross_attention(p1, d1, d1)
        inter = self.layer_norm(inter.squeeze(1))
        feat = self.feature_encoder(torch.cat([c, p, l, t, ctx], dim=-1))
        return {'score': self.scorer(torch.cat([inter, feat], dim=-1))}

class DocMatchNetOriginal(nn.Module):
    def __init__(self, embed_dim=384, hidden_dim=256, gate_dim=32):
        super().__init__()
        self.patient_proj = nn.Linear(embed_dim, hidden_dim)
        self.doctor_proj = nn.Linear(embed_dim, hidden_dim)
        self.cross_attention = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=4, batch_first=True)
        self.norm = nn.LayerNorm(hidden_dim)

        self.ce = nn.Sequential(nn.Linear(4, gate_dim), nn.ReLU(), nn.Linear(gate_dim, gate_dim), nn.ReLU())
        self.pe = nn.Sequential(nn.Linear(5, gate_dim), nn.ReLU(), nn.Linear(gate_dim, gate_dim), nn.ReLU())
        self.le = nn.Sequential(nn.Linear(5, gate_dim), nn.ReLU(), nn.Linear(gate_dim, gate_dim), nn.ReLU())
        self.te = nn.Sequential(nn.Linear(3, gate_dim), nn.ReLU(), nn.Linear(gate_dim, gate_dim), nn.ReLU())

        self.cg = nn.Sequential(nn.Linear(hidden_dim + 8, 64), nn.ReLU(), nn.Linear(64, gate_dim), nn.Sigmoid())
        self.pg = nn.Sequential(nn.Linear(hidden_dim + 8, 64), nn.ReLU(), nn.Linear(64, gate_dim), nn.Sigmoid())
        self.lg = nn.Sequential(nn.Linear(hidden_dim + 8, 64), nn.ReLU(), nn.Linear(64, gate_dim), nn.Sigmoid())
        self.tg = nn.Sequential(nn.Linear(hidden_dim + 8, 64), nn.ReLU(), nn.Linear(64, gate_dim), nn.Sigmoid())

        self.scorer = nn.Sequential(nn.Linear(gate_dim * 4, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid())

    def forward(self, pe, de, c, p, l, t, ctx):
        p1 = self.patient_proj(pe).unsqueeze(1)
        d1 = self.doctor_proj(de).unsqueeze(1)
        inter, _ = self.cross_attention(p1, d1, d1)
        inter = self.norm(inter.squeeze(1))

        g_in = torch.cat([inter, ctx], dim=-1)
        gc, gp, gl, gt = self.cg(g_in), self.pg(g_in), self.lg(g_in), self.tg(g_in)

        fused = torch.cat([gc*self.ce(c), gp*self.pe(p), gl*self.le(l), gt*self.te(t)], dim=-1)
        score = self.scorer(fused)
        return {'score': score}

class DocMatchNetJEPA(nn.Module):
    def __init__(self, embed_dim=384, latent_dim=256, gate_dim=32):
        super().__init__()
        self.patient_encoder = nn.Sequential(nn.Linear(embed_dim, 512), nn.LayerNorm(512), nn.GELU(), nn.Linear(512, latent_dim), nn.LayerNorm(latent_dim))
        self.doctor_encoder = nn.Sequential(nn.Linear(embed_dim, 512), nn.LayerNorm(512), nn.GELU(), nn.Linear(512, latent_dim), nn.LayerNorm(latent_dim))

        self.ce = nn.Sequential(nn.Linear(4, gate_dim), nn.GELU(), nn.Linear(gate_dim, gate_dim), nn.GELU())
        self.pe = nn.Sequential(nn.Linear(5, gate_dim), nn.GELU(), nn.Linear(gate_dim, gate_dim), nn.GELU())
        self.le = nn.Sequential(nn.Linear(5, gate_dim), nn.GELU(), nn.Linear(gate_dim, gate_dim), nn.GELU())
        self.te = nn.Sequential(nn.Linear(3, gate_dim), nn.GELU(), nn.Linear(gate_dim, gate_dim), nn.GELU())

        self.cg = nn.Sequential(nn.Linear(latent_dim + 8, 64), nn.GELU(), nn.Linear(64, gate_dim), nn.Sigmoid())
        self.pg = nn.Sequential(nn.Linear(latent_dim + 8, 64), nn.GELU(), nn.Linear(64, gate_dim), nn.Sigmoid())
        self.lg = nn.Sequential(nn.Linear(latent_dim + 8, 64), nn.GELU(), nn.Linear(64, gate_dim), nn.Sigmoid())
        self.tg = nn.Sequential(nn.Linear(latent_dim + 8, 64), nn.GELU(), nn.Linear(64, gate_dim), nn.Sigmoid())

        self.predictor = nn.Sequential(nn.Linear(latent_dim + gate_dim * 4, 256), nn.LayerNorm(256), nn.GELU(), nn.Linear(256, latent_dim))
        self.pp = nn.Sequential(nn.Linear(latent_dim, latent_dim), nn.GELU(), nn.Linear(latent_dim, 128))
        self.dp = nn.Sequential(nn.Linear(latent_dim, latent_dim), nn.GELU(), nn.Linear(latent_dim, 128))
        self.log_temperature = nn.Parameter(torch.log(torch.tensor(0.07)))

    def forward(self, pe, de, c, p, l, t, ctx):
        pl = self.patient_encoder(pe)
        dl = self.doctor_encoder(de)

        gi = torch.cat([pl, ctx], dim=-1)
        gc, gp, gl, gt = self.cg(gi), self.pg(gi), self.lg(gi), self.tg(gi)
        feat = torch.cat([gc*self.ce(c), gp*self.pe(p), gl*self.le(l), gt*self.te(t)], dim=-1)

        pred = self.predictor(torch.cat([pl, feat], dim=-1))
        pred_proj = self.pp(pred)
        doc_proj = self.dp(dl)

        pn = F.normalize(pred_proj, dim=-1)
        dn = F.normalize(doc_proj, dim=-1)
        score = ((pn * dn).sum(dim=-1, keepdim=True) + 1) / 2

        return {'score': score, 'predicted_ideal': pred_proj, 'doctor_embedding': doc_proj, 'temperature': self.log_temperature.exp()}

    def get_parameter_groups(self, base_lr, doctor_lr_mult=0.05):
        doc_params = list(self.doctor_encoder.parameters()) + list(self.dp.parameters())
        doc_ids = set(id(p) for p in doc_params)
        other = [p for p in self.parameters() if id(p) not in doc_ids]
        return [{'params': other, 'lr': base_lr}, {'params': doc_params, 'lr': base_lr * doctor_lr_mult}]

In [None]:
# ============================================================
# CELL 5: Train/Eval Utilities
# ============================================================
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def ndcg_at_5(scores, labels):
    order = np.argsort(-scores)[:5]
    dcg = sum((2**labels[i]-1)/np.log2(r+2) for r, i in enumerate(order))
    ideal = np.argsort(-labels)[:5]
    idcg = sum((2**labels[i]-1)/np.log2(r+2) for r, i in enumerate(ideal))
    return float(dcg/idcg) if idcg > 0 else 0.0

def evaluate_ndcg5(model_name, model, loader):
    if model_name != 'StaticMCDA':
        model.eval()

    vals = []
    with torch.no_grad():
        for batch in loader:
            pe = batch['case_embedding'].to(device)
            docs = batch['doctor_embeddings'].squeeze(0).to(device)
            c = batch['clinical'].squeeze(0).to(device)
            p = batch['pastwork'].squeeze(0).to(device)
            l = batch['logistics'].squeeze(0).to(device)
            t = batch['trust'].squeeze(0).to(device)
            ctx = batch['context'].to(device)
            labels = batch['relevance'].squeeze(0).cpu().numpy()

            scores = []
            for i in range(docs.shape[0]):
                if model_name == 'StaticMCDA':
                    s = model.score(c[i:i+1], p[i:i+1], l[i:i+1], t[i:i+1]).item()
                else:
                    s = model(pe, docs[i:i+1], c[i:i+1], p[i:i+1], l[i:i+1], t[i:i+1], ctx)['score'].item()
                scores.append(s)
            vals.append(ndcg_at_5(np.array(scores), labels))
    return float(np.mean(vals))

def train_model(model_name, train_idx, seed):
    set_seed(seed)

    train_loader = DataLoader(PairDataset(train_idx), batch_size=128, shuffle=True, num_workers=2)
    val_loader = DataLoader(EvalDataset(VAL_IDX), batch_size=1, shuffle=False)
    test_loader = DataLoader(EvalDataset(TEST_IDX), batch_size=1, shuffle=False)

    if model_name == 'StaticMCDA':
        model = StaticMCDA()
        return evaluate_ndcg5(model_name, model, test_loader)

    if model_name == 'NeuralRanker':
        model = NeuralRanker().to(device)
        optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
    elif model_name == 'DocMatchNet-Original':
        model = DocMatchNetOriginal().to(device)
        optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
    elif model_name == 'DocMatchNet-JEPA':
        model = DocMatchNetJEPA().to(device)
        optimizer = AdamW(model.get_parameter_groups(1e-4, doctor_lr_mult=0.05), weight_decay=1e-5)
    else:
        raise ValueError(model_name)

    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=2)
    best_val = -1.0
    best_state = None
    patience = 4
    pc = 0

    epochs = 12
    for ep in range(epochs):
        model.train()
        for batch in train_loader:
            pe = batch['case_embedding'].to(device)
            pde = batch['pos_doctor_embedding'].to(device)
            nde = batch['neg_doctor_embedding'].to(device)
            pcf = batch['pos_clinical'].to(device)
            ncf = batch['neg_clinical'].to(device)
            ppf = batch['pos_pastwork'].to(device)
            npf = batch['neg_pastwork'].to(device)
            plf = batch['pos_logistics'].to(device)
            nlf = batch['neg_logistics'].to(device)
            ptf = batch['pos_trust'].to(device)
            ntf = batch['neg_trust'].to(device)
            ctx = batch['context'].to(device)

            optimizer.zero_grad()

            po = model(pe, pde, pcf, ppf, plf, ptf, ctx)
            no = model(pe, nde, ncf, npf, nlf, ntf, ctx)

            if model_name == 'NeuralRanker':
                loss = F.relu(0.1 - (po['score'] - no['score'])).mean()
            elif model_name == 'DocMatchNet-Original':
                rank = F.relu(0.1 - (po['score'] - no['score'])).mean()
                bce = (F.binary_cross_entropy(po['score'], torch.ones_like(po['score'])) +
                       F.binary_cross_entropy(no['score'], torch.zeros_like(no['score']))) / 2
                loss = rank + bce
            else:
                p = F.normalize(po['predicted_ideal'], dim=-1)
                t = F.normalize(po['doctor_embedding'], dim=-1)
                logits = p @ t.T / torch.clamp(po['temperature'], min=1e-8)
                labels = torch.arange(logits.shape[0], device=logits.device)
                loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

        scheduler.step()

        val_ndcg = evaluate_ndcg5(model_name, model, val_loader)
        if val_ndcg > best_val:
            best_val = val_ndcg
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            pc = 0
        else:
            pc += 1
            if pc >= patience:
                break

    if best_state is not None:
        model.load_state_dict(best_state)

    return evaluate_ndcg5(model_name, model, test_loader)

In [None]:
# ============================================================
# CELL 6: Run Sample-Efficiency Study
# ============================================================
sample_efficiency_results = {m: {} for m in MODELS}

for model_name in MODELS:
    print('\n' + '=' * 70)
    print(f'Model: {model_name}')
    print('=' * 70)

    for data_size in DATA_SIZES:
        run_scores = []
        print(f'  Data size: {data_size}')

        for seed in SEEDS:
            set_seed(seed)
            n = min(data_size, len(TRAIN_FULL))
            rng = np.random.default_rng(seed)
            sub_idx = rng.choice(TRAIN_FULL, size=n, replace=False)

            score = train_model(model_name, sub_idx, seed)
            run_scores.append(float(score))
            print(f'    Seed {seed}: NDCG@5={score:.4f}')

        sample_efficiency_results[model_name][str(data_size)] = {
            'mean': float(np.mean(run_scores)),
            'std': float(np.std(run_scores))
        }

        print(f"    Mean±Std: {np.mean(run_scores):.4f} ± {np.std(run_scores):.4f}")

In [None]:
# ============================================================
# CELL 7: Save Results
# ============================================================
out_path = '/kaggle/working/results/sample_efficiency_results.json'
with open(out_path, 'w') as f:
    json.dump(sample_efficiency_results, f, indent=2)

print(f'Saved: {out_path}')

In [None]:
# ============================================================
# CELL 8: Learning Curves + JEPA Efficiency Insight
# ============================================================
plt.figure(figsize=(10, 6))
for model_name in MODELS:
    ys = [sample_efficiency_results[model_name][str(s)]['mean'] for s in DATA_SIZES]
    es = [sample_efficiency_results[model_name][str(s)]['std'] for s in DATA_SIZES]
    plt.plot(DATA_SIZES, ys, marker='o', label=model_name)
    plt.fill_between(DATA_SIZES, np.array(ys)-np.array(es), np.array(ys)+np.array(es), alpha=0.15)

plt.xscale('log')
plt.xlabel('Training Samples (log scale)')
plt.ylabel('NDCG@5')
plt.title('Sample Efficiency Curves')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plot_path = '/kaggle/working/results/sample_efficiency_curves.png'
plt.savefig(plot_path, dpi=150)
plt.show()

orig_full = sample_efficiency_results['DocMatchNet-Original'][str(max(DATA_SIZES))]['mean']
jepa_curve = {int(k): v['mean'] for k, v in sample_efficiency_results['DocMatchNet-JEPA'].items()}
orig_curve = {int(k): v['mean'] for k, v in sample_efficiency_results['DocMatchNet-Original'].items()}

def min_size_for_target(curve, target):
    for s in sorted(curve.keys()):
        if curve[s] >= target:
            return s
    return None

claim = None
for pct in [0.95, 0.90, 0.85]:
    target = orig_full * pct
    j_size = min_size_for_target(jepa_curve, target)
    o_size = min_size_for_target(orig_curve, target)
    if j_size is not None and o_size is not None and o_size > 0 and j_size < o_size:
        reduction = (1.0 - (j_size / o_size)) * 100.0
        claim = (pct, target, j_size, o_size, reduction)
        break

if claim is not None:
    pct, target, j_size, o_size, reduction = claim
    print('\nKey Insight:')
    print(f"DocMatchNet-JEPA reaches {pct*100:.0f}% of Original full-data performance (target={target:.4f})")
    print(f"with {j_size} samples vs {o_size} for Original ({reduction:.1f}% less data).")
else:
    print('\nKey Insight:')
    print('No clear data-efficiency crossover found at 95/90/85% thresholds with current runs.')

print(f'\nSaved plot: {plot_path}')