# XR2Text: Model Training with HAQT-ARR

This notebook implements the complete training pipeline for the XR2Text model featuring our novel **HAQT-ARR (Hierarchical Anatomical Query Tokens with Adaptive Region Routing)** projection layer.

## Novel Contribution: HAQT-ARR

Our key innovation is the HAQT-ARR projection layer that bridges vision and language with anatomical awareness:

1. **Hierarchical Anatomical Query Tokens**: Region-specific learnable queries for 7 anatomical regions
2. **Spatial Prior Injection**: Learnable 2D Gaussian priors for anatomical locations
3. **Adaptive Region Routing**: Dynamic weighting of anatomical region importance
4. **Cross-Region Interaction**: Transformer layers modeling inter-region dependencies

## Architecture
```
Input Image (384√ó384) ‚Üí Swin Transformer ‚Üí HAQT-ARR Projection ‚Üí BioBART Decoder ‚Üí Report
```

**Authors**: S. Nikhil, Dadhania Omkumar  
**Supervisor**: Dr. Damodar Panigrahy

In [1]:
# ============================================
# GPU/CUDA Check - Run this first!
# ============================================
import os
import sys
sys.path.insert(0, '..')

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
from tqdm.notebook import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set plotting style
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['savefig.dpi'] = 300

# GPU Check
print("=" * 50)
print("SYSTEM CONFIGURATION")
print("=" * 50)

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
    print(f"CUDA Available: True")
    print(f"GPU Connected: {gpu_name}")
    print(f"GPU Memory: {gpu_memory:.1f} GB")
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"PyTorch Version: {torch.__version__}")
    device = torch.device("cuda")
else:
    print(f"CUDA Available: False")
    print(f"WARNING: Running on CPU (Training will be slow)")
    print(f"PyTorch Version: {torch.__version__}")
    device = torch.device("cpu")

print(f"\nUsing Device: {device}")
print("=" * 50)

  import pynvml  # type: ignore[import]


SYSTEM CONFIGURATION
CUDA Available: True
GPU Connected: NVIDIA GeForce RTX 4060 Laptop GPU
GPU Memory: 8.0 GB
CUDA Version: 12.1
PyTorch Version: 2.5.1+cu121

Using Device: cuda


## 1. Configuration

In [None]:
# Training Configuration with HAQT-ARR + NOVEL FEATURES
# OPTIMIZED FOR SPEED + METRICS EVERY 2 EPOCHS
# ALL NOVEL CONTRIBUTIONS ENABLED
config = {
    # Model
    'image_size': 384,
    'encoder_name': 'base',  # Swin-Base
    'decoder_name': 'biobart',
    'use_anatomical_attention': True,  # Enable HAQT-ARR (Novel)
    
    # HAQT-ARR specific parameters (NOVEL)
    'num_regions': 7,
    'num_global_queries': 8,
    'num_region_queries': 4,
    'use_spatial_priors': True,
    'use_adaptive_routing': True,
    'use_cross_region': True,
    
    # Standard parameters
    'language_dim': 768,
    
    # Training - OPTIMIZED FOR SPEED + OOM PREVENTION
    'epochs': 50,
    'batch_size': 2,                   # Reduced to prevent OOM
    'gradient_accumulation_steps': 16, # Keeps effective batch=32
    'learning_rate': 1e-4,
    'weight_decay': 0.01,
    'warmup_steps': 1000,
    'max_grad_norm': 1.0,
    
    # Label smoothing - for better BLEU
    'label_smoothing': 0.05,
    
    # NOVEL: Novel Loss Functions - ENABLED
    'use_novel_losses': True,
    'use_anatomical_consistency_loss': True,
    'use_clinical_entity_loss': True,
    'use_region_focal_loss': True,
    'use_cross_modal_loss': False,
    'anatomical_loss_weight': 0.05,
    'clinical_loss_weight': 0.1,
    'focal_loss_weight': 0.1,
    'alignment_loss_weight': 0.1,
    
    # NOVEL: Curriculum Learning - ENABLED
    'use_curriculum_learning': True,
    
    # NOVEL: Clinical Validation - ENABLED
    'use_clinical_validation': True,
    
    # Scheduled Sampling
    'use_scheduled_sampling': True,
    'scheduled_sampling_start': 1.0,
    'scheduled_sampling_end': 0.6,
    'scheduled_sampling_warmup': 10,
    
    # Region regularization
    'use_region_regularization': True,
    'region_regularization_weight': 0.005,
    
    # Data
    'max_length': 256,
    'num_workers': 4,
    
    # Device
    'use_amp': True,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    
    # Experiment
    'experiment_name': f'xr2text_haqt_arr_novel_{datetime.now().strftime("%Y%m%d_%H%M%S")}',
    'checkpoint_dir': '../checkpoints',
    'validate_every': 2,                # SEE SCORES EVERY 2 EPOCHS
    'save_every': 5,
    'patience': 999,
    'log_dir': '../logs',
    
    # Validation - FAST (use less data)
    'val_fraction': 0.15,               # Use only 15% of val data for speed
    
    # Generation parameters - OPTIMIZED FOR SPEED
    'generation': {
        'num_beams': 2,                 # FAST: reduced from 3
        'min_length': 10,               # FAST: reduced
        'max_length': 150,              # FAST: reduced from 200
        'length_penalty': 1.0,
        'repetition_penalty': 1.1,
        'no_repeat_ngram_size': 3,
        'early_stopping': True,
    }
}

# Create directories
os.makedirs(config['checkpoint_dir'], exist_ok=True)
os.makedirs(config['log_dir'], exist_ok=True)
os.makedirs('../data/figures', exist_ok=True)

print("FAST Training Config - Scores Every 2 Epochs:")
print("=" * 60)
print("\nNOVEL CONTRIBUTIONS (ENABLED):")
print("  ‚úì HAQT-ARR Projection Layer")
print("  ‚úì Novel Loss Functions")
print("  ‚úì Clinical Validation")
print("  ‚úì Curriculum Learning")
print("\nSPEED OPTIMIZATIONS:")
print("  - validate_every: 2 (see BLEU/ROUGE every 2 epochs)")
print("  - val_fraction: 15% (fast validation)")
print("  - num_beams: 2 (fast generation)")
print("  - max_length: 150 (shorter generation)")
print("  - num_workers: 4")
print("\nOOM PREVENTION:")
print("  - batch_size: 2")
print("  - gradient_accumulation_steps: 16")
print("=" * 60)

## 2. Load Model and Data

In [3]:
from src.models.xr2text import XR2TextModel, DEFAULT_CONFIG
from src.models.anatomical_attention import ANATOMICAL_REGIONS
from src.data.dataloader import get_dataloaders
from src.utils.device import setup_cuda_optimizations

# Setup CUDA optimizations for RTX 4060
setup_cuda_optimizations()

# Create model with HAQT-ARR (Novel Architecture)
print("Creating XR2Text model with HAQT-ARR projection layer...")
model_config = {
    'image_size': config['image_size'],
    'use_anatomical_attention': config['use_anatomical_attention'],  # Enable HAQT-ARR
    'encoder': {
        'model_name': config['encoder_name'],
        'pretrained': True,
        'freeze_layers': 2,  # Freeze first 2 Swin layers
    },
    'projection': {
        # HAQT-ARR parameters (Novel)
        'language_dim': config['language_dim'],
        'num_regions': config['num_regions'],
        'num_global_queries': config['num_global_queries'],
        'num_region_queries': config['num_region_queries'],
        'use_spatial_priors': config['use_spatial_priors'],
        'use_adaptive_routing': config['use_adaptive_routing'],
        'use_cross_region': config['use_cross_region'],
        'feature_size': 12,  # 384/32 = 12x12 patches
    },
    'decoder': {
        'model_name': config['decoder_name'],
        'max_length': config['max_length'],
    }
}

model = XR2TextModel.from_config(model_config)
model = model.to(config['device'])

# Model summary
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n{'='*50}")
print("XR2Text Model with HAQT-ARR (Novel)")
print(f"{'='*50}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Frozen parameters: {total_params - trainable_params:,}")
print(f"\nAnatomical regions: {model.get_anatomical_regions()}")
print(f"Total queries: {config['num_global_queries'] + config['num_regions'] * config['num_region_queries']}")

[32m2026-01-07 12:32:18.069[0m | [1mINFO    [0m | [36msrc.utils.device[0m:[36msetup_cuda_optimizations[0m:[36m86[0m - [1mEnabled cuDNN benchmark mode[0m
[32m2026-01-07 12:32:18.072[0m | [1mINFO    [0m | [36msrc.utils.device[0m:[36msetup_cuda_optimizations[0m:[36m92[0m - [1mEnabled TF32 for matrix operations[0m
[32m2026-01-07 12:32:18.072[0m | [1mINFO    [0m | [36msrc.utils.device[0m:[36msetup_cuda_optimizations[0m:[36m96[0m - [1mCleared CUDA cache[0m
[32m2026-01-07 12:32:18.076[0m | [1mINFO    [0m | [36msrc.models.xr2text[0m:[36m__init__[0m:[36m92[0m - [1mBuilding Swin Transformer Encoder...[0m
[32m2026-01-07 12:32:18.077[0m | [1mINFO    [0m | [36msrc.models.swin_encoder[0m:[36m__init__[0m:[36m81[0m - [1mInitializing Swin Encoder: swin_base_patch4_window7_224[0m
[32m2026-01-07 12:32:18.080[0m | [1mINFO    [0m | [36msrc.models.swin_encoder[0m:[36m__init__[0m:[36m82[0m - [1mPretrained: True, Image Size: 384[0m


Creating XR2Text model with HAQT-ARR projection layer...


[32m2026-01-07 12:32:21.478[0m | [1mINFO    [0m | [36msrc.models.swin_encoder[0m:[36m__init__[0m:[36m96[0m - [1mSwin feature dimension: 1024[0m
[32m2026-01-07 12:32:21.483[0m | [1mINFO    [0m | [36msrc.models.swin_encoder[0m:[36m_freeze_layers[0m:[36m136[0m - [1mFrozen 404,424 parameters in 2 layers[0m
[32m2026-01-07 12:32:21.485[0m | [1mINFO    [0m | [36msrc.models.swin_encoder[0m:[36m__init__[0m:[36m117[0m - [1mSwin Encoder initialized successfully[0m
[32m2026-01-07 12:32:21.487[0m | [1mINFO    [0m | [36msrc.models.xr2text[0m:[36m__init__[0m:[36m110[0m - [1mBuilding HAQT-ARR (Hierarchical Anatomical) Projection Layer...[0m
[32m2026-01-07 12:32:21.489[0m | [1mINFO    [0m | [36msrc.models.anatomical_attention[0m:[36m__init__[0m:[36m601[0m - [1mInitializing HAQT-ARR Projection Layer[0m
[32m2026-01-07 12:32:21.490[0m | [1mINFO    [0m | [36msrc.models.anatomical_attention[0m:[36m__init__[0m:[36m602[0m - [1m  Visual dim


XR2Text Model with HAQT-ARR (Novel)
Total parameters: 251,441,388
Trainable parameters: 251,036,964
Frozen parameters: 404,424

Anatomical regions: ['right_lung', 'left_lung', 'heart', 'mediastinum', 'spine', 'diaphragm', 'costophrenic_angles']
Total queries: 36


In [4]:
# Load data
print("\nLoading datasets...")
tokenizer = model.get_tokenizer()

train_loader, val_loader, test_loader = get_dataloaders(
    tokenizer=tokenizer,
    batch_size=config['batch_size'],
    num_workers=config['num_workers'],
    image_size=config['image_size'],
    max_length=config['max_length'],
    train_subset=None,  # Use full dataset, or set to e.g., 1000 for testing
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

[32m2026-01-07 12:32:25.725[0m | [1mINFO    [0m | [36msrc.data.dataloader[0m:[36mget_dataloaders[0m:[36m47[0m - [1mCreating dataloaders...[0m
[32m2026-01-07 12:32:25.750[0m | [1mINFO    [0m | [36msrc.data.dataset[0m:[36m__init__[0m:[36m56[0m - [1mLoading MIMIC-CXR dataset (split: train)...[0m



Loading datasets...


[32m2026-01-07 12:32:29.593[0m | [1mINFO    [0m | [36msrc.data.dataset[0m:[36m__init__[0m:[36m76[0m - [1mLoaded 30633 samples[0m
[32m2026-01-07 12:32:29.597[0m | [1mINFO    [0m | [36msrc.data.dataset[0m:[36m__init__[0m:[36m56[0m - [1mLoading MIMIC-CXR dataset (split: validation)...[0m
[32m2026-01-07 12:32:31.272[0m | [1mINFO    [0m | [36msrc.data.dataset[0m:[36m__init__[0m:[36m76[0m - [1mLoaded 3063 samples[0m
[32m2026-01-07 12:32:31.275[0m | [1mINFO    [0m | [36msrc.data.dataset[0m:[36m__init__[0m:[36m56[0m - [1mLoading MIMIC-CXR dataset (split: test)...[0m
[32m2026-01-07 12:32:32.888[0m | [1mINFO    [0m | [36msrc.data.dataset[0m:[36m__init__[0m:[36m76[0m - [1mLoaded 3064 samples[0m
[32m2026-01-07 12:32:32.901[0m | [1mINFO    [0m | [36msrc.data.dataloader[0m:[36mget_dataloaders[0m:[36m117[0m - [1mTrain samples: 30633[0m
[32m2026-01-07 12:32:32.903[0m | [1mINFO    [0m | [36msrc.data.dataloader[0m:[36mget_dat

Train batches: 7658
Val batches: 766
Test batches: 766


## 3. Training Setup

In [5]:
from torch.optim import AdamW
from torch.cuda.amp import GradScaler, autocast
from src.training.scheduler import get_cosine_schedule_with_warmup
from src.utils.metrics import compute_metrics

# NOVEL: Import novel training components
from src.training.losses import CombinedNovelLoss
from src.training.curriculum import AnatomicalCurriculumScheduler, create_curriculum_dataloader
from src.utils.clinical_validator import ClinicalValidator

# Optimizer
no_decay = ['bias', 'LayerNorm.weight', 'layer_norm.weight']
optimizer_grouped_parameters = [
    {
        'params': [p for n, p in model.named_parameters() 
                   if p.requires_grad and not any(nd in n for nd in no_decay)],
        'weight_decay': config['weight_decay'],
    },
    {
        'params': [p for n, p in model.named_parameters() 
                   if p.requires_grad and any(nd in n for nd in no_decay)],
        'weight_decay': 0.0,
    },
]

optimizer = AdamW(optimizer_grouped_parameters, lr=config['learning_rate'])

# Scheduler
total_steps = len(train_loader) * config['epochs'] // config['gradient_accumulation_steps']
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=config['warmup_steps'],
    num_training_steps=total_steps,
)

# Mixed precision scaler
scaler = GradScaler() if config['use_amp'] else None

# NOVEL: Initialize novel loss functions
if config.get('use_novel_losses', False):
    novel_loss = CombinedNovelLoss(
        use_anatomical_consistency=config.get('use_anatomical_consistency_loss', True),
        use_clinical_entity=config.get('use_clinical_entity_loss', True),
        use_region_focal=config.get('use_region_focal_loss', True),
        use_cross_modal=config.get('use_cross_modal_loss', False),
        anatomical_weight=config.get('anatomical_loss_weight', 0.1),
        clinical_weight=config.get('clinical_loss_weight', 0.2),
        focal_weight=config.get('focal_loss_weight', 0.15),
        alignment_weight=config.get('alignment_loss_weight', 0.1),
    )
    print("‚úÖ Novel loss functions initialized")
else:
    novel_loss = None

# NOVEL: Initialize curriculum learning scheduler
if config.get('use_curriculum_learning', False):
    curriculum_scheduler = AnatomicalCurriculumScheduler()
    print("‚úÖ Curriculum learning scheduler initialized")
else:
    curriculum_scheduler = None

# NOVEL: Initialize clinical validator
if config.get('use_clinical_validation', False):
    clinical_validator = ClinicalValidator()
    print("‚úÖ Clinical validator initialized")
else:
    clinical_validator = None

print(f"\nTotal optimization steps: {total_steps}")
print(f"Warmup steps: {config['warmup_steps']}")
print(f"Novel losses: {config.get('use_novel_losses', False)}")
print(f"Curriculum learning: {config.get('use_curriculum_learning', False)}")
print(f"Clinical validation: {config.get('use_clinical_validation', False)}")

‚úÖ Novel loss functions initialized
‚úÖ Curriculum learning scheduler initialized
‚úÖ Clinical validator initialized

Total optimization steps: 47862
Warmup steps: 500
Novel losses: True
Curriculum learning: True
Clinical validation: True


## 4. Training Loop

In [6]:
# Training history
history = {
    'train_loss': [],
    'val_loss': [],
    'bleu_1': [],
    'bleu_2': [],
    'bleu_3': [],
    'bleu_4': [],
    'rouge_1': [],
    'rouge_2': [],
    'rouge_l': [],
    'learning_rate': [],
}

best_metric = 0.0
patience_counter = 0
patience = 5

In [None]:
# Main training loop - Using XR2TextTrainer class
# OPTIMIZED FOR BETTER BLEU-4 AND ROUGE METRICS
from src.training.trainer import XR2TextTrainer
import torch
import gc

# ============================================
# TRAINING OPTIONS - SET THESE
# ============================================
RESUME_FROM_CHECKPOINT = True                          # Set to True to resume training
CHECKPOINT_PATH = "../checkpoints/checkpoint_epoch_10.pt"  # Checkpoint to resume from

print("=" * 70)
print("XR2Text Training with OPTIMIZED Configuration")
print("=" * 70)

# Memory cleanup before training
print("\nClearing GPU memory...")
gc.collect()
torch.cuda.empty_cache()
if torch.cuda.is_available():
    print(f"GPU Memory - Allocated: {torch.cuda.memory_allocated()/1024**3:.2f} GB")
    print(f"GPU Memory - Cached: {torch.cuda.memory_reserved()/1024**3:.2f} GB")

# Create trainer with optimized config
trainer = XR2TextTrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    config=config,
)

# Resume from checkpoint if specified
if RESUME_FROM_CHECKPOINT:
    print(f"\nLoading checkpoint: {CHECKPOINT_PATH}")
    trainer.load_checkpoint(CHECKPOINT_PATH)
    print(f"Resuming from epoch {trainer.current_epoch}")
else:
    print("\nStarting fresh training from epoch 1")

print("\n" + "=" * 70)
print("TRAINING CONFIGURATION SUMMARY:")
print(f"  Learning rate: {config['learning_rate']}")
print(f"  Label smoothing: {config.get('label_smoothing', 0.1)}")
print(f"  Validate every: {config.get('validate_every', 2)} epochs")
print(f"  Generation beams: {config.get('generation', {}).get('num_beams', 5)}")
print(f"  Min generation length: {config.get('generation', {}).get('min_length', 20)}")
print("=" * 70 + "\n")

# Run training
final_metrics = trainer.train()

# Extract history from trainer for visualization
history = {
    'train_loss': trainer.metrics_tracker.get_history('train_loss'),
    'val_loss': trainer.metrics_tracker.get_history('val_loss'),
    'bleu_1': trainer.metrics_tracker.get_history('bleu_1'),
    'bleu_2': trainer.metrics_tracker.get_history('bleu_2'),
    'bleu_3': trainer.metrics_tracker.get_history('bleu_3'),
    'bleu_4': trainer.metrics_tracker.get_history('bleu_4'),
    'rouge_1': trainer.metrics_tracker.get_history('rouge_1'),
    'rouge_2': trainer.metrics_tracker.get_history('rouge_2'),
    'rouge_l': trainer.metrics_tracker.get_history('rouge_l'),
    'learning_rate': [trainer.scheduler.get_last_lr()[0]] * (trainer.current_epoch + 1),
}

# Add clinical validation metrics if enabled
if config.get('use_clinical_validation', False):
    history['clinical_accuracy'] = trainer.metrics_tracker.get_history('clinical_accuracy')
    history['clinical_f1'] = trainer.metrics_tracker.get_history('clinical_f1')
    history['critical_errors'] = trainer.metrics_tracker.get_history('critical_errors')

# Save training history
history_df = pd.DataFrame(history)
history_df['epoch'] = range(1, len(history_df) + 1)
os.makedirs('../data/statistics', exist_ok=True)
history_df.to_csv('../data/statistics/training_history.csv', index=False)

# Store predictions and references for sample display
predictions = []
references = []

print("\n" + "=" * 70)
print("TRAINING COMPLETE!")
print("=" * 70)
print(f"\nFinal Metrics:")
for key, value in final_metrics.items():
    if isinstance(value, float):
        print(f"  {key}: {value:.4f}")
    else:
        print(f"  {key}: {value}")

# Final memory cleanup
gc.collect()
torch.cuda.empty_cache()
print(f"\nFinal GPU Memory - Allocated: {torch.cuda.memory_allocated()/1024**3:.2f} GB")

## 5. Training Curves Visualization

## 4.5 NOVEL: Enhanced Curriculum Learning Analysis

This section provides detailed analysis of our curriculum learning strategy,
showing how it affects training dynamics and final performance.

In [8]:
# ============================================
# ENHANCED CURRICULUM LEARNING ANALYSIS
# ============================================
from src.training.curriculum import AnatomicalCurriculumScheduler

print("=" * 80)
print("NOVEL: CURRICULUM LEARNING ANALYSIS")
print("=" * 80)

# Initialize curriculum scheduler
curriculum = AnatomicalCurriculumScheduler()

# Display curriculum stages
print("\n1. CURRICULUM STAGES")
print("-" * 60)
print(f"\n{'Stage':<20} {'Epochs':<15} {'Description':<40}")
print("-" * 80)

stage_descriptions = {
    'normal_cases': 'Normal X-rays, simple findings (e.g., "lungs are clear")',
    'single_region': 'Single anatomical region findings (e.g., cardiomegaly)',
    'multi_region': 'Multiple regions, moderate complexity',
    'complex_cases': 'Complex cases with multiple severe findings',
}

for stage in curriculum.stages:
    name = stage['name']
    epoch_range = f"{stage['epoch_start']}-{stage['epoch_end']}"
    desc = stage_descriptions.get(name, 'Full dataset')
    print(f"{name:<20} {epoch_range:<15} {desc:<40}")

# Curriculum difficulty scoring
print("\n2. SAMPLE DIFFICULTY SCORING")
print("-" * 60)

sample_reports = [
    "Lungs are clear. Heart size is normal. No acute cardiopulmonary process.",
    "Mild cardiomegaly. Lungs are clear bilaterally.",
    "Bilateral pleural effusions. Cardiomegaly. Pulmonary edema.",
    "Large right pneumothorax. Left lung consolidation. Cardiomegaly. Bilateral effusions. ETT in place.",
]

print("\nSample Reports with Difficulty Scores:")
for i, report in enumerate(sample_reports):
    scores = curriculum.difficulty_scorer(report)
    total_difficulty = scores.get('num_findings', 0) + scores.get('severity_score', 0)
    print(f"\n[Sample {i+1}] Difficulty: {total_difficulty:.1f}")
    print(f"   Report: {report[:70]}...")
    print(f"   Findings: {scores.get('num_findings', 0)}, Regions: {scores.get('num_regions', 0)}")

# Simulated curriculum learning results
print("\n3. CURRICULUM LEARNING IMPACT")
print("-" * 60)

# ============================================
# NOTE: Curriculum learning benefits will be measured
# after training completes. The above shows the CONCEPT.
# Real performance comparison will be added post-training.
# ============================================

# NOTE: Curriculum learning benefits will be measured after training.
# Real performance data will be added post-training.

# ============================================
# POST-TRAINING: Curriculum Learning Analysis
# This will show real results after training completes
# ============================================

import os
import pandas as pd
import matplotlib.pyplot as plt

# Check if training history exists
history_path = '../data/statistics/training_history.csv'
if os.path.exists(history_path):
    print("\n" + "=" * 60)
    print("CURRICULUM LEARNING RESULTS (Real Data)")
    print("=" * 60)
    
    # Load training history
    df = pd.read_csv(history_path)
    
    # Analyze curriculum stage transitions
    stage_transitions = [5, 15, 30]  # Epochs where curriculum changes
    
    print("\nPerformance at Curriculum Stage Transitions:")
    print("-" * 60)
    
    for i, epoch in enumerate([1, 5, 15, 30, 50]):
        if epoch <= len(df):
            row = df.iloc[epoch-1]
            stage = ['Stage 1 (Normal)', 'Stage 1‚Üí2', 'Stage 2‚Üí3', 'Stage 3‚Üí4', 'Final'][i]
            print(f"Epoch {epoch} ({stage}):")
            print(f"  BLEU-4: {row['bleu_4']:.4f}")
            print(f"  ROUGE-L: {row['rouge_l']:.4f}")
            print(f"  Loss: {row['val_loss']:.4f}")
            print()
    
    # Plot curriculum impact
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # BLEU-4 progression
    axes[0].plot(df['epoch'], df['bleu_4'], linewidth=2, color='blue')
    for trans in stage_transitions:
        if trans <= len(df):
            axes[0].axvline(x=trans, color='red', linestyle='--', alpha=0.5)
            axes[0].text(trans, axes[0].get_ylim()[1]*0.9, 'Stage\nChange', 
                        ha='center', fontsize=8)
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('BLEU-4')
    axes[0].set_title('BLEU-4 Progression with Curriculum Stages')
    axes[0].grid(True, alpha=0.3)
    
    # Loss progression
    axes[1].plot(df['epoch'], df['val_loss'], linewidth=2, color='orange')
    for trans in stage_transitions:
        if trans <= len(df):
            axes[1].axvline(x=trans, color='red', linestyle='--', alpha=0.5)
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Validation Loss')
    axes[1].set_title('Loss Progression with Curriculum Stages')
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('../data/figures/curriculum_impact.png', dpi=300)
    plt.show()
    
    print("‚úÖ Curriculum learning analysis complete!")
    print("   Figure saved: ../data/figures/curriculum_impact.png")
else:
    print("\n‚ö†Ô∏è  Training history not found yet.")
    print("   Run this cell again after training completes.")


NOVEL: CURRICULUM LEARNING ANALYSIS

1. CURRICULUM STAGES
------------------------------------------------------------

Stage                Epochs          Description                             
--------------------------------------------------------------------------------
normal_cases         0-5             Normal X-rays, simple findings (e.g., "lungs are clear")
single_region        5-15            Single anatomical region findings (e.g., cardiomegaly)
multi_region         15-30           Multiple regions, moderate complexity   
complex_cases        30-50           Complex cases with multiple severe findings

2. SAMPLE DIFFICULTY SCORING
------------------------------------------------------------

Sample Reports with Difficulty Scores:

[Sample 1] Difficulty: 0.0
   Report: Lungs are clear. Heart size is normal. No acute cardiopulmonary proces...
   Findings: 0, Regions: 3

[Sample 2] Difficulty: 1.0
   Report: Mild cardiomegaly. Lungs are clear bilaterally....
   Findings: 1,

In [None]:
# ============================================
# TRAINING CURVES VISUALIZATION
# ============================================
# This works with both freshly trained history AND loaded history from CSV!

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Load history from CSV if not already in memory (i.e., training was skipped)
history_path = "../data/statistics/training_history.csv"

# Initialize history if it doesn't exist
if 'history' not in dir() or not history:
    history = {}

if not history.get('train_loss') or len(history.get('train_loss', [])) == 0:
    if os.path.exists(history_path):
        print("üìÇ Loading training history from saved CSV...")
        history_df = pd.read_csv(history_path)
        history = history_df.to_dict(orient='list')
        if 'epoch' in history:
            del history['epoch']
        print(f"   Loaded {len(history.get('train_loss', []))} epochs of history")
    else:
        print("‚ö†Ô∏è No training history found! Run training first (cell 11).")
        history = {}
else:
    # Save training history if it came from training
    history_df = pd.DataFrame(history)
    history_df['epoch'] = range(1, len(history_df) + 1)
    os.makedirs('../data/statistics', exist_ok=True)
    history_df.to_csv(history_path, index=False)
    print("üíæ Training history saved to CSV")

# Check if we have history to plot
if history.get('train_loss') and len(history['train_loss']) > 0:
    # Plot training curves with NOVEL features
    num_plots = 6 if config.get('use_novel_losses', False) or config.get('use_clinical_validation', False) else 4
    fig, axes = plt.subplots(2, 3, figsize=(18, 10)) if num_plots > 4 else plt.subplots(2, 2, figsize=(14, 10))
    axes = axes.flatten()

    plot_idx = 0

    # Loss curves
    axes[plot_idx].plot(history['train_loss'], label='Train Loss', color='blue', linewidth=2)
    axes[plot_idx].plot(history['val_loss'], label='Val Loss', color='orange', linewidth=2)
    axes[plot_idx].set_xlabel('Epoch')
    axes[plot_idx].set_ylabel('Loss')
    axes[plot_idx].set_title('Training and Validation Loss')
    axes[plot_idx].legend()
    axes[plot_idx].grid(True, alpha=0.3)
    plot_idx += 1

    # BLEU scores
    axes[plot_idx].plot(history['bleu_1'], label='BLEU-1', linewidth=2)
    axes[plot_idx].plot(history['bleu_2'], label='BLEU-2', linewidth=2)
    axes[plot_idx].plot(history['bleu_3'], label='BLEU-3', linewidth=2)
    axes[plot_idx].plot(history['bleu_4'], label='BLEU-4', linewidth=2)
    axes[plot_idx].set_xlabel('Epoch')
    axes[plot_idx].set_ylabel('Score')
    axes[plot_idx].set_title('BLEU Scores')
    axes[plot_idx].legend()
    axes[plot_idx].grid(True, alpha=0.3)
    plot_idx += 1

    # ROUGE scores
    axes[plot_idx].plot(history['rouge_1'], label='ROUGE-1', linewidth=2)
    axes[plot_idx].plot(history['rouge_2'], label='ROUGE-2', linewidth=2)
    axes[plot_idx].plot(history['rouge_l'], label='ROUGE-L', linewidth=2)
    axes[plot_idx].set_xlabel('Epoch')
    axes[plot_idx].set_ylabel('Score')
    axes[plot_idx].set_title('ROUGE Scores')
    axes[plot_idx].legend()
    axes[plot_idx].grid(True, alpha=0.3)
    plot_idx += 1

    # NOVEL: Novel loss components
    if config.get('use_novel_losses', False) and 'anatomical_consistency_loss' in history:
        axes[plot_idx].plot(history['anatomical_consistency_loss'], label='Anatomical Consistency', linewidth=2, color='purple')
        axes[plot_idx].plot(history['clinical_entity_loss'], label='Clinical Entity', linewidth=2, color='red')
        axes[plot_idx].plot(history['region_focal_loss'], label='Region Focal', linewidth=2, color='green')
        axes[plot_idx].set_xlabel('Epoch')
        axes[plot_idx].set_ylabel('Loss')
        axes[plot_idx].set_title('Novel Loss Components (NOVEL)')
        axes[plot_idx].legend()
        axes[plot_idx].grid(True, alpha=0.3)
        plot_idx += 1

    # NOVEL: Clinical validation metrics
    if config.get('use_clinical_validation', False) and 'clinical_accuracy' in history:
        ax_twin = axes[plot_idx].twinx()
        axes[plot_idx].plot(history['clinical_accuracy'], label='Clinical Accuracy', linewidth=2, color='blue')
        axes[plot_idx].plot(history['clinical_f1'], label='Clinical F1', linewidth=2, color='orange')
        axes[plot_idx].set_xlabel('Epoch')
        axes[plot_idx].set_ylabel('Score', color='black')
        axes[plot_idx].set_title('Clinical Validation Metrics (NOVEL)')
        axes[plot_idx].legend(loc='upper left')
        axes[plot_idx].grid(True, alpha=0.3)
        
        # Critical errors on secondary axis
        ax_twin.plot(history['critical_errors'], label='Critical Errors', linewidth=2, color='red', linestyle='--')
        ax_twin.set_ylabel('Critical Errors', color='red')
        ax_twin.legend(loc='upper right')
        ax_twin.tick_params(axis='y', labelcolor='red')
        plot_idx += 1

    # Learning rate
    if 'learning_rate' in history and len(history['learning_rate']) > 0:
        axes[plot_idx].plot(history['learning_rate'], color='green', linewidth=2)
        axes[plot_idx].set_xlabel('Epoch')
        axes[plot_idx].set_ylabel('Learning Rate')
        axes[plot_idx].set_title('Learning Rate Schedule')
        axes[plot_idx].set_yscale('log')
        axes[plot_idx].grid(True, alpha=0.3)
        plot_idx += 1

    # Hide unused subplots
    for i in range(plot_idx, len(axes)):
        axes[i].axis('off')

    plt.tight_layout()
    os.makedirs('../data/figures', exist_ok=True)
    plt.savefig('../data/figures/training_curves_novel.png', dpi=300)
    plt.show()
    print("‚úÖ Training curves saved with NOVEL features visualization")
else:
    print("‚ö†Ô∏è No training history available to plot.")
    print("   Please run training first (cell 11) or ensure training_history.csv exists.")

## 6. Sample Predictions

In [None]:
# Show sample predictions vs ground truth
print("Sample Predictions vs Ground Truth:")
print("=" * 80)

# Check if predictions and references exist
if 'predictions' not in dir() or not predictions:
    predictions = []
if 'references' not in dir() or not references:
    references = []

if len(predictions) > 0 and len(references) > 0:
    for i in range(min(5, len(predictions))):
        print(f"\n--- Sample {i+1} ---")
        print(f"\nGround Truth:")
        print(references[i][:500] + "..." if len(references[i]) > 500 else references[i])
        print(f"\nGenerated:")
        print(predictions[i][:500] + "..." if len(predictions[i]) > 500 else predictions[i])
        print("-" * 80)
else:
    print("\n‚ö†Ô∏è No predictions available yet!")
    print("   Predictions will be available after training completes (cell 11).")
    print("   Or run evaluation on test set in notebook 03_evaluation.ipynb.")

## 7. Final Results Summary

In [None]:
# Best results
import os
import pandas as pd
import numpy as np

# Initialize history if needed
if 'history' not in dir() or not history:
    history = {}

# Try to load from CSV if history is empty
history_path = "../data/statistics/training_history.csv"
if not history.get('train_loss') or len(history.get('train_loss', [])) == 0:
    if os.path.exists(history_path):
        history_df = pd.read_csv(history_path)
        history = history_df.to_dict(orient='list')
        if 'epoch' in history:
            del history['epoch']

# Check if we have data
if history.get('bleu_4') and len(history['bleu_4']) > 0 and history.get('rouge_l') and len(history['rouge_l']) > 0:
    # Find best epoch
    combined_scores = [b4 + rl for b4, rl in zip(history['bleu_4'], history['rouge_l'])]
    best_epoch = np.argmax(combined_scores)

    print("=" * 60)
    print("TRAINING RESULTS SUMMARY")
    print("=" * 60)
    print(f"\nBest Epoch: {best_epoch + 1}")
    print(f"\nBest Metrics:")
    print(f"  BLEU-1: {history['bleu_1'][best_epoch]:.4f}")
    print(f"  BLEU-2: {history['bleu_2'][best_epoch]:.4f}")
    print(f"  BLEU-3: {history['bleu_3'][best_epoch]:.4f}")
    print(f"  BLEU-4: {history['bleu_4'][best_epoch]:.4f}")
    print(f"  ROUGE-1: {history['rouge_1'][best_epoch]:.4f}")
    print(f"  ROUGE-2: {history['rouge_2'][best_epoch]:.4f}")
    print(f"  ROUGE-L: {history['rouge_l'][best_epoch]:.4f}")
    print(f"\nFinal Train Loss: {history['train_loss'][-1]:.4f}")
    print(f"Final Val Loss: {history['val_loss'][-1]:.4f}")

    # Save results table
    results_table = pd.DataFrame({
        'Metric': ['BLEU-1', 'BLEU-2', 'BLEU-3', 'BLEU-4', 'ROUGE-1', 'ROUGE-2', 'ROUGE-L'],
        'Score': [
            history['bleu_1'][best_epoch],
            history['bleu_2'][best_epoch],
            history['bleu_3'][best_epoch],
            history['bleu_4'][best_epoch],
            history['rouge_1'][best_epoch],
            history['rouge_2'][best_epoch],
            history['rouge_l'][best_epoch],
        ]
    })
    os.makedirs('../data/statistics', exist_ok=True)
    results_table.to_csv('../data/statistics/best_results.csv', index=False)
    print("\n‚úÖ Results saved to ../data/statistics/best_results.csv")
else:
    print("=" * 60)
    print("TRAINING RESULTS SUMMARY")
    print("=" * 60)
    print("\n‚ö†Ô∏è No training results available yet!")
    print("   Run training first (cell 11) to see results.")
    print("   Or ensure training_history.csv exists in ../data/statistics/")