# XR2Text: Model Training with HAQT-ARR

## IMPROVED VERSION - Optimized for RTX 4060 8GB

**Authors**: S. Nikhil, Dadhania Omkumar  
**Supervisor**: Dr. Damodar Panigrahy

---

This notebook implements the complete training pipeline for XR2Text:

### Architecture (NOVEL CONTRIBUTIONS):
1. **HAQT-ARR** - Hierarchical Anatomical Query Tokens with Adaptive Region Routing
2. **Uncertainty Quantification** - MC Dropout + Temperature Calibration
3. **Factual Grounding** - Knowledge Graph + Hallucination Detection
4. **Multi-Task Learning** - Region/Severity/Finding Classification

### Training Configuration:
- **BioBART-Large** decoder (upgraded from base)
- **Gradient Accumulation**: 128 steps (~240 steps/epoch)
- **Curriculum Learning**: 5 stages over 50 epochs
- **Gradient Checkpointing** for RTX 4060 memory efficiency
- **Estimated Time**: ~65 hours (~2.7 days)

### Expected Results:
| Metric | Target | SOTA Reference |
|--------|--------|----------------|
| BLEU-4 | 0.12+ | 0.142 (ChestBioX-Gen) |
| ROUGE-L | 0.28+ | 0.312 (ChestBioX-Gen) |
| Clinical F1 | 0.70+ | Novel metric |

In [1]:
# ============================================
# GPU/CUDA Check - Run this first!
# ============================================
import os
import sys
sys.path.insert(0, '..')

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
from tqdm.notebook import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set plotting style
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['savefig.dpi'] = 300

# GPU Check
print("=" * 50)
print("SYSTEM CONFIGURATION")
print("=" * 50)

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
    print(f"CUDA Available: True")
    print(f"GPU Connected: {gpu_name}")
    print(f"GPU Memory: {gpu_memory:.1f} GB")
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"PyTorch Version: {torch.__version__}")
    device = torch.device("cuda")
else:
    print(f"CUDA Available: False")
    print(f"WARNING: Running on CPU (Training will be slow)")
    print(f"PyTorch Version: {torch.__version__}")
    device = torch.device("cpu")

print(f"\nUsing Device: {device}")
print("=" * 50)

  import pynvml  # type: ignore[import]


SYSTEM CONFIGURATION
CUDA Available: True
GPU Connected: NVIDIA GeForce RTX 4060 Laptop GPU
GPU Memory: 8.0 GB
CUDA Version: 12.1
PyTorch Version: 2.5.1+cu121

Using Device: cuda


## 1. Configuration

In [2]:
# Training Configuration with HAQT-ARR + ALL NOVEL FEATURES (10/10 Novelty)
# OPTIMIZED FOR RTX 4060 8GB - ~65 hours (2.7 days)
config = {
    # Model
    'image_size': 384,
    'encoder_name': 'base',  # Swin-Base
    'decoder_name': 'biobart-large',  # UPGRADED: BioBART-Large for better generation
    'use_anatomical_attention': True,  # Enable HAQT-ARR (Novel)
    
    # HAQT-ARR specific parameters (NOVEL)
    'num_regions': 7,
    'num_global_queries': 8,
    'num_region_queries': 4,
    'use_spatial_priors': True,
    'use_adaptive_routing': True,
    'use_cross_region': True,
    
    # NEW: Enhancement Modules (10/10 Novelty)
    'use_uncertainty': True,           # Uncertainty quantification
    'use_grounding': True,             # Factual grounding & hallucination detection
    'use_explainability': True,        # Explainability & evidence regions
    'use_multitask': True,             # Multi-task learning heads
    
    # Standard parameters
    'language_dim': 1024,              # UPDATED: BioBART-Large uses 1024 hidden dim
    
    # Training - SPEED OPTIMIZED FOR 2.7 DAYS
    'epochs': 50,                      # 50 epochs
    'batch_size': 1,                   # Keep at 1 for memory
    'gradient_accumulation_steps': 128, # ~240 steps/epoch, ~65 hours total (2.7 days)
    'learning_rate': 1e-4,
    'weight_decay': 0.05,              # FIXED: Match default.yaml
    'warmup_steps': 500,               # Warmup steps
    'max_grad_norm': 1.0,
    
    # Label smoothing - for better BLEU
    'label_smoothing': 0.05,
    
    # NOVEL: Novel Loss Functions - ENABLED
    'use_novel_losses': True,
    'use_anatomical_consistency_loss': True,
    'use_clinical_entity_loss': True,
    'use_region_focal_loss': True,
    'use_cross_modal_loss': True,
    'anatomical_loss_weight': 0.1,
    'clinical_loss_weight': 0.2,
    'focal_loss_weight': 0.15,
    'alignment_loss_weight': 0.1,
    
    # R-Drop Regularization - DISABLED for faster training
    'use_rdrop': False,
    'rdrop_alpha': 0.7,
    
    # NOVEL: Curriculum Learning - ENABLED (5 stages over 50 epochs)
    'use_curriculum_learning': True,
    
    # NOVEL: Clinical Validation - ENABLED
    'use_clinical_validation': True,
    
    # NEW: Uncertainty Quantification
    'use_uncertainty_training': True,
    'uncertainty_dropout': 0.1,
    'mc_samples': 5,
    'use_calibration': True,
    
    # NEW: Multi-Task Learning
    'use_multi_task_learning': True,
    'auxiliary_task_weights': {
        'region_classification': 0.1,
        'severity_prediction': 0.1,
        'finding_detection': 0.15,
        'length_prediction': 0.05,
    },
    
    # NEW: Factual Grounding
    'use_factual_grounding': True,
    'grounding_loss_weight': 0.1,
    'grounding_threshold': 0.15,
    
    # NEW: OOD Detection
    'use_ood_detection': True,
    'ood_threshold': 0.5,
    
    # Scheduled Sampling
    'use_scheduled_sampling': True,
    'scheduled_sampling_start': 1.0,
    'scheduled_sampling_end': 0.4,
    'scheduled_sampling_warmup': 10,
    
    # Region regularization
    'use_region_regularization': True,
    'region_regularization_weight': 0.01,
    
    # Data
    'max_length': 256,
    'num_workers': 2,
    
    # Device
    'use_amp': True,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    
    # Experiment
    'experiment_name': f'xr2text_haqt_arr_full_novel_{datetime.now().strftime("%Y%m%d_%H%M%S")}',
    'checkpoint_dir': '../checkpoints',
    'validate_every': 2,
    'save_every': 1,                   # Save EVERY epoch
    'patience': 20,
    'log_dir': '../logs',
    
    # Validation - FAST
    'val_fraction': 0.10,
    
    # Generation parameters - OPTIMIZED FOR SPEED
    'generation': {
        'num_beams': 2,
        'min_length': 20,
        'max_length': 200,
        'length_penalty': 1.0,
        'repetition_penalty': 1.3,
        'no_repeat_ngram_size': 3,
        'early_stopping': True,
    }
}

# Create directories
os.makedirs(config['checkpoint_dir'], exist_ok=True)
os.makedirs(config['log_dir'], exist_ok=True)
os.makedirs('../data/figures', exist_ok=True)

print("=" * 70)
print("XR2Text Training Config - RTX 4060 OPTIMIZED")
print("=" * 70)
print(f"\n  Epochs: {config['epochs']}")
print(f"  Gradient Accumulation: {config['gradient_accumulation_steps']} (effective batch=128)")
print(f"  Steps per Epoch: ~240")
print(f"  Warmup Steps: {config['warmup_steps']}")
print("\nCURRICULUM STAGES (5-stage, 50 epochs):")
print("  warmup:   0-5    (normal cases)")
print("  easy:     5-12   (≤2 findings)")
print("  medium:   12-25  (≤4 findings)")
print("  hard:     25-40  (all cases)")
print("  finetune: 40-50  (full dataset)")
print("\nESTIMATED TIME: ~65 hours (~2.7 days)")
print("=" * 70)

XR2Text Training Config - RTX 4060 OPTIMIZED

  Epochs: 50
  Gradient Accumulation: 128 (effective batch=128)
  Steps per Epoch: ~240
  Warmup Steps: 500

CURRICULUM STAGES (5-stage, 50 epochs):
  warmup:   0-5    (normal cases)
  easy:     5-12   (≤2 findings)
  medium:   12-25  (≤4 findings)
  hard:     25-40  (all cases)
  finetune: 40-50  (full dataset)

ESTIMATED TIME: ~65 hours (~2.7 days)


## 2. Load Model and Data

In [3]:
from src.models.xr2text import XR2TextModel, DEFAULT_CONFIG
from src.models.anatomical_attention import ANATOMICAL_REGIONS
from src.data.dataloader import get_dataloaders
from src.utils.device import setup_cuda_optimizations

# Setup CUDA optimizations for RTX 4060
setup_cuda_optimizations()

# Create model with HAQT-ARR + ALL ENHANCEMENT MODULES (10/10 Novelty)
print("Creating XR2Text model with HAQT-ARR + Enhancement Modules...")
model_config = {
    'image_size': config['image_size'],
    'use_anatomical_attention': config['use_anatomical_attention'],  # Enable HAQT-ARR
    
    # NEW: Enhancement Modules (10/10 Novelty)
    'use_uncertainty': config.get('use_uncertainty', True),
    'use_grounding': config.get('use_grounding', True),
    'use_explainability': config.get('use_explainability', True),
    'use_multitask': config.get('use_multitask', True),
    
    'encoder': {
        'model_name': config['encoder_name'],
        'pretrained': True,
        'freeze_layers': 2,  # Freeze first 2 Swin layers
    },
    'projection': {
        # HAQT-ARR parameters (Novel)
        'language_dim': config['language_dim'],
        'num_regions': config['num_regions'],
        'num_global_queries': config['num_global_queries'],
        'num_region_queries': config['num_region_queries'],
        'use_spatial_priors': config['use_spatial_priors'],
        'use_adaptive_routing': config['use_adaptive_routing'],
        'use_cross_region': config['use_cross_region'],
        'feature_size': 12,  # 384/32 = 12x12 patches
    },
    'decoder': {
        'model_name': config['decoder_name'],
        'max_length': config['max_length'],
    }
}

model = XR2TextModel.from_config(model_config)
model = model.to(config['device'])

# Model summary
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n{'='*60}")
print("XR2Text Model with HAQT-ARR + Enhancement Modules (10/10 Novelty)")
print(f"{'='*60}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Frozen parameters: {total_params - trainable_params:,}")
print(f"\nAnatomical regions: {model.get_anatomical_regions()}")
print(f"Total queries: {config['num_global_queries'] + config['num_regions'] * config['num_region_queries']}")
print(f"\nEnhancement Modules Enabled:")
print(f"  - Uncertainty Quantification: {config.get('use_uncertainty', True)}")
print(f"  - Factual Grounding: {config.get('use_grounding', True)}")
print(f"  - Explainability: {config.get('use_explainability', True)}")
print(f"  - Multi-Task Learning: {config.get('use_multitask', True)}")

[32m2026-01-15 20:19:46.456[0m | [1mINFO    [0m | [36msrc.utils.device[0m:[36msetup_cuda_optimizations[0m:[36m86[0m - [1mEnabled cuDNN benchmark mode[0m
[32m2026-01-15 20:19:46.457[0m | [1mINFO    [0m | [36msrc.utils.device[0m:[36msetup_cuda_optimizations[0m:[36m92[0m - [1mEnabled TF32 for matrix operations[0m
[32m2026-01-15 20:19:46.457[0m | [1mINFO    [0m | [36msrc.utils.device[0m:[36msetup_cuda_optimizations[0m:[36m96[0m - [1mCleared CUDA cache[0m
[32m2026-01-15 20:19:46.459[0m | [1mINFO    [0m | [36msrc.models.xr2text[0m:[36m__init__[0m:[36m109[0m - [1mBuilding Swin Transformer Encoder...[0m
[32m2026-01-15 20:19:46.459[0m | [1mINFO    [0m | [36msrc.models.swin_encoder[0m:[36m__init__[0m:[36m81[0m - [1mInitializing Swin Encoder: swin_base_patch4_window7_224[0m
[32m2026-01-15 20:19:46.459[0m | [1mINFO    [0m | [36msrc.models.swin_encoder[0m:[36m__init__[0m:[36m82[0m - [1mPretrained: True, Image Size: 384[0m


Creating XR2Text model with HAQT-ARR + Enhancement Modules...


[32m2026-01-15 20:19:48.085[0m | [1mINFO    [0m | [36msrc.models.swin_encoder[0m:[36m__init__[0m:[36m96[0m - [1mSwin feature dimension: 1024[0m
[32m2026-01-15 20:19:48.087[0m | [1mINFO    [0m | [36msrc.models.swin_encoder[0m:[36m_freeze_layers[0m:[36m139[0m - [1mFrozen 404,424 parameters in 2 layers[0m
[32m2026-01-15 20:19:48.088[0m | [1mINFO    [0m | [36msrc.models.swin_encoder[0m:[36m__init__[0m:[36m120[0m - [1mSwin Encoder initialized successfully[0m
[32m2026-01-15 20:19:48.088[0m | [1mINFO    [0m | [36msrc.models.xr2text[0m:[36m__init__[0m:[36m127[0m - [1mBuilding HAQT-ARR (Hierarchical Anatomical) Projection Layer...[0m
[32m2026-01-15 20:19:48.089[0m | [1mINFO    [0m | [36msrc.models.anatomical_attention[0m:[36m__init__[0m:[36m799[0m - [1mInitializing HAQT-ARR Projection Layer[0m
[32m2026-01-15 20:19:48.090[0m | [1mINFO    [0m | [36msrc.models.anatomical_attention[0m:[36m__init__[0m:[36m800[0m - [1m  Visual dim


XR2Text Model with HAQT-ARR + Enhancement Modules (10/10 Novelty)
Total parameters: 541,634,767
Trainable parameters: 541,230,343
Frozen parameters: 404,424

Anatomical regions: ['right_lung', 'left_lung', 'heart', 'mediastinum', 'spine', 'diaphragm', 'costophrenic_angles']
Total queries: 36

Enhancement Modules Enabled:
  - Uncertainty Quantification: True
  - Factual Grounding: True
  - Explainability: True
  - Multi-Task Learning: True


In [4]:
# Load data
print("\nLoading datasets...")
tokenizer = model.get_tokenizer()

train_loader, val_loader, test_loader = get_dataloaders(
    tokenizer=tokenizer,
    batch_size=config['batch_size'],
    num_workers=config['num_workers'],
    image_size=config['image_size'],
    max_length=config['max_length'],
    train_subset=None,  # Use full dataset, or set to e.g., 1000 for testing
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

[32m2026-01-15 20:19:54.497[0m | [1mINFO    [0m | [36msrc.data.dataloader[0m:[36mget_dataloaders[0m:[36m47[0m - [1mCreating dataloaders...[0m
[32m2026-01-15 20:19:54.513[0m | [1mINFO    [0m | [36msrc.data.dataset[0m:[36m__init__[0m:[36m60[0m - [1mLoading MIMIC-CXR dataset (split: train)...[0m



Loading datasets...


[32m2026-01-15 20:19:59.274[0m | [1mINFO    [0m | [36msrc.data.dataset[0m:[36m__init__[0m:[36m80[0m - [1mLoaded 30633 samples[0m
[32m2026-01-15 20:19:59.275[0m | [1mINFO    [0m | [36msrc.data.dataset[0m:[36m__init__[0m:[36m60[0m - [1mLoading MIMIC-CXR dataset (split: validation)...[0m
[32m2026-01-15 20:20:00.716[0m | [1mINFO    [0m | [36msrc.data.dataset[0m:[36m__init__[0m:[36m80[0m - [1mLoaded 3063 samples[0m
[32m2026-01-15 20:20:00.716[0m | [1mINFO    [0m | [36msrc.data.dataset[0m:[36m__init__[0m:[36m60[0m - [1mLoading MIMIC-CXR dataset (split: test)...[0m
[32m2026-01-15 20:20:02.224[0m | [1mINFO    [0m | [36msrc.data.dataset[0m:[36m__init__[0m:[36m80[0m - [1mLoaded 3064 samples[0m
[32m2026-01-15 20:20:02.226[0m | [1mINFO    [0m | [36msrc.data.dataloader[0m:[36mget_dataloaders[0m:[36m117[0m - [1mTrain samples: 30633[0m
[32m2026-01-15 20:20:02.227[0m | [1mINFO    [0m | [36msrc.data.dataloader[0m:[36mget_dat

Train batches: 30633
Val batches: 3063
Test batches: 3064


## 3. Training Setup

In [5]:
from torch.optim import AdamW
from torch.cuda.amp import GradScaler, autocast
from src.training.scheduler import get_cosine_schedule_with_warmup
from src.utils.metrics import compute_metrics

# NOVEL: Import novel training components
from src.training.losses import CombinedNovelLoss
from src.training.curriculum import AnatomicalCurriculumScheduler, create_curriculum_dataloader
from src.utils.clinical_validator import ClinicalValidator

# Optimizer
no_decay = ['bias', 'LayerNorm.weight', 'layer_norm.weight']
optimizer_grouped_parameters = [
    {
        'params': [p for n, p in model.named_parameters() 
                   if p.requires_grad and not any(nd in n for nd in no_decay)],
        'weight_decay': config['weight_decay'],
    },
    {
        'params': [p for n, p in model.named_parameters() 
                   if p.requires_grad and any(nd in n for nd in no_decay)],
        'weight_decay': 0.0,
    },
]

optimizer = AdamW(optimizer_grouped_parameters, lr=config['learning_rate'])

# Scheduler
total_steps = len(train_loader) * config['epochs'] // config['gradient_accumulation_steps']
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=config['warmup_steps'],
    num_training_steps=total_steps,
)

# Mixed precision scaler
scaler = GradScaler() if config['use_amp'] else None

# NOVEL: Initialize novel loss functions
if config.get('use_novel_losses', False):
    novel_loss = CombinedNovelLoss(
        use_anatomical_consistency=config.get('use_anatomical_consistency_loss', True),
        use_clinical_entity=config.get('use_clinical_entity_loss', True),
        use_region_focal=config.get('use_region_focal_loss', True),
        use_cross_modal=config.get('use_cross_modal_loss', False),
        anatomical_weight=config.get('anatomical_loss_weight', 0.1),
        clinical_weight=config.get('clinical_loss_weight', 0.2),
        focal_weight=config.get('focal_loss_weight', 0.15),
        alignment_weight=config.get('alignment_loss_weight', 0.1),
    )
    print("✅ Novel loss functions initialized")
else:
    novel_loss = None

# NOVEL: Initialize curriculum learning scheduler
if config.get('use_curriculum_learning', False):
    curriculum_scheduler = AnatomicalCurriculumScheduler()
    print("✅ Curriculum learning scheduler initialized")
else:
    curriculum_scheduler = None

# NOVEL: Initialize clinical validator
if config.get('use_clinical_validation', False):
    clinical_validator = ClinicalValidator()
    print("✅ Clinical validator initialized")
else:
    clinical_validator = None

print(f"\nTotal optimization steps: {total_steps}")
print(f"Warmup steps: {config['warmup_steps']}")
print(f"Novel losses: {config.get('use_novel_losses', False)}")
print(f"Curriculum learning: {config.get('use_curriculum_learning', False)}")
print(f"Clinical validation: {config.get('use_clinical_validation', False)}")

✅ Novel loss functions initialized
✅ Curriculum learning scheduler initialized
✅ Clinical validator initialized

Total optimization steps: 11966
Warmup steps: 500
Novel losses: True
Curriculum learning: True
Clinical validation: True


## 4. Training Loop

In [6]:
# Training history
history = {
    'train_loss': [],
    'val_loss': [],
    'bleu_1': [],
    'bleu_2': [],
    'bleu_3': [],
    'bleu_4': [],
    'rouge_1': [],
    'rouge_2': [],
    'rouge_l': [],
    'learning_rate': [],
}

best_metric = 0.0
patience_counter = 0
patience = 5

In [7]:
# Main training loop - Using XR2TextTrainer class
# AUTO-RESUME: Automatically detects and resumes from best checkpoint
from src.training.trainer import XR2TextTrainer
import torch
import gc
from pathlib import Path

# ============================================
# AUTO-RESUME FROM CHECKPOINT
# ============================================
checkpoint_dir = Path(config['checkpoint_dir'])

def find_best_checkpoint(checkpoint_dir):
    """Find the best checkpoint to resume from."""
    checkpoint_dir = Path(checkpoint_dir)
    if not checkpoint_dir.exists():
        return None, 0
    
    # Priority: best_model.pt > latest checkpoint_epoch_*.pt
    best_model = checkpoint_dir / "best_model.pt"
    if best_model.exists():
        ckpt = torch.load(best_model, map_location='cpu')
        return str(best_model), ckpt.get('epoch', 0) + 1
    
    # Find latest epoch checkpoint
    epoch_checkpoints = list(checkpoint_dir.glob("checkpoint_epoch_*.pt"))
    if epoch_checkpoints:
        # Sort by epoch number
        def get_epoch(p):
            try:
                return int(p.stem.split('_')[-1])
            except:
                return 0
        latest = max(epoch_checkpoints, key=get_epoch)
        ckpt = torch.load(latest, map_location='cpu')
        return str(latest), ckpt.get('epoch', 0) + 1
    
    return None, 0

# Auto-detect checkpoint
checkpoint_path, resume_epoch = find_best_checkpoint(checkpoint_dir)

print("=" * 70)
print("XR2Text Training with AUTO-RESUME")
print("=" * 70)

# Memory cleanup before training
print("\nClearing GPU memory...")
gc.collect()
torch.cuda.empty_cache()
if torch.cuda.is_available():
    print(f"GPU Memory - Allocated: {torch.cuda.memory_allocated()/1024**3:.2f} GB")
    print(f"GPU Memory - Cached: {torch.cuda.memory_reserved()/1024**3:.2f} GB")

# Create trainer with optimized config
trainer = XR2TextTrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    config=config,
)

# AUTO-RESUME: Load checkpoint if found
if checkpoint_path:
    print(f"\n>>> CHECKPOINT FOUND: {checkpoint_path}")
    print(f">>> Resuming from epoch {resume_epoch}")
    trainer.load_checkpoint(checkpoint_path)
else:
    print("\n>>> No checkpoint found. Starting fresh training from epoch 1")

print("\n" + "=" * 70)
print("TRAINING CONFIGURATION SUMMARY:")
print(f"  Learning rate: {config['learning_rate']}")
print(f"  Label smoothing: {config.get('label_smoothing', 0.1)}")
print(f"  Validate every: {config.get('validate_every', 2)} epochs")
print(f"  Generation beams: {config.get('generation', {}).get('num_beams', 5)}")
print(f"  Min generation length: {config.get('generation', {}).get('min_length', 20)}")
print("=" * 70 + "\n")

# Run training
final_metrics = trainer.train()

# Extract history from trainer for visualization
history = {
    'train_loss': trainer.metrics_tracker.get_history('train_loss'),
    'val_loss': trainer.metrics_tracker.get_history('val_loss'),
    'bleu_1': trainer.metrics_tracker.get_history('bleu_1'),
    'bleu_2': trainer.metrics_tracker.get_history('bleu_2'),
    'bleu_3': trainer.metrics_tracker.get_history('bleu_3'),
    'bleu_4': trainer.metrics_tracker.get_history('bleu_4'),
    'rouge_1': trainer.metrics_tracker.get_history('rouge_1'),
    'rouge_2': trainer.metrics_tracker.get_history('rouge_2'),
    'rouge_l': trainer.metrics_tracker.get_history('rouge_l'),
    'learning_rate': [trainer.scheduler.get_last_lr()[0]] * (trainer.current_epoch + 1),
}

# Add clinical validation metrics if enabled
if config.get('use_clinical_validation', False):
    history['clinical_accuracy'] = trainer.metrics_tracker.get_history('clinical_accuracy')
    history['clinical_f1'] = trainer.metrics_tracker.get_history('clinical_f1')
    history['critical_errors'] = trainer.metrics_tracker.get_history('critical_errors')

# Save training history
history_df = pd.DataFrame(history)
history_df['epoch'] = range(1, len(history_df) + 1)
os.makedirs('../data/statistics', exist_ok=True)
history_df.to_csv('../data/statistics/training_history.csv', index=False)

# Store predictions and references for sample display
predictions = []
references = []

print("\n" + "=" * 70)
print("TRAINING COMPLETE!")
print("=" * 70)
print(f"\nFinal Metrics:")
for key, value in final_metrics.items():
    if isinstance(value, float):
        print(f"  {key}: {value:.4f}")
    else:
        print(f"  {key}: {value}")

# Final memory cleanup
gc.collect()
torch.cuda.empty_cache()
print(f"\nFinal GPU Memory - Allocated: {torch.cuda.memory_allocated()/1024**3:.2f} GB")

[32m2026-01-15 20:20:09.290[0m | [1mINFO    [0m | [36msrc.utils.device[0m:[36mget_device[0m:[36m27[0m - [1mUsing CUDA device: NVIDIA GeForce RTX 4060 Laptop GPU[0m
[32m2026-01-15 20:20:09.294[0m | [1mINFO    [0m | [36msrc.utils.device[0m:[36msetup_cuda_optimizations[0m:[36m86[0m - [1mEnabled cuDNN benchmark mode[0m
[32m2026-01-15 20:20:09.296[0m | [1mINFO    [0m | [36msrc.utils.device[0m:[36msetup_cuda_optimizations[0m:[36m92[0m - [1mEnabled TF32 for matrix operations[0m
[32m2026-01-15 20:20:09.297[0m | [1mINFO    [0m | [36msrc.utils.device[0m:[36msetup_cuda_optimizations[0m:[36m96[0m - [1mCleared CUDA cache[0m
[32m2026-01-15 20:20:09.319[0m | [1mINFO    [0m | [36msrc.training.trainer[0m:[36m__init__[0m:[36m146[0m - [1mNovel loss functions enabled[0m
[32m2026-01-15 20:20:09.320[0m | [1mINFO    [0m | [36msrc.training.trainer[0m:[36m__init__[0m:[36m154[0m - [1mCurriculum learning enabled[0m
[32m2026-01-15 20:20:09.3

XR2Text Training with AUTO-RESUME

Clearing GPU memory...
GPU Memory - Allocated: 2.03 GB
GPU Memory - Cached: 2.06 GB

>>> CHECKPOINT FOUND: ..\checkpoints\best_model.pt
>>> Resuming from epoch 4


[32m2026-01-15 20:20:13.877[0m | [1mINFO    [0m | [36msrc.training.trainer[0m:[36mload_checkpoint[0m:[36m814[0m - [1mLoaded checkpoint from ..\checkpoints\best_model.pt[0m
[32m2026-01-15 20:20:13.878[0m | [1mINFO    [0m | [36msrc.training.trainer[0m:[36mload_checkpoint[0m:[36m815[0m - [1mResuming from epoch 4[0m
[32m2026-01-15 20:20:13.882[0m | [1mINFO    [0m | [36msrc.training.trainer[0m:[36mtrain[0m:[36m380[0m - [1mStarting training...[0m
[32m2026-01-15 20:20:13.883[0m | [1mINFO    [0m | [36msrc.training.curriculum[0m:[36mprecompute_difficulty_scores[0m:[36m176[0m - [1mPre-computing difficulty scores for 30633 samples...[0m



TRAINING CONFIGURATION SUMMARY:
  Learning rate: 0.0001
  Label smoothing: 0.05
  Validate every: 2 epochs
  Generation beams: 2
  Min generation length: 20



[32m2026-01-15 20:23:20.885[0m | [1mINFO    [0m | [36msrc.training.curriculum[0m:[36mprecompute_difficulty_scores[0m:[36m194[0m - [1mPre-computed 30633 difficulty scores[0m
[32m2026-01-15 20:23:20.892[0m | [1mINFO    [0m | [36msrc.training.curriculum[0m:[36m__init__[0m:[36m328[0m - [1mCurriculum stage 'warmup': 3363/30633 samples[0m
[32m2026-01-15 20:23:20.895[0m | [1mINFO    [0m | [36msrc.training.trainer[0m:[36mtrain[0m:[36m407[0m - [1mCurriculum stage: warmup (3363/30633 samples)[0m
Epoch 5:  96%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋    | 26/27 [51:31<01:58, 118.89s/step, loss=8.6355]
[32m2026-01-15 21:14:52.101[0m | [1mINFO    [0m | [36msrc.training.trainer[0m:[36mtrain[0m:[36m435[0m - [1mEpoch 5/50 | Train: 7.5560 | Val: SKIPPED (validating every 2 epochs)[0m
[32m2026-01-15 21:14:59.153[0m | [1mINFO    [0m | [36msrc.utils.logger[0m:[36mlog_checkpoint[

RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
# ROBUST FIX - handles any array length mismatch
import pandas as pd
import os

# Get all the metrics
val_loss = trainer.metrics_tracker.get_history('val_loss')
bleu_1 = trainer.metrics_tracker.get_history('bleu_1')
bleu_2 = trainer.metrics_tracker.get_history('bleu_2')
bleu_3 = trainer.metrics_tracker.get_history('bleu_3')
bleu_4 = trainer.metrics_tracker.get_history('bleu_4')
rouge_1 = trainer.metrics_tracker.get_history('rouge_1')
rouge_2 = trainer.metrics_tracker.get_history('rouge_2')
rouge_l = trainer.metrics_tracker.get_history('rouge_l')
train_loss = trainer.metrics_tracker.get_history('train_loss')

# Debug: Print lengths
print("Array lengths:")
print(f"  val_loss: {len(val_loss)}")
print(f"  bleu_4: {len(bleu_4)}")
print(f"  rouge_l: {len(rouge_l)}")
print(f"  train_loss: {len(train_loss)}")

# Find minimum length among validation metrics
min_len = min(len(val_loss), len(bleu_4), len(rouge_l))
print(f"\nUsing {min_len} epochs")

# Create DataFrame with matching lengths
history_df = pd.DataFrame({
    'epoch': list(range(2, 2 + min_len * 2, 2))[:min_len],
    'val_loss': val_loss[:min_len],
    'bleu_1': bleu_1[:min_len],
    'bleu_2': bleu_2[:min_len],
    'bleu_3': bleu_3[:min_len],
    'bleu_4': bleu_4[:min_len],
    'rouge_1': rouge_1[:min_len],
    'rouge_2': rouge_2[:min_len],
    'rouge_l': rouge_l[:min_len],
})

# Add train_loss if available (sample every 2nd)
if train_loss:
    sampled_train = train_loss[1::2][:min_len]
    if len(sampled_train) == min_len:
        history_df['train_loss'] = sampled_train

# Save
os.makedirs('../data/statistics', exist_ok=True)
history_df.to_csv('../data/statistics/training_history.csv', index=False)

print(f"\n✅ Saved {len(history_df)} epochs!")
print("\nLast 5 rows:")
print(history_df.tail())
print(f"\nBest BLEU-4: {history_df['bleu_4'].max():.4f}")
print(f"Best ROUGE-L: {history_df['rouge_l'].max():.4f}")


## 5. Training Curves Visualization

## 4.5 NOVEL: Enhanced Curriculum Learning Analysis

This section provides detailed analysis of our curriculum learning strategy,
showing how it affects training dynamics and final performance.

In [None]:
# ============================================
# NOVEL: ENHANCED CURRICULUM LEARNING ANALYSIS (5 STAGES, 50 EPOCHS)
# ============================================
from src.training.curriculum import AnatomicalCurriculumScheduler
import os
import pandas as pd
import matplotlib.pyplot as plt

print("=" * 80)
print("NOVEL: CURRICULUM LEARNING ANALYSIS (5 STAGES, 50 EPOCHS)")
print("=" * 80)

# Initialize curriculum scheduler
curriculum = AnatomicalCurriculumScheduler()

# Display curriculum stages
print("\n1. CURRICULUM STAGES (5-Stage Progressive Training)")
print("-" * 60)
print(f"\n{'Stage':<20} {'Epochs':<15} {'Description':<40}")
print("-" * 80)

stage_descriptions = {
    'warmup': 'Warmup with easy cases only',
    'easy': 'Normal X-rays, simple findings',
    'medium': 'Single anatomical region findings',
    'hard': 'Multiple regions, moderate complexity',
    'finetune': 'Complex cases, full dataset fine-tuning',
}

for stage in curriculum.stages:
    name = stage['name']
    epoch_range = f"{stage['epoch_start']}-{stage['epoch_end']}"
    desc = stage_descriptions.get(name, 'Full dataset')
    print(f"{name:<20} {epoch_range:<15} {desc:<40}")

# Sample difficulty scoring demo
print("\n2. SAMPLE DIFFICULTY SCORING")
print("-" * 60)

sample_reports = [
    "Lungs are clear. Heart size is normal. No acute cardiopulmonary process.",
    "Mild cardiomegaly. Lungs are clear bilaterally.",
    "Bilateral pleural effusions. Cardiomegaly. Pulmonary edema.",
    "Large right pneumothorax. Left lung consolidation. Cardiomegaly.",
]

print("\nSample Reports with Difficulty Scores:")
for i, report in enumerate(sample_reports):
    scores = curriculum.difficulty_scorer(report)
    total_difficulty = scores.get('num_findings', 0) + scores.get('severity_score', 0)
    print(f"\n[Sample {i+1}] Difficulty: {total_difficulty:.1f}")
    print(f"   Report: {report[:60]}...")
    print(f"   Findings: {scores.get('num_findings', 0)}, Regions: {scores.get('num_regions', 0)}")

# Load and analyze training history
print("\n3. CURRICULUM LEARNING IMPACT")
print("-" * 60)

history_path = '../data/statistics/training_history.csv'
if os.path.exists(history_path):
    print("\n" + "=" * 60)
    print("CURRICULUM LEARNING RESULTS (Real Data)")
    print("=" * 60)

    df = pd.read_csv(history_path)

    print("\nPerformance at Curriculum Stage Transitions:")
    print("-" * 60)

    # 5-stage curriculum for 50 epochs: warmup(0-5), easy(5-12), medium(12-25), hard(25-40), finetune(40-50)
    stage_info = [
        (5, 'End of Stage 1 (Warmup)'),
        (12, 'End of Stage 2 (Easy Cases)'),
        (25, 'End of Stage 3 (Medium Cases)'),
        (40, 'End of Stage 4 (Hard Cases)'),
        (50, 'End of Stage 5 (Fine-tuning)'),
    ]

    for target_epoch, stage_name in stage_info:
        mask = df['epoch'] <= target_epoch
        if mask.any():
            row = df[mask].iloc[-1]
            print(f"\nEpoch {int(row['epoch'])} - {stage_name}:")
            print(f"  BLEU-4:  {row['bleu_4']:.4f}")
            print(f"  ROUGE-L: {row['rouge_l']:.4f}")
            print(f"  Val Loss: {row['val_loss']:.4f}")

    # Plot curriculum impact
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # Stage transitions at epochs 5, 12, 25, 40
    stage_transitions = [5, 12, 25, 40]

    # BLEU-4 progression with stage markers
    axes[0].plot(df['epoch'], df['bleu_4'], linewidth=2, color='blue', marker='o', markersize=3)
    for trans in stage_transitions:
        axes[0].axvline(x=trans, color='red', linestyle='--', alpha=0.7)
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('BLEU-4')
    axes[0].set_title('BLEU-4 Progression with 5-Stage Curriculum')
    axes[0].grid(True, alpha=0.3)

    # Loss progression with stage markers
    axes[1].plot(df['epoch'], df['val_loss'], linewidth=2, color='orange', marker='o', markersize=3)
    for trans in stage_transitions:
        axes[1].axvline(x=trans, color='red', linestyle='--', alpha=0.7)
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Validation Loss')
    axes[1].set_title('Loss Progression with 5-Stage Curriculum')
    axes[1].grid(True, alpha=0.3)

    plt.tight_layout()
    os.makedirs('../data/figures', exist_ok=True)
    plt.savefig('../data/figures/curriculum_impact.png', dpi=300, bbox_inches='tight')
    plt.show()

    print("\n" + "=" * 60)
    print("Curriculum learning analysis complete!")
    print("Figure saved: ../data/figures/curriculum_impact.png")
    print("=" * 60)
else:
    print("\nTraining history not found yet.")
    print("Run this cell again after training completes.")

In [None]:
# ============================================
# FIXED: TRAINING CURVES VISUALIZATION
# ============================================
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Load history from CSV
history_path = "../data/statistics/training_history.csv"

if os.path.exists(history_path):
    print("Loading training history from CSV...")
    df = pd.read_csv(history_path)
    print(f"Loaded {len(df)} epochs of data")

    # Check if we have data
    if len(df) > 0 and 'bleu_4' in df.columns:
        # Create figure with subplots
        fig, axes = plt.subplots(2, 2, figsize=(14, 10))
        axes = axes.flatten()

        # Plot 1: Validation Loss
        axes[0].plot(df['epoch'], df['val_loss'], label='Val Loss', color='orange', linewidth=2, marker='o', markersize=4)
        axes[0].set_xlabel('Epoch')
        axes[0].set_ylabel('Loss')
        axes[0].set_title('Validation Loss')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)

        # Plot 2: BLEU Scores
        axes[1].plot(df['epoch'], df['bleu_1'], label='BLEU-1', linewidth=2)
        axes[1].plot(df['epoch'], df['bleu_2'], label='BLEU-2', linewidth=2)
        axes[1].plot(df['epoch'], df['bleu_3'], label='BLEU-3', linewidth=2)
        axes[1].plot(df['epoch'], df['bleu_4'], label='BLEU-4', linewidth=2, marker='o', markersize=4)
        axes[1].set_xlabel('Epoch')
        axes[1].set_ylabel('Score')
        axes[1].set_title('BLEU Scores')
        axes[1].legend()
        axes[1].grid(True, alpha=0.3)

        # Plot 3: ROUGE Scores
        axes[2].plot(df['epoch'], df['rouge_1'], label='ROUGE-1', linewidth=2)
        axes[2].plot(df['epoch'], df['rouge_2'], label='ROUGE-2', linewidth=2)
        axes[2].plot(df['epoch'], df['rouge_l'], label='ROUGE-L', linewidth=2, marker='o', markersize=4)
        axes[2].set_xlabel('Epoch')
        axes[2].set_ylabel('Score')
        axes[2].set_title('ROUGE Scores')
        axes[2].legend()
        axes[2].grid(True, alpha=0.3)

        # Plot 4: Combined BLEU-4 and ROUGE-L
        axes[3].plot(df['epoch'], df['bleu_4'], label='BLEU-4', linewidth=2, color='blue', marker='o', markersize=4)
        axes[3].plot(df['epoch'], df['rouge_l'], label='ROUGE-L', linewidth=2, color='green', marker='s', markersize=4)
        axes[3].set_xlabel('Epoch')
        axes[3].set_ylabel('Score')
        axes[3].set_title('BLEU-4 vs ROUGE-L Comparison')
        axes[3].legend()
        axes[3].grid(True, alpha=0.3)

        plt.tight_layout()
        os.makedirs('../data/figures', exist_ok=True)
        plt.savefig('../data/figures/training_curves.png', dpi=300, bbox_inches='tight')
        plt.show()

        print("\nTraining curves saved to ../data/figures/training_curves.png")

        # Print summary
        print("\n" + "=" * 50)
        print("TRAINING SUMMARY")
        print("=" * 50)
        print(f"Best BLEU-4:  {df['bleu_4'].max():.4f} (Epoch {df.loc[df['bleu_4'].idxmax(), 'epoch']:.0f})")
        print(f"Best ROUGE-L: {df['rouge_l'].max():.4f} (Epoch {df.loc[df['rouge_l'].idxmax(), 'epoch']:.0f})")
        print(f"Final Val Loss: {df['val_loss'].iloc[-1]:.4f}")
    else:
        print("No valid data in CSV file")
else:
    print("Training history CSV not found!")
    print("Expected at:", history_path)


## 6. Sample Predictions

In [None]:
# Show sample predictions vs ground truth
print("Sample Predictions vs Ground Truth:")
print("=" * 80)

# Check if predictions and references exist
if 'predictions' not in dir() or not predictions:
    predictions = []
if 'references' not in dir() or not references:
    references = []

if len(predictions) > 0 and len(references) > 0:
    for i in range(min(5, len(predictions))):
        print(f"\n--- Sample {i+1} ---")
        print(f"\nGround Truth:")
        print(references[i][:500] + "..." if len(references[i]) > 500 else references[i])
        print(f"\nGenerated:")
        print(predictions[i][:500] + "..." if len(predictions[i]) > 500 else predictions[i])
        print("-" * 80)
else:
    print("\n⚠️ No predictions available yet!")
    print("   Predictions will be available after training completes (cell 11).")
    print("   Or run evaluation on test set in notebook 03_evaluation.ipynb.")

## 7. Final Results Summary

In [None]:
# ============================================
# FIXED: FINAL RESULTS SUMMARY
# ============================================
import os
import pandas as pd
import numpy as np

history_path = "../data/statistics/training_history.csv"

if os.path.exists(history_path):
    df = pd.read_csv(history_path)

    # Find best epoch by combined BLEU-4 + ROUGE-L score
    df['combined_score'] = df['bleu_4'] + df['rouge_l']
    best_idx = df['combined_score'].idxmax()
    best_row = df.loc[best_idx]
    final_row = df.iloc[-1]

    print("=" * 60)
    print("TRAINING RESULTS SUMMARY")
    print("=" * 60)

    # FIXED: Use actual epoch value from the dataframe
    print(f"\nBest Epoch: {int(best_row['epoch'])} (by BLEU-4 + ROUGE-L)")

    print(f"\nBest Metrics (Epoch {int(best_row['epoch'])}):")
    print(f"  BLEU-1:  {best_row['bleu_1']:.4f}")
    print(f"  BLEU-2:  {best_row['bleu_2']:.4f}")
    print(f"  BLEU-3:  {best_row['bleu_3']:.4f}")
    print(f"  BLEU-4:  {best_row['bleu_4']:.4f}")
    print(f"  ROUGE-1: {best_row['rouge_1']:.4f}")
    print(f"  ROUGE-2: {best_row['rouge_2']:.4f}")
    print(f"  ROUGE-L: {best_row['rouge_l']:.4f}")

    print(f"\nFinal Metrics (Epoch {int(final_row['epoch'])}):")
    print(f"  Val Loss: {final_row['val_loss']:.4f}")
    print(f"  BLEU-4:   {final_row['bleu_4']:.4f}")
    print(f"  ROUGE-L:  {final_row['rouge_l']:.4f}")

    # Save best results to CSV
    results_table = pd.DataFrame({
        'Metric': ['BLEU-1', 'BLEU-2', 'BLEU-3', 'BLEU-4', 'ROUGE-1', 'ROUGE-2', 'ROUGE-L'],
        'Best Score': [
            best_row['bleu_1'],
            best_row['bleu_2'],
            best_row['bleu_3'],
            best_row['bleu_4'],
            best_row['rouge_1'],
            best_row['rouge_2'],
            best_row['rouge_l'],
        ],
        'Final Score': [
            final_row['bleu_1'],
            final_row['bleu_2'],
            final_row['bleu_3'],
            final_row['bleu_4'],
            final_row['rouge_1'],
            final_row['rouge_2'],
            final_row['rouge_l'],
        ]
    })

    os.makedirs('../data/statistics', exist_ok=True)
    results_table.to_csv('../data/statistics/best_results.csv', index=False)

    print("\n" + "=" * 60)
    print("Results saved to ../data/statistics/best_results.csv")
    print("=" * 60)

    # Display table
    print("\nResults Table:")
    print(results_table.to_string(index=False))

else:
    print("=" * 60)
    print("TRAINING RESULTS SUMMARY")
    print("=" * 60)
    print("\nNo training results available yet!")
    print("Run training first (cell 11) to see results.")


In [None]:
## 8. NOVEL: Enhanced Analysis with New Features

##This section demonstrates the new enhancement modules for comprehensive report analysis.

In [None]:
# ============================================
# NOVEL: Enhanced Analysis Demo
# ============================================
# This demonstrates the new enhancement modules

print("=" * 70)
print("NOVEL ENHANCEMENT MODULES DEMO")
print("=" * 70)

# Check if model has enhancement modules
if hasattr(model, 'generate_with_analysis'):
    print("\n✅ Model has enhanced analysis capabilities!")
    print("\nAvailable analysis features:")
    print("  1. Uncertainty Quantification")
    print("     - Overall confidence score (0-1)")
    print("     - Per-finding confidence scores")
    print("     - Calibrated uncertainty estimates")
    print("\n  2. Factual Grounding")
    print("     - Detected medical findings")
    print("     - Potential hallucinations flagged")
    print("     - Knowledge graph validation")
    print("\n  3. Explainability")
    print("     - Evidence regions highlighted")
    print("     - Clinical reasoning chains")
    print("     - Attention visualizations")
    print("\n  4. Multi-Task Outputs")
    print("     - Region classification")
    print("     - Severity prediction")
    print("     - Finding detection")
    
    # Demo analysis on a sample if test data is available
    print("\n" + "-" * 50)
    print("Running Enhanced Analysis on Sample...")
    print("-" * 50)
    
    try:
        # Get a sample from test loader
        sample_batch = next(iter(test_loader))
        sample_image = sample_batch['images'][0:1].to(config['device'])
        
        # Run enhanced analysis
        with torch.no_grad():
            analysis = model.generate_with_analysis(
                sample_image,
                max_length=config['generation']['max_length'],
                num_beams=config['generation']['num_beams'],
            )
        
        print(f"\n📝 Generated Report:")
        print(f"   {analysis.get('report', 'N/A')[:200]}...")
        
        print(f"\n📊 Uncertainty Analysis:")
        print(f"   Overall Confidence: {analysis.get('confidence', 0):.2%}")
        if 'finding_confidences' in analysis:
            print(f"   Finding Confidences: {len(analysis['finding_confidences'])} findings analyzed")
        
        print(f"\n🔍 Factual Grounding:")
        if 'detected_findings' in analysis:
            print(f"   Detected Findings: {analysis['detected_findings'][:5]}")
        if 'potential_hallucinations' in analysis:
            print(f"   Potential Hallucinations: {len(analysis.get('potential_hallucinations', []))}")
        
        print(f"\n💡 Explainability:")
        if 'evidence_regions' in analysis:
            print(f"   Evidence Regions: {len(analysis['evidence_regions'])} regions identified")
        if 'reasoning' in analysis:
            print(f"   Clinical Reasoning: Available")
            
    except Exception as e:
        print(f"   Demo skipped (requires trained model): {e}")
        
else:
    print("\n⚠️  Enhancement modules not loaded in current model.")
    print("   Ensure use_uncertainty, use_grounding, use_explainability, use_multitask are True.")
    print("   Re-initialize model with updated config to enable these features.")

print("\n" + "=" * 70)
print("Enhanced Analysis Demo Complete")
print("=" * 70)