# Efficiency Benchmarking
=======================
Measure inference latency, parameter count, and training time.

In [None]:
import time
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

class SimpleMLP(nn.Module):
    def __init__(self, embed_dim=384):
        super().__init__()
        input_dim = embed_dim + embed_dim + 4 + 5 + 5 + 3 + 8
        self.network = nn.Sequential(
            nn.Linear(input_dim, 256), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(256, 128), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(128, 64), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(64, 1), nn.Sigmoid()
        )

    def forward(self, patient_emb, doctor_emb, clinical, pastwork, logistics, trust, context):
        x = torch.cat([patient_emb, doctor_emb, clinical, pastwork, logistics, trust, context], dim=-1)
        return self.network(x)

class NeuralRanker(nn.Module):
    def __init__(self, embed_dim=384, hidden_dim=256):
        super().__init__()
        self.patient_proj = nn.Linear(embed_dim, hidden_dim)
        self.doctor_proj = nn.Linear(embed_dim, hidden_dim)
        self.cross_attention = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=4, batch_first=True)
        self.layer_norm = nn.LayerNorm(hidden_dim)
        self.feature_encoder = nn.Sequential(nn.Linear(25, 64), nn.ReLU(), nn.Linear(64, 64))
        self.scorer = nn.Sequential(
            nn.Linear(hidden_dim + 64, 128), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid()
        )

    def forward(self, patient_emb, doctor_emb, clinical, pastwork, logistics, trust, context):
        p = self.patient_proj(patient_emb).unsqueeze(1)
        d = self.doctor_proj(doctor_emb).unsqueeze(1)
        interaction, _ = self.cross_attention(p, d, d)
        interaction = self.layer_norm(interaction.squeeze(1))
        feat = self.feature_encoder(torch.cat([clinical, pastwork, logistics, trust, context], dim=-1))
        return self.scorer(torch.cat([interaction, feat], dim=-1))

class DINModel(nn.Module):
    def __init__(self, embed_dim=384):
        super().__init__()
        self.case_encoder = nn.Linear(embed_dim + 8, 128)
        self.doctor_encoder = nn.Linear(embed_dim + 17, 128)
        self.attention = nn.Sequential(nn.Linear(128 * 3, 64), nn.ReLU(), nn.Linear(64, 1))
        self.scorer = nn.Sequential(nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid())

    def forward(self, patient_emb, doctor_emb, clinical, pastwork, logistics, trust, context):
        case_enc = F.relu(self.case_encoder(torch.cat([patient_emb, context], dim=-1)))
        doc_feat = torch.cat([clinical, pastwork, logistics, trust], dim=-1)
        doc_enc = F.relu(self.doctor_encoder(torch.cat([doctor_emb, doc_feat], dim=-1)))
        attn = torch.sigmoid(self.attention(torch.cat([case_enc, doc_enc, case_enc * doc_enc], dim=-1)))
        return self.scorer(attn * doc_enc)

class DocMatchNetOriginal(nn.Module):
    def __init__(self, embed_dim=384, hidden_dim=256, gate_dim=32):
        super().__init__()
        self.patient_proj = nn.Linear(embed_dim, hidden_dim)
        self.doctor_proj = nn.Linear(embed_dim, hidden_dim)
        self.cross_attention = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=4, batch_first=True)
        self.norm = nn.LayerNorm(hidden_dim)
        self.ce = nn.Sequential(nn.Linear(4, gate_dim), nn.ReLU(), nn.Linear(gate_dim, gate_dim), nn.ReLU())
        self.pe = nn.Sequential(nn.Linear(5, gate_dim), nn.ReLU(), nn.Linear(gate_dim, gate_dim), nn.ReLU())
        self.le = nn.Sequential(nn.Linear(5, gate_dim), nn.ReLU(), nn.Linear(gate_dim, gate_dim), nn.ReLU())
        self.te = nn.Sequential(nn.Linear(3, gate_dim), nn.ReLU(), nn.Linear(gate_dim, gate_dim), nn.ReLU())
        self.cg = nn.Sequential(nn.Linear(hidden_dim + 8, 64), nn.ReLU(), nn.Linear(64, gate_dim), nn.Sigmoid())
        self.pg = nn.Sequential(nn.Linear(hidden_dim + 8, 64), nn.ReLU(), nn.Linear(64, gate_dim), nn.Sigmoid())
        self.lg = nn.Sequential(nn.Linear(hidden_dim + 8, 64), nn.ReLU(), nn.Linear(64, gate_dim), nn.Sigmoid())
        self.tg = nn.Sequential(nn.Linear(hidden_dim + 8, 64), nn.ReLU(), nn.Linear(64, gate_dim), nn.Sigmoid())
        self.scorer = nn.Sequential(nn.Linear(gate_dim * 4, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid())

    def forward(self, patient_emb, doctor_emb, clinical, pastwork, logistics, trust, context):
        p = self.patient_proj(patient_emb).unsqueeze(1)
        d = self.doctor_proj(doctor_emb).unsqueeze(1)
        interaction, _ = self.cross_attention(p, d, d)
        interaction = self.norm(interaction.squeeze(1))
        g_in = torch.cat([interaction, context], dim=-1)
        gc, gp, gl, gt = self.cg(g_in), self.pg(g_in), self.lg(g_in), self.tg(g_in)
        fused = torch.cat([gc*self.ce(clinical), gp*self.pe(pastwork), gl*self.le(logistics), gt*self.te(trust)], dim=-1)
        return self.scorer(fused)

class DocMatchNetJEPA(nn.Module):
    def __init__(self, embed_dim=384, latent_dim=256, gate_dim=32, context_dim=8, dropout=0.1):
        super().__init__()
        self.patient_encoder = nn.Sequential(
            nn.Linear(embed_dim, 512), nn.LayerNorm(512), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(512, latent_dim), nn.LayerNorm(latent_dim)
        )
        self.doctor_encoder = nn.Sequential(
            nn.Linear(embed_dim, 512), nn.LayerNorm(512), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(512, latent_dim), nn.LayerNorm(latent_dim)
        )
        self.ce = nn.Sequential(nn.Linear(4, gate_dim), nn.GELU(), nn.Linear(gate_dim, gate_dim), nn.GELU())
        self.pe = nn.Sequential(nn.Linear(5, gate_dim), nn.GELU(), nn.Linear(gate_dim, gate_dim), nn.GELU())
        self.le = nn.Sequential(nn.Linear(5, gate_dim), nn.GELU(), nn.Linear(gate_dim, gate_dim), nn.GELU())
        self.te = nn.Sequential(nn.Linear(3, gate_dim), nn.GELU(), nn.Linear(gate_dim, gate_dim), nn.GELU())
        self.cg = nn.Sequential(nn.Linear(latent_dim + 8, 64), nn.GELU(), nn.Linear(64, gate_dim), nn.Sigmoid())
        self.pg = nn.Sequential(nn.Linear(latent_dim + 8, 64), nn.GELU(), nn.Linear(64, gate_dim), nn.Sigmoid())
        self.lg = nn.Sequential(nn.Linear(latent_dim + 8, 64), nn.GELU(), nn.Linear(64, gate_dim), nn.Sigmoid())
        self.tg = nn.Sequential(nn.Linear(latent_dim + 8, 64), nn.GELU(), nn.Linear(64, gate_dim), nn.Sigmoid())
        self.predictor = nn.Sequential(nn.Linear(latent_dim + gate_dim * 4, 256), nn.LayerNorm(256), nn.GELU(), nn.Linear(256, latent_dim))
        self.pp = nn.Sequential(nn.Linear(latent_dim, latent_dim), nn.GELU(), nn.Linear(latent_dim, 128))
        self.dp = nn.Sequential(nn.Linear(latent_dim, latent_dim), nn.GELU(), nn.Linear(latent_dim, 128))
        self.log_temperature = nn.Parameter(torch.log(torch.tensor(0.07)))

    def forward(self, patient_emb, doctor_emb, clinical, pastwork, logistics, trust, context):
        pl = self.patient_encoder(patient_emb)
        dl = self.doctor_encoder(doctor_emb)
        gi = torch.cat([pl, context], dim=-1)
        gc, gp, gl, gt = self.cg(gi), self.pg(gi), self.lg(gi), self.tg(gi)
        feat = torch.cat([gc*self.ce(clinical), gp*self.pe(pastwork), gl*self.le(logistics), gt*self.te(trust)], dim=-1)
        pred = self.predictor(torch.cat([pl, feat], dim=-1))
        pred_proj = self.pp(pred)
        doc_proj = self.dp(dl)
        pn = F.normalize(pred_proj, dim=-1)
        dn = F.normalize(doc_proj, dim=-1)
        score = ((pn * dn).sum(dim=-1, keepdim=True) + 1) / 2
        return score

def benchmark_model(model, device, n_warmup=50, n_measure=200):
    model.eval()
    model = model.to(device)

    batch_size = 1
    patient_emb = torch.randn(batch_size, 384).to(device)
    doctor_emb = torch.randn(batch_size, 384).to(device)
    clinical = torch.randn(batch_size, 4).to(device)
    pastwork = torch.randn(batch_size, 5).to(device)
    logistics = torch.randn(batch_size, 5).to(device)
    trust = torch.randn(batch_size, 3).to(device)
    context = torch.randn(batch_size, 8).to(device)

    with torch.no_grad():
        for _ in range(n_warmup):
            _ = model(patient_emb, doctor_emb, clinical, pastwork, logistics, trust, context)

    if device.type == 'cuda':
        torch.cuda.synchronize()

    times = []
    with torch.no_grad():
        for _ in range(n_measure):
            if device.type == 'cuda':
                torch.cuda.synchronize()
            start = time.perf_counter()
            _ = model(patient_emb, doctor_emb, clinical, pastwork, logistics, trust, context)
            if device.type == 'cuda':
                torch.cuda.synchronize()
            end = time.perf_counter()
            times.append((end - start) * 1000)

    return {
        'mean_ms': float(np.mean(times)),
        'std_ms': float(np.std(times)),
        'p50_ms': float(np.percentile(times, 50)),
        'p95_ms': float(np.percentile(times, 95)),
        'p99_ms': float(np.percentile(times, 99))
    }

def count_parameters(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable

models_to_bench = {
    'SimpleMLP': SimpleMLP(),
    'NeuralRanker': NeuralRanker(),
    'DINModel': DINModel(),
    'DocMatchNet-Original': DocMatchNetOriginal(),
    'DocMatchNet-JEPA': DocMatchNetJEPA()
}

efficiency_results = {}

for model_name, model in models_to_bench.items():
    print(f"\nBenchmarking {model_name}...")
    total_params, trainable_params = count_parameters(model)

    gpu_device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    gpu_results = benchmark_model(model, gpu_device)
    cpu_results = benchmark_model(model, torch.device('cpu'))

    efficiency_results[model_name] = {
        'total_params': int(total_params),
        'trainable_params': int(trainable_params),
        'gpu_latency_ms': gpu_results,
        'cpu_latency_ms': cpu_results
    }

    print(f"  Parameters: {trainable_params:,}")
    print(f"  GPU latency: {gpu_results['mean_ms']:.2f} ± {gpu_results['std_ms']:.2f} ms")
    print(f"  CPU latency: {cpu_results['mean_ms']:.2f} ± {cpu_results['std_ms']:.2f} ms")

print('\n\nBatch scoring benchmark (100 doctors per case):')
for model_name, model in models_to_bench.items():
    model = model.to(device).eval()
    patient_emb = torch.randn(1, 384).to(device)
    context = torch.randn(1, 8).to(device)

    with torch.no_grad():
        if device.type == 'cuda':
            torch.cuda.synchronize()
        start = time.perf_counter()

        for _ in range(100):
            doctor_emb = torch.randn(1, 384).to(device)
            clinical = torch.randn(1, 4).to(device)
            pastwork = torch.randn(1, 5).to(device)
            logistics = torch.randn(1, 5).to(device)
            trust = torch.randn(1, 3).to(device)
            _ = model(patient_emb, doctor_emb, clinical, pastwork, logistics, trust, context)

        if device.type == 'cuda':
            torch.cuda.synchronize()
        end = time.perf_counter()

    batch_time = (end - start) * 1000
    print(f"  {model_name}: {batch_time:.1f} ms for 100 doctors")

with open('/kaggle/working/results/efficiency_results.json', 'w') as f:
    json.dump(efficiency_results, f, indent=2, default=str)

print('\n✅ Efficiency results saved!')