## Import

In [1]:
!pip install torch_geometric
!pip install faiss-cpu

Collecting torch_geometric
  Downloading torch_geometric-2.7.0-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.7.0-py3-none-any.whl (1.3 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.3 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m40.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.7.0
Collecting faiss-cpu
  Downloading faiss_cpu-1.13.2-cp310-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (7.6 kB)
Downloading faiss_cpu-1.13.2-cp310-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (23.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m

# Retraining P1

In [9]:
#!/usr/bin/env python3
"""
================================================================================
SHIFAMIND2 PHASE 1: Concept Bottleneck Model with TOP-50 ICD-10 Labels
================================================================================
Author: Mohammed Sameer Syed
University of Arizona - MS in AI Capstone

EXACT COPY of original Phase 1 code with GPU OPTIMIZATIONS:
1. ✅ 7 epochs (instead of 5)
2. ✅ Fixed duplicate concepts in GLOBAL_CONCEPTS (fever, edema) → 111 concepts
3. ✅ GPU OPTIMIZED: batch_size=64 train, 128 val (8x faster!)
4. ✅ FP16 mixed precision for 96GB GPU
5. ✅ pin_memory for faster data transfer
6. ✅ Loads existing data from run_20260102_203225

Expected: 7 epochs in ~15-20 minutes (vs 3-4 hours original) ⚡⚡⚡

Architecture (UNCHANGED):
1. BioClinicalBERT base encoder
2. Multi-head cross-attention with concepts (MULTIPLICATIVE bottleneck)
3. Concept Head (predicts clinical concepts)
4. Diagnosis Head (predicts TOP-50 ICD-10 codes)

Multi-Objective Loss:
L_total = λ1·L_dx + λ2·L_align + λ3·L_concept

================================================================================
"""

print("="*80)
print("🚀 SHIFAMIND2 PHASE 1 - TOP-50 ICD-10 LABELS")
print("="*80)

# ============================================================================
# IMPORTS & SETUP
# ============================================================================

import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler

import numpy as np
import pandas as pd
from sklearn.metrics import f1_score, precision_score, recall_score, roc_auc_score
from sklearn.model_selection import train_test_split
from transformers import (
    AutoTokenizer, AutoModel,
    get_linear_schedule_with_warmup
)

import json
import pickle
import gzip
from pathlib import Path
from tqdm.auto import tqdm
from typing import Dict, List, Tuple
from collections import defaultdict, Counter
import re
from datetime import datetime

# Reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\n🖥️  Device: {device}")
if torch.cuda.is_available():
    print(f"🔥 GPU: {torch.cuda.get_device_name(0)}")
    print(f"💾 VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# ============================================================================
# CONFIGURATION
# ============================================================================

print("\n" + "="*80)
print("⚙️  CONFIGURATION")
print("="*80)

# Create timestamped run folder
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
BASE_PATH = Path('/content/drive/MyDrive/ShifaMind')
OUTPUT_BASE = BASE_PATH / '10_ShifaMind' / f'run_{timestamp}'

# Run-specific paths
SHARED_DATA_PATH = OUTPUT_BASE / 'shared_data'
CHECKPOINT_PATH = OUTPUT_BASE / 'checkpoints' / 'phase1'
RESULTS_PATH = OUTPUT_BASE / 'results' / 'phase1'
CONCEPT_STORE_PATH = OUTPUT_BASE / 'concept_store'
LOGS_PATH = OUTPUT_BASE / 'logs'

# Create all directories
for path in [SHARED_DATA_PATH, CHECKPOINT_PATH, RESULTS_PATH, CONCEPT_STORE_PATH, LOGS_PATH]:
    path.mkdir(parents=True, exist_ok=True)

print(f"\n📁 Run Folder: {OUTPUT_BASE}")
print(f"📁 Timestamp: {timestamp}")
print(f"📁 Shared Data: {SHARED_DATA_PATH}")
print(f"📁 Checkpoints: {CHECKPOINT_PATH}")
print(f"📁 Results: {RESULTS_PATH}")
print(f"📁 Concept Store: {CONCEPT_STORE_PATH}")

# Path to existing processed data
EXISTING_RUN = BASE_PATH / '10_ShifaMind' / 'run_20260102_203225'
DATA_CSV_PATH = EXISTING_RUN / 'mimic_dx_data_top50.csv'

# Fixed global concept space (≤120 concepts)
# ✅ FIXED: Removed duplicates ('fever' and 'edema' appeared twice)
GLOBAL_CONCEPTS = [
    # Symptoms
    'fever', 'cough', 'dyspnea', 'pain', 'nausea', 'vomiting', 'diarrhea', 'fatigue',
    'headache', 'dizziness', 'weakness', 'confusion', 'syncope', 'chest', 'abdominal',
    'dysphagia', 'hemoptysis', 'hematuria', 'hematemesis', 'melena', 'jaundice',
    'edema', 'rash', 'pruritus', 'weight', 'anorexia', 'malaise',
    # Vital signs / Physical findings (removed duplicate 'fever')
    'hypotension', 'hypertension', 'tachycardia', 'bradycardia', 'tachypnea', 'hypoxia',
    'hypothermia', 'shock', 'altered', 'lethargic', 'obtunded',
    # Organ systems
    'cardiac', 'pulmonary', 'renal', 'hepatic', 'neurologic', 'gastrointestinal',
    'respiratory', 'cardiovascular', 'genitourinary', 'musculoskeletal', 'endocrine',
    'hematologic', 'dermatologic', 'psychiatric',
    # Common conditions
    'infection', 'sepsis', 'pneumonia', 'uti', 'cellulitis', 'meningitis',
    'failure', 'infarction', 'ischemia', 'hemorrhage', 'thrombosis', 'embolism',
    'obstruction', 'perforation', 'rupture', 'stenosis', 'regurgitation',
    'hypertrophy', 'atrophy', 'neoplasm', 'malignancy', 'metastasis',
    # Lab/diagnostic
    'elevated', 'decreased', 'anemia', 'leukocytosis', 'thrombocytopenia',
    'hyperglycemia', 'hypoglycemia', 'acidosis', 'alkalosis', 'hypoxemia',
    'creatinine', 'bilirubin', 'troponin', 'bnp', 'lactate', 'wbc', 'cultures',
    # Imaging/procedures (removed duplicate 'edema')
    'infiltrate', 'consolidation', 'effusion', 'cardiomegaly',
    'ultrasound', 'ct', 'mri', 'xray', 'echo', 'ekg',
    # Treatments
    'antibiotics', 'diuretics', 'vasopressors', 'insulin', 'anticoagulation',
    'oxygen', 'ventilation', 'dialysis', 'transfusion', 'surgery'
]

print(f"\n🧠 Global Concept Space: {len(GLOBAL_CONCEPTS)} concepts (FIXED: removed duplicates)")

# Hyperparameters (EXACT match to original)
LAMBDA_DX = 1.0
LAMBDA_ALIGN = 0.5
LAMBDA_CONCEPT = 0.3

print(f"\n⚖️  Loss Weights:")
print(f"   λ1 (Diagnosis): {LAMBDA_DX}")
print(f"   λ2 (Alignment): {LAMBDA_ALIGN}")
print(f"   λ3 (Concept):   {LAMBDA_CONCEPT}")

# ============================================================================
# LOAD EXISTING DATA (run_20260102_203225)
# ============================================================================

print("\n" + "="*80)
print("📊 LOADING EXISTING DATA")
print("="*80)

# Load Top-50 codes
top50_info_path = EXISTING_RUN / 'shared_data' / 'top50_icd10_info.json'
with open(top50_info_path, 'r') as f:
    top50_info = json.load(f)
TOP_50_CODES = top50_info['top_50_codes']

print(f"\n✅ Loaded Top-50 codes: {len(TOP_50_CODES)}")

# Load CSV
print(f"\n✅ Loading from: {DATA_CSV_PATH}")
df_all = pd.read_csv(DATA_CSV_PATH)

# Reconstruct 'labels' column from individual code columns
df_all['labels'] = df_all[TOP_50_CODES].values.tolist()

print(f"✅ Loaded {len(df_all):,} samples")

# ============================================================================
# CREATE TRAIN/VAL/TEST SPLITS
# ============================================================================

print("\n" + "="*80)
print("📊 CREATING TRAIN/VAL/TEST SPLITS")
print("="*80)

df = df_all[['text', 'labels'] + TOP_50_CODES].copy()
df = df.dropna(subset=['text'])

print(f"\n📊 Dataset size: {len(df):,} samples")

# Random split: 70% train, 15% val, 15% test
train_idx, temp_idx = train_test_split(
    range(len(df)),
    test_size=0.3,
    random_state=SEED
)
val_idx, test_idx = train_test_split(
    temp_idx,
    test_size=0.5,
    random_state=SEED
)

df_train = df.iloc[train_idx].reset_index(drop=True)
df_val = df.iloc[val_idx].reset_index(drop=True)
df_test = df.iloc[test_idx].reset_index(drop=True)

print(f"\n✅ Splits created:")
print(f"   Train: {len(df_train):,} ({len(df_train)/len(df)*100:.1f}%)")
print(f"   Val:   {len(df_val):,} ({len(df_val)/len(df)*100:.1f}%)")
print(f"   Test:  {len(df_test):,} ({len(df_test)/len(df)*100:.1f}%)")

# Save splits
with open(SHARED_DATA_PATH / 'train_split.pkl', 'wb') as f:
    pickle.dump(df_train, f)
with open(SHARED_DATA_PATH / 'val_split.pkl', 'wb') as f:
    pickle.dump(df_val, f)
with open(SHARED_DATA_PATH / 'test_split.pkl', 'wb') as f:
    pickle.dump(df_test, f)

print(f"\n💾 Saved splits to: {SHARED_DATA_PATH}")

# ============================================================================
# GENERATE CONCEPT LABELS (KEYWORD-BASED)
# ============================================================================

print("\n" + "="*80)
print("🧠 GENERATING CONCEPT LABELS (KEYWORD-BASED)")
print("="*80)

def generate_concept_labels(texts, concepts):
    """Generate binary concept labels based on keyword presence"""
    labels = []
    for text in tqdm(texts, desc="Labeling concepts"):
        text_lower = str(text).lower()
        concept_label = [1 if concept in text_lower else 0 for concept in concepts]
        labels.append(concept_label)
    return np.array(labels)

print(f"\n🔍 Using {len(GLOBAL_CONCEPTS)} global concepts")

train_concept_labels = generate_concept_labels(df_train['text'], GLOBAL_CONCEPTS)
val_concept_labels = generate_concept_labels(df_val['text'], GLOBAL_CONCEPTS)
test_concept_labels = generate_concept_labels(df_test['text'], GLOBAL_CONCEPTS)

print(f"\n✅ Concept labels generated:")
print(f"   Shape: {train_concept_labels.shape}")
print(f"   Avg concepts/sample: {train_concept_labels.sum(axis=1).mean():.2f}")

# Save concept labels
np.save(SHARED_DATA_PATH / 'train_concept_labels.npy', train_concept_labels)
np.save(SHARED_DATA_PATH / 'val_concept_labels.npy', val_concept_labels)
np.save(SHARED_DATA_PATH / 'test_concept_labels.npy', test_concept_labels)

# Save concept list
with open(SHARED_DATA_PATH / 'concept_list.json', 'w') as f:
    json.dump(GLOBAL_CONCEPTS, f, indent=2)

print(f"💾 Saved concept labels to: {SHARED_DATA_PATH}")

# ============================================================================
# ARCHITECTURE (EXACT COPY FROM ORIGINAL)
# ============================================================================

print("\n" + "="*80)
print("🏗️  ARCHITECTURE: CONCEPT BOTTLENECK")
print("="*80)

class ConceptBottleneckCrossAttention(nn.Module):
    """Multiplicative concept bottleneck with cross-attention"""
    def __init__(self, hidden_size, num_heads=8, dropout=0.1, layer_idx=1):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        self.layer_idx = layer_idx

        self.query = nn.Linear(hidden_size, hidden_size)
        self.key = nn.Linear(hidden_size, hidden_size)
        self.value = nn.Linear(hidden_size, hidden_size)
        self.out_proj = nn.Linear(hidden_size, hidden_size)

        self.gate_net = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, hidden_size),
            nn.Sigmoid()
        )

        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(hidden_size)

    def forward(self, hidden_states, concept_embeddings, attention_mask=None):
        batch_size, seq_len, _ = hidden_states.shape
        num_concepts = concept_embeddings.shape[0]

        concepts_batch = concept_embeddings.unsqueeze(0).expand(batch_size, -1, -1)

        Q = self.query(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.key(concepts_batch).view(batch_size, num_concepts, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.value(concepts_batch).view(batch_size, num_concepts, self.num_heads, self.head_dim).transpose(1, 2)

        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context = torch.matmul(attn_weights, V)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size)
        context = self.out_proj(context)

        pooled_text = hidden_states.mean(dim=1, keepdim=True).expand(-1, seq_len, -1)
        pooled_context = context.mean(dim=1, keepdim=True).expand(-1, seq_len, -1)
        gate_input = torch.cat([pooled_text, pooled_context], dim=-1)
        gate = self.gate_net(gate_input)

        output = gate * context
        output = self.layer_norm(output)

        return output, attn_weights.mean(dim=1), gate.mean()


class ShifaMind2Phase1(nn.Module):
    """ShifaMind2 Phase 1: Concept Bottleneck with Top-50 ICD-10"""
    def __init__(self, base_model, num_concepts, num_classes, fusion_layers=[9, 11]):
        super().__init__()
        self.base_model = base_model
        self.hidden_size = base_model.config.hidden_size
        self.num_concepts = num_concepts
        self.fusion_layers = fusion_layers

        self.concept_embeddings = nn.Parameter(
            torch.randn(num_concepts, self.hidden_size) * 0.02
        )

        self.fusion_modules = nn.ModuleDict({
            str(layer): ConceptBottleneckCrossAttention(self.hidden_size, layer_idx=layer)
            for layer in fusion_layers
        })

        self.concept_head = nn.Linear(self.hidden_size, num_concepts)
        self.diagnosis_head = nn.Linear(self.hidden_size, num_classes)
        self.dropout = nn.Dropout(0.1)

    def forward(self, input_ids, attention_mask, return_attention=False):
        outputs = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True
        )

        hidden_states = outputs.hidden_states
        current_hidden = outputs.last_hidden_state

        attention_maps = {}
        gate_values = []

        for layer_idx in self.fusion_layers:
            if str(layer_idx) in self.fusion_modules:
                layer_hidden = hidden_states[layer_idx]
                fused_hidden, attn, gate = self.fusion_modules[str(layer_idx)](
                    layer_hidden, self.concept_embeddings, attention_mask
                )
                current_hidden = fused_hidden
                gate_values.append(gate.item())

                if return_attention:
                    attention_maps[f'layer_{layer_idx}'] = attn

        cls_hidden = self.dropout(current_hidden[:, 0, :])
        concept_scores = torch.sigmoid(self.concept_head(cls_hidden))
        diagnosis_logits = self.diagnosis_head(cls_hidden)

        result = {
            'logits': diagnosis_logits,
            'concept_scores': concept_scores,
            'hidden_states': current_hidden,
            'cls_hidden': cls_hidden,
            'avg_gate': np.mean(gate_values) if gate_values else 0.0
        }

        if return_attention:
            result['attention_maps'] = attention_maps

        return result


class MultiObjectiveLoss(nn.Module):
    """Multi-objective loss: L_dx + L_align + L_concept"""
    def __init__(self, lambda_dx=1.0, lambda_align=0.5, lambda_concept=0.3):
        super().__init__()
        self.lambda_dx = lambda_dx
        self.lambda_align = lambda_align
        self.lambda_concept = lambda_concept
        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, outputs, dx_labels, concept_labels):
        loss_dx = self.bce(outputs['logits'], dx_labels)

        dx_probs = torch.sigmoid(outputs['logits'])
        concept_scores = outputs['concept_scores']
        loss_align = torch.abs(
            dx_probs.unsqueeze(-1) - concept_scores.unsqueeze(1)
        ).mean()

        concept_logits = torch.logit(concept_scores.clamp(1e-7, 1-1e-7))
        loss_concept = self.bce(concept_logits, concept_labels)

        total_loss = (
            self.lambda_dx * loss_dx +
            self.lambda_align * loss_align +
            self.lambda_concept * loss_concept
        )

        components = {
            'total': total_loss.item(),
            'dx': loss_dx.item(),
            'align': loss_align.item(),
            'concept': loss_concept.item()
        }

        return total_loss, components


print("✅ Architecture defined (VERIFIED)")

# ============================================================================
# DATASET
# ============================================================================

class ConceptDataset(Dataset):
    def __init__(self, texts, labels, concept_labels, tokenizer, max_length=384):
        self.texts = texts
        self.labels = labels
        self.concept_labels = concept_labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        encoding = self.tokenizer(
            str(self.texts[idx]),
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.FloatTensor(self.labels[idx]),
            'concept_labels': torch.FloatTensor(self.concept_labels[idx])
        }

# ============================================================================
# TRAINING
# ============================================================================

print("\n" + "="*80)
print("🏋️  TRAINING PHASE 1")
print("="*80)

tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
base_model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT").to(device)

model = ShifaMind2Phase1(
    base_model,
    num_concepts=len(GLOBAL_CONCEPTS),
    num_classes=len(TOP_50_CODES),
    fusion_layers=[9, 11]
).to(device)

print(f"✅ Model loaded: {sum(p.numel() for p in model.parameters()):,} parameters")
print(f"   Num concepts: {len(GLOBAL_CONCEPTS)}")
print(f"   Num diagnoses: {len(TOP_50_CODES)}")

# Create datasets
train_dataset = ConceptDataset(
    df_train['text'].tolist(),
    df_train['labels'].tolist(),
    train_concept_labels,
    tokenizer
)
val_dataset = ConceptDataset(
    df_val['text'].tolist(),
    df_val['labels'].tolist(),
    val_concept_labels,
    tokenizer
)
test_dataset = ConceptDataset(
    df_test['text'].tolist(),
    df_test['labels'].tolist(),
    test_concept_labels,
    tokenizer
)

# ============================================================================
# GPU OPTIMIZATION: MAXIMIZE SPEED ON 96GB VRAM
# ============================================================================

# Optimized batch sizes for 96GB GPU (8x faster than original!)
TRAIN_BATCH_SIZE = 64   # 8x original (was 8)
VAL_BATCH_SIZE = 128    # 8x original (was 16)

# Note: num_workers=0 to avoid multiprocessing errors in Colab
# If you want MAXIMUM speed and can ignore warnings, set NUM_WORKERS=8
NUM_WORKERS = 8  # Set to 8 for max speed (will show warnings but works)

train_loader = DataLoader(
    train_dataset,
    batch_size=TRAIN_BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)
val_loader = DataLoader(
    val_dataset,
    batch_size=VAL_BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=True
)
test_loader = DataLoader(
    test_dataset,
    batch_size=VAL_BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

print(f"✅ Datasets ready (GPU OPTIMIZED)")
print(f"   Train batches: {len(train_loader)} (batch_size={TRAIN_BATCH_SIZE})")
print(f"   Val batches:   {len(val_loader)} (batch_size={VAL_BATCH_SIZE})")
print(f"   Expected time: ~15-20 mins for 7 epochs (vs 3-4 hours original) ⚡")

# Training setup (EXACT MATCH to original)
criterion = MultiObjectiveLoss(
    lambda_dx=LAMBDA_DX,
    lambda_align=LAMBDA_ALIGN,
    lambda_concept=LAMBDA_CONCEPT
)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)

# ✅ CHANGED: 7 epochs (was 5)
num_epochs = 7
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=len(train_loader) // 2,
    num_training_steps=len(train_loader) * num_epochs
)

# ✅ NEW: Mixed precision for 96GB GPU (optional but faster)
USE_AMP = torch.cuda.is_available()
scaler = GradScaler() if USE_AMP else None

if USE_AMP:
    print(f"✅ Using mixed precision (FP16) for faster training")

best_f1 = 0.0
history = {'train_loss': [], 'val_f1': [], 'concept_f1': []}

# Training loop (EXACT COPY from original)
for epoch in range(num_epochs):
    print(f"\n{'='*70}\nEpoch {epoch+1}/{num_epochs}\n{'='*70}")

    model.train()
    epoch_losses = defaultdict(list)

    for batch in tqdm(train_loader, desc="Training"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        dx_labels = batch['labels'].to(device)
        concept_labels = batch['concept_labels'].to(device)

        optimizer.zero_grad()

        # Mixed precision forward pass
        if USE_AMP:
            with autocast():
                outputs = model(input_ids, attention_mask)
                loss, components = criterion(outputs, dx_labels, concept_labels)

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(input_ids, attention_mask)
            loss, components = criterion(outputs, dx_labels, concept_labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

        scheduler.step()

        for k, v in components.items():
            epoch_losses[k].append(v)

    print(f"\n📊 Epoch {epoch+1} Losses:")
    print(f"   Total:     {np.mean(epoch_losses['total']):.4f}")
    print(f"   Diagnosis: {np.mean(epoch_losses['dx']):.4f}")
    print(f"   Alignment: {np.mean(epoch_losses['align']):.4f}")
    print(f"   Concept:   {np.mean(epoch_losses['concept']):.4f}")

    # Validation
    model.eval()
    all_dx_preds, all_dx_labels = [], []
    all_concept_preds, all_concept_labels = [], []

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            dx_labels = batch['labels'].to(device)
            concept_labels = batch['concept_labels'].to(device)

            if USE_AMP:
                with autocast():
                    outputs = model(input_ids, attention_mask)
            else:
                outputs = model(input_ids, attention_mask)

            all_dx_preds.append(torch.sigmoid(outputs['logits']).cpu())
            all_dx_labels.append(dx_labels.cpu())
            all_concept_preds.append(outputs['concept_scores'].cpu())
            all_concept_labels.append(concept_labels.cpu())

    all_dx_preds = torch.cat(all_dx_preds, dim=0).numpy()
    all_dx_labels = torch.cat(all_dx_labels, dim=0).numpy()
    all_concept_preds = torch.cat(all_concept_preds, dim=0).numpy()
    all_concept_labels = torch.cat(all_concept_labels, dim=0).numpy()

    dx_pred_binary = (all_dx_preds > 0.5).astype(int)
    concept_pred_binary = (all_concept_preds > 0.5).astype(int)

    dx_f1 = f1_score(all_dx_labels, dx_pred_binary, average='macro', zero_division=0)
    concept_f1 = f1_score(all_concept_labels, concept_pred_binary, average='macro', zero_division=0)

    print(f"\n📈 Validation:")
    print(f"   Diagnosis F1: {dx_f1:.4f}")
    print(f"   Concept F1:   {concept_f1:.4f}")

    history['train_loss'].append(np.mean(epoch_losses['total']))
    history['val_f1'].append(dx_f1)
    history['concept_f1'].append(concept_f1)

    if dx_f1 > best_f1:
        best_f1 = dx_f1
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'macro_f1': best_f1,
            'concept_f1': concept_f1,
            'concept_embeddings': model.concept_embeddings.data.cpu(),
            'num_concepts': model.num_concepts,
            'config': {
                'num_concepts': len(GLOBAL_CONCEPTS),
                'num_classes': len(TOP_50_CODES),
                'fusion_layers': [9, 11],
                'lambda_dx': LAMBDA_DX,
                'lambda_align': LAMBDA_ALIGN,
                'lambda_concept': LAMBDA_CONCEPT,
                'top_50_codes': TOP_50_CODES,
                'timestamp': timestamp
            }
        }
        torch.save(checkpoint, CHECKPOINT_PATH / 'phase1_best.pt')
        print(f"   ✅ Saved best model (F1: {best_f1:.4f})")

print(f"\n✅ Training complete! Best Diagnosis F1: {best_f1:.4f}")

# ============================================================================
# FINAL EVALUATION
# ============================================================================

print("\n" + "="*80)
print("📊 FINAL TEST EVALUATION")
print("="*80)

checkpoint = torch.load(CHECKPOINT_PATH / 'phase1_best.pt', map_location=device, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

all_dx_preds, all_dx_labels = [], []
all_concept_preds, all_concept_labels = [], []

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Testing"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        dx_labels = batch['labels'].to(device)
        concept_labels = batch['concept_labels'].to(device)

        if USE_AMP:
            with autocast():
                outputs = model(input_ids, attention_mask)
        else:
            outputs = model(input_ids, attention_mask)

        all_dx_preds.append(torch.sigmoid(outputs['logits']).cpu())
        all_dx_labels.append(dx_labels.cpu())
        all_concept_preds.append(outputs['concept_scores'].cpu())
        all_concept_labels.append(concept_labels.cpu())

all_dx_preds = torch.cat(all_dx_preds, dim=0).numpy()
all_dx_labels = torch.cat(all_dx_labels, dim=0).numpy()
all_concept_preds = torch.cat(all_concept_preds, dim=0).numpy()
all_concept_labels = torch.cat(all_concept_labels, dim=0).numpy()

dx_pred_binary = (all_dx_preds > 0.5).astype(int)
concept_pred_binary = (all_concept_preds > 0.5).astype(int)

macro_f1 = f1_score(all_dx_labels, dx_pred_binary, average='macro', zero_division=0)
micro_f1 = f1_score(all_dx_labels, dx_pred_binary, average='micro', zero_division=0)
macro_precision = precision_score(all_dx_labels, dx_pred_binary, average='macro', zero_division=0)
macro_recall = recall_score(all_dx_labels, dx_pred_binary, average='macro', zero_division=0)

per_class_f1 = [
    f1_score(all_dx_labels[:, i], dx_pred_binary[:, i], zero_division=0)
    for i in range(len(TOP_50_CODES))
]

concept_f1 = f1_score(all_concept_labels, concept_pred_binary, average='macro', zero_division=0)

print("\n" + "="*80)
print("🎉 SHIFAMIND2 PHASE 1 - FINAL RESULTS")
print("="*80)

print("\n🎯 Diagnosis Performance (Top-50 ICD-10):")
print(f"   Macro F1:    {macro_f1:.4f}")
print(f"   Micro F1:    {micro_f1:.4f}")
print(f"   Precision:   {macro_precision:.4f}")
print(f"   Recall:      {macro_recall:.4f}")

print(f"\n🧠 Concept Performance:")
print(f"   Concept F1:  {concept_f1:.4f}")
print(f"   Concepts used: {len(GLOBAL_CONCEPTS)} (fixed duplicates)")

print(f"\n📊 Top-10 Best Performing Diagnoses:")
top_10_best = sorted(zip(TOP_50_CODES, per_class_f1), key=lambda x: x[1], reverse=True)[:10]
for rank, (code, f1) in enumerate(top_10_best, 1):
    count = top50_info['top_50_counts'].get(code, 0)
    print(f"   {rank}. {code}: F1={f1:.4f} (n={count:,})")

print(f"\n📊 Top-10 Worst Performing Diagnoses:")
top_10_worst = sorted(zip(TOP_50_CODES, per_class_f1), key=lambda x: x[1])[:10]
for rank, (code, f1) in enumerate(top_10_worst, 1):
    count = top50_info['top_50_counts'].get(code, 0)
    print(f"   {rank}. {code}: F1={f1:.4f} (n={count:,})")

# Save results
results = {
    'phase': 'ShifaMind2 Phase 1 - Top-50 ICD-10',
    'timestamp': timestamp,
    'run_folder': str(OUTPUT_BASE),
    'diagnosis_metrics': {
        'macro_f1': float(macro_f1),
        'micro_f1': float(micro_f1),
        'precision': float(macro_precision),
        'recall': float(macro_recall),
        'per_class_f1': {code: float(f1) for code, f1 in zip(TOP_50_CODES, per_class_f1)}
    },
    'concept_metrics': {
        'concept_f1': float(concept_f1),
        'num_concepts': len(GLOBAL_CONCEPTS)
    },
    'dataset_info': {
        'num_labels': len(TOP_50_CODES),
        'train_samples': len(df_train),
        'val_samples': len(df_val),
        'test_samples': len(df_test)
    },
    'loss_weights': {
        'lambda_dx': LAMBDA_DX,
        'lambda_align': LAMBDA_ALIGN,
        'lambda_concept': LAMBDA_CONCEPT
    },
    'training_history': history,
    'changes_from_original': [
        'Epochs: 7 (was 5)',
        'Fixed duplicate concepts (111 instead of 113)',
        'Added FP16 mixed precision for 96GB GPU'
    ]
}

with open(RESULTS_PATH / 'results.json', 'w') as f:
    json.dump(results, f, indent=2)

# Save per-label F1 scores as CSV
per_label_df = pd.DataFrame({
    'icd_code': TOP_50_CODES,
    'f1_score': per_class_f1,
    'train_count': [top50_info['top_50_counts'].get(code, 0) for code in TOP_50_CODES]
})
per_label_df = per_label_df.sort_values('f1_score', ascending=False)
per_label_df.to_csv(RESULTS_PATH / 'per_label_f1.csv', index=False)

print(f"\n💾 Results saved to: {RESULTS_PATH / 'results.json'}")
print(f"💾 Per-label F1 saved to: {RESULTS_PATH / 'per_label_f1.csv'}")
print(f"💾 Best model saved to: {CHECKPOINT_PATH / 'phase1_best.pt'}")

print("\n" + "="*80)
print("✅ SHIFAMIND2 PHASE 1 COMPLETE!")
print("="*80)
print("\n📍 Summary:")
print(f"   ✅ Dataset loaded: {len(df):,} samples")
print(f"   ✅ Fresh train/val/test splits created")
print(f"   ✅ Concept bottleneck model trained (7 epochs)")
print(f"   ✅ Macro F1: {macro_f1:.4f} | Micro F1: {micro_f1:.4f}")
print(f"   ✅ {len(GLOBAL_CONCEPTS)} concepts (duplicates removed)")
print(f"\n📁 All artifacts saved to: {OUTPUT_BASE}")
print("\nAlhamdulillah! 🤲")

🚀 SHIFAMIND2 PHASE 1 - TOP-50 ICD-10 LABELS

🖥️  Device: cuda
🔥 GPU: NVIDIA RTX PRO 6000 Blackwell Server Edition
💾 VRAM: 102.0 GB

⚙️  CONFIGURATION

📁 Run Folder: /content/drive/MyDrive/ShifaMind/10_ShifaMind/run_20260215_013437
📁 Timestamp: 20260215_013437
📁 Shared Data: /content/drive/MyDrive/ShifaMind/10_ShifaMind/run_20260215_013437/shared_data
📁 Checkpoints: /content/drive/MyDrive/ShifaMind/10_ShifaMind/run_20260215_013437/checkpoints/phase1
📁 Results: /content/drive/MyDrive/ShifaMind/10_ShifaMind/run_20260215_013437/results/phase1
📁 Concept Store: /content/drive/MyDrive/ShifaMind/10_ShifaMind/run_20260215_013437/concept_store

🧠 Global Concept Space: 111 concepts (FIXED: removed duplicates)

⚖️  Loss Weights:
   λ1 (Diagnosis): 1.0
   λ2 (Alignment): 0.5
   λ3 (Concept):   0.3

📊 LOADING EXISTING DATA

✅ Loaded Top-50 codes: 50

✅ Loading from: /content/drive/MyDrive/ShifaMind/10_ShifaMind/run_20260102_203225/mimic_dx_data_top50.csv
✅ Loaded 115,103 samples

📊 CREATING TRAIN/VA

Labeling concepts:   0%|          | 0/80572 [00:00<?, ?it/s]

Labeling concepts:   0%|          | 0/17265 [00:00<?, ?it/s]

Labeling concepts:   0%|          | 0/17266 [00:00<?, ?it/s]


✅ Concept labels generated:
   Shape: (80572, 111)
   Avg concepts/sample: 23.01
💾 Saved concept labels to: /content/drive/MyDrive/ShifaMind/10_ShifaMind/run_20260215_013437/shared_data

🏗️  ARCHITECTURE: CONCEPT BOTTLENECK
✅ Architecture defined (VERIFIED)

🏋️  TRAINING PHASE 1


Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]

BertModel LOAD REPORT from: emilyalsentzer/Bio_ClinicalBERT
Key                                        | Status     |  | 
-------------------------------------------+------------+--+-
cls.predictions.transform.LayerNorm.weight | UNEXPECTED |  | 
cls.predictions.transform.dense.bias       | UNEXPECTED |  | 
cls.predictions.transform.LayerNorm.bias   | UNEXPECTED |  | 
cls.predictions.bias                       | UNEXPECTED |  | 
cls.seq_relationship.weight                | UNEXPECTED |  | 
cls.seq_relationship.bias                  | UNEXPECTED |  | 
cls.predictions.decoder.weight             | UNEXPECTED |  | 
cls.predictions.transform.dense.weight     | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


✅ Model loaded: 116,789,153 parameters
   Num concepts: 111
   Num diagnoses: 50
✅ Datasets ready (GPU OPTIMIZED)
   Train batches: 1259 (batch_size=64)
   Val batches:   135 (batch_size=128)
   Expected time: ~15-20 mins for 7 epochs (vs 3-4 hours original) ⚡
✅ Using mixed precision (FP16) for faster training

Epoch 1/7


Training:   0%|          | 0/1259 [00:00<?, ?it/s]


📊 Epoch 1 Losses:
   Total:     0.5368
   Diagnosis: 0.3530
   Alignment: 0.0878
   Concept:   0.4664


Validating:   0%|          | 0/135 [00:00<?, ?it/s]


📈 Validation:
   Diagnosis F1: 0.0419
   Concept F1:   0.0277
   ✅ Saved best model (F1: 0.0419)

Epoch 2/7


Training:   0%|          | 0/1259 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e1123cda3e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    Exception ignored in: if w.is_alive():<function _MultiProcessingDataLoaderIter.__del__ at 0x7e1123cda3e0>

 Traceback (most recent call last):
   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
       self._shutdown_workers() 
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
     ^if w.is_alive():^
^ ^ ^ ^ ^^ ^ ^ ^^^^
^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^^    ^assert self._parent_pid == os.getpid(), 'can only test a child process'^
^ ^ ^ ^ ^  
   File "/usr/lib


📊 Epoch 2 Losses:
   Total:     0.4414
   Diagnosis: 0.2728
   Alignment: 0.0944
   Concept:   0.4046


Validating:   0%|          | 0/135 [00:00<?, ?it/s]


📈 Validation:
   Diagnosis F1: 0.1230
   Concept F1:   0.0380
   ✅ Saved best model (F1: 0.1230)

Epoch 3/7


Training:   0%|          | 0/1259 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e1123cda3e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e1123cda3e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16


📊 Epoch 3 Losses:
   Total:     0.4269
   Diagnosis: 0.2576
   Alignment: 0.1036
   Concept:   0.3916


Validating:   0%|          | 0/135 [00:00<?, ?it/s]


📈 Validation:
   Diagnosis F1: 0.1604
   Concept F1:   0.0451
   ✅ Saved best model (F1: 0.1604)

Epoch 4/7


Training:   0%|          | 0/1259 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e1123cda3e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e1123cda3e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16


📊 Epoch 4 Losses:
   Total:     0.4189
   Diagnosis: 0.2494
   Alignment: 0.1088
   Concept:   0.3837


Validating:   0%|          | 0/135 [00:00<?, ?it/s]


📈 Validation:
   Diagnosis F1: 0.1917
   Concept F1:   0.0578
   ✅ Saved best model (F1: 0.1917)

Epoch 5/7


Training:   0%|          | 0/1259 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e1123cda3e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e1123cda3e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16


📊 Epoch 5 Losses:
   Total:     0.4136
   Diagnosis: 0.2440
   Alignment: 0.1118
   Concept:   0.3790


Validating:   0%|          | 0/135 [00:00<?, ?it/s]


📈 Validation:
   Diagnosis F1: 0.2153
   Concept F1:   0.0685
   ✅ Saved best model (F1: 0.2153)

Epoch 6/7


Training:   0%|          | 0/1259 [00:00<?, ?it/s]


📊 Epoch 6 Losses:
   Total:     0.4097
   Diagnosis: 0.2400
   Alignment: 0.1138
   Concept:   0.3760


Validating:   0%|          | 0/135 [00:00<?, ?it/s]


📈 Validation:
   Diagnosis F1: 0.2212
   Concept F1:   0.0687
   ✅ Saved best model (F1: 0.2212)

Epoch 7/7


Training:   0%|          | 0/1259 [00:00<?, ?it/s]


📊 Epoch 7 Losses:
   Total:     0.4072
   Diagnosis: 0.2374
   Alignment: 0.1147
   Concept:   0.3747


Validating:   0%|          | 0/135 [00:00<?, ?it/s]


📈 Validation:
   Diagnosis F1: 0.2295
   Concept F1:   0.0720
   ✅ Saved best model (F1: 0.2295)

✅ Training complete! Best Diagnosis F1: 0.2295

📊 FINAL TEST EVALUATION


Testing:   0%|          | 0/135 [00:00<?, ?it/s]


🎉 SHIFAMIND2 PHASE 1 - FINAL RESULTS

🎯 Diagnosis Performance (Top-50 ICD-10):
   Macro F1:    0.2294
   Micro F1:    0.3595
   Precision:   0.5506
   Recall:      0.1696

🧠 Concept Performance:
   Concept F1:  0.0725
   Concepts used: 111 (fixed duplicates)

📊 Top-10 Best Performing Diagnoses:
   1. Z951: F1=0.7691 (n=6,274)
   2. I2510: F1=0.7190 (n=22,606)
   3. I10: F1=0.6827 (n=43,570)
   4. E785: F1=0.6282 (n=44,038)
   5. J449: F1=0.5770 (n=10,268)
   6. Z7901: F1=0.5584 (n=15,321)
   7. E1122: F1=0.5437 (n=9,205)
   8. Z86718: F1=0.5306 (n=7,598)
   9. Z794: F1=0.5270 (n=15,275)
   10. E039: F1=0.4831 (n=15,252)

📊 Top-10 Worst Performing Diagnoses:
   1. D649: F1=0.0000 (n=12,467)
   2. Y929: F1=0.0000 (n=11,548)
   3. N189: F1=0.0000 (n=8,565)
   4. K5900: F1=0.0000 (n=7,097)
   5. I480: F1=0.0000 (n=6,695)
   6. G4700: F1=0.0000 (n=6,450)
   7. D696: F1=0.0000 (n=6,438)
   8. Y92239: F1=0.0000 (n=5,981)
   9. J189: F1=0.0000 (n=5,790)
   10. Z23: F1=0.0000 (n=5,714)

💾 Resu

# Threshold

In [10]:
#!/usr/bin/env python3
"""
================================================================================
SHIFAMIND2 PHASE 1: THRESHOLD TUNING (POST-HOC OPTIMIZATION)
================================================================================
Author: Mohammed Sameer Syed
University of Arizona - MS in AI Capstone

Optimizes classification thresholds per-label to maximize F1 scores.
Finds optimal thresholds for each of the Top-50 ICD-10 codes independently.

Similar to Phase 5 threshold tuning but for Phase 1 multi-label classification.
================================================================================
"""

print("="*80)
print("🎯 SHIFAMIND2 PHASE 1 - THRESHOLD TUNING")
print("="*80)

# ============================================================================
# IMPORTS
# ============================================================================

import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast

import numpy as np
import pandas as pd
from sklearn.metrics import f1_score, precision_score, recall_score
from transformers import AutoTokenizer, AutoModel

import json
import pickle
from pathlib import Path
from tqdm.auto import tqdm
from typing import List, Tuple

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\n🖥️  Device: {device}")

# ============================================================================
# CONFIGURATION - POINT TO YOUR RUN FOLDER
# ============================================================================

print("\n" + "="*80)
print("⚙️  CONFIGURATION")
print("="*80)

# CHANGE THIS to your run folder from the training output
RUN_FOLDER = '/content/drive/MyDrive/ShifaMind/10_ShifaMind/run_20260215_013437'

BASE_PATH = Path(RUN_FOLDER)
CHECKPOINT_PATH = BASE_PATH / 'checkpoints' / 'phase1' / 'phase1_best.pt'
RESULTS_PATH = BASE_PATH / 'results' / 'phase1'
SHARED_DATA_PATH = BASE_PATH / 'shared_data'

print(f"\n📁 Run Folder: {BASE_PATH}")
print(f"📁 Checkpoint: {CHECKPOINT_PATH}")

if not CHECKPOINT_PATH.exists():
    print(f"\n❌ ERROR: Checkpoint not found at {CHECKPOINT_PATH}")
    print("Please update RUN_FOLDER to your actual run folder!")
    exit(1)

# ============================================================================
# LOAD ARCHITECTURE (SAME AS TRAINING)
# ============================================================================

print("\n" + "="*80)
print("🏗️  LOADING ARCHITECTURE")
print("="*80)

class ConceptBottleneckCrossAttention(nn.Module):
    """Multiplicative concept bottleneck with cross-attention"""
    def __init__(self, hidden_size, num_heads=8, dropout=0.1, layer_idx=1):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        self.layer_idx = layer_idx

        self.query = nn.Linear(hidden_size, hidden_size)
        self.key = nn.Linear(hidden_size, hidden_size)
        self.value = nn.Linear(hidden_size, hidden_size)
        self.out_proj = nn.Linear(hidden_size, hidden_size)

        self.gate_net = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, hidden_size),
            nn.Sigmoid()
        )

        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(hidden_size)

    def forward(self, hidden_states, concept_embeddings, attention_mask=None):
        batch_size, seq_len, _ = hidden_states.shape
        num_concepts = concept_embeddings.shape[0]

        concepts_batch = concept_embeddings.unsqueeze(0).expand(batch_size, -1, -1)

        Q = self.query(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.key(concepts_batch).view(batch_size, num_concepts, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.value(concepts_batch).view(batch_size, num_concepts, self.num_heads, self.head_dim).transpose(1, 2)

        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn_weights = torch.nn.functional.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context = torch.matmul(attn_weights, V)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size)
        context = self.out_proj(context)

        pooled_text = hidden_states.mean(dim=1, keepdim=True).expand(-1, seq_len, -1)
        pooled_context = context.mean(dim=1, keepdim=True).expand(-1, seq_len, -1)
        gate_input = torch.cat([pooled_text, pooled_context], dim=-1)
        gate = self.gate_net(gate_input)

        output = gate * context
        output = self.layer_norm(output)

        return output, attn_weights.mean(dim=1), gate.mean()


class ShifaMind2Phase1(nn.Module):
    """ShifaMind2 Phase 1: Concept Bottleneck with Top-50 ICD-10"""
    def __init__(self, base_model, num_concepts, num_classes, fusion_layers=[9, 11]):
        super().__init__()
        self.base_model = base_model
        self.hidden_size = base_model.config.hidden_size
        self.num_concepts = num_concepts
        self.fusion_layers = fusion_layers

        self.concept_embeddings = nn.Parameter(
            torch.randn(num_concepts, self.hidden_size) * 0.02
        )

        self.fusion_modules = nn.ModuleDict({
            str(layer): ConceptBottleneckCrossAttention(self.hidden_size, layer_idx=layer)
            for layer in fusion_layers
        })

        self.concept_head = nn.Linear(self.hidden_size, num_concepts)
        self.diagnosis_head = nn.Linear(self.hidden_size, num_classes)
        self.dropout = nn.Dropout(0.1)

    def forward(self, input_ids, attention_mask, return_attention=False):
        outputs = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True
        )

        hidden_states = outputs.hidden_states
        current_hidden = outputs.last_hidden_state

        attention_maps = {}
        gate_values = []

        for layer_idx in self.fusion_layers:
            if str(layer_idx) in self.fusion_modules:
                layer_hidden = hidden_states[layer_idx]
                fused_hidden, attn, gate = self.fusion_modules[str(layer_idx)](
                    layer_hidden, self.concept_embeddings, attention_mask
                )
                current_hidden = fused_hidden
                gate_values.append(gate.item())

                if return_attention:
                    attention_maps[f'layer_{layer_idx}'] = attn

        cls_hidden = self.dropout(current_hidden[:, 0, :])
        concept_scores = torch.sigmoid(self.concept_head(cls_hidden))
        diagnosis_logits = self.diagnosis_head(cls_hidden)

        result = {
            'logits': diagnosis_logits,
            'concept_scores': concept_scores,
            'hidden_states': current_hidden,
            'cls_hidden': cls_hidden,
            'avg_gate': np.mean(gate_values) if gate_values else 0.0
        }

        if return_attention:
            result['attention_maps'] = attention_maps

        return result


class ConceptDataset(Dataset):
    def __init__(self, texts, labels, concept_labels, tokenizer, max_length=384):
        self.texts = texts
        self.labels = labels
        self.concept_labels = concept_labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        encoding = self.tokenizer(
            str(self.texts[idx]),
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.FloatTensor(self.labels[idx]),
            'concept_labels': torch.FloatTensor(self.concept_labels[idx])
        }

print("✅ Architecture loaded")

# ============================================================================
# LOAD TRAINED MODEL & DATA
# ============================================================================

print("\n" + "="*80)
print("📦 LOADING TRAINED MODEL & DATA")
print("="*80)

# Load checkpoint
checkpoint = torch.load(CHECKPOINT_PATH, map_location=device, weights_only=False)
config = checkpoint['config']

TOP_50_CODES = config['top_50_codes']
print(f"✅ Loaded Top-50 codes: {len(TOP_50_CODES)}")

# Load global concepts
with open(SHARED_DATA_PATH / 'concept_list.json', 'r') as f:
    GLOBAL_CONCEPTS = json.load(f)
print(f"✅ Loaded {len(GLOBAL_CONCEPTS)} concepts")

# Initialize model
print(f"\n🔄 Loading BioClinicalBERT...")
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
base_model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT").to(device)

model = ShifaMind2Phase1(
    base_model,
    num_concepts=len(GLOBAL_CONCEPTS),
    num_classes=len(TOP_50_CODES),
    fusion_layers=[9, 11]
).to(device)

model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"✅ Model loaded with {sum(p.numel() for p in model.parameters()):,} parameters")

# Load validation data
with open(SHARED_DATA_PATH / 'val_split.pkl', 'rb') as f:
    df_val = pickle.load(f)

val_concept_labels = np.load(SHARED_DATA_PATH / 'val_concept_labels.npy')

print(f"✅ Loaded validation set: {len(df_val):,} samples")

# ============================================================================
# GET PREDICTIONS ON VALIDATION SET
# ============================================================================

print("\n" + "="*80)
print("🔮 GENERATING PREDICTIONS")
print("="*80)

val_dataset = ConceptDataset(
    df_val['text'].tolist(),
    df_val['labels'].tolist(),
    val_concept_labels,
    tokenizer
)

val_loader = DataLoader(
    val_dataset,
    batch_size=256,
    num_workers=0,
    pin_memory=True
)

all_probs = []
all_labels = []

USE_AMP = torch.cuda.is_available()

with torch.no_grad():
    for batch in tqdm(val_loader, desc="Getting predictions"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        if USE_AMP:
            with autocast():
                outputs = model(input_ids, attention_mask)
        else:
            outputs = model(input_ids, attention_mask)

        probs = torch.sigmoid(outputs['logits']).cpu().numpy()
        all_probs.append(probs)
        all_labels.append(labels.cpu().numpy())

all_probs = np.vstack(all_probs)
all_labels = np.vstack(all_labels)

print(f"✅ Predictions shape: {all_probs.shape}")
print(f"   Labels shape: {all_labels.shape}")

# ============================================================================
# THRESHOLD TUNING (PER-LABEL)
# ============================================================================

print("\n" + "="*80)
print("🎯 THRESHOLD TUNING (PER-LABEL)")
print("="*80)

THRESHOLD_CANDIDATES = np.arange(0.05, 0.96, 0.05)  # 0.05 to 0.95 in steps of 0.05

optimal_thresholds = {}
best_f1_scores = {}

print(f"Testing {len(THRESHOLD_CANDIDATES)} thresholds per label: {THRESHOLD_CANDIDATES[0]:.2f} to {THRESHOLD_CANDIDATES[-1]:.2f}")

for label_idx, code in enumerate(tqdm(TOP_50_CODES, desc="Tuning thresholds")):
    label_probs = all_probs[:, label_idx]
    label_true = all_labels[:, label_idx]

    best_f1 = 0.0
    best_threshold = 0.5

    for threshold in THRESHOLD_CANDIDATES:
        preds = (label_probs >= threshold).astype(int)
        f1 = f1_score(label_true, preds, zero_division=0)

        if f1 > best_f1:
            best_f1 = f1
            best_threshold = threshold

    optimal_thresholds[code] = float(best_threshold)
    best_f1_scores[code] = float(best_f1)

print(f"\n✅ Optimal thresholds found for {len(optimal_thresholds)} labels")

# ============================================================================
# EVALUATE WITH OPTIMAL THRESHOLDS
# ============================================================================

print("\n" + "="*80)
print("📊 EVALUATION: OPTIMAL THRESHOLDS vs FIXED 0.5")
print("="*80)

# Fixed threshold (0.5)
preds_fixed = (all_probs >= 0.5).astype(int)
f1_macro_fixed = f1_score(all_labels, preds_fixed, average='macro', zero_division=0)
f1_micro_fixed = f1_score(all_labels, preds_fixed, average='micro', zero_division=0)
precision_fixed = precision_score(all_labels, preds_fixed, average='macro', zero_division=0)
recall_fixed = recall_score(all_labels, preds_fixed, average='macro', zero_division=0)

# Optimal thresholds (per-label)
preds_optimal = np.zeros_like(all_probs, dtype=int)
for label_idx, code in enumerate(TOP_50_CODES):
    threshold = optimal_thresholds[code]
    preds_optimal[:, label_idx] = (all_probs[:, label_idx] >= threshold).astype(int)

f1_macro_optimal = f1_score(all_labels, preds_optimal, average='macro', zero_division=0)
f1_micro_optimal = f1_score(all_labels, preds_optimal, average='micro', zero_division=0)
precision_optimal = precision_score(all_labels, preds_optimal, average='macro', zero_division=0)
recall_optimal = recall_score(all_labels, preds_optimal, average='macro', zero_division=0)

print("\n" + "="*80)
print("🎉 THRESHOLD TUNING RESULTS")
print("="*80)

print("\n📊 Fixed Threshold (0.5):")
print(f"   Macro F1:    {f1_macro_fixed:.4f}")
print(f"   Micro F1:    {f1_micro_fixed:.4f}")
print(f"   Precision:   {precision_fixed:.4f}")
print(f"   Recall:      {recall_fixed:.4f}")

print("\n🎯 Optimal Thresholds (Per-Label):")
print(f"   Macro F1:    {f1_macro_optimal:.4f} (+{f1_macro_optimal - f1_macro_fixed:+.4f})")
print(f"   Micro F1:    {f1_micro_optimal:.4f} (+{f1_micro_optimal - f1_micro_fixed:+.4f})")
print(f"   Precision:   {precision_optimal:.4f} (+{precision_optimal - precision_fixed:+.4f})")
print(f"   Recall:      {recall_optimal:.4f} (+{recall_optimal - recall_fixed:+.4f})")

improvement = ((f1_macro_optimal - f1_macro_fixed) / f1_macro_fixed) * 100
print(f"\n🚀 Improvement: {improvement:+.2f}% relative gain in Macro F1!")

# ============================================================================
# TOP IMPROVEMENTS
# ============================================================================

print("\n📊 Top-10 Largest Improvements:")
improvements = {}
for code in TOP_50_CODES:
    idx = TOP_50_CODES.index(code)
    f1_fixed = f1_score(all_labels[:, idx], preds_fixed[:, idx], zero_division=0)
    f1_optimal = best_f1_scores[code]
    improvements[code] = f1_optimal - f1_fixed

top_improvements = sorted(improvements.items(), key=lambda x: x[1], reverse=True)[:10]

for rank, (code, improvement) in enumerate(top_improvements, 1):
    threshold = optimal_thresholds[code]
    f1_opt = best_f1_scores[code]
    print(f"   {rank}. {code}: +{improvement:.4f} (threshold={threshold:.2f}, F1={f1_opt:.4f})")

# ============================================================================
# SAVE RESULTS
# ============================================================================

print("\n" + "="*80)
print("💾 SAVING RESULTS")
print("="*80)

# Save optimal thresholds
thresholds_path = RESULTS_PATH / 'optimal_thresholds.json'
with open(thresholds_path, 'w') as f:
    json.dump(optimal_thresholds, f, indent=2)

print(f"✅ Optimal thresholds saved to: {thresholds_path}")

# Save comparison results
results = {
    'fixed_threshold': {
        'threshold': 0.5,
        'macro_f1': float(f1_macro_fixed),
        'micro_f1': float(f1_micro_fixed),
        'precision': float(precision_fixed),
        'recall': float(recall_fixed)
    },
    'optimal_thresholds': {
        'macro_f1': float(f1_macro_optimal),
        'micro_f1': float(f1_micro_optimal),
        'precision': float(precision_optimal),
        'recall': float(recall_optimal),
        'improvement_pct': float(improvement),
        'thresholds': optimal_thresholds,
        'per_label_f1': best_f1_scores
    }
}

results_path = RESULTS_PATH / 'threshold_tuning_results.json'
with open(results_path, 'w') as f:
    json.dump(results, f, indent=2)

print(f"✅ Results saved to: {results_path}")

# Save per-label comparison
comparison_df = pd.DataFrame({
    'icd_code': TOP_50_CODES,
    'optimal_threshold': [optimal_thresholds[c] for c in TOP_50_CODES],
    'f1_optimal': [best_f1_scores[c] for c in TOP_50_CODES],
    'f1_fixed_0.5': [f1_score(all_labels[:, i], preds_fixed[:, i], zero_division=0) for i in range(len(TOP_50_CODES))],
    'improvement': [improvements[c] for c in TOP_50_CODES]
})
comparison_df = comparison_df.sort_values('improvement', ascending=False)

csv_path = RESULTS_PATH / 'threshold_comparison.csv'
comparison_df.to_csv(csv_path, index=False)

print(f"✅ Comparison CSV saved to: {csv_path}")

print("\n" + "="*80)
print("✅ THRESHOLD TUNING COMPLETE!")
print("="*80)
print(f"\n📊 Summary:")
print(f"   Fixed (0.5):  Macro F1 = {f1_macro_fixed:.4f}")
print(f"   Optimal:      Macro F1 = {f1_macro_optimal:.4f} ({improvement:+.2f}%)")
print(f"   Best label:   {top_improvements[0][0]} (+{top_improvements[0][1]:.4f})")
print(f"\n💾 All results saved to: {RESULTS_PATH}")
print("\nAlhamdulillah! 🤲")

🎯 SHIFAMIND2 PHASE 1 - THRESHOLD TUNING

🖥️  Device: cuda

⚙️  CONFIGURATION

📁 Run Folder: /content/drive/MyDrive/ShifaMind/10_ShifaMind/run_20260215_013437
📁 Checkpoint: /content/drive/MyDrive/ShifaMind/10_ShifaMind/run_20260215_013437/checkpoints/phase1/phase1_best.pt

🏗️  LOADING ARCHITECTURE
✅ Architecture loaded

📦 LOADING TRAINED MODEL & DATA
✅ Loaded Top-50 codes: 50
✅ Loaded 111 concepts

🔄 Loading BioClinicalBERT...


Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]

BertModel LOAD REPORT from: emilyalsentzer/Bio_ClinicalBERT
Key                                        | Status     |  | 
-------------------------------------------+------------+--+-
cls.predictions.transform.LayerNorm.weight | UNEXPECTED |  | 
cls.predictions.transform.dense.bias       | UNEXPECTED |  | 
cls.predictions.transform.LayerNorm.bias   | UNEXPECTED |  | 
cls.predictions.bias                       | UNEXPECTED |  | 
cls.seq_relationship.weight                | UNEXPECTED |  | 
cls.seq_relationship.bias                  | UNEXPECTED |  | 
cls.predictions.decoder.weight             | UNEXPECTED |  | 
cls.predictions.transform.dense.weight     | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


✅ Model loaded with 116,789,153 parameters
✅ Loaded validation set: 17,265 samples

🔮 GENERATING PREDICTIONS


Getting predictions:   0%|          | 0/68 [00:00<?, ?it/s]

✅ Predictions shape: (17265, 50)
   Labels shape: (17265, 50)

🎯 THRESHOLD TUNING (PER-LABEL)
Testing 19 thresholds per label: 0.05 to 0.95


Tuning thresholds:   0%|          | 0/50 [00:00<?, ?it/s]


✅ Optimal thresholds found for 50 labels

📊 EVALUATION: OPTIMAL THRESHOLDS vs FIXED 0.5

🎉 THRESHOLD TUNING RESULTS

📊 Fixed Threshold (0.5):
   Macro F1:    0.2298
   Micro F1:    0.3583
   Precision:   0.5322
   Recall:      0.1697

🎯 Optimal Thresholds (Per-Label):
   Macro F1:    0.4343 (++0.2045)
   Micro F1:    0.4403 (++0.0820)
   Precision:   0.4064 (+-0.1258)
   Recall:      0.5343 (++0.3646)

🚀 Improvement: +88.99% relative gain in Macro F1!

📊 Top-10 Largest Improvements:
   1. I480: +0.3920 (threshold=0.15, F1=0.3920)
   2. Z87891: +0.3886 (threshold=0.20, F1=0.5215)
   3. N189: +0.3621 (threshold=0.20, F1=0.3621)
   4. N183: +0.3607 (threshold=0.20, F1=0.4366)
   5. I5032: +0.3542 (threshold=0.20, F1=0.3760)
   6. Z7902: +0.3451 (threshold=0.15, F1=0.3600)
   7. Z66: +0.3317 (threshold=0.20, F1=0.4196)
   8. E669: +0.3303 (threshold=0.20, F1=0.3424)
   9. N390: +0.3219 (threshold=0.15, F1=0.3648)
   10. I110: +0.3185 (threshold=0.20, F1=0.4299)

💾 SAVING RESULTS
✅ Optimal

# Retraining P2

In [12]:
#!/usr/bin/env python3
"""
================================================================================
SHIFAMIND v302 PHASE 2: GAT with UMLS Knowledge Graph (MAXIMUM GPU OPTIMIZED)
================================================================================
Author: Mohammed Sameer Syed
University of Arizona - MS in AI Capstone

OPTIMIZATIONS:
1. ✅ MAXIMUM GPU: batch_size=128 train, 256 val (16x faster!)
2. ✅ num_workers=8 + prefetch_factor=2 (parallel data loading)
3. ✅ FP16 mixed precision for 96GB GPU
4. ✅ Scaled learning rate: 8e-5 (was 2e-5) for larger batches
5. ✅ 7 epochs (instead of 5)
6. ✅ Loads from NEWEST Phase 1 checkpoint (not old v301)
7. ✅ pin_memory for faster data transfer

Expected: 7 epochs in ~10-15 minutes (vs 2-3 hours original) ⚡⚡⚡
Expected GPU usage: 60-80GB / 96GB

Architecture:
- Input: Clinical text
- BioClinicalBERT encoder
- GAT on UMLS graph (concept + diagnosis nodes)
- Concept bottleneck with cross-attention
- Multi-objective loss (diagnosis + alignment + concepts)
- Transfer learning from Phase 1 trained model

Target: F1 > 0.45 (with threshold tuning)
================================================================================
"""

import warnings
warnings.filterwarnings('ignore')

import os
import sys
import json
import pickle
import numpy as np
import pandas as pd
from pathlib import Path
from datetime import datetime
from collections import defaultdict, Counter
import time
import re

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from sklearn.metrics import f1_score, precision_score, recall_score
from tqdm.auto import tqdm

print("="*80)
print("🚀 SHIFAMIND v302 PHASE 2: GAT + UMLS (MAXIMUM GPU OPTIMIZED)")
print("="*80)
print("Using UMLS MRREL for rich hierarchical relationships")
print("Training from NEWEST Phase 1 checkpoint with BioClinicalBERT")
print("MAXIMUM GPU optimization for 96GB VRAM!")
print()

# ============================================================================
# CONFIGURATION
# ============================================================================

BASE_PATH = Path('/content/drive/MyDrive/ShifaMind')
UMLS_PATH = BASE_PATH / '01_Raw_Datasets' / 'Extracted' / 'umls-2025AA-metathesaurus-full' / '2025AA' / 'META'

# Output folder
OUTPUT_BASE = BASE_PATH / '11_ShifaMind_v302'
RUN_TIMESTAMP = datetime.now().strftime("%Y%m%d_%H%M%S")
RUN_PATH = OUTPUT_BASE / f'run_{RUN_TIMESTAMP}'

# Create subfolders
SHARED_DATA_PATH = RUN_PATH / 'shared_data'
GRAPH_PATH = RUN_PATH / 'phase_2_graph'
MODELS_PATH = RUN_PATH / 'phase_2_models'
RESULTS_PATH = RUN_PATH / 'phase_2_results'

for path in [SHARED_DATA_PATH, GRAPH_PATH, MODELS_PATH, RESULTS_PATH]:
    path.mkdir(parents=True, exist_ok=True)

print(f"📁 Run folder: {RUN_PATH}")
print(f"📁 Graph: {GRAPH_PATH}")
print(f"📁 Models: {MODELS_PATH}")
print(f"📁 Results: {RESULTS_PATH}")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\n🖥️  Device: {device}")
if torch.cuda.is_available():
    print(f"🔥 GPU: {torch.cuda.get_device_name(0)}")
    print(f"💾 VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# ============================================================================
# HYPERPARAMETERS - MAXIMUM GPU OPTIMIZATION! 🚀
# ============================================================================

# AGGRESSIVE batch sizes for 96GB GPU
TRAIN_BATCH_SIZE = 128  # 16x original (was 8)
VAL_BATCH_SIZE = 256    # 16x original (was 16)

# Parallel data loading
NUM_WORKERS = 8         # 8 CPU cores for data loading
PREFETCH_FACTOR = 2     # Preload 2 batches ahead

# Training params
LEARNING_RATE = 8e-5    # Scaled for larger batch (was 2e-5)
NUM_EPOCHS = 7          # More epochs (was 5)
MAX_LENGTH = 384
SEED = 42

# Loss weights
LAMBDA_DX = 1.0
LAMBDA_ALIGN = 0.5
LAMBDA_CONCEPT = 0.3

# Graph hyperparameters
GRAPH_HIDDEN_DIM = 256
GAT_HEADS = 4
GAT_LAYERS = 2
GAT_DROPOUT = 0.3

# Mixed precision
USE_AMP = torch.cuda.is_available()

np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"\n⚙️  Hyperparameters (MAXIMUM GPU OPTIMIZATION):")
print(f"   Train batch size: {TRAIN_BATCH_SIZE} (16x original!)")
print(f"   Val batch size:   {VAL_BATCH_SIZE} (16x original!)")
print(f"   num_workers:      {NUM_WORKERS}")
print(f"   prefetch_factor:  {PREFETCH_FACTOR}")
print(f"   Learning rate:    {LEARNING_RATE} (scaled 4x)")
print(f"   Epochs:           {NUM_EPOCHS}")
print(f"   FP16 precision:   {USE_AMP}")
print(f"   GAT heads:        {GAT_HEADS}")
print(f"   GAT layers:       {GAT_LAYERS}")

# ============================================================================
# LOAD DATA FROM NEWEST PHASE 1 RUN
# ============================================================================

print("\n" + "="*80)
print("📋 LOADING DATA FROM NEWEST PHASE 1 RUN")
print("="*80)

# Find NEWEST run from 10_ShifaMind (should be the optimized Phase 1)
OLD_RUN_PATH = BASE_PATH / '10_ShifaMind'
run_folders = sorted([d for d in OLD_RUN_PATH.glob('run_*') if d.is_dir()], reverse=True)
if not run_folders:
    print("❌ No Phase 1 run found!")
    sys.exit(1)

# Use NEWEST run (should be from optimized phase1_training.py)
PHASE1_RUN = run_folders[0]
OLD_SHARED = PHASE1_RUN / 'shared_data'
print(f"📁 Loading from: {PHASE1_RUN.name}")
print(f"   (Should be the NEWEST optimized Phase 1 run)")

# Load splits
with open(OLD_SHARED / 'train_split.pkl', 'rb') as f:
    df_train = pickle.load(f)
with open(OLD_SHARED / 'val_split.pkl', 'rb') as f:
    df_val = pickle.load(f)
with open(OLD_SHARED / 'test_split.pkl', 'rb') as f:
    df_test = pickle.load(f)

# Load concept labels
train_concept_labels = np.load(OLD_SHARED / 'train_concept_labels.npy')
val_concept_labels = np.load(OLD_SHARED / 'val_concept_labels.npy')
test_concept_labels = np.load(OLD_SHARED / 'test_concept_labels.npy')

# Load Top-50 codes (from ORIGINAL run, not Phase 1 run)
# Phase 1 loaded this but didn't save it to the new folder
ORIGINAL_RUN = BASE_PATH / '10_ShifaMind' / 'run_20260102_203225'
ORIGINAL_SHARED = ORIGINAL_RUN / 'shared_data'

with open(ORIGINAL_SHARED / 'top50_icd10_info.json', 'r') as f:
    top50_info = json.load(f)
    TOP_50_CODES = top50_info['top_50_codes']

# Load concept list
with open(OLD_SHARED / 'concept_list.json', 'r') as f:
    ALL_CONCEPTS = json.load(f)

NUM_CONCEPTS = len(ALL_CONCEPTS)
NUM_LABELS = len(TOP_50_CODES)

print(f"\n✅ Loaded data:")
print(f"   Train: {len(df_train):,} samples")
print(f"   Val:   {len(df_val):,} samples")
print(f"   Test:  {len(df_test):,} samples")
print(f"   Concepts: {NUM_CONCEPTS}")
print(f"   Diagnoses: {NUM_LABELS}")

# Copy to new run folder
with open(SHARED_DATA_PATH / 'top50_icd10_info.json', 'w') as f:
    json.dump(top50_info, f, indent=2)
with open(SHARED_DATA_PATH / 'concept_list.json', 'w') as f:
    json.dump(ALL_CONCEPTS, f, indent=2)

# ============================================================================
# BUILD UMLS KNOWLEDGE GRAPH FROM MRREL
# ============================================================================

print("\n" + "="*80)
print("🕸️  BUILDING UMLS KNOWLEDGE GRAPH FROM MRREL")
print("="*80)

# Install torch_geometric
try:
    import torch_geometric
    from torch_geometric.nn import GATConv
    from torch_geometric.data import Data
    print("✅ torch_geometric found")
except ImportError:
    print("Installing torch_geometric...")
    os.system('pip install -q torch-geometric')
    import torch_geometric
    from torch_geometric.nn import GATConv
    from torch_geometric.data import Data

import networkx as nx

def load_umls_cui_mappings(umls_path, concepts, icd_codes):
    """Map concepts and ICD codes to UMLS CUIs"""
    print("\n📖 Loading UMLS MRCONSO for CUI mappings...")

    mrconso_path = umls_path / 'MRCONSO.RRF'
    if not mrconso_path.exists():
        print(f"❌ MRCONSO.RRF not found at {mrconso_path}")
        return {}, {}

    concept_to_cui = {}
    icd_to_cui = {}

    start_time = time.time()
    count = 0

    with open(mrconso_path, 'r', encoding='utf-8', errors='ignore') as f:
        for line in f:
            parts = line.strip().split('|')
            if len(parts) < 15:
                continue

            cui = parts[0]
            language = parts[1]
            source = parts[11]  # SAB (Source Abbreviation)
            concept_str = parts[14].lower().strip()

            if language != 'ENG':
                continue

            # Map concepts
            for concept in concepts:
                if concept.lower() == concept_str:
                    concept_to_cui[concept] = cui

            # Map ICD-10 codes
            if source == 'ICD10CM':
                code = parts[13]  # CODE field
                # Remove dots from ICD codes (I10.0 -> I10)
                code_clean = code.replace('.', '')
                if code_clean in icd_codes:
                    icd_to_cui[code_clean] = cui

            count += 1
            if count % 500000 == 0:
                print(f"   Processed {count:,} entries...")

    elapsed = time.time() - start_time
    print(f"✅ Loaded MRCONSO in {elapsed:.1f}s")
    print(f"   Concepts mapped: {len(concept_to_cui)}/{len(concepts)} ({len(concept_to_cui)/len(concepts)*100:.1f}%)")
    print(f"   ICD codes mapped: {len(icd_to_cui)}/{len(icd_codes)} ({len(icd_to_cui)/len(icd_codes)*100:.1f}%)")

    return concept_to_cui, icd_to_cui

def load_umls_relationships(umls_path, valid_cuis):
    """Load hierarchical relationships from MRREL"""
    print("\n📖 Loading UMLS MRREL for relationships...")

    mrrel_path = umls_path / 'MRREL.RRF'
    if not mrrel_path.exists():
        print(f"❌ MRREL.RRF not found at {mrrel_path}")
        return []

    relationships = []
    valid_cui_set = set(valid_cuis)

    # Relationship types we care about
    important_rels = {'CHD', 'PAR', 'RB', 'RN', 'SY', 'isa'}
    # CHD: has child, PAR: has parent, RB: broader, RN: narrower, SY: synonym, isa: is-a

    start_time = time.time()
    count = 0

    with open(mrrel_path, 'r', encoding='utf-8', errors='ignore') as f:
        for line in f:
            parts = line.strip().split('|')
            if len(parts) < 8:
                continue

            cui1 = parts[0]
            rel = parts[3]  # REL field
            cui2 = parts[4]

            # Only keep relationships between our CUIs
            if cui1 in valid_cui_set and cui2 in valid_cui_set:
                if rel in important_rels:
                    relationships.append((cui1, rel, cui2))

            count += 1
            if count % 1000000 == 0:
                print(f"   Processed {count:,} entries...")

    elapsed = time.time() - start_time
    print(f"✅ Loaded MRREL in {elapsed:.1f}s")
    print(f"   Found {len(relationships):,} relevant relationships")

    return relationships

def build_umls_graph(concepts, icd_codes, concept_to_cui, icd_to_cui, relationships):
    """Build NetworkX graph from UMLS data"""
    print("\n🔧 Building knowledge graph...")

    G = nx.DiGraph()

    # Add concept nodes
    for concept in concepts:
        cui = concept_to_cui.get(concept)
        G.add_node(concept, node_type='concept', cui=cui)

    # Add diagnosis nodes
    for code in icd_codes:
        cui = icd_to_cui.get(code)
        G.add_node(code, node_type='diagnosis', cui=cui)

    # Build CUI to node mapping
    cui_to_nodes = defaultdict(list)
    for node, data in G.nodes(data=True):
        if data.get('cui'):
            cui_to_nodes[data['cui']].append(node)

    # Add edges from relationships
    edges_added = 0
    for cui1, rel, cui2 in relationships:
        nodes1 = cui_to_nodes.get(cui1, [])
        nodes2 = cui_to_nodes.get(cui2, [])

        for n1 in nodes1:
            for n2 in nodes2:
                if n1 != n2 and not G.has_edge(n1, n2):
                    # Weight based on relationship type
                    if rel in ['CHD', 'PAR', 'isa']:
                        weight = 1.0  # Strong hierarchical
                    elif rel in ['RB', 'RN']:
                        weight = 0.8  # Semantic
                    else:
                        weight = 0.5  # Synonym

                    G.add_edge(n1, n2, edge_type=rel, weight=weight)
                    edges_added += 1

    # Add same-chapter edges for diagnoses without CUIs
    print("\n🔗 Adding ICD chapter similarity edges...")
    chapter_groups = defaultdict(list)
    for code in icd_codes:
        chapter = code[0] if code else 'X'
        chapter_groups[chapter].append(code)

    chapter_edges = 0
    for chapter, codes in chapter_groups.items():
        for i, code1 in enumerate(codes):
            for code2 in codes[i+1:]:
                if not G.has_edge(code1, code2):
                    G.add_edge(code1, code2, edge_type='same_chapter', weight=0.3)
                    G.add_edge(code2, code1, edge_type='same_chapter', weight=0.3)
                    chapter_edges += 2

    print(f"   Added {chapter_edges} chapter similarity edges")

    print(f"\n✅ Knowledge graph built:")
    print(f"   Nodes: {G.number_of_nodes()}")
    print(f"   Edges: {G.number_of_edges()}")
    print(f"   - UMLS relationship edges: {edges_added}")
    print(f"   - Chapter similarity edges: {chapter_edges}")
    print(f"   Avg degree: {2*G.number_of_edges()/G.number_of_nodes():.1f}")

    return G

# Build graph
concept_to_cui, icd_to_cui = load_umls_cui_mappings(UMLS_PATH, ALL_CONCEPTS, TOP_50_CODES)
all_cuis = set(concept_to_cui.values()) | set(icd_to_cui.values())
relationships = load_umls_relationships(UMLS_PATH, all_cuis)
knowledge_graph = build_umls_graph(ALL_CONCEPTS, TOP_50_CODES, concept_to_cui, icd_to_cui, relationships)

# Save graph
with open(GRAPH_PATH / 'umls_knowledge_graph.gpickle', 'wb') as f:
    pickle.dump(knowledge_graph, f)
print(f"\n💾 Saved graph to {GRAPH_PATH / 'umls_knowledge_graph.gpickle'}")

# ============================================================================
# INITIALIZE NODE FEATURES WITH BIOCLINICALBERT
# ============================================================================

print("\n" + "="*80)
print("🔧 INITIALIZING NODE FEATURES WITH BIOCLINICALBERT")
print("="*80)

# Install transformers
try:
    from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
    print("✅ transformers found")
except ImportError:
    print("Installing transformers...")
    os.system('pip install -q transformers')
    from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup

MODEL_NAME = 'emilyalsentzer/Bio_ClinicalBERT'
print(f"\nLoading {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
bert_model = AutoModel.from_pretrained(MODEL_NAME).to(device)
print("✅ BioClinicalBERT loaded")

def get_bert_embedding(text, tokenizer, model, device):
    """Get [CLS] embedding for text"""
    encoding = tokenizer(
        text,
        max_length=128,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )

    with torch.no_grad():
        input_ids = encoding['input_ids'].to(device)
        attention_mask = encoding['attention_mask'].to(device)
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        cls_embedding = outputs.last_hidden_state[:, 0, :].squeeze(0)

    return cls_embedding.cpu()

print("\n🔄 Computing node embeddings...")
node_features = {}
all_nodes = list(knowledge_graph.nodes())

for node in tqdm(all_nodes, desc="Encoding nodes"):
    # For concepts: use concept text
    # For diagnoses: use ICD code description
    if knowledge_graph.nodes[node]['node_type'] == 'concept':
        text = node  # Concept name
    else:
        # Use ICD code as text (could enhance with description)
        text = f"ICD-10 diagnosis code {node}"

    embedding = get_bert_embedding(text, tokenizer, bert_model, device)
    node_features[node] = embedding

print(f"✅ Computed {len(node_features)} node embeddings (768-dim)")

# Convert to PyTorch Geometric format
def nx_to_pyg_with_features(G, node_features):
    """Convert NetworkX graph to PyG with node features"""
    all_nodes = list(G.nodes())
    node_to_idx = {node: idx for idx, node in enumerate(all_nodes)}

    # Stack node features
    x = torch.stack([node_features[node] for node in all_nodes])

    # Edge indices and attributes
    edge_index = []
    edge_attr = []
    for u, v, data in G.edges(data=True):
        edge_index.append([node_to_idx[u], node_to_idx[v]])
        edge_attr.append(data.get('weight', 1.0))

    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attr, dtype=torch.float).unsqueeze(-1)

    # Node type mask
    node_types = []
    for node in all_nodes:
        if G.nodes[node]['node_type'] == 'diagnosis':
            node_types.append(0)
        else:
            node_types.append(1)
    node_type_mask = torch.tensor(node_types, dtype=torch.long)

    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
    data.node_type_mask = node_type_mask
    data.node_to_idx = node_to_idx
    data.idx_to_node = {idx: node for node, idx in node_to_idx.items()}

    return data

graph_data = nx_to_pyg_with_features(knowledge_graph, node_features)
print(f"\n✅ PyTorch Geometric data:")
print(f"   Nodes: {graph_data.x.shape[0]}")
print(f"   Node features: {graph_data.x.shape[1]}-dim")
print(f"   Edges: {graph_data.edge_index.shape[1]}")

# Save graph data
torch.save(graph_data, GRAPH_PATH / 'graph_data.pt')
print(f"💾 Saved to {GRAPH_PATH / 'graph_data.pt'}")

# ============================================================================
# GAT ENCODER
# ============================================================================

print("\n" + "="*80)
print("🏗️  BUILDING GAT ENCODER")
print("="*80)

class GATEncoder(nn.Module):
    """GAT encoder for learning concept embeddings from knowledge graph"""
    def __init__(self, in_channels, hidden_channels, num_layers=2, heads=4, dropout=0.3):
        super().__init__()

        self.num_layers = num_layers
        self.convs = nn.ModuleList()

        # First layer: in -> hidden
        self.convs.append(GATConv(
            in_channels,
            hidden_channels // heads,  # Output per head
            heads=heads,
            dropout=dropout,
            concat=True
        ))

        # Middle layers
        for _ in range(num_layers - 2):
            self.convs.append(GATConv(
                hidden_channels,
                hidden_channels // heads,
                heads=heads,
                dropout=dropout,
                concat=True
            ))

        # Last layer: hidden -> hidden (average heads)
        if num_layers > 1:
            self.convs.append(GATConv(
                hidden_channels,
                hidden_channels,
                heads=1,
                dropout=dropout,
                concat=False
            ))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < self.num_layers - 1:
                x = F.elu(x)
                x = self.dropout(x)

        return x

gat_encoder = GATEncoder(
    in_channels=768,  # BioClinicalBERT
    hidden_channels=GRAPH_HIDDEN_DIM,
    num_layers=GAT_LAYERS,
    heads=GAT_HEADS,
    dropout=GAT_DROPOUT
).to(device)

print(f"✅ GAT encoder built:")
print(f"   Input: 768-dim (BioClinicalBERT)")
print(f"   Output: {GRAPH_HIDDEN_DIM}-dim")
print(f"   Layers: {GAT_LAYERS}")
print(f"   Heads: {GAT_HEADS}")
print(f"   Parameters: {sum(p.numel() for p in gat_encoder.parameters()):,}")

# ============================================================================
# PHASE 2 MODEL
# ============================================================================

print("\n" + "="*80)
print("🏗️  BUILDING PHASE 2 MODEL")
print("="*80)

class ShifaMind302Phase2(nn.Module):
    """
    ShifaMind v302 Phase 2: GAT + UMLS Knowledge Graph

    Architecture:
    1. BioClinicalBERT text encoder
    2. GAT graph encoder for concepts
    3. Cross-attention fusion
    4. Multiplicative bottleneck
    5. Multi-head outputs (diagnosis, concepts)
    """
    def __init__(self, bert_model, gat_encoder, graph_data, num_concepts, num_diagnoses):
        super().__init__()

        self.bert = bert_model
        self.gat = gat_encoder
        self.hidden_size = 768
        self.graph_hidden = GRAPH_HIDDEN_DIM
        self.num_concepts = num_concepts
        self.num_diagnoses = num_diagnoses

        # Store graph
        self.register_buffer('graph_x', graph_data.x)
        self.register_buffer('graph_edge_index', graph_data.edge_index)
        self.graph_node_to_idx = graph_data.node_to_idx
        self.graph_idx_to_node = graph_data.idx_to_node

        # Project graph embeddings to BERT dimension
        self.graph_proj = nn.Linear(self.graph_hidden, self.hidden_size)

        # Concept fusion: combine BERT + GAT embeddings
        self.concept_fusion = nn.Sequential(
            nn.Linear(self.hidden_size + self.hidden_size, self.hidden_size),
            nn.LayerNorm(self.hidden_size),
            nn.ReLU(),
            nn.Dropout(0.3)
        )

        # Cross-attention: text attends to enhanced concepts
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=self.hidden_size,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )

        # Multiplicative gating
        self.gate_net = nn.Sequential(
            nn.Linear(self.hidden_size * 2, self.hidden_size),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.Sigmoid()
        )

        self.layer_norm = nn.LayerNorm(self.hidden_size)

        # Output heads
        self.concept_head = nn.Linear(self.hidden_size, num_concepts)
        self.diagnosis_head = nn.Linear(self.hidden_size, num_diagnoses)

        self.dropout = nn.Dropout(0.1)

    def get_graph_concept_embeddings(self):
        """Run GAT and extract concept embeddings"""
        # Run GAT on full graph
        graph_embeddings = self.gat(self.graph_x, self.graph_edge_index)

        # Extract concept node embeddings
        concept_embeds = []
        for concept in ALL_CONCEPTS:
            if concept in self.graph_node_to_idx:
                idx = self.graph_node_to_idx[concept]
                concept_embeds.append(graph_embeddings[idx])
            else:
                # Fallback: zeros
                concept_embeds.append(torch.zeros(self.graph_hidden, device=self.graph_x.device))

        concept_embeds = torch.stack(concept_embeds)  # [num_concepts, graph_hidden]
        concept_embeds = self.graph_proj(concept_embeds)  # [num_concepts, 768]

        return concept_embeds

    def forward(self, input_ids, attention_mask, concept_embeddings_bert):
        """
        Forward pass with BERT + GAT fusion

        Args:
            input_ids: [batch, seq_len]
            attention_mask: [batch, seq_len]
            concept_embeddings_bert: [num_concepts, 768] - learned BERT concept embeddings
        """
        batch_size = input_ids.shape[0]

        # 1. Encode text with BERT
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        hidden_states = outputs.last_hidden_state  # [batch, seq_len, 768]

        # 2. Get GAT-enhanced concept embeddings
        gat_concepts = self.get_graph_concept_embeddings()  # [num_concepts, 768]

        # 3. Fuse BERT + GAT concept embeddings
        bert_concepts = concept_embeddings_bert.unsqueeze(0).expand(batch_size, -1, -1)
        gat_concepts_batched = gat_concepts.unsqueeze(0).expand(batch_size, -1, -1)

        fused_input = torch.cat([bert_concepts, gat_concepts_batched], dim=-1)  # [batch, num_concepts, 1536]
        enhanced_concepts = self.concept_fusion(fused_input)  # [batch, num_concepts, 768]

        # 4. Cross-attention: text attends to enhanced concepts
        context, attn_weights = self.cross_attention(
            query=hidden_states,
            key=enhanced_concepts,
            value=enhanced_concepts,
            need_weights=True
        )  # context: [batch, seq_len, 768]

        # 5. Multiplicative bottleneck gating
        pooled_text = hidden_states.mean(dim=1)  # [batch, 768]
        pooled_context = context.mean(dim=1)  # [batch, 768]

        gate_input = torch.cat([pooled_text, pooled_context], dim=-1)
        gate = self.gate_net(gate_input)  # [batch, 768]

        bottleneck_output = gate * pooled_context
        bottleneck_output = self.layer_norm(bottleneck_output)

        # 6. Output heads
        cls_hidden = self.dropout(pooled_text)
        concept_logits = self.concept_head(cls_hidden)
        concept_scores = torch.sigmoid(concept_logits)
        diagnosis_logits = self.diagnosis_head(bottleneck_output)

        return {
            'logits': diagnosis_logits,
            'concept_logits': concept_logits,
            'concept_scores': concept_scores,
            'gate_values': gate,
            'attention_weights': attn_weights,
            'bottleneck_output': bottleneck_output
        }

# Build model
model = ShifaMind302Phase2(
    bert_model=bert_model,
    gat_encoder=gat_encoder,
    graph_data=graph_data,
    num_concepts=NUM_CONCEPTS,
    num_diagnoses=NUM_LABELS
).to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n✅ Model built:")
print(f"   Total parameters: {total_params:,}")
print(f"   Trainable parameters: {trainable_params:,}")
print(f"   BERT: {sum(p.numel() for p in model.bert.parameters()):,}")
print(f"   GAT: {sum(p.numel() for p in model.gat.parameters()):,}")

# Create concept embedding layer
concept_embedding_layer = nn.Embedding(NUM_CONCEPTS, 768).to(device)

print(f"\n✅ Concept embedding layer created:")
print(f"   Parameters: {sum(p.numel() for p in concept_embedding_layer.parameters()):,}")

# ============================================================================
# LOAD PHASE 1 CHECKPOINT (FROM NEWEST RUN)
# ============================================================================

print("\n" + "="*80)
print("📥 LOADING PHASE 1 CHECKPOINT (FROM NEWEST RUN)")
print("="*80)

# Load from NEWEST Phase 1 run
PHASE1_CHECKPOINT = PHASE1_RUN / 'checkpoints' / 'phase1' / 'phase1_best.pt'

if PHASE1_CHECKPOINT.exists():
    print(f"📁 Loading from: {PHASE1_CHECKPOINT}")

    try:
        checkpoint = torch.load(PHASE1_CHECKPOINT, map_location=device, weights_only=False)

        # Load weights with strict=False (partial loading)
        model.load_state_dict(checkpoint['model_state_dict'], strict=False)

        print("✅ Loaded Phase 1 weights (partial transfer learning)")
        print("   - BERT encoder: ✅ Transferred")
        print("   - Concept head: ✅ Transferred (if compatible)")
        print("   - Diagnosis head: ✅ Transferred (if compatible)")
        print("   - GAT encoder: ⚠️  New (will be trained from scratch)")
        print("   - Graph projection: ⚠️  New (will be trained from scratch)")

    except Exception as e:
        print(f"⚠️  Could not load Phase 1 weights: {e}")
        print("   Training from scratch (BioClinicalBERT pretrained only)")
else:
    print(f"⚠️  Phase 1 checkpoint not found at: {PHASE1_CHECKPOINT}")
    print("   Training from scratch (BioClinicalBERT pretrained only)")

# ============================================================================
# DATASET AND TRAINING SETUP
# ============================================================================

print("\n" + "="*80)
print("📦 CREATING DATASETS (MAXIMUM GPU OPTIMIZATION)")
print("="*80)

class ConceptDataset(Dataset):
    def __init__(self, texts, labels, concept_labels, tokenizer, max_length=384):
        self.texts = texts
        self.labels = labels
        self.concept_labels = concept_labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        encoding = self.tokenizer(
            str(self.texts[idx]),
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.FloatTensor(self.labels[idx]),
            'concept_labels': torch.FloatTensor(self.concept_labels[idx])
        }

train_dataset = ConceptDataset(
    df_train['text'].tolist(),
    df_train['labels'].tolist(),
    train_concept_labels,
    tokenizer,
    MAX_LENGTH
)
val_dataset = ConceptDataset(
    df_val['text'].tolist(),
    df_val['labels'].tolist(),
    val_concept_labels,
    tokenizer,
    MAX_LENGTH
)
test_dataset = ConceptDataset(
    df_test['text'].tolist(),
    df_test['labels'].tolist(),
    test_concept_labels,
    tokenizer,
    MAX_LENGTH
)

# MAXIMUM GPU DATA LOADERS! 🚀
train_loader = DataLoader(
    train_dataset,
    batch_size=TRAIN_BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    prefetch_factor=PREFETCH_FACTOR
)
val_loader = DataLoader(
    val_dataset,
    batch_size=VAL_BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    prefetch_factor=PREFETCH_FACTOR
)
test_loader = DataLoader(
    test_dataset,
    batch_size=VAL_BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    prefetch_factor=PREFETCH_FACTOR
)

print(f"✅ Datasets created (MAXIMUM GPU OPTIMIZATION 🔥):")
print(f"   Train: {len(train_dataset)} samples, {len(train_loader)} batches (batch_size={TRAIN_BATCH_SIZE})")
print(f"   Val:   {len(val_dataset)} samples, {len(val_loader)} batches (batch_size={VAL_BATCH_SIZE})")
print(f"   Test:  {len(test_dataset)} samples, {len(test_loader)} batches (batch_size={VAL_BATCH_SIZE})")
print(f"   Expected GPU usage: 60-80GB / 96GB")
print(f"   Expected time: ~10-15 mins for 7 epochs ⚡⚡⚡")

# Loss function
class MultiObjectiveLoss(nn.Module):
    """Multi-objective loss with alignment"""
    def __init__(self, lambda_dx, lambda_align, lambda_concept):
        super().__init__()
        self.lambda_dx = lambda_dx
        self.lambda_align = lambda_align
        self.lambda_concept = lambda_concept
        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, outputs, dx_labels, concept_labels):
        # 1. Diagnosis loss
        loss_dx = self.bce(outputs['logits'], dx_labels)

        # 2. Alignment loss
        dx_probs = torch.sigmoid(outputs['logits'])
        concept_scores = outputs['concept_scores']
        loss_align = torch.abs(
            dx_probs.unsqueeze(-1) - concept_scores.unsqueeze(1)
        ).mean()

        # 3. Concept loss (use concept_logits directly)
        loss_concept = self.bce(outputs['concept_logits'], concept_labels)

        total_loss = (
            self.lambda_dx * loss_dx +
            self.lambda_align * loss_align +
            self.lambda_concept * loss_concept
        )

        return total_loss, {
            'total': total_loss.item(),
            'dx': loss_dx.item(),
            'align': loss_align.item(),
            'concept': loss_concept.item()
        }

criterion = MultiObjectiveLoss(LAMBDA_DX, LAMBDA_ALIGN, LAMBDA_CONCEPT)

# Optimizer includes both model and concept embedding layer
optimizer = torch.optim.AdamW(
    list(model.parameters()) + list(concept_embedding_layer.parameters()),
    lr=LEARNING_RATE,
    weight_decay=0.01
)

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=len(train_loader) // 2,
    num_training_steps=len(train_loader) * NUM_EPOCHS
)

# Mixed precision scaler
scaler = GradScaler() if USE_AMP else None

print(f"\n✅ Training setup:")
print(f"   Loss: {LAMBDA_DX}*Dx + {LAMBDA_ALIGN}*Align + {LAMBDA_CONCEPT}*Concept")
print(f"   Optimizer: AdamW (lr={LEARNING_RATE}, scaled for larger batches)")
print(f"   Scheduler: Linear warmup")
print(f"   FP16 mixed precision: {USE_AMP}")

# ============================================================================
# TRAINING LOOP
# ============================================================================

print("\n" + "="*80)
print("🚀 TRAINING (MAXIMUM GPU OPTIMIZATION)")
print("="*80)

def evaluate(model, dataloader, criterion, device, concept_embeddings):
    """Evaluate model"""
    model.eval()

    all_dx_preds = []
    all_dx_labels = []
    all_concept_preds = []
    all_concept_labels = []

    total_loss = 0
    loss_components = defaultdict(float)

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            dx_labels = batch['labels'].to(device)
            concept_labels = batch['concept_labels'].to(device)

            if USE_AMP:
                with autocast():
                    outputs = model(input_ids, attention_mask, concept_embeddings)
                    loss, components = criterion(outputs, dx_labels, concept_labels)
            else:
                outputs = model(input_ids, attention_mask, concept_embeddings)
                loss, components = criterion(outputs, dx_labels, concept_labels)

            total_loss += loss.item()
            for key, val in components.items():
                loss_components[key] += val

            all_dx_preds.append(torch.sigmoid(outputs['logits']).cpu().numpy())
            all_dx_labels.append(dx_labels.cpu().numpy())
            all_concept_preds.append(outputs['concept_scores'].cpu().numpy())
            all_concept_labels.append(concept_labels.cpu().numpy())

    all_dx_preds = np.vstack(all_dx_preds)
    all_dx_labels = np.vstack(all_dx_labels)
    all_concept_preds = np.vstack(all_concept_preds)
    all_concept_labels = np.vstack(all_concept_labels)

    dx_pred_binary = (all_dx_preds > 0.5).astype(int)
    concept_pred_binary = (all_concept_preds > 0.5).astype(int)

    dx_f1 = f1_score(all_dx_labels, dx_pred_binary, average='macro', zero_division=0)
    concept_f1 = f1_score(all_concept_labels, concept_pred_binary, average='macro', zero_division=0)

    return {
        'loss': total_loss / len(dataloader),
        'dx_f1': dx_f1,
        'concept_f1': concept_f1,
        'loss_dx': loss_components['dx'] / len(dataloader),
        'loss_align': loss_components['align'] / len(dataloader),
        'loss_concept': loss_components['concept'] / len(dataloader)
    }

history = {
    'train_loss': [],
    'val_loss': [],
    'val_dx_f1': [],
    'val_concept_f1': []
}

best_f1 = 0
best_epoch = 0

# Extract concept embeddings
concept_embeddings = concept_embedding_layer.weight.detach()

print(f"\n{'='*80}")
print(f"Starting training for {NUM_EPOCHS} epochs...")
print(f"{'='*80}\n")

for epoch in range(NUM_EPOCHS):
    model.train()

    train_loss = 0
    loss_components = defaultdict(float)

    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{NUM_EPOCHS}')

    for batch in pbar:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        dx_labels = batch['labels'].to(device)
        concept_labels = batch['concept_labels'].to(device)

        optimizer.zero_grad()

        if USE_AMP:
            with autocast():
                outputs = model(input_ids, attention_mask, concept_embeddings)
                loss, components = criterion(outputs, dx_labels, concept_labels)

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(input_ids, attention_mask, concept_embeddings)
            loss, components = criterion(outputs, dx_labels, concept_labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

        scheduler.step()

        train_loss += loss.item()
        for key, val in components.items():
            loss_components[key] += val

        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'dx': f'{components["dx"]:.4f}',
            'align': f'{components["align"]:.4f}'
        })

    avg_train_loss = train_loss / len(train_loader)

    print(f"\n📊 Epoch {epoch+1} Losses:")
    print(f"   Total:     {avg_train_loss:.4f}")
    print(f"   Diagnosis: {loss_components['dx']/len(train_loader):.4f}")
    print(f"   Alignment: {loss_components['align']/len(train_loader):.4f}")
    print(f"   Concept:   {loss_components['concept']/len(train_loader):.4f}")

    # Validation
    print(f"\n   Validating...")
    val_metrics = evaluate(model, val_loader, criterion, device, concept_embeddings)

    print(f"\n📈 Validation:")
    print(f"   Diagnosis F1: {val_metrics['dx_f1']:.4f}")
    print(f"   Concept F1:   {val_metrics['concept_f1']:.4f}")

    history['train_loss'].append(avg_train_loss)
    history['val_loss'].append(val_metrics['loss'])
    history['val_dx_f1'].append(val_metrics['dx_f1'])
    history['val_concept_f1'].append(val_metrics['concept_f1'])

    # Save best model
    if val_metrics['dx_f1'] > best_f1:
        best_f1 = val_metrics['dx_f1']
        best_epoch = epoch + 1
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'concept_embeddings': concept_embeddings,
            'val_dx_f1': val_metrics['dx_f1'],
            'val_metrics': val_metrics,
            'config': {
                'num_concepts': NUM_CONCEPTS,
                'num_diagnoses': NUM_LABELS,
                'graph_hidden_dim': GRAPH_HIDDEN_DIM,
                'gat_heads': GAT_HEADS,
                'gat_layers': GAT_LAYERS,
                'top_50_codes': TOP_50_CODES
            }
        }, MODELS_PATH / 'phase2_best.pt')
        print(f"   ✅ Saved best model (F1: {best_f1:.4f})")

    print()

print(f"\n{'='*80}")
print(f"✅ Training complete!")
print(f"   Best epoch: {best_epoch}")
print(f"   Best val F1: {best_f1:.4f}")
print(f"{'='*80}\n")

# Save history
with open(RESULTS_PATH / 'training_history.json', 'w') as f:
    json.dump(history, f, indent=2)

# ============================================================================
# FINAL TEST EVALUATION
# ============================================================================

print("="*80)
print("📊 FINAL TEST EVALUATION")
print("="*80)

checkpoint = torch.load(MODELS_PATH / 'phase2_best.pt', map_location=device, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

all_dx_preds, all_dx_labels = [], []
all_concept_preds, all_concept_labels = [], []

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Testing"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        dx_labels = batch['labels'].to(device)
        concept_labels = batch['concept_labels'].to(device)

        if USE_AMP:
            with autocast():
                outputs = model(input_ids, attention_mask, concept_embeddings)
        else:
            outputs = model(input_ids, attention_mask, concept_embeddings)

        all_dx_preds.append(torch.sigmoid(outputs['logits']).cpu().numpy())
        all_dx_labels.append(dx_labels.cpu().numpy())
        all_concept_preds.append(outputs['concept_scores'].cpu().numpy())
        all_concept_labels.append(concept_labels.cpu().numpy())

all_dx_preds = np.vstack(all_dx_preds)
all_dx_labels = np.vstack(all_dx_labels)
all_concept_preds = np.vstack(all_concept_preds)
all_concept_labels = np.vstack(all_concept_labels)

dx_pred_binary = (all_dx_preds > 0.5).astype(int)
concept_pred_binary = (all_concept_preds > 0.5).astype(int)

macro_f1 = f1_score(all_dx_labels, dx_pred_binary, average='macro', zero_division=0)
micro_f1 = f1_score(all_dx_labels, dx_pred_binary, average='micro', zero_division=0)
macro_precision = precision_score(all_dx_labels, dx_pred_binary, average='macro', zero_division=0)
macro_recall = recall_score(all_dx_labels, dx_pred_binary, average='macro', zero_division=0)

per_class_f1 = [
    f1_score(all_dx_labels[:, i], dx_pred_binary[:, i], zero_division=0)
    for i in range(NUM_LABELS)
]

concept_f1 = f1_score(all_concept_labels, concept_pred_binary, average='macro', zero_division=0)

print("\n" + "="*80)
print("🎉 SHIFAMIND v302 PHASE 2 - FINAL RESULTS")
print("="*80)

print("\n🎯 Diagnosis Performance (Fixed 0.5 threshold):")
print(f"   Macro F1:    {macro_f1:.4f}")
print(f"   Micro F1:    {micro_f1:.4f}")
print(f"   Precision:   {macro_precision:.4f}")
print(f"   Recall:      {macro_recall:.4f}")

print(f"\n🧠 Concept Performance:")
print(f"   Concept F1:  {concept_f1:.4f}")

print(f"\n📊 Top-10 Best Performing Diagnoses:")
top_10_best = sorted(zip(TOP_50_CODES, per_class_f1), key=lambda x: x[1], reverse=True)[:10]
for rank, (code, f1) in enumerate(top_10_best, 1):
    count = top50_info['top_50_counts'].get(code, 0)
    print(f"   {rank}. {code}: F1={f1:.4f} (n={count:,})")

print(f"\n⚠️  Note: This is with fixed 0.5 threshold!")
print(f"   Run threshold tuning next for ~2x improvement!")

# Save results
results = {
    'phase': 'ShifaMind v302 Phase 2 - GAT + UMLS (OPTIMIZED)',
    'timestamp': RUN_TIMESTAMP,
    'run_folder': str(RUN_PATH),
    'diagnosis_metrics': {
        'macro_f1': float(macro_f1),
        'micro_f1': float(micro_f1),
        'precision': float(macro_precision),
        'recall': float(macro_recall),
        'per_class_f1': {code: float(f1) for code, f1 in zip(TOP_50_CODES, per_class_f1)}
    },
    'concept_metrics': {
        'concept_f1': float(concept_f1),
        'num_concepts': NUM_CONCEPTS
    },
    'architecture': {
        'graph_construction': 'UMLS MRREL',
        'node_features': 'BioClinicalBERT embeddings',
        'gnn': 'GAT',
        'gat_heads': GAT_HEADS,
        'gat_layers': GAT_LAYERS
    },
    'optimizations': {
        'batch_size_train': TRAIN_BATCH_SIZE,
        'batch_size_val': VAL_BATCH_SIZE,
        'num_workers': NUM_WORKERS,
        'prefetch_factor': PREFETCH_FACTOR,
        'learning_rate': LEARNING_RATE,
        'fp16_precision': USE_AMP
    },
    'training_history': history
}

with open(RESULTS_PATH / 'results.json', 'w') as f:
    json.dump(results, f, indent=2)

per_label_df = pd.DataFrame({
    'icd_code': TOP_50_CODES,
    'f1_score': per_class_f1,
    'train_count': [top50_info['top_50_counts'].get(code, 0) for code in TOP_50_CODES]
})
per_label_df = per_label_df.sort_values('f1_score', ascending=False)
per_label_df.to_csv(RESULTS_PATH / 'per_label_f1.csv', index=False)

print(f"\n💾 Results saved to: {RESULTS_PATH / 'results.json'}")
print(f"💾 Per-label F1 saved to: {RESULTS_PATH / 'per_label_f1.csv'}")
print(f"💾 Best model saved to: {MODELS_PATH / 'phase2_best.pt'}")

print("\n" + "="*80)
print("✅ SHIFAMIND v302 PHASE 2 COMPLETE!")
print("="*80)

print(f"\n📍 Summary:")
print(f"   ✅ GAT + UMLS knowledge graph")
print(f"   ✅ MAXIMUM GPU optimization (batch_size={TRAIN_BATCH_SIZE}/{VAL_BATCH_SIZE})")
print(f"   ✅ Transfer learning from Phase 1")
print(f"   ✅ Macro F1: {macro_f1:.4f} (before threshold tuning)")
print(f"   ⏱️  Training time: {NUM_EPOCHS} epochs in ~10-15 mins")

print(f"\n🎯 NEXT STEP:")
print(f"   Run threshold tuning on this model for ~2x F1 improvement!")
print(f"   Expected after tuning: Macro F1 > 0.40")

print(f"\n📁 All artifacts saved to: {RUN_PATH}")
print("\nAlhamdulillah! 🤲")

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e1123cda3e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1618, in _shutdown_workers
  File "/usr/lib/python3.12/multiprocessing/process.py", line 149, in join
  File "/usr/lib/python3.12/multiprocessing/popen_fork.py", line 40, in wait
  File "/usr/lib/python3.12/multiprocessing/connection.py", line 1136, in wait
  File "/usr/lib/python3.12/selectors.py", line 415, in select
KeyboardInterrupt: 


🚀 SHIFAMIND v302 PHASE 2: GAT + UMLS (MAXIMUM GPU OPTIMIZED)
Using UMLS MRREL for rich hierarchical relationships
Training from NEWEST Phase 1 checkpoint with BioClinicalBERT
MAXIMUM GPU optimization for 96GB VRAM!

📁 Run folder: /content/drive/MyDrive/ShifaMind/11_ShifaMind_v302/run_20260215_022518
📁 Graph: /content/drive/MyDrive/ShifaMind/11_ShifaMind_v302/run_20260215_022518/phase_2_graph
📁 Models: /content/drive/MyDrive/ShifaMind/11_ShifaMind_v302/run_20260215_022518/phase_2_models
📁 Results: /content/drive/MyDrive/ShifaMind/11_ShifaMind_v302/run_20260215_022518/phase_2_results

🖥️  Device: cuda
🔥 GPU: NVIDIA RTX PRO 6000 Blackwell Server Edition
💾 VRAM: 102.0 GB

⚙️  Hyperparameters (MAXIMUM GPU OPTIMIZATION):
   Train batch size: 128 (16x original!)
   Val batch size:   256 (16x original!)
   num_workers:      8
   prefetch_factor:  2
   Learning rate:    8e-05 (scaled 4x)
   Epochs:           7
   FP16 precision:   True
   GAT heads:        4
   GAT layers:       2

📋 LOADING DA

Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]

BertModel LOAD REPORT from: emilyalsentzer/Bio_ClinicalBERT
Key                                        | Status     |  | 
-------------------------------------------+------------+--+-
cls.predictions.transform.LayerNorm.weight | UNEXPECTED |  | 
cls.predictions.transform.dense.bias       | UNEXPECTED |  | 
cls.predictions.transform.LayerNorm.bias   | UNEXPECTED |  | 
cls.predictions.bias                       | UNEXPECTED |  | 
cls.seq_relationship.weight                | UNEXPECTED |  | 
cls.seq_relationship.bias                  | UNEXPECTED |  | 
cls.predictions.decoder.weight             | UNEXPECTED |  | 
cls.predictions.transform.dense.weight     | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


✅ BioClinicalBERT loaded

🔄 Computing node embeddings...


Encoding nodes:   0%|          | 0/161 [00:00<?, ?it/s]

✅ Computed 161 node embeddings (768-dim)

✅ PyTorch Geometric data:
   Nodes: 161
   Node features: 768-dim
   Edges: 308
💾 Saved to /content/drive/MyDrive/ShifaMind/11_ShifaMind_v302/run_20260215_022518/phase_2_graph/graph_data.pt

🏗️  BUILDING GAT ENCODER
✅ GAT encoder built:
   Input: 768-dim (BioClinicalBERT)
   Output: 256-dim
   Layers: 2
   Heads: 4
   Parameters: 263,680

🏗️  BUILDING PHASE 2 MODEL

✅ Model built:
   Total parameters: 114,212,001
   Trainable parameters: 114,212,001
   BERT: 108,310,272
   GAT: 263,680

✅ Concept embedding layer created:
   Parameters: 85,248

📥 LOADING PHASE 1 CHECKPOINT (FROM NEWEST RUN)
📁 Loading from: /content/drive/MyDrive/ShifaMind/10_ShifaMind/run_20260215_013437/checkpoints/phase1/phase1_best.pt
✅ Loaded Phase 1 weights (partial transfer learning)
   - BERT encoder: ✅ Transferred
   - Concept head: ✅ Transferred (if compatible)
   - Diagnosis head: ✅ Transferred (if compatible)
   - GAT encoder: ⚠️  New (will be trained from scratch)
  

Epoch 1/7:   0%|          | 0/630 [00:00<?, ?it/s]


📊 Epoch 1 Losses:
   Total:     0.4706
   Diagnosis: 0.3030
   Alignment: 0.0906
   Concept:   0.4080

   Validating...

📈 Validation:
   Diagnosis F1: 0.1043
   Concept F1:   0.0875
   ✅ Saved best model (F1: 0.1043)



Epoch 2/7:   0%|          | 0/630 [00:00<?, ?it/s]


📊 Epoch 2 Losses:
   Total:     0.4210
   Diagnosis: 0.2571
   Alignment: 0.1108
   Concept:   0.3615

   Validating...

📈 Validation:
   Diagnosis F1: 0.1948
   Concept F1:   0.1271
   ✅ Saved best model (F1: 0.1948)



Epoch 3/7:   0%|          | 0/630 [00:00<?, ?it/s]


📊 Epoch 3 Losses:
   Total:     0.4066
   Diagnosis: 0.2424
   Alignment: 0.1190
   Concept:   0.3490

   Validating...

📈 Validation:
   Diagnosis F1: 0.2322
   Concept F1:   0.1638
   ✅ Saved best model (F1: 0.2322)



Epoch 4/7:   0%|          | 0/630 [00:00<?, ?it/s]


📊 Epoch 4 Losses:
   Total:     0.3962
   Diagnosis: 0.2317
   Alignment: 0.1243
   Concept:   0.3412

   Validating...

📈 Validation:
   Diagnosis F1: 0.2734
   Concept F1:   0.1801
   ✅ Saved best model (F1: 0.2734)



Epoch 5/7:   0%|          | 0/630 [00:00<?, ?it/s]


📊 Epoch 5 Losses:
   Total:     0.3861
   Diagnosis: 0.2210
   Alignment: 0.1287
   Concept:   0.3358

   Validating...

📈 Validation:
   Diagnosis F1: 0.2930
   Concept F1:   0.1964
   ✅ Saved best model (F1: 0.2930)



Epoch 6/7:   0%|          | 0/630 [00:00<?, ?it/s]


📊 Epoch 6 Losses:
   Total:     0.3766
   Diagnosis: 0.2108
   Alignment: 0.1327
   Concept:   0.3316

   Validating...

📈 Validation:
   Diagnosis F1: 0.3160
   Concept F1:   0.2096
   ✅ Saved best model (F1: 0.3160)



Epoch 7/7:   0%|          | 0/630 [00:00<?, ?it/s]


📊 Epoch 7 Losses:
   Total:     0.3687
   Diagnosis: 0.2022
   Alignment: 0.1355
   Concept:   0.3291

   Validating...

📈 Validation:
   Diagnosis F1: 0.3244
   Concept F1:   0.2183
   ✅ Saved best model (F1: 0.3244)


✅ Training complete!
   Best epoch: 7
   Best val F1: 0.3244

📊 FINAL TEST EVALUATION


Testing:   0%|          | 0/68 [00:00<?, ?it/s]


🎉 SHIFAMIND v302 PHASE 2 - FINAL RESULTS

🎯 Diagnosis Performance (Fixed 0.5 threshold):
   Macro F1:    0.3234
   Micro F1:    0.4265
   Precision:   0.6001
   Recall:      0.2451

🧠 Concept Performance:
   Concept F1:  0.2175

📊 Top-10 Best Performing Diagnoses:
   1. Z951: F1=0.8062 (n=6,274)
   2. I2510: F1=0.7363 (n=22,606)
   3. I10: F1=0.6781 (n=43,570)
   4. E785: F1=0.6452 (n=44,038)
   5. Z955: F1=0.6249 (n=7,759)
   6. Z7901: F1=0.5755 (n=15,321)
   7. E1122: F1=0.5715 (n=9,205)
   8. J449: F1=0.5687 (n=10,268)
   9. Z794: F1=0.5502 (n=15,275)
   10. Z86718: F1=0.5475 (n=7,598)

⚠️  Note: This is with fixed 0.5 threshold!
   Run threshold tuning next for ~2x improvement!

💾 Results saved to: /content/drive/MyDrive/ShifaMind/11_ShifaMind_v302/run_20260215_022518/phase_2_results/results.json
💾 Per-label F1 saved to: /content/drive/MyDrive/ShifaMind/11_ShifaMind_v302/run_20260215_022518/phase_2_results/per_label_f1.csv
💾 Best model saved to: /content/drive/MyDrive/ShifaMind/11

# Threshold

In [14]:
#!/usr/bin/env python3
"""
================================================================================
SHIFAMIND v302 PHASE 2: THRESHOLD TUNING (POST-HOC OPTIMIZATION)
================================================================================
Author: Mohammed Sameer Syed
University of Arizona - MS in AI Capstone

Optimizes classification thresholds per-label to maximize F1 scores for Phase 2.
Finds optimal thresholds for each of the Top-50 ICD-10 codes independently.

Handles Phase 2 GAT + UMLS architecture with graph-enhanced concepts.
================================================================================
"""

print("="*80)
print("🎯 SHIFAMIND v302 PHASE 2 - THRESHOLD TUNING")
print("="*80)

# ============================================================================
# IMPORTS
# ============================================================================

import warnings
warnings.filterwarnings('ignore')

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast

import numpy as np
import pandas as pd
from sklearn.metrics import f1_score, precision_score, recall_score
from transformers import AutoTokenizer, AutoModel

import json
import pickle
from pathlib import Path
from tqdm.auto import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\n🖥️  Device: {device}")

# ============================================================================
# CONFIGURATION - AUTO-DETECT LATEST RUN
# ============================================================================

print("\n" + "="*80)
print("⚙️  CONFIGURATION")
print("="*80)

# Auto-detect the LATEST Phase 2 run
BASE_PATH = Path('/content/drive/MyDrive/ShifaMind')
OUTPUT_BASE = BASE_PATH / '11_ShifaMind_v302'

run_folders = sorted([d for d in OUTPUT_BASE.glob('run_*') if d.is_dir()], reverse=True)
if not run_folders:
    print("❌ No Phase 2 run found!")
    exit(1)

RUN_FOLDER = run_folders[0]
CHECKPOINT_PATH = RUN_FOLDER / 'phase_2_models' / 'phase2_best.pt'
RESULTS_PATH = RUN_FOLDER / 'phase_2_results'
SHARED_DATA_PATH = RUN_FOLDER / 'shared_data'
GRAPH_PATH = RUN_FOLDER / 'phase_2_graph'

print(f"\n📁 Run Folder: {RUN_FOLDER.name}")
print(f"📁 Checkpoint: {CHECKPOINT_PATH}")

if not CHECKPOINT_PATH.exists():
    print(f"\n❌ ERROR: Checkpoint not found at {CHECKPOINT_PATH}")
    print("Please train Phase 2 first!")
    exit(1)

# ============================================================================
# LOAD TORCH GEOMETRIC
# ============================================================================

try:
    import torch_geometric
    from torch_geometric.nn import GATConv
    from torch_geometric.data import Data
    print("✅ torch_geometric found")
except ImportError:
    print("Installing torch_geometric...")
    os.system('pip install -q torch-geometric')
    import torch_geometric
    from torch_geometric.nn import GATConv
    from torch_geometric.data import Data

# ============================================================================
# LOAD ARCHITECTURE (SAME AS PHASE 2 TRAINING)
# ============================================================================

print("\n" + "="*80)
print("🏗️  LOADING ARCHITECTURE")
print("="*80)

# Load checkpoint to get config
checkpoint = torch.load(CHECKPOINT_PATH, map_location=device, weights_only=False)
config = checkpoint['config']

TOP_50_CODES = config['top_50_codes']
NUM_LABELS = len(TOP_50_CODES)

GRAPH_HIDDEN_DIM = config['graph_hidden_dim']
GAT_HEADS = config['gat_heads']
GAT_LAYERS = config['gat_layers']

print(f"✅ Loaded config:")
print(f"   Diagnoses: {NUM_LABELS}")
print(f"   GAT hidden dim: {GRAPH_HIDDEN_DIM}")
print(f"   GAT heads: {GAT_HEADS}")
print(f"   GAT layers: {GAT_LAYERS}")

# Load concept list
with open(SHARED_DATA_PATH / 'concept_list.json', 'r') as f:
    ALL_CONCEPTS = json.load(f)
NUM_CONCEPTS = len(ALL_CONCEPTS)

print(f"   Concepts: {NUM_CONCEPTS}")

# ============================================================================
# DEFINE ARCHITECTURES
# ============================================================================

class GATEncoder(nn.Module):
    """GAT encoder for learning concept embeddings from knowledge graph"""
    def __init__(self, in_channels, hidden_channels, num_layers=2, heads=4, dropout=0.3):
        super().__init__()

        self.num_layers = num_layers
        self.convs = nn.ModuleList()

        # First layer: in -> hidden
        self.convs.append(GATConv(
            in_channels,
            hidden_channels // heads,  # Output per head
            heads=heads,
            dropout=dropout,
            concat=True
        ))

        # Middle layers
        for _ in range(num_layers - 2):
            self.convs.append(GATConv(
                hidden_channels,
                hidden_channels // heads,
                heads=heads,
                dropout=dropout,
                concat=True
            ))

        # Last layer: hidden -> hidden (average heads)
        if num_layers > 1:
            self.convs.append(GATConv(
                hidden_channels,
                hidden_channels,
                heads=1,
                dropout=dropout,
                concat=False
            ))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < self.num_layers - 1:
                x = F.elu(x)
                x = self.dropout(x)

        return x


class ShifaMind302Phase2(nn.Module):
    """ShifaMind v302 Phase 2: GAT + UMLS Knowledge Graph"""
    def __init__(self, bert_model, gat_encoder, graph_data, num_concepts, num_diagnoses, graph_hidden_dim):
        super().__init__()

        self.bert = bert_model
        self.gat = gat_encoder
        self.hidden_size = 768
        self.graph_hidden = graph_hidden_dim
        self.num_concepts = num_concepts
        self.num_diagnoses = num_diagnoses

        # Store graph
        self.register_buffer('graph_x', graph_data.x)
        self.register_buffer('graph_edge_index', graph_data.edge_index)
        self.graph_node_to_idx = graph_data.node_to_idx
        self.graph_idx_to_node = graph_data.idx_to_node

        # Project graph embeddings to BERT dimension
        self.graph_proj = nn.Linear(self.graph_hidden, self.hidden_size)

        # Concept fusion: combine BERT + GAT embeddings
        self.concept_fusion = nn.Sequential(
            nn.Linear(self.hidden_size + self.hidden_size, self.hidden_size),
            nn.LayerNorm(self.hidden_size),
            nn.ReLU(),
            nn.Dropout(0.3)
        )

        # Cross-attention: text attends to enhanced concepts
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=self.hidden_size,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )

        # Multiplicative gating
        self.gate_net = nn.Sequential(
            nn.Linear(self.hidden_size * 2, self.hidden_size),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.Sigmoid()
        )

        self.layer_norm = nn.LayerNorm(self.hidden_size)

        # Output heads
        self.concept_head = nn.Linear(self.hidden_size, num_concepts)
        self.diagnosis_head = nn.Linear(self.hidden_size, num_diagnoses)

        self.dropout = nn.Dropout(0.1)

    def get_graph_concept_embeddings(self):
        """Run GAT and extract concept embeddings"""
        # Run GAT on full graph
        graph_embeddings = self.gat(self.graph_x, self.graph_edge_index)

        # Extract concept node embeddings
        concept_embeds = []
        for concept in ALL_CONCEPTS:
            if concept in self.graph_node_to_idx:
                idx = self.graph_node_to_idx[concept]
                concept_embeds.append(graph_embeddings[idx])
            else:
                # Fallback: zeros
                concept_embeds.append(torch.zeros(self.graph_hidden, device=self.graph_x.device))

        concept_embeds = torch.stack(concept_embeds)  # [num_concepts, graph_hidden]
        concept_embeds = self.graph_proj(concept_embeds)  # [num_concepts, 768]

        return concept_embeds

    def forward(self, input_ids, attention_mask, concept_embeddings_bert):
        """Forward pass with BERT + GAT fusion"""
        batch_size = input_ids.shape[0]

        # 1. Encode text with BERT
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        hidden_states = outputs.last_hidden_state  # [batch, seq_len, 768]

        # 2. Get GAT-enhanced concept embeddings
        gat_concepts = self.get_graph_concept_embeddings()  # [num_concepts, 768]

        # 3. Fuse BERT + GAT concept embeddings
        bert_concepts = concept_embeddings_bert.unsqueeze(0).expand(batch_size, -1, -1)
        gat_concepts_batched = gat_concepts.unsqueeze(0).expand(batch_size, -1, -1)

        fused_input = torch.cat([bert_concepts, gat_concepts_batched], dim=-1)  # [batch, num_concepts, 1536]
        enhanced_concepts = self.concept_fusion(fused_input)  # [batch, num_concepts, 768]

        # 4. Cross-attention: text attends to enhanced concepts
        context, attn_weights = self.cross_attention(
            query=hidden_states,
            key=enhanced_concepts,
            value=enhanced_concepts,
            need_weights=True
        )  # context: [batch, seq_len, 768]

        # 5. Multiplicative bottleneck gating
        pooled_text = hidden_states.mean(dim=1)  # [batch, 768]
        pooled_context = context.mean(dim=1)  # [batch, 768]

        gate_input = torch.cat([pooled_text, pooled_context], dim=-1)
        gate = self.gate_net(gate_input)  # [batch, 768]

        bottleneck_output = gate * pooled_context
        bottleneck_output = self.layer_norm(bottleneck_output)

        # 6. Output heads
        cls_hidden = self.dropout(pooled_text)
        concept_logits = self.concept_head(cls_hidden)
        concept_scores = torch.sigmoid(concept_logits)
        diagnosis_logits = self.diagnosis_head(bottleneck_output)

        return {
            'logits': diagnosis_logits,
            'concept_logits': concept_logits,
            'concept_scores': concept_scores,
            'gate_values': gate,
            'attention_weights': attn_weights,
            'bottleneck_output': bottleneck_output
        }


class ConceptDataset(Dataset):
    def __init__(self, texts, labels, concept_labels, tokenizer, max_length=384):
        self.texts = texts
        self.labels = labels
        self.concept_labels = concept_labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        encoding = self.tokenizer(
            str(self.texts[idx]),
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.FloatTensor(self.labels[idx]),
            'concept_labels': torch.FloatTensor(self.concept_labels[idx])
        }

print("✅ Architecture defined")

# ============================================================================
# LOAD TRAINED MODEL & DATA
# ============================================================================

print("\n" + "="*80)
print("📦 LOADING TRAINED MODEL & DATA")
print("="*80)

# Load graph data
graph_data = torch.load(GRAPH_PATH / 'graph_data.pt', map_location=device,weights_only=False)
print(f"✅ Loaded graph data:")
print(f"   Nodes: {graph_data.x.shape[0]}")
print(f"   Edges: {graph_data.edge_index.shape[1]}")

# Initialize BioClinicalBERT
print(f"\n🔄 Loading BioClinicalBERT...")
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
bert_model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT").to(device)
print("✅ BioClinicalBERT loaded")

# Initialize GAT
gat_encoder = GATEncoder(
    in_channels=768,
    hidden_channels=GRAPH_HIDDEN_DIM,
    num_layers=GAT_LAYERS,
    heads=GAT_HEADS,
    dropout=0.3
).to(device)

# Initialize model
model = ShifaMind302Phase2(
    bert_model=bert_model,
    gat_encoder=gat_encoder,
    graph_data=graph_data,
    num_concepts=NUM_CONCEPTS,
    num_diagnoses=NUM_LABELS,
    graph_hidden_dim=GRAPH_HIDDEN_DIM
).to(device)

# Load trained weights
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"✅ Model loaded with {sum(p.numel() for p in model.parameters()):,} parameters")

# Load concept embeddings
concept_embeddings = checkpoint['concept_embeddings']

# Load validation data from Phase 1 run
PHASE1_RUN_PATH = BASE_PATH / '10_ShifaMind'
phase1_runs = sorted([d for d in PHASE1_RUN_PATH.glob('run_*') if d.is_dir()], reverse=True)
if not phase1_runs:
    print("❌ No Phase 1 run found!")
    exit(1)

OLD_SHARED = phase1_runs[0] / 'shared_data'
with open(OLD_SHARED / 'val_split.pkl', 'rb') as f:
    df_val = pickle.load(f)

val_concept_labels = np.load(OLD_SHARED / 'val_concept_labels.npy')

print(f"✅ Loaded validation set: {len(df_val):,} samples")

# ============================================================================
# GET PREDICTIONS ON VALIDATION SET
# ============================================================================

print("\n" + "="*80)
print("🔮 GENERATING PREDICTIONS")
print("="*80)

val_dataset = ConceptDataset(
    df_val['text'].tolist(),
    df_val['labels'].tolist(),
    val_concept_labels,
    tokenizer
)

val_loader = DataLoader(
    val_dataset,
    batch_size=256,
    num_workers=0,
    pin_memory=True
)

all_probs = []
all_labels = []

USE_AMP = torch.cuda.is_available()

with torch.no_grad():
    for batch in tqdm(val_loader, desc="Getting predictions"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        if USE_AMP:
            with autocast():
                outputs = model(input_ids, attention_mask, concept_embeddings)
        else:
            outputs = model(input_ids, attention_mask, concept_embeddings)

        probs = torch.sigmoid(outputs['logits']).cpu().numpy()
        all_probs.append(probs)
        all_labels.append(labels.cpu().numpy())

all_probs = np.vstack(all_probs)
all_labels = np.vstack(all_labels)

print(f"✅ Predictions shape: {all_probs.shape}")
print(f"   Labels shape: {all_labels.shape}")

# ============================================================================
# THRESHOLD TUNING (PER-LABEL)
# ============================================================================

print("\n" + "="*80)
print("🎯 THRESHOLD TUNING (PER-LABEL)")
print("="*80)

THRESHOLD_CANDIDATES = np.arange(0.05, 0.96, 0.05)  # 0.05 to 0.95 in steps of 0.05

optimal_thresholds = {}
best_f1_scores = {}

print(f"Testing {len(THRESHOLD_CANDIDATES)} thresholds per label: {THRESHOLD_CANDIDATES[0]:.2f} to {THRESHOLD_CANDIDATES[-1]:.2f}")

for label_idx, code in enumerate(tqdm(TOP_50_CODES, desc="Tuning thresholds")):
    label_probs = all_probs[:, label_idx]
    label_true = all_labels[:, label_idx]

    best_f1 = 0.0
    best_threshold = 0.5

    for threshold in THRESHOLD_CANDIDATES:
        preds = (label_probs >= threshold).astype(int)
        f1 = f1_score(label_true, preds, zero_division=0)

        if f1 > best_f1:
            best_f1 = f1
            best_threshold = threshold

    optimal_thresholds[code] = float(best_threshold)
    best_f1_scores[code] = float(best_f1)

print(f"\n✅ Optimal thresholds found for {len(optimal_thresholds)} labels")

# ============================================================================
# EVALUATE WITH OPTIMAL THRESHOLDS
# ============================================================================

print("\n" + "="*80)
print("📊 EVALUATION: OPTIMAL THRESHOLDS vs FIXED 0.5")
print("="*80)

# Fixed threshold (0.5)
preds_fixed = (all_probs >= 0.5).astype(int)
f1_macro_fixed = f1_score(all_labels, preds_fixed, average='macro', zero_division=0)
f1_micro_fixed = f1_score(all_labels, preds_fixed, average='micro', zero_division=0)
precision_fixed = precision_score(all_labels, preds_fixed, average='macro', zero_division=0)
recall_fixed = recall_score(all_labels, preds_fixed, average='macro', zero_division=0)

# Optimal thresholds (per-label)
preds_optimal = np.zeros_like(all_probs, dtype=int)
for label_idx, code in enumerate(TOP_50_CODES):
    threshold = optimal_thresholds[code]
    preds_optimal[:, label_idx] = (all_probs[:, label_idx] >= threshold).astype(int)

f1_macro_optimal = f1_score(all_labels, preds_optimal, average='macro', zero_division=0)
f1_micro_optimal = f1_score(all_labels, preds_optimal, average='micro', zero_division=0)
precision_optimal = precision_score(all_labels, preds_optimal, average='macro', zero_division=0)
recall_optimal = recall_score(all_labels, preds_optimal, average='macro', zero_division=0)

print("\n" + "="*80)
print("🎉 THRESHOLD TUNING RESULTS")
print("="*80)

print("\n📊 Fixed Threshold (0.5):")
print(f"   Macro F1:    {f1_macro_fixed:.4f}")
print(f"   Micro F1:    {f1_micro_fixed:.4f}")
print(f"   Precision:   {precision_fixed:.4f}")
print(f"   Recall:      {recall_fixed:.4f}")

print("\n🎯 Optimal Thresholds (Per-Label):")
print(f"   Macro F1:    {f1_macro_optimal:.4f} (+{f1_macro_optimal - f1_macro_fixed:+.4f})")
print(f"   Micro F1:    {f1_micro_optimal:.4f} (+{f1_micro_optimal - f1_micro_fixed:+.4f})")
print(f"   Precision:   {precision_optimal:.4f} (+{precision_optimal - precision_fixed:+.4f})")
print(f"   Recall:      {recall_optimal:.4f} (+{recall_optimal - recall_fixed:+.4f})")

improvement = ((f1_macro_optimal - f1_macro_fixed) / f1_macro_fixed) * 100 if f1_macro_fixed > 0 else 0
print(f"\n🚀 Improvement: {improvement:+.2f}% relative gain in Macro F1!")

# ============================================================================
# TOP IMPROVEMENTS
# ============================================================================

print("\n📊 Top-10 Largest Improvements:")
improvements = {}
for code in TOP_50_CODES:
    idx = TOP_50_CODES.index(code)
    f1_fixed = f1_score(all_labels[:, idx], preds_fixed[:, idx], zero_division=0)
    f1_optimal = best_f1_scores[code]
    improvements[code] = f1_optimal - f1_fixed

top_improvements = sorted(improvements.items(), key=lambda x: x[1], reverse=True)[:10]

for rank, (code, improvement) in enumerate(top_improvements, 1):
    threshold = optimal_thresholds[code]
    f1_opt = best_f1_scores[code]
    print(f"   {rank}. {code}: +{improvement:.4f} (threshold={threshold:.2f}, F1={f1_opt:.4f})")

# ============================================================================
# SAVE RESULTS
# ============================================================================

print("\n" + "="*80)
print("💾 SAVING RESULTS")
print("="*80)

# Save optimal thresholds
thresholds_path = RESULTS_PATH / 'optimal_thresholds.json'
with open(thresholds_path, 'w') as f:
    json.dump(optimal_thresholds, f, indent=2)

print(f"✅ Optimal thresholds saved to: {thresholds_path}")

# Save comparison results
results = {
    'fixed_threshold': {
        'threshold': 0.5,
        'macro_f1': float(f1_macro_fixed),
        'micro_f1': float(f1_micro_fixed),
        'precision': float(precision_fixed),
        'recall': float(recall_fixed)
    },
    'optimal_thresholds': {
        'macro_f1': float(f1_macro_optimal),
        'micro_f1': float(f1_micro_optimal),
        'precision': float(precision_optimal),
        'recall': float(recall_optimal),
        'improvement_pct': float(improvement),
        'thresholds': optimal_thresholds,
        'per_label_f1': best_f1_scores
    }
}

results_path = RESULTS_PATH / 'threshold_tuning_results.json'
with open(results_path, 'w') as f:
    json.dump(results, f, indent=2)

print(f"✅ Results saved to: {results_path}")

# Save per-label comparison
comparison_df = pd.DataFrame({
    'icd_code': TOP_50_CODES,
    'optimal_threshold': [optimal_thresholds[c] for c in TOP_50_CODES],
    'f1_optimal': [best_f1_scores[c] for c in TOP_50_CODES],
    'f1_fixed_0.5': [f1_score(all_labels[:, i], preds_fixed[:, i], zero_division=0) for i in range(len(TOP_50_CODES))],
    'improvement': [improvements[c] for c in TOP_50_CODES]
})
comparison_df = comparison_df.sort_values('improvement', ascending=False)

csv_path = RESULTS_PATH / 'threshold_comparison.csv'
comparison_df.to_csv(csv_path, index=False)

print(f"✅ Comparison CSV saved to: {csv_path}")

print("\n" + "="*80)
print("✅ THRESHOLD TUNING COMPLETE!")
print("="*80)
print(f"\n📊 Summary:")
print(f"   Fixed (0.5):  Macro F1 = {f1_macro_fixed:.4f}")
print(f"   Optimal:      Macro F1 = {f1_macro_optimal:.4f} ({improvement:+.2f}%)")
print(f"   Best label:   {top_improvements[0][0]} (+{top_improvements[0][1]:.4f})")
print(f"\n💾 All results saved to: {RESULTS_PATH}")
print("\nAlhamdulillah! 🤲")

🎯 SHIFAMIND v302 PHASE 2 - THRESHOLD TUNING

🖥️  Device: cuda

⚙️  CONFIGURATION

📁 Run Folder: run_20260215_022518
📁 Checkpoint: /content/drive/MyDrive/ShifaMind/11_ShifaMind_v302/run_20260215_022518/phase_2_models/phase2_best.pt
✅ torch_geometric found

🏗️  LOADING ARCHITECTURE
✅ Loaded config:
   Diagnoses: 50
   GAT hidden dim: 256
   GAT heads: 4
   GAT layers: 2
   Concepts: 111
✅ Architecture defined

📦 LOADING TRAINED MODEL & DATA
✅ Loaded graph data:
   Nodes: 161
   Edges: 308

🔄 Loading BioClinicalBERT...


Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]

BertModel LOAD REPORT from: emilyalsentzer/Bio_ClinicalBERT
Key                                        | Status     |  | 
-------------------------------------------+------------+--+-
cls.predictions.transform.LayerNorm.weight | UNEXPECTED |  | 
cls.predictions.transform.dense.bias       | UNEXPECTED |  | 
cls.predictions.transform.LayerNorm.bias   | UNEXPECTED |  | 
cls.predictions.bias                       | UNEXPECTED |  | 
cls.seq_relationship.weight                | UNEXPECTED |  | 
cls.seq_relationship.bias                  | UNEXPECTED |  | 
cls.predictions.decoder.weight             | UNEXPECTED |  | 
cls.predictions.transform.dense.weight     | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


✅ BioClinicalBERT loaded
✅ Model loaded with 114,212,001 parameters
✅ Loaded validation set: 17,265 samples

🔮 GENERATING PREDICTIONS


Getting predictions:   0%|          | 0/68 [00:00<?, ?it/s]

✅ Predictions shape: (17265, 50)
   Labels shape: (17265, 50)

🎯 THRESHOLD TUNING (PER-LABEL)
Testing 19 thresholds per label: 0.05 to 0.95


Tuning thresholds:   0%|          | 0/50 [00:00<?, ?it/s]


✅ Optimal thresholds found for 50 labels

📊 EVALUATION: OPTIMAL THRESHOLDS vs FIXED 0.5

🎉 THRESHOLD TUNING RESULTS

📊 Fixed Threshold (0.5):
   Macro F1:    0.3246
   Micro F1:    0.4250
   Precision:   0.6099
   Recall:      0.2459

🎯 Optimal Thresholds (Per-Label):
   Macro F1:    0.4535 (++0.1289)
   Micro F1:    0.4718 (++0.0467)
   Precision:   0.4370 (+-0.1729)
   Recall:      0.5128 (++0.2669)

🚀 Improvement: +39.69% relative gain in Macro F1!

📊 Top-10 Largest Improvements:
   1. Z7902: +0.3221 (threshold=0.15, F1=0.3829)
   2. Z87891: +0.2888 (threshold=0.20, F1=0.5213)
   3. N189: +0.2885 (threshold=0.20, F1=0.3895)
   4. D649: +0.2687 (threshold=0.10, F1=0.2687)
   5. Y929: +0.2596 (threshold=0.15, F1=0.3101)
   6. J189: +0.2568 (threshold=0.15, F1=0.3398)
   7. E871: +0.2370 (threshold=0.15, F1=0.2982)
   8. Y92230: +0.2175 (threshold=0.15, F1=0.2288)
   9. D696: +0.2122 (threshold=0.10, F1=0.2122)
   10. E872: +0.2037 (threshold=0.20, F1=0.3288)

💾 SAVING RESULTS
✅ Optim

In [35]:
# Download checker script
!wget "https://raw.githubusercontent.com/SyedMohammedSameer/ShifaMind_Workspace/claude/review-pipeline-files-i2crP/check_phase2_embeddings.py?$(date +%s)" \
     -O /tmp/check_embeddings.py \
     --no-cache

# Run it
!python /tmp/check_embeddings.py


--2026-02-15 03:56:59--  https://raw.githubusercontent.com/SyedMohammedSameer/ShifaMind_Workspace/claude/review-pipeline-files-i2crP/check_phase2_embeddings.py?1771127819
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1058 (1.0K) [text/plain]
Saving to: ‘/tmp/check_embeddings.py’


2026-02-15 03:56:59 (387 MB/s) - ‘/tmp/check_embeddings.py’ saved [1058/1058]

🔍 Loading Phase 2 checkpoint...

📋 Checkpoint keys:
   - epoch
   - model_state_dict
   - optimizer_state_dict
   - concept_embeddings
   - val_dx_f1
   - val_metrics
   - config

🔑 Model state_dict keys:
   1. graph_x: torch.Size([161, 768])
   2. graph_edge_index: torch.Size([2, 308])
   3. bert.embeddings.word_embeddings.weight: torch.Size([28996, 768])
   4. bert.embeddings.position_embeddings.weig

# Retraining P3

In [40]:
#!/usr/bin/env python3
"""
================================================================================
🚀 SHIFAMIND v302 PHASE 3: RAG with FAISS (MAXIMUM GPU OPTIMIZED)
================================================================================
Using FAISS + sentence-transformers for RAG
Training from NEWEST Phase 2 checkpoint with BioClinicalBERT + GAT
MAXIMUM GPU optimization for 96GB VRAM!

Architecture:
1. Load Phase 2 model (BioClinicalBERT + Concept Bottleneck + GAT)
2. Build FAISS evidence store with clinical knowledge + MIMIC prototypes
3. Gated RAG fusion (40% cap)
4. Fine-tune with diagnosis-focused training

Target: Diagnosis F1 > 0.80
================================================================================
"""

print("="*80)
print("🚀 SHIFAMIND v302 PHASE 3 - RAG WITH FAISS (MAXIMUM GPU OPTIMIZED)")
print("="*80)

# ============================================================================
# IMPORTS & SETUP
# ============================================================================

import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import numpy as np
import pandas as pd
from sklearn.metrics import f1_score, precision_score, recall_score
from transformers import (
    AutoTokenizer, AutoModel,
    get_linear_schedule_with_warmup
)

from sentence_transformers import SentenceTransformer

try:
    import faiss
    FAISS_AVAILABLE = True
except ImportError:
    print("⚠️  FAISS not available - installing...")
    import subprocess
    subprocess.check_call(['pip', 'install', 'faiss-cpu'])
    import faiss
    FAISS_AVAILABLE = True

import json
import pickle
from pathlib import Path
from tqdm.auto import tqdm
from datetime import datetime
import sys

# ============================================================================
# REPRODUCIBILITY
# ============================================================================

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ============================================================================
# PATHS & CONFIGURATION
# ============================================================================

print("\n" + "="*80)
print("⚙️  CONFIGURATION")
print("="*80)

BASE_PATH = Path('/content/drive/MyDrive/ShifaMind')
SHIFAMIND_V302_BASE = BASE_PATH / '11_ShifaMind_v302'

# Find newest Phase 2 run folder (exclude Phase 3 runs!)
run_folders = sorted([d for d in SHIFAMIND_V302_BASE.glob('run_*') if d.is_dir() and 'phase3' not in d.name], reverse=True)
if not run_folders:
    print("❌ No Phase 2 run found!")
    sys.exit(1)

PHASE2_RUN = run_folders[0]
print(f"\n📁 Loading from Phase 2 run: {PHASE2_RUN.name}")

# Phase 2 checkpoint path (Phase 2 saves as phase2_best.pt)
PHASE2_CHECKPOINT = PHASE2_RUN / 'phase_2_models' / 'phase2_best.pt'
if not PHASE2_CHECKPOINT.exists():
    print(f"❌ Phase 2 checkpoint not found at {PHASE2_CHECKPOINT}")
    sys.exit(1)

# Find Phase 1 run for shared_data (Phase 2 doesn't copy shared_data to its run folder)
PHASE1_BASE = BASE_PATH / '10_ShifaMind'
phase1_folders = sorted([d for d in PHASE1_BASE.glob('run_*') if d.is_dir()], reverse=True)
if not phase1_folders:
    print("❌ No Phase 1 run found!")
    sys.exit(1)

PHASE1_RUN = phase1_folders[0]
OLD_SHARED = PHASE1_RUN / 'shared_data'
print(f"📁 Loading shared data from Phase 1 run: {PHASE1_RUN.name}")

# Create new Phase 3 folders
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
RUN_FOLDER = SHIFAMIND_V302_BASE / f"run_{timestamp}_phase3"
CHECKPOINT_PATH = RUN_FOLDER / 'phase_3_models'
RESULTS_PATH = RUN_FOLDER / 'phase_3_results'
EVIDENCE_PATH = RUN_FOLDER / 'evidence_store'
SHARED_DATA_PATH = RUN_FOLDER / 'shared_data'

for path in [CHECKPOINT_PATH, RESULTS_PATH, EVIDENCE_PATH, SHARED_DATA_PATH]:
    path.mkdir(parents=True, exist_ok=True)

print(f"\n📁 Run folder: {RUN_FOLDER}")
print(f"📁 Checkpoints: {CHECKPOINT_PATH}")
print(f"📁 Results: {RESULTS_PATH}")
print(f"📁 Evidence: {EVIDENCE_PATH}")

# ============================================================================
# GPU OPTIMIZATION SETTINGS
# ============================================================================

print(f"\n🖥️  Device: {device}")

if device.type == 'cuda':
    gpu_name = torch.cuda.get_device_name(0)
    total_vram = torch.cuda.get_device_properties(0).total_memory / (1024**3)
    print(f"🔥 GPU: {gpu_name}")
    print(f"💾 VRAM: {total_vram:.1f} GB")

    # MAXIMUM GPU OPTIMIZATION
    # For 96GB VRAM - go BIG!
    if total_vram >= 90:
        TRAIN_BATCH_SIZE = 96      # 12x original! (was 8)
        VAL_BATCH_SIZE = 192       # 12x original! (was 16)
        NUM_WORKERS = 12
        PREFETCH_FACTOR = 3
        LEARNING_RATE = 6e-5       # Scale with batch size
        GRADIENT_ACCUM_STEPS = 1
    elif total_vram >= 70:
        TRAIN_BATCH_SIZE = 64
        VAL_BATCH_SIZE = 128
        NUM_WORKERS = 10
        PREFETCH_FACTOR = 3
        LEARNING_RATE = 4e-5
        GRADIENT_ACCUM_STEPS = 1
    elif total_vram >= 40:
        TRAIN_BATCH_SIZE = 32
        VAL_BATCH_SIZE = 64
        NUM_WORKERS = 8
        PREFETCH_FACTOR = 2
        LEARNING_RATE = 2e-5
        GRADIENT_ACCUM_STEPS = 2
    else:
        TRAIN_BATCH_SIZE = 16
        VAL_BATCH_SIZE = 32
        NUM_WORKERS = 4
        PREFETCH_FACTOR = 2
        LEARNING_RATE = 1e-5
        GRADIENT_ACCUM_STEPS = 4
else:
    TRAIN_BATCH_SIZE = 4
    VAL_BATCH_SIZE = 8
    NUM_WORKERS = 2
    PREFETCH_FACTOR = 2
    LEARNING_RATE = 5e-6
    GRADIENT_ACCUM_STEPS = 8

# Training settings
EPOCHS = 5
USE_FP16 = device.type == 'cuda'

# RAG settings
RAG_TOP_K = 3
RAG_THRESHOLD = 0.7
RAG_GATE_MAX = 0.4
PROTOTYPES_PER_DIAGNOSIS = 20

# Loss weights
LAMBDA_DX = 2.0
LAMBDA_ALIGN = 0.5
LAMBDA_CONCEPT = 0.3

print(f"\n⚙️  Hyperparameters (MAXIMUM GPU OPTIMIZATION):")
print(f"   Train batch size: {TRAIN_BATCH_SIZE} ({TRAIN_BATCH_SIZE//8}x original!)")
print(f"   Val batch size:   {VAL_BATCH_SIZE} ({VAL_BATCH_SIZE//16}x original!)")
print(f"   Gradient accum:   {GRADIENT_ACCUM_STEPS}")
print(f"   num_workers:      {NUM_WORKERS}")
print(f"   prefetch_factor:  {PREFETCH_FACTOR}")
print(f"   Learning rate:    {LEARNING_RATE}")
print(f"   Epochs:           {EPOCHS}")
print(f"   FP16 precision:   {USE_FP16}")

print(f"\n⚖️  Loss Weights:")
print(f"   λ_dx:      {LAMBDA_DX}")
print(f"   λ_align:   {LAMBDA_ALIGN}")
print(f"   λ_concept: {LAMBDA_CONCEPT}")

print(f"\n📚 RAG Configuration:")
print(f"   Top-K:     {RAG_TOP_K}")
print(f"   Threshold: {RAG_THRESHOLD}")
print(f"   Gate Max:  {RAG_GATE_MAX}")

# ============================================================================
# LOAD DATA FROM PHASE 2
# ============================================================================

print("\n" + "="*80)
print("📋 LOADING DATA FROM PHASE 2")
print("="*80)

# Load splits
with open(OLD_SHARED / 'train_split.pkl', 'rb') as f:
    df_train = pickle.load(f)
with open(OLD_SHARED / 'val_split.pkl', 'rb') as f:
    df_val = pickle.load(f)
with open(OLD_SHARED / 'test_split.pkl', 'rb') as f:
    df_test = pickle.load(f)

# Load concept labels
train_concept_labels = np.load(OLD_SHARED / 'train_concept_labels.npy')
val_concept_labels = np.load(OLD_SHARED / 'val_concept_labels.npy')
test_concept_labels = np.load(OLD_SHARED / 'test_concept_labels.npy')

# Load Top-50 codes (from original run)
ORIGINAL_RUN = BASE_PATH / '10_ShifaMind' / 'run_20260102_203225'
ORIGINAL_SHARED = ORIGINAL_RUN / 'shared_data'

with open(ORIGINAL_SHARED / 'top50_icd10_info.json', 'r') as f:
    top50_info = json.load(f)
    TOP_50_CODES = top50_info['top_50_codes']

# Load concept list
with open(OLD_SHARED / 'concept_list.json', 'r') as f:
    ALL_CONCEPTS = json.load(f)

NUM_CONCEPTS = len(ALL_CONCEPTS)
NUM_LABELS = len(TOP_50_CODES)

print(f"\n✅ Loaded data:")
print(f"   Train: {len(df_train):,} samples")
print(f"   Val:   {len(df_val):,} samples")
print(f"   Test:  {len(df_test):,} samples")
print(f"   Concepts: {NUM_CONCEPTS}")
print(f"   Diagnoses: {NUM_LABELS}")

# Copy shared data to new run folder
import shutil
for filename in ['top50_icd10_info.json', 'concept_list.json']:
    src = ORIGINAL_SHARED / filename if filename == 'top50_icd10_info.json' else OLD_SHARED / filename
    dst = SHARED_DATA_PATH / filename
    shutil.copy(src, dst)

# ============================================================================
# BUILD EVIDENCE CORPUS
# ============================================================================

print("\n" + "="*80)
print("📚 BUILDING EVIDENCE CORPUS")
print("="*80)

def build_evidence_corpus_top50(top_50_codes, df_train):
    """
    Build evidence corpus for Top-50 diagnoses
    1. Clinical knowledge (curated)
    2. Case prototypes from MIMIC
    """
    print("\n📖 Building evidence corpus...")

    corpus = []

    # Clinical knowledge base
    clinical_knowledge_base = {
        # Respiratory (J codes)
        'J': 'Respiratory conditions: assess cough, dyspnea, chest imaging, oxygen saturation',
        'J18': 'Pneumonia diagnosis requires fever, cough, infiltrates on imaging',
        'J44': 'COPD: chronic airflow limitation, emphysema, chronic bronchitis',
        'J96': 'Respiratory failure: hypoxia, hypercapnia, requires oxygen support',

        # Cardiac (I codes)
        'I': 'Cardiovascular disease: assess chest pain, dyspnea, edema, cardiac markers',
        'I50': 'Heart failure: dyspnea, edema, elevated BNP, reduced EF on echo',
        'I25': 'Ischemic heart disease: angina, troponin, EKG changes',
        'I21': 'MI: acute chest pain, troponin elevation, ST changes',
        'I48': 'Atrial fibrillation: irregular rhythm, palpitations, stroke risk',
        'I10': 'Hypertension: elevated BP, end-organ damage assessment',

        # Infection (A codes)
        'A': 'Infectious disease: fever, cultures, antibiotics',
        'A41': 'Sepsis: organ dysfunction, hypotension, lactate >2, positive cultures',

        # Renal (N codes)
        'N': 'Renal disease: creatinine, BUN, urine output',
        'N17': 'Acute kidney injury: rapid creatinine rise, oliguria',
        'N18': 'Chronic kidney disease: GFR <60, proteinuria',
        'N39': 'Urinary tract disorders: dysuria, frequency, positive culture',

        # Metabolic (E codes)
        'E': 'Endocrine/metabolic: glucose, electrolytes, hormone levels',
        'E11': 'Type 2 diabetes: hyperglycemia, A1c >6.5%, insulin resistance',
        'E87': 'Electrolyte disorders: sodium, potassium, calcium imbalance',
        'E86': 'Volume depletion: dehydration, hypovolemia',

        # GI (K codes)
        'K': 'GI disease: abdominal pain, nausea, imaging',
        'K80': 'Cholelithiasis: RUQ pain, ultrasound showing stones',
        'K21': 'GERD: heartburn, acid reflux, esophagitis',

        # Mental health (F codes)
        'F': 'Mental health: psychiatric assessment, mood, cognition',
        'F32': 'Depression: low mood, anhedonia, sleep disturbance',
        'F41': 'Anxiety: excessive worry, panic, physical symptoms',

        # Injury (S/T codes)
        'S': 'Injury/trauma: mechanism, imaging, stabilization',
        'T': 'Poisoning/external causes: toxicology, supportive care',

        # Neoplasm (C/D codes)
        'C': 'Malignancy: histology, staging, treatment planning',
        'D': 'Benign neoplasm: imaging, biopsy if indicated',

        # Blood (D5-D7)
        'D6': 'Anemia: CBC, iron studies, transfusion if severe',

        # Neurological (G codes)
        'G': 'Neurological: mental status, focal deficits, imaging',
        'G89': 'Pain syndromes: assessment, multimodal analgesia',
    }

    print("\n📝 Adding clinical knowledge...")
    for code in top_50_codes:
        matched = False
        for key, knowledge in clinical_knowledge_base.items():
            if code.startswith(key):
                corpus.append({
                    'text': f"{code}: {knowledge}",
                    'diagnosis': code,
                    'source': 'clinical_knowledge'
                })
                matched = True
                break

        if not matched:
            corpus.append({
                'text': f"{code}: Diagnosis code requiring clinical correlation",
                'diagnosis': code,
                'source': 'clinical_knowledge'
            })

    print(f"   Added {len(corpus)} clinical knowledge passages")

    # Case prototypes from MIMIC
    print(f"\n🏥 Sampling {PROTOTYPES_PER_DIAGNOSIS} case prototypes per diagnosis...")

    for idx, dx_code in enumerate(top_50_codes):
        # Find positive samples
        code_column_exists = dx_code in df_train.columns
        if code_column_exists:
            positive_samples = df_train[df_train[dx_code] == 1]
        else:
            if 'labels' in df_train.columns:
                code_idx = top_50_codes.index(dx_code)
                positive_samples = df_train[df_train['labels'].apply(
                    lambda x: x[code_idx] == 1 if isinstance(x, list) and len(x) > code_idx else False
                )]
            else:
                positive_samples = pd.DataFrame()

        n_samples = min(len(positive_samples), PROTOTYPES_PER_DIAGNOSIS)
        if n_samples > 0:
            sampled = positive_samples.sample(n=n_samples, random_state=SEED)

            for _, row in sampled.iterrows():
                text = str(row['text'])[:500]  # Truncate for efficiency
                corpus.append({
                    'text': text,
                    'diagnosis': dx_code,
                    'source': 'mimic_prototype'
                })

        if (idx + 1) % 10 == 0:
            print(f"   Processed {idx + 1}/{len(top_50_codes)} diagnoses...")

    print(f"\n✅ Evidence corpus built:")
    print(f"   Total passages: {len(corpus)}")
    print(f"   Clinical knowledge: {len([c for c in corpus if c['source'] == 'clinical_knowledge'])}")
    print(f"   MIMIC prototypes: {len([c for c in corpus if c['source'] == 'mimic_prototype'])}")

    return corpus

evidence_corpus = build_evidence_corpus_top50(TOP_50_CODES, df_train)

with open(EVIDENCE_PATH / 'evidence_corpus.json', 'w') as f:
    json.dump(evidence_corpus, f, indent=2)

print(f"💾 Saved corpus to: {EVIDENCE_PATH / 'evidence_corpus.json'}")

# ============================================================================
# FAISS RETRIEVER
# ============================================================================

print("\n" + "="*80)
print("🔍 BUILDING FAISS RETRIEVER")
print("="*80)

class SimpleRAG:
    """Simple RAG using FAISS + sentence-transformers"""
    def __init__(self, model_name='sentence-transformers/all-MiniLM-L6-v2', top_k=3, threshold=0.7):
        print(f"\n🤖 Initializing RAG with {model_name}...")
        self.encoder = SentenceTransformer(model_name)
        self.encoder.to(device)  # Put on GPU!
        self.top_k = top_k
        self.threshold = threshold
        self.index = None
        self.documents = []
        print(f"✅ RAG encoder loaded on {device}")

    def build_index(self, documents: list):
        print(f"\n🔨 Building FAISS index from {len(documents)} documents...")
        self.documents = documents
        texts = [doc['text'] for doc in documents]

        print("   Encoding documents...")
        embeddings = self.encoder.encode(
            texts,
            show_progress_bar=True,
            convert_to_numpy=True,
            batch_size=256  # Batch encoding for speed
        )
        embeddings = embeddings.astype('float32')

        # Normalize for cosine similarity
        faiss.normalize_L2(embeddings)

        dimension = embeddings.shape[1]

        # Try GPU index first, fallback to CPU if not available
        use_gpu = False
        if device.type == 'cuda':
            try:
                print("   Building GPU FAISS index...")
                res = faiss.StandardGpuResources()
                cpu_index = faiss.IndexFlatIP(dimension)
                self.index = faiss.index_cpu_to_gpu(res, 0, cpu_index)
                use_gpu = True
            except (AttributeError, RuntimeError) as e:
                print(f"   ⚠️  GPU FAISS not available ({e.__class__.__name__})")
                print("   Falling back to CPU FAISS (still fast for 1050 docs!)")
                self.index = faiss.IndexFlatIP(dimension)
        else:
            self.index = faiss.IndexFlatIP(dimension)

        self.index.add(embeddings)

        print(f"✅ FAISS index built:")
        print(f"   Dimension: {dimension}")
        print(f"   Total vectors: {self.index.ntotal}")
        print(f"   Device: {'GPU' if use_gpu else 'CPU'}")

    def retrieve(self, query: str) -> str:
        if self.index is None:
            return ""

        query_embedding = self.encoder.encode([query], convert_to_numpy=True).astype('float32')
        faiss.normalize_L2(query_embedding)

        scores, indices = self.index.search(query_embedding, self.top_k)

        relevant_texts = []
        for score, idx in zip(scores[0], indices[0]):
            if score >= self.threshold:
                relevant_texts.append(self.documents[idx]['text'])

        return " ".join(relevant_texts) if relevant_texts else ""

rag = SimpleRAG(top_k=RAG_TOP_K, threshold=RAG_THRESHOLD)
rag.build_index(evidence_corpus)

# ============================================================================
# LOAD PHASE 2 MODEL ARCHITECTURE (GAT-based)
# ============================================================================

print("\n" + "="*80)
print("🏗️  LOADING PHASE 2 MODEL ARCHITECTURE")
print("="*80)

# Import GAT components
try:
    from torch_geometric.nn import GATConv
    from torch_geometric.data import Data
except ImportError:
    print("⚠️  Installing torch_geometric...")
    import subprocess
    subprocess.check_call(['pip', 'install', 'torch_geometric'])
    from torch_geometric.nn import GATConv
    from torch_geometric.data import Data

class GATEncoder(nn.Module):
    """GAT encoder for learning concept embeddings from knowledge graph"""
    def __init__(self, in_channels, hidden_channels, num_layers=2, heads=4, dropout=0.3):
        super().__init__()

        self.num_layers = num_layers
        self.convs = nn.ModuleList()

        # First layer: in -> hidden
        self.convs.append(GATConv(
            in_channels,
            hidden_channels // heads,  # Output per head
            heads=heads,
            dropout=dropout,
            concat=True
        ))

        # Middle layers
        for _ in range(num_layers - 2):
            self.convs.append(GATConv(
                hidden_channels,
                hidden_channels // heads,
                heads=heads,
                dropout=dropout,
                concat=True
            ))

        # Last layer: hidden -> hidden (average heads)
        if num_layers > 1:
            self.convs.append(GATConv(
                hidden_channels,
                hidden_channels,
                heads=1,
                dropout=dropout,
                concat=False
            ))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < self.num_layers - 1:
                x = F.relu(x)
                x = self.dropout(x)
        return x


class ShifaMindPhase2GAT(nn.Module):
    """
    Phase 2 model - EXACT architecture from phase2_training_optimized.py
    """
    def __init__(self, bert_model, gat_encoder, graph_data, num_concepts, num_diagnoses):
        super().__init__()

        self.bert = bert_model
        self.gat = gat_encoder
        self.hidden_size = 768
        self.graph_hidden = 256  # From Phase 2
        self.num_concepts = num_concepts
        self.num_diagnoses = num_diagnoses

        # Store graph
        self.register_buffer('graph_x', graph_data.x)
        self.register_buffer('graph_edge_index', graph_data.edge_index)
        self.graph_node_to_idx = graph_data.node_to_idx
        self.graph_idx_to_node = graph_data.idx_to_node

        # Project graph embeddings to BERT dimension
        self.graph_proj = nn.Linear(self.graph_hidden, self.hidden_size)

        # Concept fusion: combine BERT + GAT embeddings
        self.concept_fusion = nn.Sequential(
            nn.Linear(self.hidden_size + self.hidden_size, self.hidden_size),
            nn.LayerNorm(self.hidden_size),
            nn.ReLU(),
            nn.Dropout(0.3)
        )

        # Cross-attention: text attends to enhanced concepts
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=self.hidden_size,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )

        # Multiplicative gating
        self.gate_net = nn.Sequential(
            nn.Linear(self.hidden_size * 2, self.hidden_size),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.Sigmoid()
        )

        self.layer_norm = nn.LayerNorm(self.hidden_size)

        # Output heads
        self.concept_head = nn.Linear(self.hidden_size, num_concepts)
        self.diagnosis_head = nn.Linear(self.hidden_size, num_diagnoses)

        self.dropout = nn.Dropout(0.1)

    def get_graph_concept_embeddings(self):
        """Run GAT and extract concept embeddings"""
        # Run GAT on full graph
        graph_embeddings = self.gat(self.graph_x, self.graph_edge_index)

        # Extract concept node embeddings (using ALL_CONCEPTS from global scope)
        concept_embeds = []
        for concept in ALL_CONCEPTS:
            if concept in self.graph_node_to_idx:
                idx = self.graph_node_to_idx[concept]
                concept_embeds.append(graph_embeddings[idx])
            else:
                # Fallback: zeros
                concept_embeds.append(torch.zeros(self.graph_hidden, device=self.graph_x.device))

        return torch.stack(concept_embeds)  # [num_concepts, graph_hidden]

    def forward(self, input_ids, attention_mask, concept_embeddings_bert):
        """
        Forward pass with BERT + GAT fusion (EXACT Phase 2 architecture)

        Args:
            input_ids: [batch, seq_len]
            attention_mask: [batch, seq_len]
            concept_embeddings_bert: [num_concepts, 768] - learned BERT concept embeddings
        """
        batch_size = input_ids.shape[0]

        # 1. Encode text with BERT
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        hidden_states = outputs.last_hidden_state  # [batch, seq_len, 768]

        # 2. Get GAT-enhanced concept embeddings
        gat_concepts = self.get_graph_concept_embeddings()  # [num_concepts, 256]
        gat_concepts = self.graph_proj(gat_concepts)  # Project to [num_concepts, 768]

        # 3. Fuse BERT + GAT concept embeddings
        bert_concepts = concept_embeddings_bert.unsqueeze(0).expand(batch_size, -1, -1)
        gat_concepts_batched = gat_concepts.unsqueeze(0).expand(batch_size, -1, -1)

        fused_input = torch.cat([bert_concepts, gat_concepts_batched], dim=-1)  # [batch, num_concepts, 1536]
        enhanced_concepts = self.concept_fusion(fused_input)  # [batch, num_concepts, 768]

        # 4. Cross-attention: text attends to enhanced concepts
        context, attn_weights = self.cross_attention(
            query=hidden_states,
            key=enhanced_concepts,
            value=enhanced_concepts,
            need_weights=True
        )  # context: [batch, seq_len, 768]

        # 5. Multiplicative bottleneck gating
        pooled_text = hidden_states.mean(dim=1)  # [batch, 768]
        pooled_context = context.mean(dim=1)  # [batch, 768]

        gate_input = torch.cat([pooled_text, pooled_context], dim=-1)
        gate = self.gate_net(gate_input)  # [batch, 768]

        bottleneck_output = gate * pooled_context
        bottleneck_output = self.layer_norm(bottleneck_output)

        # 6. Output heads
        cls_hidden = self.dropout(pooled_text)
        concept_logits = self.concept_head(cls_hidden)
        concept_scores = torch.sigmoid(concept_logits)
        diagnosis_logits = self.diagnosis_head(bottleneck_output)

        return {
            'logits': diagnosis_logits,
            'concept_logits': concept_logits,
            'concept_scores': concept_scores,
            'gate_values': gate,
            'attention_weights': attn_weights,
            'bottleneck_output': bottleneck_output
        }

# ============================================================================
# PHASE 3 MODEL (with RAG)
# ============================================================================

class ShifaMindPhase3RAG(nn.Module):
    """Phase 3: Phase 2 + RAG integration"""
    def __init__(self, phase2_model, rag_retriever, hidden_size=768):
        super().__init__()

        self.phase2_model = phase2_model
        self.rag = rag_retriever
        self.hidden_size = hidden_size

        # RAG components
        rag_dim = 384  # all-MiniLM-L6-v2 embedding size
        self.rag_projection = nn.Linear(rag_dim, hidden_size)

        self.rag_gate = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size),
            nn.Sigmoid()
        )

    def forward(self, input_ids, attention_mask, concept_embeddings_bert, input_texts=None):
        """
        Phase 3 forward with RAG augmentation

        Args:
            input_ids: [batch, seq_len]
            attention_mask: [batch, seq_len]
            concept_embeddings_bert: [num_concepts, 768] - learned BERT concept embeddings from Phase 1
            input_texts: List of input texts for RAG retrieval (optional)
        """
        batch_size = input_ids.shape[0]

        # RAG retrieval and augmentation
        if self.rag is not None and input_texts is not None:
            # Retrieve relevant evidence
            rag_texts = [self.rag.retrieve(text) for text in input_texts]

            # Encode RAG context
            rag_embeddings = []
            for rag_text in rag_texts:
                if rag_text:
                    emb = self.rag.encoder.encode([rag_text], convert_to_numpy=True)[0]
                else:
                    emb = np.zeros(384)
                rag_embeddings.append(emb)

            rag_embeddings = torch.tensor(np.array(rag_embeddings), dtype=torch.float32).to(input_ids.device)
            rag_context = self.rag_projection(rag_embeddings)  # [batch, 768]

            # Get pooled BERT for gating
            with torch.no_grad():
                bert_outputs = self.phase2_model.bert(input_ids=input_ids, attention_mask=attention_mask)
                pooled_bert = bert_outputs.last_hidden_state.mean(dim=1)

            # Gated fusion
            gate_input = torch.cat([pooled_bert, rag_context], dim=-1)
            gate = self.rag_gate(gate_input)
            gate = gate * RAG_GATE_MAX  # Cap at 40%

            # Augment concept embeddings with RAG context
            # Broadcast rag_context to match concept embeddings shape
            rag_aug = (gate * rag_context).mean(dim=0, keepdim=True)  # [1, 768]
            concept_embeddings_augmented = concept_embeddings_bert + rag_aug  # [num_concepts, 768]
        else:
            concept_embeddings_augmented = concept_embeddings_bert

        # Run Phase 2 model with augmented concept embeddings
        outputs = self.phase2_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            concept_embeddings_bert=concept_embeddings_augmented
        )

        return outputs

# Load Phase 2 graph data
print("\n📊 Loading Phase 2 graph data...")
GRAPH_PATH = PHASE2_RUN / 'phase_2_graph'
graph_data = torch.load(GRAPH_PATH / 'graph_data.pt', map_location='cpu', weights_only=False)
# Move graph data to device immediately
graph_data = graph_data.to(device)
print(f"✅ Loaded graph: {graph_data.num_nodes} nodes, {graph_data.num_edges} edges")
print(f"✅ Graph data on device: {device}")

# Create BioClinicalBERT
print("\n🤖 Loading BioClinicalBERT...")
bert_model = AutoModel.from_pretrained('emilyalsentzer/Bio_ClinicalBERT')

# Create GAT encoder (same params as Phase 2)
print("\n🔨 Building GAT encoder...")
gat_encoder = GATEncoder(
    in_channels=768,
    hidden_channels=256,  # GRAPH_HIDDEN_DIM from Phase 2
    num_layers=2,
    heads=4,
    dropout=0.3
)

# Initialize Phase 2 model
print("\n🏗️  Initializing Phase 2 model...")
phase2_model = ShifaMindPhase2GAT(
    bert_model=bert_model,
    gat_encoder=gat_encoder,
    graph_data=graph_data,
    num_concepts=NUM_CONCEPTS,
    num_diagnoses=NUM_LABELS
)

# Load concept embeddings from Phase 2 checkpoint (best source!)
print(f"\n📥 Loading concept embeddings...")

# Load Phase 2 checkpoint (contains trained concept embeddings!)
checkpoint = torch.load(PHASE2_CHECKPOINT, map_location='cpu', weights_only=False)

concept_embeddings_bert = None
if 'concept_embeddings' in checkpoint:
    # Phase 2 saved concept embeddings at top level!
    concept_emb_tensor = checkpoint['concept_embeddings']
    concept_embeddings_bert = nn.Embedding(NUM_CONCEPTS, 768)
    concept_embeddings_bert.weight = nn.Parameter(concept_emb_tensor)
    print(f"✅ Loaded trained concept embeddings from Phase 2: {concept_emb_tensor.shape}")
else:
    print("⚠️  Phase 2 checkpoint missing concept_embeddings key")
    # Try Phase 1 as fallback
    PHASE1_CHECKPOINT = PHASE1_RUN / 'checkpoints' / 'phase1' / 'phase1_best.pt'
    if PHASE1_CHECKPOINT.exists():
        try:
            phase1_ckpt = torch.load(PHASE1_CHECKPOINT, map_location='cpu', weights_only=False)
            concept_emb_weight = phase1_ckpt['model_state_dict']['concept_embeddings.weight']
            concept_embeddings_bert = nn.Embedding(NUM_CONCEPTS, 768)
            concept_embeddings_bert.weight = nn.Parameter(concept_emb_weight)
            print(f"✅ Loaded concept embeddings from Phase 1: {concept_emb_weight.shape}")
        except (KeyError, FileNotFoundError) as e:
            print(f"⚠️  Could not load from Phase 1: {e}")

if concept_embeddings_bert is None:
    print("⚠️  Creating fresh concept embeddings (performance will be degraded!)")
    concept_embeddings_bert = nn.Embedding(NUM_CONCEPTS, 768)
    nn.init.xavier_uniform_(concept_embeddings_bert.weight)
    print("✅ Created fresh concept embeddings")

# Move concept embeddings to device
concept_embeddings_bert = concept_embeddings_bert.to(device)
print(f"✅ Concept embeddings on device: {device}")

# Load Phase 2 model weights (checkpoint already loaded above)
print(f"\n📥 Loading Phase 2 model weights...")
phase2_model.load_state_dict(checkpoint['model_state_dict'], strict=True)
print("✅ Loaded Phase 2 model weights")

# Move Phase 2 model to device BEFORE wrapping in Phase 3
phase2_model = phase2_model.to(device)
print(f"✅ Phase 2 model moved to device: {device}")

# Initialize Phase 3 model with RAG
print("\n🏗️  Initializing Phase 3 model with RAG...")
model = ShifaMindPhase3RAG(
    phase2_model=phase2_model,
    rag_retriever=rag,
    hidden_size=768
).to(device)

print(f"\n✅ ShifaMind Phase 3 model initialized")
print(f"   Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"   Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

# ============================================================================
# DATASET
# ============================================================================

print("\n" + "="*80)
print("📦 PREPARING DATASETS")
print("="*80)

class RAGDataset(Dataset):
    def __init__(self, df, tokenizer, concept_labels, top50_codes):
        self.texts = df['text'].tolist()
        # Extract multi-label columns (Top-50 diagnosis codes)
        # The CSV has diagnosis codes as column names, extract them
        self.labels = df[top50_codes].values  # [num_samples, 50]
        self.tokenizer = tokenizer
        self.concept_labels = concept_labels

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])

        encoding = self.tokenizer(
            text,
            truncation=True,
            max_length=512,
            padding='max_length',
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'text': text,
            'labels': torch.tensor(self.labels[idx], dtype=torch.float),
            'concept_labels': torch.tensor(self.concept_labels[idx], dtype=torch.float)
        }

tokenizer = AutoTokenizer.from_pretrained('emilyalsentzer/Bio_ClinicalBERT')

train_dataset = RAGDataset(df_train, tokenizer, train_concept_labels, TOP_50_CODES)
val_dataset = RAGDataset(df_val, tokenizer, val_concept_labels, TOP_50_CODES)
test_dataset = RAGDataset(df_test, tokenizer, test_concept_labels, TOP_50_CODES)

# Use pin_memory for faster GPU transfer
train_loader = DataLoader(
    train_dataset,
    batch_size=TRAIN_BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    prefetch_factor=PREFETCH_FACTOR if NUM_WORKERS > 0 else None,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=VAL_BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    prefetch_factor=PREFETCH_FACTOR if NUM_WORKERS > 0 else None,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=VAL_BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    prefetch_factor=PREFETCH_FACTOR if NUM_WORKERS > 0 else None,
    pin_memory=True
)

print(f"\n✅ Datasets ready:")
print(f"   Train batches: {len(train_loader)}")
print(f"   Val batches:   {len(val_loader)}")
print(f"   Test batches:  {len(test_loader)}")

# ============================================================================
# LOSS & OPTIMIZER
# ============================================================================

print("\n" + "="*80)
print("⚙️  TRAINING SETUP")
print("="*80)

class MultiObjectiveLoss(nn.Module):
    def __init__(self, lambda_dx, lambda_align, lambda_concept):
        super().__init__()
        self.lambda_dx = lambda_dx
        self.lambda_align = lambda_align
        self.lambda_concept = lambda_concept
        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, outputs, dx_labels, concept_labels):
        loss_dx = self.bce(outputs['logits'], dx_labels)

        dx_probs = torch.sigmoid(outputs['logits'])
        concept_scores = outputs['concept_scores']
        loss_align = torch.abs(dx_probs.unsqueeze(-1) - concept_scores.unsqueeze(1)).mean()

        loss_concept = self.bce(outputs['concept_logits'], concept_labels)

        total_loss = (
            self.lambda_dx * loss_dx +
            self.lambda_align * loss_align +
            self.lambda_concept * loss_concept
        )

        return total_loss, {
            'loss_dx': loss_dx.item(),
            'loss_align': loss_align.item(),
            'loss_concept': loss_concept.item(),
            'total_loss': total_loss.item()
        }

criterion = MultiObjectiveLoss(LAMBDA_DX, LAMBDA_ALIGN, LAMBDA_CONCEPT)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)

total_steps = len(train_loader) * EPOCHS // GRADIENT_ACCUM_STEPS
warmup_steps = int(0.1 * total_steps)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)

# FP16 training
scaler = torch.cuda.amp.GradScaler() if USE_FP16 else None

print(f"\n✅ Training setup complete")
print(f"   Optimizer: AdamW (lr={LEARNING_RATE}, weight_decay=0.01)")
print(f"   Scheduler: Linear warmup ({warmup_steps} steps) + decay ({total_steps} total)")
print(f"   Mixed precision: {USE_FP16}")

# ============================================================================
# TRAINING LOOP
# ============================================================================

print("\n" + "="*80)
print("🏋️  TRAINING PHASE 3 (RAG-ENHANCED)")
print("="*80)

best_val_f1 = 0.0
history = {'train_loss': [], 'val_loss': [], 'val_f1': [], 'val_precision': [], 'val_recall': []}

for epoch in range(EPOCHS):
    print(f"\n{'='*80}")
    print(f"📍 Epoch {epoch+1}/{EPOCHS}")
    print(f"{'='*80}")

    # ========================================================================
    # TRAINING
    # ========================================================================

    model.train()
    train_losses = []
    optimizer.zero_grad()

    pbar = tqdm(train_loader, desc="Training")
    for batch_idx, batch in enumerate(pbar):
        input_ids = batch['input_ids'].to(device, non_blocking=True)
        attention_mask = batch['attention_mask'].to(device, non_blocking=True)
        labels = batch['labels'].to(device, non_blocking=True)
        concept_labels = batch['concept_labels'].to(device, non_blocking=True)
        texts = batch['text']

        # Forward pass with mixed precision
        if USE_FP16:
            with torch.cuda.amp.autocast():
                outputs = model(input_ids, attention_mask, concept_embeddings_bert.weight, input_texts=texts)
                loss, loss_components = criterion(outputs, labels, concept_labels)
                loss = loss / GRADIENT_ACCUM_STEPS

            scaler.scale(loss).backward()
        else:
            outputs = model(input_ids, attention_mask, concept_embeddings_bert.weight, input_texts=texts)
            loss, loss_components = criterion(outputs, labels, concept_labels)
            loss = loss / GRADIENT_ACCUM_STEPS
            loss.backward()

        # Gradient accumulation
        if (batch_idx + 1) % GRADIENT_ACCUM_STEPS == 0:
            if USE_FP16:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()

            scheduler.step()
            optimizer.zero_grad()

        train_losses.append(loss.item() * GRADIENT_ACCUM_STEPS)

        # Update progress bar
        if device.type == 'cuda':
            gpu_mem = torch.cuda.memory_allocated() / (1024**3)
            gpu_max = torch.cuda.max_memory_allocated() / (1024**3)
            pbar.set_postfix({
                'loss': f"{loss.item() * GRADIENT_ACCUM_STEPS:.4f}",
                'GPU': f"{gpu_mem:.1f}/{gpu_max:.1f}GB"
            })
        else:
            pbar.set_postfix({'loss': f"{loss.item() * GRADIENT_ACCUM_STEPS:.4f}"})

    avg_train_loss = np.mean(train_losses)
    history['train_loss'].append(avg_train_loss)

    print(f"\n📊 Training complete:")
    print(f"   Avg Loss: {avg_train_loss:.4f}")
    if device.type == 'cuda':
        print(f"   Peak GPU Memory: {torch.cuda.max_memory_allocated() / (1024**3):.1f} GB")
        torch.cuda.reset_peak_memory_stats()

    # ========================================================================
    # VALIDATION
    # ========================================================================

    model.eval()
    val_losses = []
    all_preds = []
    all_labels = []

    print(f"\n🔍 Validating...")
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            input_ids = batch['input_ids'].to(device, non_blocking=True)
            attention_mask = batch['attention_mask'].to(device, non_blocking=True)
            labels = batch['labels'].to(device, non_blocking=True)
            concept_labels = batch['concept_labels'].to(device, non_blocking=True)
            texts = batch['text']

            if USE_FP16:
                with torch.cuda.amp.autocast():
                    outputs = model(input_ids, attention_mask, concept_embeddings_bert.weight, input_texts=texts)
                    loss, _ = criterion(outputs, labels, concept_labels)
            else:
                outputs = model(input_ids, attention_mask, concept_embeddings_bert.weight, input_texts=texts)
                loss, _ = criterion(outputs, labels, concept_labels)

            val_losses.append(loss.item())

            preds = (torch.sigmoid(outputs['logits']) > 0.5).cpu().numpy()
            all_preds.append(preds)
            all_labels.append(labels.cpu().numpy())

    all_preds = np.vstack(all_preds)
    all_labels = np.vstack(all_labels)

    avg_val_loss = np.mean(val_losses)
    val_f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)
    val_precision = precision_score(all_labels, all_preds, average='macro', zero_division=0)
    val_recall = recall_score(all_labels, all_preds, average='macro', zero_division=0)

    history['val_loss'].append(avg_val_loss)
    history['val_f1'].append(val_f1)
    history['val_precision'].append(val_precision)
    history['val_recall'].append(val_recall)

    print(f"\n📊 Validation Results:")
    print(f"   Loss:      {avg_val_loss:.4f}")
    print(f"   F1:        {val_f1:.4f}")
    print(f"   Precision: {val_precision:.4f}")
    print(f"   Recall:    {val_recall:.4f}")

    # Save best model
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1

        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_f1': best_val_f1,
            'history': history,
            'config': {
                'num_concepts': NUM_CONCEPTS,
                'num_labels': NUM_LABELS,
                'top_50_codes': TOP_50_CODES,
                'all_concepts': ALL_CONCEPTS,
                'rag_config': {
                    'top_k': RAG_TOP_K,
                    'threshold': RAG_THRESHOLD,
                    'gate_max': RAG_GATE_MAX
                },
                'training_config': {
                    'batch_size': TRAIN_BATCH_SIZE,
                    'learning_rate': LEARNING_RATE,
                    'epochs': EPOCHS,
                    'lambda_dx': LAMBDA_DX,
                    'lambda_align': LAMBDA_ALIGN,
                    'lambda_concept': LAMBDA_CONCEPT
                },
                'timestamp': timestamp
            }
        }, CHECKPOINT_PATH / 'best_model.pth')

        print(f"   ✅ Saved best model (F1: {best_val_f1:.4f})")

# ============================================================================
# FINAL EVALUATION
# ============================================================================

print("\n" + "="*80)
print("📊 FINAL EVALUATION ON TEST SET")
print("="*80)

# Load best model
checkpoint = torch.load(CHECKPOINT_PATH / 'best_model.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

all_preds = []
all_labels = []
all_probs = []

print("\n🔍 Evaluating on test set...")
with torch.no_grad():
    for batch in tqdm(test_loader, desc="Testing"):
        input_ids = batch['input_ids'].to(device, non_blocking=True)
        attention_mask = batch['attention_mask'].to(device, non_blocking=True)
        labels = batch['labels'].to(device, non_blocking=True)
        texts = batch['text']

        if USE_FP16:
            with torch.cuda.amp.autocast():
                outputs = model(input_ids, attention_mask, concept_embeddings_bert.weight, input_texts=texts)
        else:
            outputs = model(input_ids, attention_mask, concept_embeddings_bert.weight, input_texts=texts)

        probs = torch.sigmoid(outputs['logits']).cpu().numpy()
        preds = (probs > 0.5).astype(int)

        all_probs.append(probs)
        all_preds.append(preds)
        all_labels.append(labels.cpu().numpy())

all_probs = np.vstack(all_probs)
all_preds = np.vstack(all_preds)
all_labels = np.vstack(all_labels)

# Calculate metrics
macro_f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)
micro_f1 = f1_score(all_labels, all_preds, average='micro', zero_division=0)
macro_precision = precision_score(all_labels, all_preds, average='macro', zero_division=0)
macro_recall = recall_score(all_labels, all_preds, average='macro', zero_division=0)
per_class_f1 = f1_score(all_labels, all_preds, average=None, zero_division=0)

print(f"\n🎯 Test Set Performance:")
print(f"   Macro F1:        {macro_f1:.4f}")
print(f"   Micro F1:        {micro_f1:.4f}")
print(f"   Macro Precision: {macro_precision:.4f}")
print(f"   Macro Recall:    {macro_recall:.4f}")

# Save results
results = {
    'phase': 'ShifaMind v302 Phase 3 - RAG with FAISS',
    'timestamp': timestamp,
    'run_folder': str(RUN_FOLDER),
    'test_metrics': {
        'macro_f1': float(macro_f1),
        'micro_f1': float(micro_f1),
        'macro_precision': float(macro_precision),
        'macro_recall': float(macro_recall),
        'per_class_f1': {code: float(f1) for code, f1 in zip(TOP_50_CODES, per_class_f1)}
    },
    'validation_metrics': {
        'best_f1': float(best_val_f1),
        'final_f1': float(history['val_f1'][-1])
    },
    'architecture': 'BioClinicalBERT + Concept Bottleneck + GAT + FAISS RAG',
    'rag_config': {
        'method': 'FAISS + sentence-transformers',
        'model': 'all-MiniLM-L6-v2',
        'top_k': RAG_TOP_K,
        'threshold': RAG_THRESHOLD,
        'gate_max': RAG_GATE_MAX,
        'corpus_size': len(evidence_corpus)
    },
    'training_config': {
        'batch_size': TRAIN_BATCH_SIZE,
        'learning_rate': LEARNING_RATE,
        'epochs': EPOCHS,
        'lambda_dx': LAMBDA_DX,
        'lambda_align': LAMBDA_ALIGN,
        'lambda_concept': LAMBDA_CONCEPT,
        'fp16': USE_FP16,
        'gradient_accum_steps': GRADIENT_ACCUM_STEPS
    },
    'training_history': history
}

with open(RESULTS_PATH / 'results.json', 'w') as f:
    json.dump(results, f, indent=2)

# Save predictions
np.save(RESULTS_PATH / 'test_predictions.npy', all_preds)
np.save(RESULTS_PATH / 'test_probabilities.npy', all_probs)
np.save(RESULTS_PATH / 'test_labels.npy', all_labels)

print(f"\n💾 Results saved to: {RESULTS_PATH}")
print(f"💾 Best model saved to: {CHECKPOINT_PATH / 'best_model.pth'}")

print("\n" + "="*80)
print("✅ SHIFAMIND v302 PHASE 3 COMPLETE!")
print("="*80)
print(f"\n📍 Run folder: {RUN_FOLDER}")
print(f"   Test Macro F1: {macro_f1:.4f}")
print(f"   Test Micro F1: {micro_f1:.4f}")
print(f"   Best Val F1:   {best_val_f1:.4f}")
print("\nNext: Run phase3_threshold_optimization.py for optimal thresholds")
print("\nAlhamdulillah! 🤲")

🚀 SHIFAMIND v302 PHASE 3 - RAG WITH FAISS (MAXIMUM GPU OPTIMIZED)

⚙️  CONFIGURATION

📁 Loading from Phase 2 run: run_20260215_022518
📁 Loading shared data from Phase 1 run: run_20260215_013437

📁 Run folder: /content/drive/MyDrive/ShifaMind/11_ShifaMind_v302/run_20260215_041200_phase3
📁 Checkpoints: /content/drive/MyDrive/ShifaMind/11_ShifaMind_v302/run_20260215_041200_phase3/phase_3_models
📁 Results: /content/drive/MyDrive/ShifaMind/11_ShifaMind_v302/run_20260215_041200_phase3/phase_3_results
📁 Evidence: /content/drive/MyDrive/ShifaMind/11_ShifaMind_v302/run_20260215_041200_phase3/evidence_store

🖥️  Device: cuda
🔥 GPU: NVIDIA RTX PRO 6000 Blackwell Server Edition
💾 VRAM: 95.0 GB

⚙️  Hyperparameters (MAXIMUM GPU OPTIMIZATION):
   Train batch size: 96 (12x original!)
   Val batch size:   192 (12x original!)
   Gradient accum:   1
   num_workers:      12
   prefetch_factor:  3
   Learning rate:    6e-05
   Epochs:           5
   FP16 precision:   True

⚖️  Loss Weights:
   λ_dx:      

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e1123cda3e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1618, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/usr/lib/python3.12/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/connection.py", line 1136, in wait
    ready = selector.select(timeout)
            ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/selectors.py", line 415, in select
    fd_event_list = self._selector.poll(timeout)
    

   Processed 40/50 diagnoses...
   Processed 50/50 diagnoses...

✅ Evidence corpus built:
   Total passages: 1050
   Clinical knowledge: 50
   MIMIC prototypes: 1000
💾 Saved corpus to: /content/drive/MyDrive/ShifaMind/11_ShifaMind_v302/run_20260215_041200_phase3/evidence_store/evidence_corpus.json

🔍 BUILDING FAISS RETRIEVER

🤖 Initializing RAG with sentence-transformers/all-MiniLM-L6-v2...


Loading weights:   0%|          | 0/103 [00:00<?, ?it/s]

BertModel LOAD REPORT from: sentence-transformers/all-MiniLM-L6-v2
Key                     | Status     |  | 
------------------------+------------+--+-
embeddings.position_ids | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


✅ RAG encoder loaded on cuda

🔨 Building FAISS index from 1050 documents...
   Encoding documents...


Batches:   0%|          | 0/5 [00:00<?, ?it/s]

   Building GPU FAISS index...
   ⚠️  GPU FAISS not available (AttributeError)
   Falling back to CPU FAISS (still fast for 1050 docs!)
✅ FAISS index built:
   Dimension: 384
   Total vectors: 1050
   Device: CPU

🏗️  LOADING PHASE 2 MODEL ARCHITECTURE

📊 Loading Phase 2 graph data...
✅ Loaded graph: 161 nodes, 308 edges
✅ Graph data on device: cuda

🤖 Loading BioClinicalBERT...


Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]

BertModel LOAD REPORT from: emilyalsentzer/Bio_ClinicalBERT
Key                                        | Status     |  | 
-------------------------------------------+------------+--+-
cls.predictions.transform.LayerNorm.weight | UNEXPECTED |  | 
cls.predictions.transform.dense.bias       | UNEXPECTED |  | 
cls.predictions.transform.LayerNorm.bias   | UNEXPECTED |  | 
cls.predictions.bias                       | UNEXPECTED |  | 
cls.seq_relationship.weight                | UNEXPECTED |  | 
cls.seq_relationship.bias                  | UNEXPECTED |  | 
cls.predictions.decoder.weight             | UNEXPECTED |  | 
cls.predictions.transform.dense.weight     | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.



🔨 Building GAT encoder...

🏗️  Initializing Phase 2 model...

📥 Loading concept embeddings...
✅ Loaded trained concept embeddings from Phase 2: torch.Size([111, 768])
✅ Concept embeddings on device: cuda

📥 Loading Phase 2 model weights...
✅ Loaded Phase 2 model weights
✅ Phase 2 model moved to device: cuda

🏗️  Initializing Phase 3 model with RAG...

✅ ShifaMind Phase 3 model initialized
   Total parameters: 115,688,097
   Trainable parameters: 115,688,097

📦 PREPARING DATASETS

✅ Datasets ready:
   Train batches: 840
   Val batches:   90
   Test batches:  90

⚙️  TRAINING SETUP

✅ Training setup complete
   Optimizer: AdamW (lr=6e-05, weight_decay=0.01)
   Scheduler: Linear warmup (420 steps) + decay (4200 total)
   Mixed precision: True

🏋️  TRAINING PHASE 3 (RAG-ENHANCED)

📍 Epoch 1/5


Training:   0%|          | 0/840 [00:01<?, ?it/s]


📊 Training complete:
   Avg Loss: 0.5640
   Peak GPU Memory: 41.4 GB

🔍 Validating...


Validation:   0%|          | 0/90 [00:01<?, ?it/s]


📊 Validation Results:
   Loss:      0.6434
   F1:        0.3575
   Precision: 0.5959
   Recall:    0.2818
   ✅ Saved best model (F1: 0.3575)

📍 Epoch 2/5


Training:   0%|          | 0/840 [00:01<?, ?it/s]


📊 Training complete:
   Avg Loss: 0.5544
   Peak GPU Memory: 29.7 GB

🔍 Validating...


Validation:   0%|          | 0/90 [00:01<?, ?it/s]


📊 Validation Results:
   Loss:      0.6481
   F1:        0.3755
   Precision: 0.5809
   Recall:    0.3051
   ✅ Saved best model (F1: 0.3755)

📍 Epoch 3/5


Training:   0%|          | 0/840 [00:01<?, ?it/s]


📊 Training complete:
   Avg Loss: 0.5471
   Peak GPU Memory: 29.7 GB

🔍 Validating...


Validation:   0%|          | 0/90 [00:01<?, ?it/s]


📊 Validation Results:
   Loss:      0.6502
   F1:        0.3855
   Precision: 0.5752
   Recall:    0.3130
   ✅ Saved best model (F1: 0.3855)

📍 Epoch 4/5


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e1123cda3e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e1123cda3e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

Training:   0%|          | 0/840 [00:02<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e1123cda3e0><function _MultiProcessingDataLoaderIter.__del__ at 0x7e1123cda3e0>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()    
self._shutdown_workers()  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
Exception ignored in: Exception ignored in: Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x7e1123cda3e0><function _MultiProcessingDataLoaderIter.__del__ at 0x7e1123cda3e0><function _MultiProcessingDataLoaderIter.__del__ at 0x7e1123cda3e0>    

  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shut


📊 Training complete:
   Avg Loss: 0.5418
   Peak GPU Memory: 29.7 GB

🔍 Validating...


Validation:   0%|          | 0/90 [00:01<?, ?it/s]


📊 Validation Results:
   Loss:      0.6524
   F1:        0.3846
   Precision: 0.5730
   Recall:    0.3133

📍 Epoch 5/5


Training:   0%|          | 0/840 [00:01<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e1123cda3e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e1123cda3e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16


📊 Training complete:
   Avg Loss: 0.5379
   Peak GPU Memory: 29.7 GB

🔍 Validating...


Validation:   0%|          | 0/90 [00:01<?, ?it/s]


📊 Validation Results:
   Loss:      0.6537
   F1:        0.3848
   Precision: 0.5740
   Recall:    0.3128

📊 FINAL EVALUATION ON TEST SET


UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL numpy._core.multiarray.scalar was not an allowed global by default. Please use `torch.serialization.add_safe_globals([numpy._core.multiarray.scalar])` or the `torch.serialization.safe_globals([numpy._core.multiarray.scalar])` context manager to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.

# Threshold

In [48]:
#!/usr/bin/env python3
"""
ShifaMind v302 - Phase 3 Threshold Tuning & Final Evaluation
Loads trained Phase 3 model, runs test evaluation, and tunes thresholds
"""

import os
import sys
import json
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm
from datetime import datetime
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score,
    f1_score, hamming_loss, classification_report,
    multilabel_confusion_matrix, roc_auc_score
)
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import Dataset, DataLoader
import torch_geometric.nn as gnn
from torch_geometric.data import Data
from sentence_transformers import SentenceTransformer
import faiss

# ============================================================================
# CONSTANTS
# ============================================================================

RAG_GATE_MAX = 0.4  # Cap RAG influence at 40%

# ============================================================================
# CONFIGURATION
# ============================================================================

print("\n" + "="*80)
print("🎯 SHIFAMIND v302 PHASE 3 - THRESHOLD TUNING & FINAL EVALUATION")
print("="*80)

# Paths - EXACT same logic as phase3_training_optimized.py
BASE_PATH = Path('/content/drive/MyDrive/ShifaMind')
SHIFAMIND_V302_BASE = BASE_PATH / '11_ShifaMind_v302'

# Find most recent Phase 3 run
phase3_runs = sorted([d for d in SHIFAMIND_V302_BASE.glob('run_*_phase3') if d.is_dir()])
if not phase3_runs:
    raise FileNotFoundError("❌ No Phase 3 runs found!")

PHASE3_RUN = phase3_runs[-1]
print(f"\n📁 Loading from Phase 3 run: {PHASE3_RUN.name}")

CHECKPOINT_PATH = PHASE3_RUN / 'phase_3_models'
RESULTS_PATH = PHASE3_RUN / 'phase_3_results'
EVIDENCE_PATH = PHASE3_RUN / 'evidence_store'

# Find Phase 2 run (most recent in v302 folder, excluding phase3)
phase2_runs = sorted([d for d in SHIFAMIND_V302_BASE.glob('run_*') if d.is_dir() and '_phase3' not in d.name])
if not phase2_runs:
    raise FileNotFoundError("❌ No Phase 2 runs found!")
PHASE2_RUN = phase2_runs[-1]
print(f"📁 Phase 2 run: {PHASE2_RUN.name}")

GRAPH_PATH = PHASE2_RUN / 'phase_2_graph'
PHASE2_CHECKPOINT = PHASE2_RUN / 'phase_2_models' / 'phase2_best.pt'

# Phase 1 shared data - look in 10_ShifaMind folder (NOT 11_ShifaMind_v302!)
PHASE1_BASE = BASE_PATH / '10_ShifaMind'
phase1_folders = sorted([d for d in PHASE1_BASE.glob('run_*') if d.is_dir()], reverse=True)
if not phase1_folders:
    raise FileNotFoundError("❌ No Phase 1 run found in 10_ShifaMind!")
PHASE1_RUN = phase1_folders[0]
print(f"📁 Phase 1 shared data: {PHASE1_RUN.name} (from 10_ShifaMind)")

SHARED_DATA_PATH = PHASE1_RUN / 'shared_data'

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\n🖥️  Device: {device}")

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"🔥 GPU: {gpu_name}")
    print(f"💾 VRAM: {gpu_memory:.1f} GB")

# Constants
NUM_CONCEPTS = 111
NUM_DIAGNOSES = 50
SEED = 42
BATCH_SIZE = 192  # Large batch for inference

# Set seeds
torch.manual_seed(SEED)
np.random.seed(SEED)

# ============================================================================
# LOAD SHARED DATA
# ============================================================================

print("\n" + "="*80)
print("📋 LOADING DATA")
print("="*80)

# Load splits
with open(SHARED_DATA_PATH / 'train_split.pkl', 'rb') as f:
    df_train = pickle.load(f)
with open(SHARED_DATA_PATH / 'val_split.pkl', 'rb') as f:
    df_val = pickle.load(f)
with open(SHARED_DATA_PATH / 'test_split.pkl', 'rb') as f:
    df_test = pickle.load(f)

# Load concept/diagnosis mappings - EXACT same as phase3_training_optimized.py
# Load concept list from Phase 1 run
with open(SHARED_DATA_PATH / 'concept_list.json', 'r') as f:
    ALL_CONCEPTS = json.load(f)

# Load Top-50 codes from original run (same as Phase 3 training)
ORIGINAL_RUN = BASE_PATH / '10_ShifaMind' / 'run_20260102_203225'
ORIGINAL_SHARED = ORIGINAL_RUN / 'shared_data'
with open(ORIGINAL_SHARED / 'top50_icd10_info.json', 'r') as f:
    top50_info = json.load(f)
    top_50_codes = top50_info['top_50_codes']

print(f"\n✅ Loaded data:")
print(f"   Train: {len(df_train):,} samples")
print(f"   Val:   {len(df_val):,} samples")
print(f"   Test:  {len(df_test):,} samples")
print(f"   Concepts: {len(ALL_CONCEPTS)}")
print(f"   Diagnoses: {len(top_50_codes)}")

# ============================================================================
# DATASET
# ============================================================================

class RAGDataset(Dataset):
    def __init__(self, df, tokenizer, concept_labels, top50_codes):
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.concept_labels = concept_labels
        self.top50_codes = top50_codes

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        text = str(row['text'])

        # Multi-label format
        if isinstance(self.top50_codes, list):
            labels = row[self.top50_codes].values.astype(np.float32)
        else:
            labels = row['labels']

        # Tokenize
        encoding = self.tokenizer(
            text,
            max_length=512,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.FloatTensor(labels),
            'text': text
        }

# ============================================================================
# RAG RETRIEVER
# ============================================================================

class RAGRetriever:
    def __init__(self, corpus_path, model_name='sentence-transformers/all-MiniLM-L6-v2'):
        print(f"\n🤖 Initializing RAG with {model_name}...")
        self.encoder = SentenceTransformer(model_name)
        self.encoder = self.encoder.to(device)
        print(f"✅ RAG encoder loaded on {device}")

        # Load corpus
        with open(corpus_path, 'r') as f:
            self.corpus = json.load(f)

        print(f"📚 Loaded {len(self.corpus)} evidence passages")

        # Build FAISS index
        self._build_index()

    def _build_index(self):
        print("\n🔨 Building FAISS index...")
        texts = [item['text'] for item in self.corpus]
        embeddings = self.encoder.encode(texts, show_progress_bar=True, convert_to_numpy=True)

        # Try GPU FAISS first
        try:
            res = faiss.StandardGpuResources()
            index_flat = faiss.IndexFlatIP(embeddings.shape[1])
            self.index = faiss.index_cpu_to_gpu(res, 0, index_flat)
            self.index.add(embeddings.astype('float32'))
            print(f"✅ GPU FAISS index built: {embeddings.shape[1]} dims, {len(embeddings)} vectors")
        except (AttributeError, RuntimeError):
            # Fallback to CPU
            self.index = faiss.IndexFlatIP(embeddings.shape[1])
            self.index.add(embeddings.astype('float32'))
            print(f"✅ CPU FAISS index built: {embeddings.shape[1]} dims, {len(embeddings)} vectors")

    def retrieve(self, query_text, top_k=3):
        """Retrieve top-k relevant passages"""
        query_emb = self.encoder.encode([query_text], convert_to_numpy=True)
        scores, indices = self.index.search(query_emb.astype('float32'), top_k)

        results = []
        for score, idx in zip(scores[0], indices[0]):
            results.append({
                'text': self.corpus[idx]['text'],
                'score': float(score),
                'diagnosis': self.corpus[idx].get('diagnosis', 'N/A'),
                'source': self.corpus[idx].get('source', 'unknown')
            })

        return results

# ============================================================================
# MODEL ARCHITECTURES
# ============================================================================

class GATEncoder(nn.Module):
    """GAT encoder for learning concept embeddings from knowledge graph"""
    def __init__(self, in_channels, hidden_channels, num_layers=2, heads=4, dropout=0.3):
        super().__init__()

        self.num_layers = num_layers
        self.convs = nn.ModuleList()

        # First layer: in -> hidden
        self.convs.append(gnn.GATConv(
            in_channels,
            hidden_channels // heads,  # Output per head
            heads=heads,
            dropout=dropout,
            concat=True
        ))

        # Middle layers
        for _ in range(num_layers - 2):
            self.convs.append(gnn.GATConv(
                hidden_channels,
                hidden_channels // heads,
                heads=heads,
                dropout=dropout,
                concat=True
            ))

        # Last layer: hidden -> hidden (average heads)
        if num_layers > 1:
            self.convs.append(gnn.GATConv(
                hidden_channels,
                hidden_channels,
                heads=1,
                dropout=dropout,
                concat=False
            ))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < self.num_layers - 1:
                x = F.elu(x)
                x = self.dropout(x)

        return x


class ShifaMindPhase2GAT(nn.Module):
    """Phase 2 model - EXACT architecture"""
    def __init__(self, bert_model, gat_encoder, graph_data, num_concepts, num_diagnoses):
        super().__init__()

        self.bert = bert_model
        self.gat = gat_encoder
        self.hidden_size = 768
        self.graph_hidden = 256
        self.num_concepts = num_concepts
        self.num_diagnoses = num_diagnoses

        # Store graph
        self.register_buffer('graph_x', graph_data.x)
        self.register_buffer('graph_edge_index', graph_data.edge_index)
        self.graph_node_to_idx = graph_data.node_to_idx
        self.graph_idx_to_node = graph_data.idx_to_node

        # Project graph embeddings to BERT dimension
        self.graph_proj = nn.Linear(self.graph_hidden, self.hidden_size)

        # Concept fusion
        self.concept_fusion = nn.Sequential(
            nn.Linear(self.hidden_size + self.hidden_size, self.hidden_size),
            nn.LayerNorm(self.hidden_size),
            nn.ReLU(),
            nn.Dropout(0.3)
        )

        # Cross-attention
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=self.hidden_size,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )

        # Multiplicative gating
        self.gate_net = nn.Sequential(
            nn.Linear(self.hidden_size * 2, self.hidden_size),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.Sigmoid()
        )

        self.layer_norm = nn.LayerNorm(self.hidden_size)

        # Output heads
        self.concept_head = nn.Linear(self.hidden_size, num_concepts)
        self.diagnosis_head = nn.Linear(self.hidden_size, num_diagnoses)

        self.dropout = nn.Dropout(0.1)

    def get_graph_concept_embeddings(self):
        """Run GAT and extract concept embeddings"""
        graph_embeddings = self.gat(self.graph_x, self.graph_edge_index)

        concept_embeds = []
        for concept in ALL_CONCEPTS:
            if concept in self.graph_node_to_idx:
                idx = self.graph_node_to_idx[concept]
                concept_embeds.append(graph_embeddings[idx])
            else:
                concept_embeds.append(torch.zeros(self.graph_hidden, device=self.graph_x.device))

        return torch.stack(concept_embeds)

    def forward(self, input_ids, attention_mask, concept_embeddings_bert):
        batch_size = input_ids.shape[0]

        # 1. Encode text with BERT
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        hidden_states = outputs.last_hidden_state

        # 2. Get GAT-enhanced concept embeddings
        gat_concepts = self.get_graph_concept_embeddings()
        gat_concepts = self.graph_proj(gat_concepts)  # Project to 768-dim

        # 3. Fuse BERT + GAT concept embeddings
        bert_concepts = concept_embeddings_bert.unsqueeze(0).expand(batch_size, -1, -1)
        gat_concepts_batched = gat_concepts.unsqueeze(0).expand(batch_size, -1, -1)

        fused_input = torch.cat([bert_concepts, gat_concepts_batched], dim=-1)
        enhanced_concepts = self.concept_fusion(fused_input)

        # 4. Cross-attention
        context, attn_weights = self.cross_attention(
            query=hidden_states,
            key=enhanced_concepts,
            value=enhanced_concepts,
            need_weights=True
        )

        # 5. Multiplicative bottleneck gating
        pooled_text = hidden_states.mean(dim=1)
        pooled_context = context.mean(dim=1)

        gate_input = torch.cat([pooled_text, pooled_context], dim=-1)
        gate = self.gate_net(gate_input)

        bottleneck_output = gate * pooled_context
        bottleneck_output = self.layer_norm(bottleneck_output)

        # 6. Output heads
        cls_hidden = self.dropout(pooled_text)
        concept_logits = self.concept_head(cls_hidden)
        concept_scores = torch.sigmoid(concept_logits)
        diagnosis_logits = self.diagnosis_head(bottleneck_output)

        return {
            'logits': diagnosis_logits,
            'concept_logits': concept_logits,
            'concept_scores': concept_scores,
            'gate_values': gate,
            'attention_weights': attn_weights,
            'bottleneck_output': bottleneck_output
        }


class ShifaMindPhase3RAG(nn.Module):
    """Phase 3: Phase 2 + RAG integration"""
    def __init__(self, phase2_model, rag_retriever, hidden_size=768):
        super().__init__()

        self.phase2_model = phase2_model
        self.rag = rag_retriever
        self.hidden_size = hidden_size

        # RAG components
        rag_dim = 384  # all-MiniLM-L6-v2 embedding size
        self.rag_projection = nn.Linear(rag_dim, hidden_size)

        self.rag_gate = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size),
            nn.Sigmoid()
        )

    def forward(self, input_ids, attention_mask, concept_embeddings_bert, texts=None, input_texts=None, use_rag=True):
        """
        Phase 3 forward with RAG augmentation

        Args:
            input_ids: [batch, seq_len]
            attention_mask: [batch, seq_len]
            concept_embeddings_bert: [num_concepts, 768] - learned BERT concept embeddings from Phase 1
            texts: List of input texts for RAG retrieval (optional) - alias for input_texts
            input_texts: List of input texts for RAG retrieval (optional)
            use_rag: Whether to use RAG (default True)
        """
        # Support both parameter names
        if input_texts is None:
            input_texts = texts

        batch_size = input_ids.shape[0]

        # RAG retrieval and augmentation
        if use_rag and self.rag is not None and input_texts is not None:
            # Retrieve relevant evidence
            rag_texts = [self.rag.retrieve(text) for text in input_texts]

            # Encode RAG context
            rag_embeddings = []
            for rag_text in rag_texts:
                if rag_text:
                    emb = self.rag.encoder.encode([rag_text], convert_to_numpy=True)[0]
                else:
                    emb = np.zeros(384)
                rag_embeddings.append(emb)

            rag_embeddings = torch.tensor(np.array(rag_embeddings), dtype=torch.float32).to(input_ids.device)
            rag_context = self.rag_projection(rag_embeddings)  # [batch, 768]

            # Get pooled BERT for gating
            with torch.no_grad():
                bert_outputs = self.phase2_model.bert(input_ids=input_ids, attention_mask=attention_mask)
                pooled_bert = bert_outputs.last_hidden_state.mean(dim=1)

            # Gated fusion
            gate_input = torch.cat([pooled_bert, rag_context], dim=-1)
            gate = self.rag_gate(gate_input)
            gate = gate * RAG_GATE_MAX  # Cap at 40%

            # Augment concept embeddings with RAG context
            # Broadcast rag_context to match concept embeddings shape
            rag_aug = (gate * rag_context).mean(dim=0, keepdim=True)  # [1, 768]
            concept_embeddings_augmented = concept_embeddings_bert + rag_aug  # [num_concepts, 768]
        else:
            concept_embeddings_augmented = concept_embeddings_bert

        # Run Phase 2 model with augmented concept embeddings
        outputs = self.phase2_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            concept_embeddings_bert=concept_embeddings_augmented
        )

        return outputs

# ============================================================================
# LOAD MODEL
# ============================================================================

print("\n" + "="*80)
print("🏗️  LOADING PHASE 3 MODEL")
print("="*80)

# Load graph
print("\n📊 Loading Phase 2 graph data...")
graph_data = torch.load(GRAPH_PATH / 'graph_data.pt', map_location='cpu', weights_only=False)
graph_data = graph_data.to(device)
print(f"✅ Loaded graph: {graph_data.num_nodes} nodes, {graph_data.num_edges} edges")

# Load BERT
print("\n🤖 Loading BioClinicalBERT...")
tokenizer = AutoTokenizer.from_pretrained('emilyalsentzer/Bio_ClinicalBERT')
bert_model = AutoModel.from_pretrained('emilyalsentzer/Bio_ClinicalBERT')

# Build GAT
print("🔨 Building GAT encoder...")
gat_encoder = GATEncoder(
    in_channels=graph_data.num_node_features,
    hidden_channels=256,
    num_layers=2,
    heads=4,
    dropout=0.3
)

# Initialize Phase 2 model
print("🏗️  Initializing Phase 2 model...")
phase2_model = ShifaMindPhase2GAT(
    bert_model=bert_model,
    gat_encoder=gat_encoder,
    graph_data=graph_data,
    num_concepts=NUM_CONCEPTS,
    num_diagnoses=NUM_DIAGNOSES
)

# Load trained concept embeddings from Phase 2
print("\n📥 Loading Phase 2 checkpoint...")
phase2_checkpoint = torch.load(PHASE2_CHECKPOINT, map_location='cpu', weights_only=False)

if 'concept_embeddings' in phase2_checkpoint:
    concept_emb_tensor = phase2_checkpoint['concept_embeddings']
    print(f"✅ Loaded trained concept embeddings: {concept_emb_tensor.shape}")
    concept_embeddings_bert = nn.Embedding(NUM_CONCEPTS, 768)
    concept_embeddings_bert.weight = nn.Parameter(concept_emb_tensor)
else:
    raise KeyError("❌ No concept_embeddings in Phase 2 checkpoint!")

concept_embeddings_bert = concept_embeddings_bert.to(device)

# Load Phase 2 weights
phase2_model.load_state_dict(phase2_checkpoint['model_state_dict'], strict=True)
phase2_model = phase2_model.to(device)
print("✅ Phase 2 model loaded")

# Load RAG retriever
print("\n📚 Loading RAG retriever...")
rag_retriever = RAGRetriever(EVIDENCE_PATH / 'evidence_corpus.json')

# Initialize Phase 3 model
print("\n🏗️  Initializing Phase 3 model...")
model = ShifaMindPhase3RAG(phase2_model, rag_retriever, hidden_size=768)
model = model.to(device)

# Load Phase 3 checkpoint
print("\n📥 Loading Phase 3 best model...")
checkpoint = torch.load(CHECKPOINT_PATH / 'best_model.pth', map_location=device, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"✅ Loaded best model from epoch {checkpoint['epoch']}")
print(f"   Best Val F1: {checkpoint['best_f1']:.4f}")

# ============================================================================
# PREPARE DATASETS
# ============================================================================

print("\n" + "="*80)
print("📦 PREPARING DATASETS")
print("="*80)

val_dataset = RAGDataset(df_val, tokenizer, ALL_CONCEPTS, top_50_codes)
test_dataset = RAGDataset(df_test, tokenizer, ALL_CONCEPTS, top_50_codes)

val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

print(f"\n✅ Datasets ready:")
print(f"   Val batches:  {len(val_loader)}")
print(f"   Test batches: {len(test_loader)}")

# ============================================================================
# FINAL TEST EVALUATION (with default threshold 0.5)
# ============================================================================

print("\n" + "="*80)
print("📊 FINAL TEST EVALUATION (threshold=0.5)")
print("="*80)

all_preds = []
all_labels = []
all_probs = []

print("\n🔍 Evaluating on test set...")
with torch.no_grad():
    for batch in tqdm(test_loader, desc="Testing"):
        input_ids = batch['input_ids'].to(device, non_blocking=True)
        attention_mask = batch['attention_mask'].to(device, non_blocking=True)
        labels = batch['labels'].to(device, non_blocking=True)
        texts = batch['text']

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            concept_embeddings_bert=concept_embeddings_bert.weight,
            texts=texts,
            use_rag=True
        )

        probs = torch.sigmoid(outputs['logits'])
        preds = (probs > 0.5).float()

        all_preds.append(preds.cpu().numpy())
        all_labels.append(labels.cpu().numpy())
        all_probs.append(probs.cpu().numpy())

all_preds = np.vstack(all_preds)
all_labels = np.vstack(all_labels)
all_probs = np.vstack(all_probs)

# Compute metrics
test_f1_micro = f1_score(all_labels, all_preds, average='micro', zero_division=0)
test_f1_macro = f1_score(all_labels, all_preds, average='macro', zero_division=0)
test_f1_weighted = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
test_precision = precision_score(all_labels, all_preds, average='micro', zero_division=0)
test_recall = recall_score(all_labels, all_preds, average='micro', zero_division=0)
test_accuracy = accuracy_score(all_labels, all_preds)

print(f"\n📊 Test Results (threshold=0.5):")
print(f"   F1 (micro):    {test_f1_micro:.4f}")
print(f"   F1 (macro):    {test_f1_macro:.4f}")
print(f"   F1 (weighted): {test_f1_weighted:.4f}")
print(f"   Precision:     {test_precision:.4f}")
print(f"   Recall:        {test_recall:.4f}")
print(f"   Accuracy:      {test_accuracy:.4f}")

# ============================================================================
# THRESHOLD TUNING ON VALIDATION SET
# ============================================================================

print("\n" + "="*80)
print("🎯 THRESHOLD TUNING ON VALIDATION SET")
print("="*80)

# Get validation predictions
val_preds_all = []
val_labels_all = []
val_probs_all = []

print("\n🔍 Getting validation predictions...")
with torch.no_grad():
    for batch in tqdm(val_loader, desc="Validation"):
        input_ids = batch['input_ids'].to(device, non_blocking=True)
        attention_mask = batch['attention_mask'].to(device, non_blocking=True)
        labels = batch['labels'].to(device, non_blocking=True)
        texts = batch['text']

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            concept_embeddings_bert=concept_embeddings_bert.weight,
            texts=texts,
            use_rag=True
        )

        probs = torch.sigmoid(outputs['logits'])

        val_probs_all.append(probs.cpu().numpy())
        val_labels_all.append(labels.cpu().numpy())

val_probs_all = np.vstack(val_probs_all)
val_labels_all = np.vstack(val_labels_all)

# Tune threshold
print("\n🔍 Searching for optimal threshold...")
thresholds = np.arange(0.1, 0.9, 0.05)
best_threshold = 0.5
best_f1 = 0.0

threshold_results = []

for threshold in tqdm(thresholds, desc="Tuning"):
    val_preds_thresh = (val_probs_all > threshold).astype(int)
    f1 = f1_score(val_labels_all, val_preds_thresh, average='micro', zero_division=0)

    threshold_results.append({
        'threshold': float(threshold),
        'f1_micro': float(f1)
    })

    if f1 > best_f1:
        best_f1 = f1
        best_threshold = threshold

print(f"\n✅ Optimal threshold found: {best_threshold:.2f}")
print(f"   Val F1 (micro): {best_f1:.4f}")

# Save threshold tuning results
with open(RESULTS_PATH / 'threshold_tuning.json', 'w') as f:
    json.dump({
        'best_threshold': float(best_threshold),
        'best_val_f1': float(best_f1),
        'all_thresholds': threshold_results
    }, f, indent=2)

print(f"\n💾 Saved threshold tuning results to: {RESULTS_PATH / 'threshold_tuning.json'}")

# ============================================================================
# FINAL TEST EVALUATION WITH TUNED THRESHOLD
# ============================================================================

print("\n" + "="*80)
print(f"📊 FINAL TEST EVALUATION (threshold={best_threshold:.2f})")
print("="*80)

# Apply tuned threshold
test_preds_tuned = (all_probs > best_threshold).astype(int)

# Compute metrics
tuned_f1_micro = f1_score(all_labels, test_preds_tuned, average='micro', zero_division=0)
tuned_f1_macro = f1_score(all_labels, test_preds_tuned, average='macro', zero_division=0)
tuned_f1_weighted = f1_score(all_labels, test_preds_tuned, average='weighted', zero_division=0)
tuned_precision = precision_score(all_labels, test_preds_tuned, average='micro', zero_division=0)
tuned_recall = recall_score(all_labels, test_preds_tuned, average='micro', zero_division=0)
tuned_accuracy = accuracy_score(all_labels, test_preds_tuned)

print(f"\n📊 Final Test Results (tuned threshold={best_threshold:.2f}):")
print(f"   F1 (micro):    {tuned_f1_micro:.4f}")
print(f"   F1 (macro):    {tuned_f1_macro:.4f}")
print(f"   F1 (weighted): {tuned_f1_weighted:.4f}")
print(f"   Precision:     {tuned_precision:.4f}")
print(f"   Recall:        {tuned_recall:.4f}")
print(f"   Accuracy:      {tuned_accuracy:.4f}")

print("\n📈 Improvement from threshold tuning:")
print(f"   ΔF1 (micro):    {tuned_f1_micro - test_f1_micro:+.4f}")
print(f"   ΔF1 (macro):    {tuned_f1_macro - test_f1_macro:+.4f}")
print(f"   ΔPrecision:     {tuned_precision - test_precision:+.4f}")
print(f"   ΔRecall:        {tuned_recall - test_recall:+.4f}")

# ============================================================================
# SAVE FINAL RESULTS
# ============================================================================

print("\n" + "="*80)
print("💾 SAVING RESULTS")
print("="*80)

final_results = {
    'model': 'ShifaMind v302 Phase 3 (RAG)',
    'timestamp': datetime.now().isoformat(),
    'test_samples': len(df_test),

    'results_default_threshold': {
        'threshold': 0.5,
        'f1_micro': float(test_f1_micro),
        'f1_macro': float(test_f1_macro),
        'f1_weighted': float(test_f1_weighted),
        'precision': float(test_precision),
        'recall': float(test_recall),
        'accuracy': float(test_accuracy)
    },

    'threshold_tuning': {
        'best_threshold': float(best_threshold),
        'val_f1': float(best_f1)
    },

    'results_tuned_threshold': {
        'threshold': float(best_threshold),
        'f1_micro': float(tuned_f1_micro),
        'f1_macro': float(tuned_f1_macro),
        'f1_weighted': float(tuned_f1_weighted),
        'precision': float(tuned_precision),
        'recall': float(tuned_recall),
        'accuracy': float(tuned_accuracy)
    },

    'improvement': {
        'delta_f1_micro': float(tuned_f1_micro - test_f1_micro),
        'delta_f1_macro': float(tuned_f1_macro - test_f1_macro),
        'delta_precision': float(tuned_precision - test_precision),
        'delta_recall': float(tuned_recall - test_recall)
    }
}

# Save results
with open(RESULTS_PATH / 'final_test_results.json', 'w') as f:
    json.dump(final_results, f, indent=2)

print(f"✅ Saved final results to: {RESULTS_PATH / 'final_test_results.json'}")

# Per-diagnosis results
per_diagnosis_metrics = []
for i, dx_code in enumerate(top_50_codes):
    dx_f1 = f1_score(all_labels[:, i], test_preds_tuned[:, i], zero_division=0)
    dx_precision = precision_score(all_labels[:, i], test_preds_tuned[:, i], zero_division=0)
    dx_recall = recall_score(all_labels[:, i], test_preds_tuned[:, i], zero_division=0)

    per_diagnosis_metrics.append({
        'diagnosis': dx_code,
        'f1': float(dx_f1),
        'precision': float(dx_precision),
        'recall': float(dx_recall),
        'support': int(all_labels[:, i].sum())
    })

with open(RESULTS_PATH / 'per_diagnosis_metrics.json', 'w') as f:
    json.dump(per_diagnosis_metrics, f, indent=2)

print(f"✅ Saved per-diagnosis metrics to: {RESULTS_PATH / 'per_diagnosis_metrics.json'}")

# Save predictions
np.save(RESULTS_PATH / 'test_predictions.npy', test_preds_tuned)
np.save(RESULTS_PATH / 'test_probabilities.npy', all_probs)
np.save(RESULTS_PATH / 'test_labels.npy', all_labels)

print(f"✅ Saved predictions to: {RESULTS_PATH}")

print("\n" + "="*80)
print("✅ PHASE 3 EVALUATION COMPLETE!")
print("="*80)
print(f"\n🎯 Best threshold: {best_threshold:.2f}")
print(f"📊 Final F1 (micro): {tuned_f1_micro:.4f}")
print(f"📊 Final F1 (macro): {tuned_f1_macro:.4f}")
print(f"🎉 All results saved to: {RESULTS_PATH}")
print("\n" + "="*80)


🎯 SHIFAMIND v302 PHASE 3 - THRESHOLD TUNING & FINAL EVALUATION

📁 Loading from Phase 3 run: run_20260215_041200_phase3
📁 Phase 2 run: run_20260215_022518
📁 Phase 1 shared data: run_20260215_013437 (from 10_ShifaMind)

🖥️  Device: cuda
🔥 GPU: NVIDIA RTX PRO 6000 Blackwell Server Edition
💾 VRAM: 102.0 GB

📋 LOADING DATA

✅ Loaded data:
   Train: 80,572 samples
   Val:   17,265 samples
   Test:  17,266 samples
   Concepts: 111
   Diagnoses: 50

🏗️  LOADING PHASE 3 MODEL

📊 Loading Phase 2 graph data...
✅ Loaded graph: 161 nodes, 308 edges

🤖 Loading BioClinicalBERT...


Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]

BertModel LOAD REPORT from: emilyalsentzer/Bio_ClinicalBERT
Key                                        | Status     |  | 
-------------------------------------------+------------+--+-
cls.predictions.transform.LayerNorm.weight | UNEXPECTED |  | 
cls.predictions.transform.dense.bias       | UNEXPECTED |  | 
cls.predictions.transform.LayerNorm.bias   | UNEXPECTED |  | 
cls.predictions.bias                       | UNEXPECTED |  | 
cls.seq_relationship.weight                | UNEXPECTED |  | 
cls.seq_relationship.bias                  | UNEXPECTED |  | 
cls.predictions.decoder.weight             | UNEXPECTED |  | 
cls.predictions.transform.dense.weight     | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


🔨 Building GAT encoder...
🏗️  Initializing Phase 2 model...

📥 Loading Phase 2 checkpoint...
✅ Loaded trained concept embeddings: torch.Size([111, 768])
✅ Phase 2 model loaded

📚 Loading RAG retriever...

🤖 Initializing RAG with sentence-transformers/all-MiniLM-L6-v2...


Loading weights:   0%|          | 0/103 [00:00<?, ?it/s]

BertModel LOAD REPORT from: sentence-transformers/all-MiniLM-L6-v2
Key                     | Status     |  | 
------------------------+------------+--+-
embeddings.position_ids | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


✅ RAG encoder loaded on cuda
📚 Loaded 1050 evidence passages

🔨 Building FAISS index...


Batches:   0%|          | 0/33 [00:00<?, ?it/s]

✅ CPU FAISS index built: 384 dims, 1050 vectors

🏗️  Initializing Phase 3 model...

📥 Loading Phase 3 best model...
✅ Loaded best model from epoch 2
   Best Val F1: 0.3855

📦 PREPARING DATASETS

✅ Datasets ready:
   Val batches:  90
   Test batches: 90

📊 FINAL TEST EVALUATION (threshold=0.5)

🔍 Evaluating on test set...


Testing: 100%|██████████| 90/90 [04:46<00:00,  3.18s/it]



📊 Test Results (threshold=0.5):
   F1 (micro):    0.4868
   F1 (macro):    0.3850
   F1 (weighted): 0.4514
   Precision:     0.6753
   Recall:        0.3806
   Accuracy:      0.0434

🎯 THRESHOLD TUNING ON VALIDATION SET

🔍 Getting validation predictions...


Validation: 100%|██████████| 90/90 [04:46<00:00,  3.18s/it]



🔍 Searching for optimal threshold...


Tuning: 100%|██████████| 16/16 [00:00<00:00, 39.00it/s]



✅ Optimal threshold found: 0.30
   Val F1 (micro): 0.5244

💾 Saved threshold tuning results to: /content/drive/MyDrive/ShifaMind/11_ShifaMind_v302/run_20260215_041200_phase3/phase_3_results/threshold_tuning.json

📊 FINAL TEST EVALUATION (threshold=0.30)

📊 Final Test Results (tuned threshold=0.30):
   F1 (micro):    0.5281
   F1 (macro):    0.4411
   F1 (weighted): 0.5039
   Precision:     0.5384
   Recall:        0.5183
   Accuracy:      0.0359

📈 Improvement from threshold tuning:
   ΔF1 (micro):    +0.0413
   ΔF1 (macro):    +0.0561
   ΔPrecision:     -0.1369
   ΔRecall:        +0.1377

💾 SAVING RESULTS
✅ Saved final results to: /content/drive/MyDrive/ShifaMind/11_ShifaMind_v302/run_20260215_041200_phase3/phase_3_results/final_test_results.json
✅ Saved per-diagnosis metrics to: /content/drive/MyDrive/ShifaMind/11_ShifaMind_v302/run_20260215_041200_phase3/phase_3_results/per_diagnosis_metrics.json
✅ Saved predictions to: /content/drive/MyDrive/ShifaMind/11_ShifaMind_v302/run_2026021