# DocMatchNet-JEPA Ablation Studies
=================================
Test contribution of each component across seeds `[42, 123, 456]`.

In [None]:
# ============================================================
# CELL 1: Setup
# ============================================================
import os
import json
import time
import copy
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name()}')

os.makedirs('/kaggle/working/results', exist_ok=True)

In [None]:
# ============================================================
# CELL 2: Load Data
# ============================================================
DATA_DIR = '/kaggle/input/docmatchnet-jepa-data/data'

doctor_embeddings = torch.load(f'{DATA_DIR}/doctor_embeddings.pt', weights_only=False)
case_embeddings = torch.load(f'{DATA_DIR}/case_embeddings.pt', weights_only=False)
clinical_features = torch.load(f'{DATA_DIR}/clinical_features.pt', weights_only=False)
pastwork_features = torch.load(f'{DATA_DIR}/pastwork_features.pt', weights_only=False)
logistics_features = torch.load(f'{DATA_DIR}/logistics_features.pt', weights_only=False)
trust_features = torch.load(f'{DATA_DIR}/trust_features.pt', weights_only=False)
context_features = torch.load(f'{DATA_DIR}/context_features.pt', weights_only=False)
relevance_labels = torch.load(f'{DATA_DIR}/relevance_labels.pt', weights_only=False)
doctor_indices = torch.load(f'{DATA_DIR}/doctor_indices.pt', weights_only=False)
case_metadata = torch.load(f'{DATA_DIR}/case_metadata.pt', weights_only=False)
splits = torch.load(f'{DATA_DIR}/splits.pt', weights_only=False)

print('Loaded data successfully')
print(f"Train/Val/Test: {len(splits['train'])}/{len(splits['val'])}/{len(splits['test'])}")

In [None]:
# ============================================================
# CELL 3: Datasets
# ============================================================
class AblationTrainDataset(Dataset):
    def __init__(self, indices, case_emb, doc_emb, doc_indices, clinical, pastwork, logistics, trust, context, relevance):
        self.indices = indices
        self.case_emb = case_emb
        self.doc_emb = doc_emb
        self.doc_indices = doc_indices
        self.clinical = clinical
        self.pastwork = pastwork
        self.logistics = logistics
        self.trust = trust
        self.context = context
        self.relevance = relevance

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        case_idx = self.indices[idx]
        rel = self.relevance[case_idx]

        pos_mask = rel >= 3
        if pos_mask.sum() == 0:
            pos_mask = rel == rel.max()

        neg_mask = rel <= 1
        if neg_mask.sum() == 0:
            neg_mask = rel < rel.max()

        hard_neg_mask = rel == 2

        pos_pool = torch.where(pos_mask)[0]
        neg_pool = torch.where(neg_mask)[0]
        hard_pool = torch.where(hard_neg_mask)[0]

        pos_local = pos_pool[torch.randint(len(pos_pool), (1,))].item()

        if len(hard_pool) > 0 and torch.rand(1).item() < 0.2:
            neg_local = hard_pool[torch.randint(len(hard_pool), (1,))].item()
        else:
            neg_local = neg_pool[torch.randint(len(neg_pool), (1,))].item()

        pos_global = self.doc_indices[case_idx, pos_local]
        neg_global = self.doc_indices[case_idx, neg_local]

        return {
            'case_embedding': self.case_emb[case_idx],
            'pos_doctor_embedding': self.doc_emb[pos_global],
            'neg_doctor_embedding': self.doc_emb[neg_global],
            'pos_clinical': self.clinical[case_idx, pos_local],
            'neg_clinical': self.clinical[case_idx, neg_local],
            'pos_pastwork': self.pastwork[case_idx, pos_local],
            'neg_pastwork': self.pastwork[case_idx, neg_local],
            'pos_logistics': self.logistics[case_idx, pos_local],
            'neg_logistics': self.logistics[case_idx, neg_local],
            'pos_trust': self.trust[case_idx, pos_local],
            'neg_trust': self.trust[case_idx, neg_local],
            'context': self.context[case_idx]
        }

class AblationEvalDataset(Dataset):
    def __init__(self, indices, case_emb, doc_emb, doc_indices, clinical, pastwork, logistics, trust, context, relevance):
        self.indices = indices
        self.case_emb = case_emb
        self.doc_emb = doc_emb
        self.doc_indices = doc_indices
        self.clinical = clinical
        self.pastwork = pastwork
        self.logistics = logistics
        self.trust = trust
        self.context = context
        self.relevance = relevance

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        case_idx = self.indices[idx]
        global_docs = self.doc_indices[case_idx]
        return {
            'case_embedding': self.case_emb[case_idx],
            'doctor_embeddings': self.doc_emb[global_docs],
            'clinical': self.clinical[case_idx],
            'pastwork': self.pastwork[case_idx],
            'logistics': self.logistics[case_idx],
            'trust': self.trust[case_idx],
            'context': self.context[case_idx],
            'relevance': self.relevance[case_idx]
        }

In [None]:
# ============================================================
# CELL 4: Model + Losses + Metrics
# ============================================================
class DocMatchNetJEPA(nn.Module):
    def __init__(self, embed_dim=384, latent_dim=256, gate_dim=32, context_dim=8, dropout=0.1, no_gates=False):
        super().__init__()
        self.gate_dim = gate_dim
        self.no_gates = no_gates

        self.patient_encoder = nn.Sequential(
            nn.Linear(embed_dim, 512), nn.LayerNorm(512), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(512, latent_dim), nn.LayerNorm(latent_dim)
        )
        self.doctor_encoder = nn.Sequential(
            nn.Linear(embed_dim, 512), nn.LayerNorm(512), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(512, latent_dim), nn.LayerNorm(latent_dim)
        )

        self.clinical_encoder = self._make_encoder(4, gate_dim)
        self.pastwork_encoder = self._make_encoder(5, gate_dim)
        self.logistics_encoder = self._make_encoder(5, gate_dim)
        self.trust_encoder = self._make_encoder(3, gate_dim)

        gate_input_dim = latent_dim + context_dim
        self.clinical_gate = self._make_gate(gate_input_dim, gate_dim)
        self.pastwork_gate = self._make_gate(gate_input_dim, gate_dim)
        self.logistics_gate = self._make_gate(gate_input_dim, gate_dim)
        self.trust_gate = self._make_gate(gate_input_dim, gate_dim)
        self._init_gate_biases()

        self.predictor = nn.Sequential(
            nn.Linear(latent_dim + gate_dim * 4, 256),
            nn.LayerNorm(256), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(256, 256), nn.LayerNorm(256), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(256, latent_dim)
        )

        self.predictor_proj = nn.Sequential(nn.Linear(latent_dim, latent_dim), nn.GELU(), nn.Linear(latent_dim, 128))
        self.doctor_proj = nn.Sequential(nn.Linear(latent_dim, latent_dim), nn.GELU(), nn.Linear(latent_dim, 128))
        self.log_temperature = nn.Parameter(torch.log(torch.tensor(0.07)))

    def _make_encoder(self, in_d, out_d):
        return nn.Sequential(
            nn.Linear(in_d, out_d), nn.BatchNorm1d(out_d), nn.GELU(),
            nn.Linear(out_d, out_d), nn.BatchNorm1d(out_d), nn.GELU()
        )

    def _make_gate(self, in_d, out_d):
        return nn.Sequential(nn.Linear(in_d, 64), nn.GELU(), nn.Linear(64, out_d), nn.Sigmoid())

    def _init_gate_biases(self):
        nn.init.constant_(self.clinical_gate[-2].bias, 0.4)
        nn.init.constant_(self.pastwork_gate[-2].bias, 0.0)
        nn.init.constant_(self.logistics_gate[-2].bias, 0.0)
        nn.init.constant_(self.trust_gate[-2].bias, -0.4)

    def compute_gates(self, patient_latent, context):
        if self.no_gates:
            batch = patient_latent.shape[0]
            g = torch.full((batch, self.gate_dim), 0.5, device=patient_latent.device, dtype=patient_latent.dtype)
            return {'clinical': g, 'pastwork': g, 'logistics': g, 'trust': g}
        gate_input = torch.cat([patient_latent, context], dim=-1)
        return {
            'clinical': self.clinical_gate(gate_input),
            'pastwork': self.pastwork_gate(gate_input),
            'logistics': self.logistics_gate(gate_input),
            'trust': self.trust_gate(gate_input)
        }

    def forward(self, patient_emb, doctor_emb, clinical, pastwork, logistics, trust, context):
        patient_latent = self.patient_encoder(patient_emb)
        doctor_latent = self.doctor_encoder(doctor_emb)

        enc_clinical = self.clinical_encoder(clinical)
        enc_pastwork = self.pastwork_encoder(pastwork)
        enc_logistics = self.logistics_encoder(logistics)
        enc_trust = self.trust_encoder(trust)

        gates = self.compute_gates(patient_latent, context)

        gated = torch.cat([
            gates['clinical'] * enc_clinical,
            gates['pastwork'] * enc_pastwork,
            gates['logistics'] * enc_logistics,
            gates['trust'] * enc_trust
        ], dim=-1)

        pred = self.predictor(torch.cat([patient_latent, gated], dim=-1))
        pred_proj = self.predictor_proj(pred)
        doc_proj = self.doctor_proj(doctor_latent)

        p = F.normalize(pred_proj, dim=-1)
        d = F.normalize(doc_proj, dim=-1)
        score = ((p * d).sum(dim=-1, keepdim=True) + 1) / 2

        return {
            'score': score,
            'predicted_ideal': pred_proj,
            'doctor_embedding': doc_proj,
            'temperature': self.log_temperature.exp(),
            'gates': gates
        }

    def get_parameter_groups(self, base_lr, doctor_lr_mult=0.05):
        doc_params = list(self.doctor_encoder.parameters()) + list(self.doctor_proj.parameters())
        doc_ids = set(id(p) for p in doc_params)
        other = [p for p in self.parameters() if id(p) not in doc_ids]
        return [
            {'params': other, 'lr': base_lr},
            {'params': doc_params, 'lr': base_lr * doctor_lr_mult}
        ]

def infonce_loss(pred, target, temperature):
    p = F.normalize(pred, dim=-1)
    t = F.normalize(target, dim=-1)
    logits = p @ t.T / torch.clamp(temperature, min=1e-8)
    labels = torch.arange(logits.shape[0], device=logits.device)
    return (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2

def mse_loss(pred, target):
    return F.mse_loss(pred, target)

def ranking_loss(pos_score, neg_score, margin=0.1):
    return F.relu(margin - (pos_score - neg_score)).mean()

def vicreg_gate_loss(gates_dict):
    total = 0.0
    for gate_vals in gates_dict.values():
        var = F.relu(1.0 - gate_vals.var(dim=0, unbiased=False)).mean()
        centered = gate_vals - gate_vals.mean(dim=0, keepdim=True)
        if gate_vals.shape[0] > 1 and gate_vals.shape[1] > 1:
            cov = (centered.T @ centered) / (gate_vals.shape[0] - 1)
            off = cov - torch.diag(torch.diag(cov))
            cov_loss = off.pow(2).mean()
        else:
            cov_loss = torch.tensor(0.0, device=gate_vals.device)
        total = total + var + 0.1 * cov_loss
    return total / len(gates_dict)

def map_score(scores, labels, threshold=2):
    order = np.argsort(-scores)
    relevant = labels >= threshold
    precisions = []
    rel_count = 0
    for i, idx in enumerate(order):
        if relevant[idx]:
            rel_count += 1
            precisions.append(rel_count / (i + 1))
    return float(np.mean(precisions)) if len(precisions) else 0.0

def mrr_score(scores, labels, threshold=2):
    order = np.argsort(-scores)
    relevant = labels >= threshold
    for i, idx in enumerate(order):
        if relevant[idx]:
            return 1.0 / (i + 1)
    return 0.0

def ndcg_at_k(scores, labels, k=5):
    order = np.argsort(-scores)[:k]
    dcg = sum((2**labels[i] - 1) / np.log2(rank + 2) for rank, i in enumerate(order))
    ideal = np.argsort(-labels)[:k]
    idcg = sum((2**labels[i] - 1) / np.log2(rank + 2) for rank, i in enumerate(ideal))
    return float(dcg / idcg) if idcg > 0 else 0.0

In [None]:
# ============================================================
# CELL 5: Training + Evaluation Runner
# ============================================================
BASE_CFG = {
    'stage1_epochs': 10,
    'stage2_epochs': 20,
    'single_stage_epochs': 30,
    'batch_size': 128,
    'lr': 1e-4,
    'weight_decay': 1e-5,
    'doctor_lr_mult': 0.05,
    'lambda_gate': 0.05,
    'patience': 6,
    'loss_type': 'infonce',
    'two_stage': True,
    'gate_dim': 32,
    'no_gates': False
}

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def build_loaders(seed, batch_size):
    train_ds = AblationTrainDataset(
        splits['train'], case_embeddings, doctor_embeddings, doctor_indices,
        clinical_features, pastwork_features, logistics_features, trust_features,
        context_features, relevance_labels
    )
    val_ds = AblationEvalDataset(
        splits['val'], case_embeddings, doctor_embeddings, doctor_indices,
        clinical_features, pastwork_features, logistics_features, trust_features,
        context_features, relevance_labels
    )
    test_ds = AblationEvalDataset(
        splits['test'], case_embeddings, doctor_embeddings, doctor_indices,
        clinical_features, pastwork_features, logistics_features, trust_features,
        context_features, relevance_labels
    )

    g = torch.Generator()
    g.manual_seed(seed)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, generator=g)
    val_loader = DataLoader(val_ds, batch_size=1, shuffle=False)
    test_loader = DataLoader(test_ds, batch_size=1, shuffle=False)
    return train_loader, val_loader, test_loader

def evaluate_model(model, loader):
    model.eval()
    ndcg5_vals, map_vals, mrr_vals = [], [], []

    with torch.no_grad():
        for batch in loader:
            case_emb = batch['case_embedding'].to(device)
            docs = batch['doctor_embeddings'].squeeze(0).to(device)
            clinical = batch['clinical'].squeeze(0).to(device)
            pastwork = batch['pastwork'].squeeze(0).to(device)
            logistics = batch['logistics'].squeeze(0).to(device)
            trust = batch['trust'].squeeze(0).to(device)
            context = batch['context'].to(device)
            labels = batch['relevance'].squeeze(0).cpu().numpy()

            scores = []
            for i in range(docs.shape[0]):
                out = model(case_emb, docs[i:i+1], clinical[i:i+1], pastwork[i:i+1], logistics[i:i+1], trust[i:i+1], context)
                scores.append(out['score'].item())
            scores = np.array(scores)

            ndcg5_vals.append(ndcg_at_k(scores, labels, 5))
            map_vals.append(map_score(scores, labels))
            mrr_vals.append(mrr_score(scores, labels))

    return {
        'ndcg@5': float(np.mean(ndcg5_vals)),
        'map': float(np.mean(map_vals)),
        'mrr': float(np.mean(mrr_vals))
    }

def train_with_config(cfg, seed):
    set_seed(seed)
    train_loader, val_loader, test_loader = build_loaders(seed, cfg['batch_size'])

    model = DocMatchNetJEPA(gate_dim=cfg['gate_dim'], no_gates=cfg['no_gates']).to(device)
    groups = model.get_parameter_groups(cfg['lr'], doctor_lr_mult=cfg['doctor_lr_mult'])
    optimizer = AdamW(groups, weight_decay=cfg['weight_decay'])
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)

    best_state = None
    best_val = -1.0
    patience_counter = 0

    def train_epoch(epoch_idx):
        model.train()
        total_loss = 0.0
        n = 0
        for batch in tqdm(train_loader, desc=f'Seed {seed} Epoch {epoch_idx}', leave=False):
            case_emb = batch['case_embedding'].to(device)
            pos_doc = batch['pos_doctor_embedding'].to(device)
            neg_doc = batch['neg_doctor_embedding'].to(device)
            pos_c = batch['pos_clinical'].to(device)
            neg_c = batch['neg_clinical'].to(device)
            pos_p = batch['pos_pastwork'].to(device)
            neg_p = batch['neg_pastwork'].to(device)
            pos_l = batch['pos_logistics'].to(device)
            neg_l = batch['neg_logistics'].to(device)
            pos_t = batch['pos_trust'].to(device)
            neg_t = batch['neg_trust'].to(device)
            context = batch['context'].to(device)

            optimizer.zero_grad()

            pos_out = model(case_emb, pos_doc, pos_c, pos_p, pos_l, pos_t, context)
            neg_out = model(case_emb, neg_doc, neg_c, neg_p, neg_l, neg_t, context)

            if cfg['loss_type'] == 'infonce':
                core = infonce_loss(pos_out['predicted_ideal'], pos_out['doctor_embedding'], pos_out['temperature'])
            elif cfg['loss_type'] == 'mse':
                core = mse_loss(pos_out['predicted_ideal'], pos_out['doctor_embedding'])
            elif cfg['loss_type'] == 'ranking':
                core = ranking_loss(pos_out['score'], neg_out['score'])
            else:
                raise ValueError(f"Unknown loss_type: {cfg['loss_type']}")

            gate_reg = vicreg_gate_loss(pos_out['gates'])
            loss = core + cfg['lambda_gate'] * gate_reg

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_loss += loss.item()
            n += 1

        scheduler.step()
        return total_loss / max(n, 1)

    if cfg['two_stage']:
        for name, p in model.named_parameters():
            if 'gate' in name:
                p.requires_grad = False

        for ep in range(cfg['stage1_epochs']):
            train_epoch(ep)

        for p in model.parameters():
            p.requires_grad = True

        for ep in range(cfg['stage2_epochs']):
            train_epoch(cfg['stage1_epochs'] + ep)
            val_metrics = evaluate_model(model, val_loader)
            v = val_metrics['ndcg@5']
            if v > best_val:
                best_val = v
                best_state = {k: t.detach().cpu().clone() for k, t in model.state_dict().items()}
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= cfg['patience']:
                    break
    else:
        for ep in range(cfg['single_stage_epochs']):
            train_epoch(ep)
            val_metrics = evaluate_model(model, val_loader)
            v = val_metrics['ndcg@5']
            if v > best_val:
                best_val = v
                best_state = {k: t.detach().cpu().clone() for k, t in model.state_dict().items()}
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= cfg['patience']:
                    break

    if best_state is not None:
        model.load_state_dict(best_state)

    return evaluate_model(model, test_loader)

In [None]:
# ============================================================
# CELL 6: Define Ablations and Run
# ============================================================
SEEDS = [42, 123, 456]

ABLATIONS = {
    'full_docmatchnet_jepa': {},
    'no_gates': {'no_gates': True},
    'no_twostage': {'two_stage': False},
    'mse_loss': {'loss_type': 'mse'},
    'ranking_loss': {'loss_type': 'ranking'},
    'symmetric_lr': {'doctor_lr_mult': 1.0},
    'no_vicreg': {'lambda_gate': 0.0},
    'gate_dim_16': {'gate_dim': 16},
    'gate_dim_64': {'gate_dim': 64}
}

ablation_results = {}

for ablation_name, override in ABLATIONS.items():
    print('\n' + '=' * 70)
    print(f'Running ablation: {ablation_name}')
    print('=' * 70)

    cfg = copy.deepcopy(BASE_CFG)
    cfg.update(override)

    ndcg_runs, map_runs, mrr_runs, time_runs = [], [], [], []

    for seed in SEEDS:
        print(f'  Seed: {seed}')
        start = time.time()
        metrics = train_with_config(cfg, seed)
        elapsed_min = (time.time() - start) / 60.0

        ndcg_runs.append(float(metrics['ndcg@5']))
        map_runs.append(float(metrics['map']))
        mrr_runs.append(float(metrics['mrr']))
        time_runs.append(float(elapsed_min))

        print(f"    NDCG@5={metrics['ndcg@5']:.4f}, MAP={metrics['map']:.4f}, MRR={metrics['mrr']:.4f}, Time={elapsed_min:.1f} min")

    ablation_results[ablation_name] = {
        'ndcg@5': {
            'mean': float(np.mean(ndcg_runs)),
            'std': float(np.std(ndcg_runs)),
            'runs': ndcg_runs
        },
        'map': {
            'mean': float(np.mean(map_runs)),
            'std': float(np.std(map_runs)),
            'runs': map_runs
        },
        'mrr': {
            'mean': float(np.mean(mrr_runs)),
            'std': float(np.std(mrr_runs)),
            'runs': mrr_runs
        },
        'training_time_min': float(np.mean(time_runs))
    }

print('\nAblation runs complete.')

In [None]:
# ============================================================
# CELL 7: Save Results
# ============================================================
out_path = '/kaggle/working/results/ablation_results.json'
with open(out_path, 'w') as f:
    json.dump(ablation_results, f, indent=2)

print(f'Saved: {out_path}')
print('\nSummary (NDCG@5 mean ± std):')
for name, metrics in ablation_results.items():
    print(f"  {name}: {metrics['ndcg@5']['mean']:.4f} ± {metrics['ndcg@5']['std']:.4f}")