# Baseline Training and Evaluation
===============================
Train and evaluate: StaticMCDA, SimpleMLP, NeuralRanker, DINModel

GPU Recommended for neural baselines

In [None]:
# ============================================================
# CELL 1: Setup
# ============================================================
import os
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name()}')

torch.manual_seed(42)
np.random.seed(42)

os.makedirs('/kaggle/working/results', exist_ok=True)

In [None]:
# ============================================================
# CELL 2: Load Data
# ============================================================
DATA_DIR = '/kaggle/input/docmatchnet-data'

doctor_embeddings = torch.load(f'{DATA_DIR}/doctor_embeddings.pt')
case_embeddings = torch.load(f'{DATA_DIR}/case_embeddings.pt')
clinical_features = torch.load(f'{DATA_DIR}/clinical_features.pt')
pastwork_features = torch.load(f'{DATA_DIR}/pastwork_features.pt')
logistics_features = torch.load(f'{DATA_DIR}/logistics_features.pt')
trust_features = torch.load(f'{DATA_DIR}/trust_features.pt')
context_features = torch.load(f'{DATA_DIR}/context_features.pt')
relevance_labels = torch.load(f'{DATA_DIR}/relevance_labels.pt')
doctor_indices = torch.load(f'{DATA_DIR}/doctor_indices.pt')
splits = torch.load(f'{DATA_DIR}/splits.pt')
case_metadata = torch.load(f'{DATA_DIR}/case_metadata.pt')

print('Loaded tensors successfully')
print(f"Train/Val/Test: {len(splits['train'])}/{len(splits['val'])}/{len(splits['test'])}")

In [None]:
# ============================================================
# CELL 3: Datasets
# ============================================================
class PairwiseBaselineDataset(Dataset):
    def __init__(self, indices, case_emb, doc_emb, doc_indices, clinical, pastwork, logistics, trust, context, relevance):
        self.indices = indices
        self.case_emb = case_emb
        self.doc_emb = doc_emb
        self.doc_indices = doc_indices
        self.clinical = clinical
        self.pastwork = pastwork
        self.logistics = logistics
        self.trust = trust
        self.context = context
        self.relevance = relevance

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        case_idx = self.indices[idx]
        rel = self.relevance[case_idx]

        pos_mask = rel >= 3
        if pos_mask.sum() == 0:
            pos_mask = rel == rel.max()

        neg_mask = rel <= 1
        if neg_mask.sum() == 0:
            neg_mask = rel < rel.max()

        pos_pool = torch.where(pos_mask)[0]
        neg_pool = torch.where(neg_mask)[0]

        pos_local = pos_pool[torch.randint(len(pos_pool), (1,))].item()
        neg_local = neg_pool[torch.randint(len(neg_pool), (1,))].item()

        pos_global = self.doc_indices[case_idx, pos_local]
        neg_global = self.doc_indices[case_idx, neg_local]

        return {
            'case_embedding': self.case_emb[case_idx],
            'pos_doctor_embedding': self.doc_emb[pos_global],
            'neg_doctor_embedding': self.doc_emb[neg_global],
            'pos_clinical': self.clinical[case_idx, pos_local],
            'neg_clinical': self.clinical[case_idx, neg_local],
            'pos_pastwork': self.pastwork[case_idx, pos_local],
            'neg_pastwork': self.pastwork[case_idx, neg_local],
            'pos_logistics': self.logistics[case_idx, pos_local],
            'neg_logistics': self.logistics[case_idx, neg_local],
            'pos_trust': self.trust[case_idx, pos_local],
            'neg_trust': self.trust[case_idx, neg_local],
            'context': self.context[case_idx]
        }

class EvalDataset(Dataset):
    def __init__(self, indices, case_emb, doc_emb, doc_indices, clinical, pastwork, logistics, trust, context, relevance, metadata):
        self.indices = indices
        self.case_emb = case_emb
        self.doc_emb = doc_emb
        self.doc_indices = doc_indices
        self.clinical = clinical
        self.pastwork = pastwork
        self.logistics = logistics
        self.trust = trust
        self.context = context
        self.relevance = relevance
        self.metadata = metadata

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        case_idx = self.indices[idx]
        global_docs = self.doc_indices[case_idx]
        return {
            'case_embedding': self.case_emb[case_idx],
            'doctor_embeddings': self.doc_emb[global_docs],
            'clinical': self.clinical[case_idx],
            'pastwork': self.pastwork[case_idx],
            'logistics': self.logistics[case_idx],
            'trust': self.trust[case_idx],
            'context': self.context[case_idx],
            'relevance': self.relevance[case_idx],
            'context_category': self.metadata['context_category'][case_idx]
        }

train_ds = PairwiseBaselineDataset(
    splits['train'], case_embeddings, doctor_embeddings, doctor_indices,
    clinical_features, pastwork_features, logistics_features, trust_features,
    context_features, relevance_labels
)
val_ds = EvalDataset(
    splits['val'], case_embeddings, doctor_embeddings, doctor_indices,
    clinical_features, pastwork_features, logistics_features, trust_features,
    context_features, relevance_labels, case_metadata
)
test_ds = EvalDataset(
    splits['test'], case_embeddings, doctor_embeddings, doctor_indices,
    clinical_features, pastwork_features, logistics_features, trust_features,
    context_features, relevance_labels, case_metadata
)

train_loader = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False)

print(f'Train batches: {len(train_loader)}')

In [None]:
# ============================================================
# CELL 4: Baseline Models
# ============================================================
class StaticMCDA:
    def __init__(self):
        self.weights = {'clinical': 0.40, 'pastwork': 0.25, 'logistics': 0.25, 'trust': 0.10}
        self.clinical_weights = [0.55, 0.20, 0.15, 0.10]
        self.pastwork_weights = [0.30, 0.25, 0.20, 0.15, 0.10]
        self.logistics_weights = [0.30, 0.25, 0.20, 0.15, 0.10]
        self.trust_weights = [0.50, 0.30, 0.20]

    def score(self, clinical, pastwork, logistics, trust):
        c_w = torch.tensor(self.clinical_weights, dtype=clinical.dtype, device=clinical.device)
        p_w = torch.tensor(self.pastwork_weights, dtype=pastwork.dtype, device=pastwork.device)
        l_w = torch.tensor(self.logistics_weights, dtype=logistics.dtype, device=logistics.device)
        t_w = torch.tensor(self.trust_weights, dtype=trust.dtype, device=trust.device)

        c_score = (clinical * c_w).sum(-1)
        p_score = (pastwork * p_w).sum(-1)
        l_score = (logistics * l_w).sum(-1)
        t_score = (trust * t_w).sum(-1)

        return self.weights['clinical'] * c_score + self.weights['pastwork'] * p_score + self.weights['logistics'] * l_score + self.weights['trust'] * t_score

class SimpleMLP(nn.Module):
    def __init__(self, embed_dim=384):
        super().__init__()
        input_dim = embed_dim + embed_dim + 4 + 5 + 5 + 3 + 8
        self.network = nn.Sequential(
            nn.Linear(input_dim, 256), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(256, 128), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(128, 64), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(64, 1), nn.Sigmoid()
        )

    def forward(self, patient_emb, doctor_emb, clinical, pastwork, logistics, trust, context):
        x = torch.cat([patient_emb, doctor_emb, clinical, pastwork, logistics, trust, context], dim=-1)
        return {'score': self.network(x)}

class NeuralRanker(nn.Module):
    def __init__(self, embed_dim=384, hidden_dim=256):
        super().__init__()
        self.patient_proj = nn.Linear(embed_dim, hidden_dim)
        self.doctor_proj = nn.Linear(embed_dim, hidden_dim)
        self.cross_attention = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=4, batch_first=True)
        self.layer_norm = nn.LayerNorm(hidden_dim)
        self.feature_encoder = nn.Sequential(nn.Linear(25, 64), nn.ReLU(), nn.Linear(64, 64))
        self.scorer = nn.Sequential(
            nn.Linear(hidden_dim + 64, 128), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid()
        )

    def forward(self, patient_emb, doctor_emb, clinical, pastwork, logistics, trust, context):
        p = self.patient_proj(patient_emb).unsqueeze(1)
        d = self.doctor_proj(doctor_emb).unsqueeze(1)
        interaction, _ = self.cross_attention(p, d, d)
        interaction = self.layer_norm(interaction.squeeze(1))
        feat = torch.cat([clinical, pastwork, logistics, trust, context], dim=-1)
        feat_enc = self.feature_encoder(feat)
        return {'score': self.scorer(torch.cat([interaction, feat_enc], dim=-1))}

class DINModel(nn.Module):
    def __init__(self, embed_dim=384):
        super().__init__()
        self.case_encoder = nn.Linear(embed_dim + 8, 128)
        self.doctor_encoder = nn.Linear(embed_dim + 17, 128)
        self.attention = nn.Sequential(nn.Linear(128 * 3, 64), nn.ReLU(), nn.Linear(64, 1))
        self.scorer = nn.Sequential(nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid())

    def forward(self, patient_emb, doctor_emb, clinical, pastwork, logistics, trust, context):
        case_enc = F.relu(self.case_encoder(torch.cat([patient_emb, context], dim=-1)))
        doc_feat = torch.cat([clinical, pastwork, logistics, trust], dim=-1)
        doc_enc = F.relu(self.doctor_encoder(torch.cat([doctor_emb, doc_feat], dim=-1)))
        attn = torch.sigmoid(self.attention(torch.cat([case_enc, doc_enc, case_enc * doc_enc], dim=-1)))
        return {'score': self.scorer(attn * doc_enc)}

In [None]:
# ============================================================
# CELL 5: Metrics and Evaluation
# ============================================================
def ndcg_at_k(scores, labels, k):
    ranking = np.argsort(-scores)[:k]
    dcg = sum((2**labels[r] - 1) / np.log2(i + 2) for i, r in enumerate(ranking))
    ideal = np.argsort(-labels)[:k]
    idcg = sum((2**labels[r] - 1) / np.log2(i + 2) for i, r in enumerate(ideal))
    return dcg / idcg if idcg > 0 else 0.0

def evaluate_model(model_or_mcda, dataloader, device, is_mcda=False):
    if not is_mcda:
        model_or_mcda.eval()

    ndcg5, ndcg10 = [], []
    context_res = {}

    with torch.no_grad():
        for batch in tqdm(dataloader, desc='Evaluating', leave=False):
            case_emb = batch['case_embedding'].to(device)
            docs = batch['doctor_embeddings'].squeeze(0).to(device)
            clinical = batch['clinical'].squeeze(0).to(device)
            pastwork = batch['pastwork'].squeeze(0).to(device)
            logistics = batch['logistics'].squeeze(0).to(device)
            trust = batch['trust'].squeeze(0).to(device)
            context = batch['context'].to(device)
            labels = batch['relevance'].squeeze(0).numpy()
            ctx = batch['context_category'][0]

            scores = []
            for i in range(docs.shape[0]):
                if is_mcda:
                    s = model_or_mcda.score(
                        clinical[i:i+1], pastwork[i:i+1], logistics[i:i+1], trust[i:i+1]
                    ).item()
                else:
                    s = model_or_mcda(
                        case_emb, docs[i:i+1], clinical[i:i+1], pastwork[i:i+1],
                        logistics[i:i+1], trust[i:i+1], context
                    )['score'].item()
                scores.append(s)

            scores = np.array(scores)
            n5 = ndcg_at_k(scores, labels, 5)
            n10 = ndcg_at_k(scores, labels, 10)
            ndcg5.append(n5)
            ndcg10.append(n10)
            context_res.setdefault(ctx, []).append(n5)

    overall = {
        'ndcg@5': {'mean': float(np.mean(ndcg5)), 'std': float(np.std(ndcg5))},
        'ndcg@10': {'mean': float(np.mean(ndcg10)), 'std': float(np.std(ndcg10))}
    }
    stratified = {k: {'mean': float(np.mean(v)), 'std': float(np.std(v))} for k, v in context_res.items()}
    return {'overall': overall, 'stratified': stratified}

def ranking_loss(pos_score, neg_score, margin=0.1):
    return F.relu(margin - (pos_score - neg_score)).mean()

In [None]:
# ============================================================
# CELL 6: Training Utilities
# ============================================================
def train_baseline(model, train_loader, val_loader, device, mode='bce', epochs=50, patience=10, lr=1e-4):
    model = model.to(device)
    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)

    best_ndcg = -1.0
    patience_counter = 0
    best_state = None
    history = {'train_loss': [], 'val_ndcg5': []}

    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        n_batches = 0

        for batch in tqdm(train_loader, desc=f'Train {mode} epoch {epoch}', leave=False):
            case_emb = batch['case_embedding'].to(device)
            pos_doc = batch['pos_doctor_embedding'].to(device)
            neg_doc = batch['neg_doctor_embedding'].to(device)
            pos_clinical = batch['pos_clinical'].to(device)
            neg_clinical = batch['neg_clinical'].to(device)
            pos_pastwork = batch['pos_pastwork'].to(device)
            neg_pastwork = batch['neg_pastwork'].to(device)
            pos_logistics = batch['pos_logistics'].to(device)
            neg_logistics = batch['neg_logistics'].to(device)
            pos_trust = batch['pos_trust'].to(device)
            neg_trust = batch['neg_trust'].to(device)
            context = batch['context'].to(device)

            optimizer.zero_grad()

            pos_out = model(case_emb, pos_doc, pos_clinical, pos_pastwork, pos_logistics, pos_trust, context)['score']
            neg_out = model(case_emb, neg_doc, neg_clinical, neg_pastwork, neg_logistics, neg_trust, context)['score']

            if mode == 'bce':
                loss = (
                    F.binary_cross_entropy(pos_out, torch.ones_like(pos_out)) +
                    F.binary_cross_entropy(neg_out, torch.zeros_like(neg_out))
                ) / 2.0
            else:
                loss = ranking_loss(pos_out, neg_out, margin=0.1)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_loss += loss.item()
            n_batches += 1

        avg_loss = total_loss / max(n_batches, 1)
        val = evaluate_model(model, val_loader, device, is_mcda=False)
        val_ndcg5 = val['overall']['ndcg@5']['mean']
        history['train_loss'].append(avg_loss)
        history['val_ndcg5'].append(val_ndcg5)

        print(f'Epoch {epoch}: loss={avg_loss:.4f}, val_ndcg5={val_ndcg5:.4f}')

        if val_ndcg5 > best_ndcg:
            best_ndcg = val_ndcg5
            patience_counter = 0
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f'Early stopping at epoch {epoch}')
                break

        scheduler.step()

    if best_state is not None:
        model.load_state_dict(best_state)
    return model, history

In [None]:
# ============================================================
# CELL 7: Run All Baselines
# ============================================================
all_results = {}

print('\n[1/4] StaticMCDA')
mcda = StaticMCDA()
all_results['StaticMCDA'] = evaluate_model(mcda, test_loader, device, is_mcda=True)
all_results['StaticMCDA']['history'] = {'train_loss': [], 'val_ndcg5': []}

print('\n[2/4] SimpleMLP (BCE)')
mlp = SimpleMLP()
mlp, mlp_hist = train_baseline(mlp, train_loader, val_loader, device, mode='bce', epochs=50, patience=10, lr=1e-4)
torch.save(mlp.state_dict(), '/kaggle/working/results/best_simple_mlp.pt')
all_results['SimpleMLP'] = evaluate_model(mlp, test_loader, device, is_mcda=False)
all_results['SimpleMLP']['history'] = mlp_hist

print('\n[3/4] NeuralRanker (Ranking Loss)')
ranker = NeuralRanker()
ranker, ranker_hist = train_baseline(ranker, train_loader, val_loader, device, mode='ranking', epochs=50, patience=10, lr=1e-4)
torch.save(ranker.state_dict(), '/kaggle/working/results/best_neural_ranker.pt')
all_results['NeuralRanker'] = evaluate_model(ranker, test_loader, device, is_mcda=False)
all_results['NeuralRanker']['history'] = ranker_hist

print('\n[4/4] DINModel (Ranking Loss)')
din = DINModel()
din, din_hist = train_baseline(din, train_loader, val_loader, device, mode='ranking', epochs=50, patience=10, lr=1e-4)
torch.save(din.state_dict(), '/kaggle/working/results/best_din_model.pt')
all_results['DINModel'] = evaluate_model(din, test_loader, device, is_mcda=False)
all_results['DINModel']['history'] = din_hist

print('\nAll baselines finished.')

In [None]:
# ============================================================
# CELL 8: Save and Summary
# ============================================================
with open('/kaggle/working/results/baseline_results.json', 'w') as f:
    json.dump(all_results, f, indent=2)

print('Saved: /kaggle/working/results/baseline_results.json')
print('\nSummary (NDCG@5 mean ± std):')
for name, res in all_results.items():
    m = res['overall']['ndcg@5']['mean']
    s = res['overall']['ndcg@5']['std']
    print(f'  {name}: {m:.4f} ± {s:.4f}')