# Gate Activation Analysis
========================
Deep analysis of context-aware gate behavior for paper Section V.D.

In [None]:
# ============================================================
# CELL 1: Load Model and Data
# ============================================================
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy import stats
from torch.utils.data import Dataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

DATA_DIR = '/kaggle/input/docmatchnet-jepa-data/data'

doctor_embeddings = torch.load(f'{DATA_DIR}/doctor_embeddings.pt', weights_only=False)
case_embeddings = torch.load(f'{DATA_DIR}/case_embeddings.pt', weights_only=False)
clinical_features = torch.load(f'{DATA_DIR}/clinical_features.pt', weights_only=False)
pastwork_features = torch.load(f'{DATA_DIR}/pastwork_features.pt', weights_only=False)
logistics_features = torch.load(f'{DATA_DIR}/logistics_features.pt', weights_only=False)
trust_features = torch.load(f'{DATA_DIR}/trust_features.pt', weights_only=False)
context_features = torch.load(f'{DATA_DIR}/context_features.pt', weights_only=False)
relevance_labels = torch.load(f'{DATA_DIR}/relevance_labels.pt', weights_only=False)
doctor_indices = torch.load(f'{DATA_DIR}/doctor_indices.pt', weights_only=False)
case_metadata = torch.load(f'{DATA_DIR}/case_metadata.pt', weights_only=False)
splits = torch.load(f'{DATA_DIR}/splits.pt', weights_only=False)

class DocMatchDatasetEval(Dataset):
    def __init__(self, indices, case_emb, doc_emb, doc_indices,
                 clinical, pastwork, logistics, trust, context,
                 relevance, metadata):
        self.indices = indices
        self.case_emb = case_emb
        self.doc_emb = doc_emb
        self.doc_indices = doc_indices
        self.clinical = clinical
        self.pastwork = pastwork
        self.logistics = logistics
        self.trust = trust
        self.context = context
        self.relevance = relevance
        self.metadata = metadata

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        case_idx = self.indices[idx]
        doc_global_indices = self.doc_indices[case_idx]
        return {
            'case_embedding': self.case_emb[case_idx],
            'doctor_embeddings': self.doc_emb[doc_global_indices],
            'clinical': self.clinical[case_idx],
            'pastwork': self.pastwork[case_idx],
            'logistics': self.logistics[case_idx],
            'trust': self.trust[case_idx],
            'context': self.context[case_idx],
            'relevance': self.relevance[case_idx],
            'context_category': self.metadata['context_category'][case_idx]
        }

class DocMatchNetJEPA(nn.Module):
    def __init__(self, embed_dim=384, latent_dim=256, gate_dim=32, context_dim=8, dropout=0.1):
        super().__init__()

        self.patient_encoder = nn.Sequential(
            nn.Linear(embed_dim, 512), nn.LayerNorm(512), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(512, latent_dim), nn.LayerNorm(latent_dim)
        )
        self.doctor_encoder = nn.Sequential(
            nn.Linear(embed_dim, 512), nn.LayerNorm(512), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(512, latent_dim), nn.LayerNorm(latent_dim)
        )

        self.clinical_encoder = self._make_encoder(4, gate_dim)
        self.pastwork_encoder = self._make_encoder(5, gate_dim)
        self.logistics_encoder = self._make_encoder(5, gate_dim)
        self.trust_encoder = self._make_encoder(3, gate_dim)

        gate_input_dim = latent_dim + context_dim
        self.clinical_gate = self._make_gate(gate_input_dim, gate_dim)
        self.pastwork_gate = self._make_gate(gate_input_dim, gate_dim)
        self.logistics_gate = self._make_gate(gate_input_dim, gate_dim)
        self.trust_gate = self._make_gate(gate_input_dim, gate_dim)
        self._init_gate_biases()

        self.predictor = nn.Sequential(
            nn.Linear(latent_dim + gate_dim * 4, 256), nn.LayerNorm(256), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(256, 256), nn.LayerNorm(256), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(256, latent_dim)
        )
        self.predictor_proj = nn.Sequential(nn.Linear(latent_dim, latent_dim), nn.GELU(), nn.Linear(latent_dim, 128))
        self.doctor_proj = nn.Sequential(nn.Linear(latent_dim, latent_dim), nn.GELU(), nn.Linear(latent_dim, 128))
        self.log_temperature = nn.Parameter(torch.log(torch.tensor(0.07)))

    def _make_encoder(self, input_dim, output_dim):
        return nn.Sequential(
            nn.Linear(input_dim, output_dim), nn.BatchNorm1d(output_dim), nn.GELU(),
            nn.Linear(output_dim, output_dim), nn.BatchNorm1d(output_dim), nn.GELU()
        )

    def _make_gate(self, input_dim, output_dim):
        return nn.Sequential(nn.Linear(input_dim, 64), nn.GELU(), nn.Linear(64, output_dim), nn.Sigmoid())

    def _init_gate_biases(self):
        nn.init.constant_(self.clinical_gate[-2].bias, 0.4)
        nn.init.constant_(self.pastwork_gate[-2].bias, 0.0)
        nn.init.constant_(self.logistics_gate[-2].bias, 0.0)
        nn.init.constant_(self.trust_gate[-2].bias, -0.4)

    def forward(self, patient_emb, doctor_emb, clinical, pastwork, logistics, trust, context):
        patient_latent = self.patient_encoder(patient_emb)
        doctor_latent = self.doctor_encoder(doctor_emb)

        enc_clinical = self.clinical_encoder(clinical)
        enc_pastwork = self.pastwork_encoder(pastwork)
        enc_logistics = self.logistics_encoder(logistics)
        enc_trust = self.trust_encoder(trust)

        gate_input = torch.cat([patient_latent, context], dim=-1)
        g_clinical = self.clinical_gate(gate_input)
        g_pastwork = self.pastwork_gate(gate_input)
        g_logistics = self.logistics_gate(gate_input)
        g_trust = self.trust_gate(gate_input)

        gated_features = torch.cat([
            g_clinical * enc_clinical,
            g_pastwork * enc_pastwork,
            g_logistics * enc_logistics,
            g_trust * enc_trust
        ], dim=-1)

        predictor_input = torch.cat([patient_latent, gated_features], dim=-1)
        predicted_ideal = self.predictor(predictor_input)

        pred_proj = self.predictor_proj(predicted_ideal)
        doc_proj = self.doctor_proj(doctor_latent)

        pred_norm = F.normalize(pred_proj, dim=-1)
        doc_norm = F.normalize(doc_proj, dim=-1)
        score = (pred_norm * doc_norm).sum(dim=-1, keepdim=True)
        score = (score + 1) / 2

        gate_means = {
            'clinical': g_clinical.mean(dim=-1),
            'pastwork': g_pastwork.mean(dim=-1),
            'logistics': g_logistics.mean(dim=-1),
            'trust': g_trust.mean(dim=-1)
        }

        return {
            'score': score,
            'predicted_ideal': pred_proj,
            'doctor_embedding': doc_proj,
            'gates': {
                'clinical': g_clinical,
                'pastwork': g_pastwork,
                'logistics': g_logistics,
                'trust': g_trust
            },
            'gate_means': gate_means,
            'temperature': self.log_temperature.exp()
        }

# Load best JEPA model
model = DocMatchNetJEPA().to(device)
model.load_state_dict(torch.load('/kaggle/working/results/best_jepa_model.pt', map_location=device, weights_only=False))
model.eval()

# Load test data
test_ds = DocMatchDatasetEval(
    splits['test'], case_embeddings, doctor_embeddings, doctor_indices,
    clinical_features, pastwork_features, logistics_features, trust_features,
    context_features, relevance_labels, case_metadata
)

In [None]:
# ============================================================
# CELL 2: Collect Gate Activations for All Test Cases
# ============================================================
all_gate_activations = {
    'clinical': [], 'pastwork': [], 'logistics': [], 'trust': []
}
all_contexts = []
all_urgencies = []

with torch.no_grad():
    for i in range(len(test_ds)):
        sample = test_ds[i]
        case_emb = sample['case_embedding'].unsqueeze(0).to(device)
        context = sample['context'].unsqueeze(0).to(device)

        # Use first doctor's features to compute gates
        # (gates depend on patient + context, not on specific doctor)
        doc_emb = sample['doctor_embeddings'][0:1].to(device)
        clinical = sample['clinical'][0:1].to(device)
        pastwork = sample['pastwork'][0:1].to(device)
        logistics = sample['logistics'][0:1].to(device)
        trust = sample['trust'][0:1].to(device)

        output = model(case_emb, doc_emb, clinical, pastwork,
                      logistics, trust, context)

        for gate_name in all_gate_activations:
            all_gate_activations[gate_name].append(
                output['gate_means'][gate_name].item()
            )

        all_contexts.append(sample['context_category'])
        urgency_val = context[0, 0].item()  # First context feature = urgency
        all_urgencies.append(urgency_val)

# Convert to arrays
for gate in all_gate_activations:
    all_gate_activations[gate] = np.array(all_gate_activations[gate])
all_contexts = np.array(all_contexts)

print(f"Collected gates for {len(all_contexts)} test cases")

In [None]:
# ============================================================
# CELL 3: Statistical Tests - Kruskal-Wallis
# ============================================================
print("\n" + "=" * 60)
print("KRUSKAL-WALLIS TEST: Do gates differ across contexts?")
print("=" * 60)

contexts_unique = ['routine', 'complex', 'rare_disease', 'emergency', 'pediatric']
statistical_results = {}

for gate_name in ['clinical', 'pastwork', 'logistics', 'trust']:
    groups = []
    for ctx in contexts_unique:
        mask = all_contexts == ctx
        if mask.sum() > 5:
            groups.append(all_gate_activations[gate_name][mask])

    if len(groups) >= 2:
        stat, p_value = stats.kruskal(*groups)

        n_total = sum(len(g) for g in groups)
        epsilon_sq = (stat - len(groups) + 1) / (n_total - len(groups))

        statistical_results[gate_name] = {
            'kruskal_stat': float(stat),
            'p_value': float(p_value),
            'epsilon_squared': float(epsilon_sq),
            'significant': p_value < 0.05
        }

        sig = "***" if p_value < 0.001 else "**" if p_value < 0.01 else "*" if p_value < 0.05 else "ns"
        print(f"\n{gate_name} gate:")
        print(f"  H = {stat:.4f}, p = {p_value:.2e} {sig}")
        print(f"  ε² = {epsilon_sq:.4f}")

        print("  Post-hoc pairwise comparisons:")
        for i, ctx_a in enumerate(contexts_unique):
            for j, ctx_b in enumerate(contexts_unique):
                if j <= i:
                    continue
                mask_a = all_contexts == ctx_a
                mask_b = all_contexts == ctx_b
                if mask_a.sum() > 5 and mask_b.sum() > 5:
                    u_stat, u_p = stats.mannwhitneyu(
                        all_gate_activations[gate_name][mask_a],
                        all_gate_activations[gate_name][mask_b],
                        alternative='two-sided'
                    )
                    sig_u = "*" if u_p < 0.05 / 10 else ""
                    print(f"    {ctx_a} vs {ctx_b}: U={u_stat:.0f}, p={u_p:.4e} {sig_u}")

In [None]:
# ============================================================
# CELL 4: Correlation Analysis
# ============================================================
print("\n" + "=" * 60)
print("GATE CORRELATION ANALYSIS")
print("=" * 60)

gate_matrix = np.column_stack([
    all_gate_activations['clinical'],
    all_gate_activations['pastwork'],
    all_gate_activations['logistics'],
    all_gate_activations['trust']
])

correlation_matrix = np.corrcoef(gate_matrix.T)
gate_names = ['Clinical', 'PastWork', 'Logistics', 'Trust']

print("\nPearson Correlation Matrix:")
print(f"{'':>12s}", end="")
for name in gate_names:
    print(f"{name:>12s}", end="")
print()

for i, name_i in enumerate(gate_names):
    print(f"{name_i:>12s}", end="")
    for j in range(len(gate_names)):
        print(f"{correlation_matrix[i, j]:>12.3f}", end="")
    print()

print("\nInterpretation:")
print("Low correlations between gates = gates are learning DIFFERENT things (good!)")
print("High correlation would suggest redundancy (bad)")

In [None]:
# ============================================================
# CELL 5: Context-Specific Gate Statistics Table
# ============================================================
print("\n" + "=" * 60)
print("CONTEXT-SPECIFIC GATE MEANS")
print("=" * 60)

context_gate_table = {}
print(f"\n{'Context':<15} {'Clinical':>10} {'PastWork':>10} {'Logistics':>10} {'Trust':>10}")
print("-" * 60)

for ctx in contexts_unique:
    mask = all_contexts == ctx
    if mask.sum() > 0:
        row = {}
        print(f"{ctx:<15}", end="")
        for gate in ['clinical', 'pastwork', 'logistics', 'trust']:
            mean_val = all_gate_activations[gate][mask].mean()
            row[gate] = float(mean_val)
            print(f"{mean_val:>10.4f}", end="")
        print()
        context_gate_table[ctx] = row

In [None]:
# ============================================================
# CELL 6: Clinical Interpretation
# ============================================================
print("\n" + "=" * 60)
print("CLINICAL INTERPRETATION OF GATE PATTERNS")
print("=" * 60)

interpretations = []

if 'emergency' in context_gate_table:
    emerg = context_gate_table['emergency']
    routine = context_gate_table.get('routine', {})

    if emerg.get('clinical', 0) > routine.get('clinical', 0):
        interpretations.append(
            "Emergency cases show HIGHER clinical gate activation, "
            "indicating the model prioritizes clinical expertise match "
            "when urgency is high."
        )

    if emerg.get('logistics', 0) > routine.get('logistics', 0):
        interpretations.append(
            "Emergency cases also elevate logistics gate, suggesting "
            "proximity and availability become critical in emergencies."
        )
    elif emerg.get('logistics', 0) < routine.get('logistics', 0):
        interpretations.append(
            "Emergency cases LOWER logistics gate, suggesting that "
            "finding the RIGHT specialist matters more than convenience."
        )

if 'rare_disease' in context_gate_table:
    rare = context_gate_table['rare_disease']

    if rare.get('pastwork', 0) > context_gate_table.get('routine', {}).get('pastwork', 0):
        interpretations.append(
            "Rare disease cases increase pastwork gate activation, "
            "reflecting the importance of research publications and "
            "specialized experience for rare conditions."
        )

if 'complex' in context_gate_table:
    comp = context_gate_table['complex']

    if comp.get('clinical', 0) > 0.5 and comp.get('pastwork', 0) > 0.5:
        interpretations.append(
            "Complex multi-comorbidity cases activate BOTH clinical "
            "and pastwork gates highly, indicating the model recognizes "
            "that both expertise match and experience matter."
        )

print("\nKey Findings:")
if len(interpretations) == 0:
    print("
1. No strong directional differences were detected under current checkpoints.")
else:
    for i, interp in enumerate(interpretations, 1):
        print(f"\n{i}. {interp}")

In [None]:
# ============================================================
# CELL 7: Save Gate Analysis Results
# ============================================================
gate_analysis_results = {
    'statistical_tests': statistical_results,
    'correlation_matrix': correlation_matrix.tolist(),
    'context_gate_table': context_gate_table,
    'interpretations': interpretations,
    'gate_per_case': {
        gate: {ctx: all_gate_activations[gate][all_contexts == ctx].tolist()
               for ctx in contexts_unique}
        for gate in ['clinical', 'pastwork', 'logistics', 'trust']
    }
}

with open('/kaggle/working/results/gate_analysis_results.json', 'w') as f:
    json.dump(gate_analysis_results, f, indent=2)

print("\n✅ Gate analysis results saved!")

In [None]:
# ============================================================
# Clinical Case Studies
# ============================================================
"""
Clinical Case Studies
=====================
Detailed analysis of specific cases for paper Section V.G.
"""

# ============================================================
# Case Study 1: Emergency - Acute Chest Pain
# ============================================================
def analyze_case_study(case_idx, model, device):
    """
    Detailed analysis of a single case.
    Shows gate activations, top recommendations, and interpretation.
    """
    sample = test_ds[case_idx]
    case_emb = sample['case_embedding'].unsqueeze(0).to(device)
    context = sample['context'].unsqueeze(0).to(device)
    ctx_category = sample['context_category']
    relevance = sample['relevance'].numpy()
    
    print(f"\n{'=' * 70}")
    print(f"CASE STUDY: Case #{case_idx}")
    print(f"Context Category: {ctx_category}")
    print(f"Case Description: {cases_df.iloc[splits['test'][case_idx]]['symptom_description']}")
    print(f"Target Specialty: {cases_df.iloc[splits['test'][case_idx]]['target_specialty']}")
    print(f"Urgency: {cases_df.iloc[splits['test'][case_idx]]['urgency_level']}")
    print(f"{'=' * 70}")
    
    # Score all doctors
    scores = []
    gate_values = {g: [] for g in ['clinical', 'pastwork', 'logistics', 'trust']}
    
    with torch.no_grad():
        for i in range(sample['doctor_embeddings'].shape[0]):
            doc_emb = sample['doctor_embeddings'][i:i+1].to(device)
            clinical = sample['clinical'][i:i+1].to(device)
            pastwork = sample['pastwork'][i:i+1].to(device)
            logistics = sample['logistics'][i:i+1].to(device)
            trust = sample['trust'][i:i+1].to(device)
            
            output = model(case_emb, doc_emb, clinical, pastwork,
                          logistics, trust, context)
            
            scores.append(output['score'].item())
            
            if i == 0:  # Gates same for all doctors (depend on patient)
                for g_name in gate_values:
                    gate_values[g_name] = output['gate_means'][g_name].item()
    
    scores = np.array(scores)
    
    # Gate analysis
    print(f"\nGate Activations:")
    for g_name, g_val in gate_values.items():
        bar = '█' * int(g_val * 30)
        print(f"  {g_name:>12s}: {g_val:.4f} |{bar}")
    
    # Top 5 recommendations
    top5_idx = np.argsort(-scores)[:5]
    print(f"\nTop 5 Recommendations:")
    print(f"{'Rank':>4s} {'Score':>8s} {'Relevance':>10s} {'Specialty':>25s} {'Experience':>12s}")
    print("-" * 65)
    
    for rank, doc_local_idx in enumerate(top5_idx, 1):
        doc_global_idx = doctor_indices[splits['test'][case_idx], doc_local_idx]
        doctor = doctors_df.iloc[doc_global_idx]
        
        print(f"{rank:>4d} {scores[doc_local_idx]:>8.4f} "
              f"{relevance[doc_local_idx]:>10d} "
              f"{doctor['specialty']:>25s} "
              f"{doctor['years_experience']:>8d} yrs")
    
    # Compare with MCDA ranking
    mcda = mcda_scores[splits['test'][case_idx]]
    mcda_top5 = np.argsort(-mcda)[:5]
    
    print(f"\nMCDA Top 5 for comparison:")
    for rank, doc_local_idx in enumerate(mcda_top5, 1):
        doc_global_idx = doctor_indices[splits['test'][case_idx], doc_local_idx]
        doctor = doctors_df.iloc[doc_global_idx]
        print(f"{rank:>4d} {mcda[doc_local_idx]:>8.4f} "
              f"{relevance[doc_local_idx]:>10d} "
              f"{doctor['specialty']:>25s}")
    
    # NDCG comparison
    jepa_ndcg5 = ndcg_at_k(scores, relevance, 5)
    mcda_ndcg5 = ndcg_at_k(mcda, relevance, 5)
    
    print(f"\nNDCG@5: JEPA = {jepa_ndcg5:.4f}, MCDA = {mcda_ndcg5:.4f}")
    improvement = ((jepa_ndcg5 - mcda_ndcg5) / mcda_ndcg5 * 100) if mcda_ndcg5 > 0 else 0
    print(f"Improvement: {improvement:+.1f}%")
    
    return {
        'case_idx': case_idx,
        'context': ctx_category,
        'gate_activations': gate_values,
        'jepa_ndcg5': jepa_ndcg5,
        'mcda_ndcg5': mcda_ndcg5,
        'improvement_pct': improvement
    }


# Find representative cases for each context
case_studies = {}

for target_ctx in ['emergency', 'rare_disease', 'routine', 'complex']:
    # Find a case of this context type in test set
    for i, ctx in enumerate(case_metadata['context_category']):
        if i in splits['test'] and ctx == target_ctx:
            test_local_idx = splits['test'].index(i)
            result = analyze_case_study(test_local_idx, model, device)
            case_studies[target_ctx] = result
            break

# Save case studies
with open('/kaggle/working/results/case_studies.json', 'w') as f:
    json.dump(case_studies, f, indent=2)

print("\n✅ Case studies complete!")
