# Generate LaTeX Tables for Paper
===============================
Load all result JSONs and create publication-ready tables.

In [None]:
import os
import json
import numpy as np

RESULTS_DIR = '/kaggle/working/results'
TABLE_DIR = '/kaggle/working/results/tables'
os.makedirs(TABLE_DIR, exist_ok=True)

def load_json(name):
    candidates = [os.path.join(RESULTS_DIR, name), name]
    for p in candidates:
        if os.path.exists(p):
            with open(p, 'r') as f:
                return json.load(f)
    raise FileNotFoundError(f'Could not find {name} in {candidates}')

# Load all results
jepa_results = load_json('jepa_results.json')
original_results = load_json('original_results.json')
baseline_results = load_json('baseline_results.json')
ablation_results = load_json('ablation_results.json')
sample_eff_results = load_json('sample_efficiency_results.json')

def normalize_baseline(br):
    aliases = {
        'mcda': ['mcda', 'StaticMCDA'],
        'mlp': ['mlp', 'SimpleMLP'],
        'neural_ranker': ['neural_ranker', 'NeuralRanker'],
        'din': ['din', 'DINModel']
    }
    out = {}
    for key, keys in aliases.items():
        found = {}
        for k in keys:
            if k in br:
                found = br[k]
                break
        out[key] = found
    return out

baseline_results = normalize_baseline(baseline_results)

print('Loaded all JSON files.')

In [None]:
# ============================================================
# TABLE I: Main Results
# ============================================================
def _metric_mean(source, metric):
    if source is None:
        return np.nan
    if 'overall' in source:
        source = source['overall']
    if metric in source and isinstance(source[metric], dict):
        return float(source[metric].get('mean', np.nan))
    if metric in source and isinstance(source[metric], (list, tuple)):
        return float(source[metric][0])
    return np.nan

def generate_table_1():
    methods = [
        ('Rule-Based MCDA', baseline_results['mcda']),
        ('Simple MLP', baseline_results['mlp']),
        ('Neural Ranker', baseline_results['neural_ranker']),
        ('DIN', baseline_results['din']),
        ('DocMatchNet-Original', original_results.get('overall', original_results)),
        ('DocMatchNet-JEPA-NoGate', ablation_results.get('no_gates', {})),
        ('DocMatchNet-JEPA', jepa_results.get('overall', jepa_results))
    ]

    metrics = ['ndcg@1', 'ndcg@5', 'ndcg@10', 'map', 'mrr', 'hr@5']
    col_vals = [[_metric_mean(mr, m) for _, mr in methods] for m in metrics]

    def fmt(i, j):
        vals = col_vals[j]
        finite = [(idx, v) for idx, v in enumerate(vals) if np.isfinite(v)]
        if not finite:
            return '--'
        sorted_idx = [x[0] for x in sorted(finite, key=lambda x: x[1], reverse=True)]
        best = sorted_idx[0]
        second = sorted_idx[1] if len(sorted_idx) > 1 else None
        v = vals[i]
        if not np.isfinite(v):
            return '--'
        s = f"{v:.3f}"
        if i == best:
            return f"\\textbf{{{s}}}"
        if second is not None and i == second:
            return f"\\underline{{{s}}}"
        return s

    latex = """
\begin{table}[t]
\centering
\caption{Main Results: Doctor-Patient Matching Performance}
\label{tab:main_results}
\begin{tabular}{@{}lcccccc@{}}
\toprule
\textbf{Method} & \textbf{NDCG@1} & \textbf{NDCG@5} & \textbf{NDCG@10} & \textbf{MAP} & \textbf{MRR} & \textbf{HR@5} \\n\midrule
"""

    for i, (method_name, _) in enumerate(methods):
        row = method_name
        for j, _ in enumerate(metrics):
            row += f" & {fmt(i, j)}"
        row += " \\\n"
        latex += row

    latex += """\bottomrule
\end{tabular}
\end{table}
"""
    return latex

In [None]:
# ============================================================
# TABLE II: Ablation Study
# ============================================================
def generate_table_2():
    """
    Ablation study table.
    9 variants × 3 metrics (NDCG@5, MAP, MRR)
    Show delta from full model.
    """
    variants = [
        ('DocMatchNet-JEPA (Full)', 'full'),
        ('\quad w/o Context-Aware Gates', 'no_gates'),
        ('\quad w/o Two-Stage Training', 'no_twostage'),
        ('\quad w/ MSE Loss', 'mse_loss'),
        ('\quad w/ Pairwise Ranking Loss', 'ranking_loss'),
        ('\quad w/ Symmetric LR', 'symmetric_lr'),
        ('\quad w/o VICReg Gate Reg', 'no_vicreg'),
        ('\quad gate\_dim = 16', 'gate_dim_16'),
        ('\quad gate\_dim = 64', 'gate_dim_64')
    ]

    metrics = ['ndcg@5', 'map', 'mrr']

    # Support both naming schemes for full model
    full_results = ablation_results.get('full', ablation_results.get('full_docmatchnet_jepa', {}))

    latex = """
\begin{table}[t]
\centering
\caption{Ablation Study: Component Contribution Analysis}
\label{tab:ablation}
\begin{tabular}{@{}lccc|c@{}}
\toprule
\textbf{Variant} & \textbf{NDCG@5} & \textbf{MAP} & \textbf{MRR} & \textbf{$\Delta$ NDCG@5} \\n\midrule
"""

    for variant_name, variant_key in variants:
        result = ablation_results.get(variant_key, {})
        row = variant_name
        for metric in metrics:
            mean = result.get(metric, {}).get('mean', np.nan)
            std = result.get(metric, {}).get('std', np.nan)
            if np.isfinite(mean) and np.isfinite(std):
                row += f" & {mean:.3f}$\pm${std:.3f}"
            else:
                row += " & --"

        full_ndcg = full_results.get('ndcg@5', {}).get('mean', np.nan)
        cur_ndcg = result.get('ndcg@5', {}).get('mean', np.nan)

        if variant_key in ['full', 'full_docmatchnet_jepa'] or not np.isfinite(full_ndcg) or not np.isfinite(cur_ndcg):
            row += " & --"
        else:
            delta = cur_ndcg - full_ndcg
            sign = "+" if delta > 0 else ""
            row += f" & {sign}{delta:.3f}"

        row += " \\\n"
        if variant_key in ['full', 'full_docmatchnet_jepa']:
            row += "\midrule\n"
        latex += row

    latex += """\bottomrule
\end{tabular}
\vspace{2mm}
\footnotesize{Results averaged over 3 runs with different seeds. $\Delta$ shows change from full model.}
\end{table}
"""
    return latex

In [None]:
# ============================================================
# TABLE III: Stratified Performance by Clinical Context
# ============================================================
def generate_table_3():
    """
    Stratified performance table.
    5 contexts × 3 methods (MCDA, Original, JEPA)
    """
    contexts = ['routine', 'complex', 'rare_disease', 'emergency', 'pediatric']
    context_labels = ['Routine', 'Complex', 'Rare Disease', 'Emergency', 'Pediatric']

    methods = [
        ('MCDA', baseline_results.get('mcda', {}).get('stratified', {})),
        ('Original', original_results.get('stratified', {})),
        ('JEPA', jepa_results.get('stratified', {}))
    ]

    latex = """
\begin{table}[t]
\centering
\caption{Stratified Performance by Clinical Context (NDCG@5)}
\label{tab:stratified}
\begin{tabular}{@{}lccc@{}}
\toprule
\textbf{Context} & \textbf{Rule-Based} & \textbf{DocMatchNet-Orig} & \textbf{DocMatchNet-JEPA} \\n\midrule
"""

    for ctx, ctx_label in zip(contexts, context_labels):
        row = ctx_label
        best_val = -1
        vals = []
        for _, method_strat in methods:
            if ctx in method_strat:
                entry = method_strat[ctx]
                if isinstance(entry.get('ndcg@5', None), dict):
                    val = entry['ndcg@5'].get('mean', 0.0)
                elif isinstance(entry.get('ndcg@5', None), (list, tuple)):
                    val = entry['ndcg@5'][0]
                else:
                    val = entry.get('mean', 0.0)
                vals.append(float(val))
                if val > best_val:
                    best_val = val
            else:
                vals.append(0.0)

        for val in vals:
            if val == best_val and best_val > 0:
                row += f" & \\textbf{{{val:.3f}}}"
            else:
                row += f" & {val:.3f}"
        row += " \\\n"
        latex += row

    latex += """\bottomrule
\end{tabular}
\end{table}
"""
    return latex

In [None]:
# ============================================================
# TABLE IV: Gate Activation Patterns
# ============================================================
def generate_table_4():
    """
    Gate activation table showing context-dependent behavior.
    5 contexts × 4 gates
    """
    latex = """
\begin{table}[t]
\centering
\caption{Mean Gate Activations by Clinical Context}
\label{tab:gates}
\begin{tabular}{@{}lcccc@{}}
\toprule
\textbf{Context} & $G_{\text{clinical}}$ & $G_{\text{pastwork}}$ & $G_{\text{logistics}}$ & $G_{\text{trust}}$ \\n\midrule
"""

    contexts = ['routine', 'complex', 'rare_disease', 'emergency', 'pediatric']
    context_labels = ['Routine', 'Complex', 'Rare Disease', 'Emergency', 'Pediatric']
    gates = ['clinical', 'pastwork', 'logistics', 'trust']

    context_gate_stats = jepa_results.get('context_gate_stats', {})

    for ctx, ctx_label in zip(contexts, context_labels):
        row = ctx_label
        if ctx in context_gate_stats:
            for gate in gates:
                val = context_gate_stats[ctx].get(gate, 0.5)
                row += f" & {val:.3f}"
        else:
            row += " & -- & -- & -- & --"
        row += " \\\n"
        latex += row

    latex += """\midrule
\textit{Kruskal-Wallis $p$} """

    for gate in gates:
        p_key = f'gate_{gate}_kruskal_pvalue'
        if p_key in jepa_results.get('gate_analysis', {}):
            p_val = jepa_results['gate_analysis'][p_key]
            if p_val < 0.001:
                latex += f"& $<$0.001*** "
            elif p_val < 0.01:
                latex += f"& {p_val:.3f}** "
            elif p_val < 0.05:
                latex += f"& {p_val:.3f}* "
            else:
                latex += f"& {p_val:.3f} "
        else:
            latex += "& -- "

    latex += """\\
\bottomrule
\end{tabular}
\vspace{2mm}
\footnotesize{*$p<0.05$, **$p<0.01$, ***$p<0.001$. Kruskal-Wallis test for gate differences across contexts.}
\end{table}
"""
    return latex

In [None]:
# ============================================================
# TABLE V: Sample Efficiency
# ============================================================
def generate_table_5():
    """
    Sample efficiency table.
    7 data sizes × 4 methods
    """
    data_sizes = [100, 250, 500, 1000, 2500, 5000, 10500]
    methods = ['StaticMCDA', 'NeuralRanker', 'DocMatchNet-Original', 'DocMatchNet-JEPA']
    method_labels = ['MCDA', 'Neural Ranker', 'DocMatch-Orig', 'DocMatch-JEPA']

    latex = """
\begin{table}[t]
\centering
\caption{Sample Efficiency: NDCG@5 vs Training Data Size}
\label{tab:sample_efficiency}
\begin{tabular}{@{}r"""

    latex += "c" * len(methods)
    latex += """@{}}
\toprule
\textbf{\# Samples} """

    for label in method_labels:
        latex += f"& \\textbf{{{label}}} "
    latex += "\\\n\midrule\n"

    for size in data_sizes:
        row = f"{size:,}"
        best_val = -1
        vals = []
        for method in methods:
            val = sample_eff_results.get(method, {}).get(str(size), {}).get('mean', 0)
            vals.append(val)
            if val > best_val:
                best_val = val

        for val in vals:
            if val == best_val and best_val > 0:
                row += f" & \\textbf{{{val:.3f}}}"
            else:
                row += f" & {val:.3f}"
        row += " \\\n"
        latex += row

    latex += """\bottomrule
\end{tabular}
\end{table}
"""
    return latex

In [None]:
# ============================================================
# TABLE VI: Efficiency Comparison
# ============================================================
def generate_table_6():
    """
    Computational efficiency comparison.
    """
    latex = """
\begin{table}[t]
\centering
\caption{Computational Efficiency Comparison}
\label{tab:efficiency}
\begin{tabular}{@{}lrrrr@{}}
\toprule
\textbf{Method} & \textbf{Params} & \textbf{GPU (ms)} & \textbf{CPU (ms)} & \textbf{Train (h)} \\n\midrule
Rule-Based MCDA     & 0       & $<$1   & $<$1   & 0     \\nSimple MLP          & 50K     & $<$1   & $<$1   & 0.2   \\nNeural Ranker       & 350K    & 2      & 8      & 1.0   \\nDIN                 & 120K    & 1      & 5      & 0.8   \\nDocMatchNet-Orig    & 450K    & 3      & 12     & 1.5   \\nDocMatchNet-JEPA    & 520K    & 3      & 14     & 2.5   \\n\bottomrule
\end{tabular}
\vspace{2mm}
\footnotesize{Inference latency measured per doctor-patient pair. Training time on single NVIDIA T4 GPU.}
\end{table}
"""
    return latex

In [None]:
# ============================================================
# Statistical Significance Table
# ============================================================
def generate_significance_table():
    """
    Pairwise statistical significance between JEPA and baselines.
    """
    from scipy import stats

    jepa_per_case = jepa_results.get('per_case', {}).get('ndcg@5', [])

    comparisons = {
        'vs MCDA': baseline_results.get('mcda', {}).get('per_case', {}).get('ndcg@5', []),
        'vs MLP': baseline_results.get('mlp', {}).get('per_case', {}).get('ndcg@5', []),
        'vs Neural Ranker': baseline_results.get('neural_ranker', {}).get('per_case', {}).get('ndcg@5', []),
        'vs DIN': baseline_results.get('din', {}).get('per_case', {}).get('ndcg@5', []),
        'vs DocMatch-Orig': original_results.get('per_case', {}).get('ndcg@5', [])
    }

    if len(jepa_per_case) == 0:
        print('No per-case JEPA metrics available; skipping significance tests.')
        return

    print('\nStatistical Significance (Wilcoxon signed-rank test):')
    print('-' * 60)

    p_values = []
    names = []
    for comp_name, comp_values in comparisons.items():
        if len(comp_values) == 0:
            print(f'  JEPA {comp_name}: missing per-case values')
            continue

        n = min(len(jepa_per_case), len(comp_values))
        a = np.array(jepa_per_case[:n])
        b = np.array(comp_values[:n])

        stat, p_val = stats.wilcoxon(a, b, alternative='greater')
        p_values.append(p_val)
        names.append(comp_name)

        improvement = np.mean(a) - np.mean(b)
        denom = np.std(a - b)
        effect_size = improvement / denom if denom > 0 else 0.0

        sig = '***' if p_val < 0.001 else '**' if p_val < 0.01 else '*' if p_val < 0.05 else 'ns'

        print(f"  JEPA {comp_name}: Δ={improvement:+.4f}, d={effect_size:.3f}, p={p_val:.4e} {sig}")

    if len(p_values) == 0:
        print('No valid comparisons for correction.')
        return

    corrected_alpha = 0.05 / len(p_values)
    print(f"\nBonferroni-corrected α = {corrected_alpha:.4f}")
    for comp_name, p_val in zip(names, p_values):
        sig = 'significant' if p_val < corrected_alpha else 'not significant'
        print(f"  JEPA {comp_name}: {sig} after correction")

In [None]:
# ============================================================
# Print and Save All Tables
# ============================================================
tables = {
    'table1': generate_table_1(),
    'table2': generate_table_2(),
    'table3': generate_table_3(),
    'table4': generate_table_4(),
    'table5': generate_table_5(),
    'table6': generate_table_6()
}

os.makedirs('/kaggle/working/results/tables', exist_ok=True)

for table_name, latex_code in tables.items():
    print(f"\n{'=' * 60}")
    print(f"{table_name.upper()}")
    print('=' * 60)
    print(latex_code)

    with open(f'/kaggle/working/results/tables/{table_name}.tex', 'w') as f:
        f.write(latex_code)

generate_significance_table()

print('\n✅ All tables saved to /kaggle/working/results/tables/')