# XR2Text: Model Training with HAQT-ARR

## CLOUD GPU VERSION - Optimized for RunPod A100 PCIe 80GB VRAM

**Authors**: S. Nikhil, Dadhania Omkumar  
**Supervisor**: Dr. Damodar Panigrahy

---

This notebook implements the complete training pipeline for XR2Text:

### Architecture (NOVEL CONTRIBUTIONS):
1. **HAQT-ARR** - Hierarchical Anatomical Query Tokens with Adaptive Region Routing
2. **Uncertainty Quantification** - MC Dropout + Temperature Calibration
3. **Factual Grounding** - Knowledge Graph + Hallucination Detection
4. **Multi-Task Learning** - Region/Severity/Finding Classification

### Training Configuration (A100 80GB Optimized):
- **FULL DATASET**: 30,633 images (21,443 train / 3,063 val / 6,127 test)
- **BioBART-Large** decoder (406M params)
- **Image Size**: 512x512 (high resolution for detail)
- **Batch Size**: 28 - A100 has superior memory bandwidth!
- **Gradient Accumulation**: 2 steps (effective batch = 56)
- **R-Drop Regularization**: ENABLED for +1-2% metrics
- **All Encoder Layers Unfrozen** - Full fine-tuning
- **Curriculum Learning**: 5 stages over 50 epochs

### Expected Results (Full Dataset):
| Metric | Target | Published SOTA |
|--------|--------|----------------|
| BLEU-4 | 0.15+ | 0.128 (ORGAN, ACL 2023) |
| ROUGE-L | 0.35+ | 0.293 (ORGAN, ACL 2023) |
| Clinical F1 | 0.80-0.85 | Novel metric |

**Note**: A100 has ~3x better FP16 performance than consumer GPUs - training will be FAST!

In [None]:
# ==============================================
# RUNPOD SETUP - Run this cell FIRST!
# ==============================================
import os
import sys
import subprocess

print("=" * 60)
print("RUNPOD AUTO-SETUP (No SSH Required!)")
print("=" * 60)

# 1. Fix Python path
sys.path.insert(0, '..')

# 2. Create directories with proper permissions
print("")
print("[1/4] Creating directories...")
dirs_to_fix = [
    '../checkpoints', 
    '../logs', 
    '../data', 
    '../data/figures', 
    '../data/statistics',
    '../data/human_evaluation',
    '../data/ablation_results',
]

for d in dirs_to_fix:
    os.makedirs(d, exist_ok=True)
    try:
        os.chmod(d, 0o777)
    except:
        pass
print("   Directories created!")

# 3. Install missing packages (if any)
print("")
print("[2/4] Checking packages...")
required = ['timm', 'albumentations', 'loguru', 'rouge_score', 'bert_score']
for pkg in required:
    try:
        __import__(pkg.replace('-', '_'))
    except ImportError:
        print(f"   Installing {pkg}...")
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', pkg])
print("   Packages OK!")

# 4. Download NLTK data
print("")
print("[3/4] NLTK data...")
try:
    import nltk
    nltk.download('punkt', quiet=True)
    nltk.download('wordnet', quiet=True)
    nltk.download('omw-1.4', quiet=True)
    print("   NLTK data ready!")
except:
    print("   NLTK download skipped")

# 5. GPU Check
print("")
print("[4/4] GPU Check...")
import torch
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"   GPU: {gpu_name}")
    print(f"   VRAM: {gpu_mem:.1f} GB")
    if "A100" in gpu_name:
        print("   >>> A100 DETECTED - BEST FOR DEEP LEARNING!")
        print("   >>> Superior tensor cores & memory bandwidth!")
    elif gpu_mem > 90:
        print("   >>> RTX 6000 Pro DETECTED (96GB)")
    elif gpu_mem > 40:
        print("   >>> A40 DETECTED")
else:
    print("   WARNING: No GPU detected!")

print("")
print("=" * 60)
print("SETUP COMPLETE! Continue running cells below.")
print("=" * 60)

In [None]:
# ============================================
# VERIFY TRAINER CHECKPOINT CONFIGURATION
# Run this cell AFTER RunPod Setup, BEFORE Training
# Saves ONLY best_model.pt at end of training
# ============================================

trainer_path = '../src/training/trainer.py'

with open(trainer_path, 'r') as f:
    content = f.read()

# ============================================
# VERIFY: Check if all patches are correctly applied
# ============================================
checks_passed = 0
total_checks = 3

# CHECK 1: best_model_state initialization
if "self.best_model_state = None" in content:
    print("‚úì Check 1 PASSED: best_model_state initialization present")
    checks_passed += 1
else:
    print("‚úó Check 1 FAILED: best_model_state initialization missing")

# CHECK 2: Store best model in memory (not saving to disk during training)
if "Store best model state in memory (NO DISK SAVE during training)" in content:
    print("‚úì Check 2 PASSED: Best model stored in memory (no disk save during training)")
    checks_passed += 1
else:
    print("‚úó Check 2 FAILED: Best model memory storage not configured")

# CHECK 3: Save best model only at the very end
if "Training complete - NOW save the best model to disk" in content:
    print("‚úì Check 3 PASSED: Best model saved only at training end")
    checks_passed += 1
else:
    print("‚úó Check 3 FAILED: End-of-training save not configured")

print("")
print("=" * 60)
if checks_passed == total_checks:
    print(f"‚úÖ TRAINER CORRECTLY CONFIGURED! ({checks_passed}/{total_checks} checks passed)")
else:
    print(f"‚ö†Ô∏è TRAINER MAY NEED UPDATES ({checks_passed}/{total_checks} checks passed)")
    print("   Please check trainer.py manually or re-download from repo")
print("=" * 60)
print("")
print("CHECKPOINT STRATEGY:")
print("  - Epoch 1-49:  NO checkpoints saved (training in memory)")
print("  - Epoch 50:    best_model.pt (final best by BLEU-4 + ROUGE-L)")
print("")
print("Best model selection: highest BLEU-4 + ROUGE-L combined score")
print("=" * 60)

## 1. Configuration

In [None]:
# =============================================================================
# Training Configuration with HAQT-ARR + ALL NOVEL FEATURES
# OPTIMIZED FOR RUNPOD A100 PCIe 80GB VRAM - SAFE + FAST
# Target: BLEU-4 > 0.15, ROUGE-L > 0.35 (competitive with SOTA)
# =============================================================================
# OPTIMIZATIONS APPLIED:
# 1. Fixed curriculum learning criteria (was too restrictive)
# 2. Added min_lr_ratio to prevent LR dropping to 0
# 3. Batch size 48 (SAFE - leaves 25GB buffer for memory spikes)
# 4. No gradient accumulation (not needed)
# 5. Num workers: 12 (match vCPU count)
# 6. Persistent workers for faster epoch transitions
# 7. FIXED: Plain cosine scheduler (no restarts) for curriculum stability
# 8. FIXED: R-Drop disabled (causes OOM with batch=48)
# =============================================================================
from datetime import datetime
import torch

config = {
    # Model - High resolution for A100
    'image_size': 512,
    'encoder_name': 'base',                    # Swin-Base (88M params)
    'decoder_name': 'biobart-large',           # BioBART-Large (406M params)
    'use_anatomical_attention': True,          # Enable HAQT-ARR (Novel)
    
    # HAQT-ARR specific parameters (NOVEL) - DOUBLED for better performance
    'num_regions': 7,
    'num_global_queries': 16,
    'num_region_queries': 8,
    'use_spatial_priors': True,
    'use_adaptive_routing': True,
    'use_cross_region': True,
    
    # Enhancement Modules (10/10 Novelty)
    'use_uncertainty': True,
    'use_grounding': True,
    'use_explainability': True,
    'use_multitask': True,
    
    # Standard parameters
    'language_dim': 1024,
    
    # ==========================================================================
    # TRAINING - A100 80GB VRAM - SAFE BATCH SIZE
    # ==========================================================================
    'epochs': 50,
    'batch_size': 48,                          # SAFE: ~50-55GB, leaves 25GB buffer!
    'gradient_accumulation_steps': 1,          # Not needed
    
    # LEARNING RATES - REDUCED FOR STABILITY
    'learning_rate': 5e-5,                     # REDUCED from 1e-4
    'encoder_lr': 1e-5,                        # REDUCED - pretrained needs lower LR
    'decoder_lr': 5e-5,                        # REDUCED
    'projection_lr': 1e-4,                     # Higher for new HAQT-ARR layers
    
    'weight_decay': 0.05,                      # INCREASED for regularization
    'warmup_steps': 1500,                      # INCREASED for stability
    'max_grad_norm': 0.5,                      # REDUCED for stability
    
    # Label smoothing - INCREASED
    'label_smoothing': 0.15,
    
    # ==========================================================================
    # SCHEDULER - FIXED: Plain cosine (no restarts) for curriculum stability
    # ==========================================================================
    'scheduler': 'cosine',                     # FIXED: Plain cosine, no restarts
    'num_cycles': 1,                           # Not used for plain cosine
    'min_lr_ratio': 0.1,                       # CRITICAL: Don't let LR drop below 10%
    
    # ==========================================================================
    # NOVEL LOSS FUNCTIONS - MINIMAL WEIGHTS TO FOCUS ON MAIN TASK
    # ==========================================================================
    'use_novel_losses': True,
    'use_anatomical_consistency_loss': True,
    'use_clinical_entity_loss': False,
    'use_region_focal_loss': True,
    'use_cross_modal_loss': False,
    
    # MINIMAL auxiliary loss weights
    'anatomical_loss_weight': 0.005,
    'clinical_loss_weight': 0.0,
    'focal_loss_weight': 0.005,
    'alignment_loss_weight': 0.0,
    
    # R-Drop - DISABLED (causes OOM with batch=48)
    'use_rdrop': False,                        # FIXED: Disabled - causes 2x VRAM
    'rdrop_alpha': 0.1,                        # Only used if use_rdrop: true
    
    # CURRICULUM LEARNING - FIXED CRITERIA
    'use_curriculum_learning': True,
    'curriculum_stages': [
        {'name': 'warmup', 'epoch_start': 0, 'epoch_end': 5,
         'criteria': {'max_findings': 1, 'max_regions': 2}},  # FIXED: removed severity:normal
        {'name': 'easy', 'epoch_start': 5, 'epoch_end': 12,
         'criteria': {'max_findings': 2, 'max_regions': 3}},
        {'name': 'medium', 'epoch_start': 12, 'epoch_end': 25,
         'criteria': {'max_findings': 4, 'max_regions': 5}},
        {'name': 'hard', 'epoch_start': 25, 'epoch_end': 40,
         'criteria': {}},
        {'name': 'finetune', 'epoch_start': 40, 'epoch_end': 50,
         'criteria': {}},
    ],
    
    # Clinical Validation
    'use_clinical_validation': True,
    
    # Uncertainty Quantification
    'use_uncertainty_training': True,
    'uncertainty_dropout': 0.1,
    'mc_samples': 5,
    'use_calibration': True,
    
    # Multi-Task Learning - REDUCED weights
    'use_multi_task_learning': True,
    'auxiliary_task_weights': {
        'region_classification': 0.02,
        'severity_prediction': 0.02,
        'finding_detection': 0.05,
        'length_prediction': 0.01,
    },
    
    # Factual Grounding - REDUCED
    'use_factual_grounding': True,
    'grounding_loss_weight': 0.02,
    'grounding_threshold': 0.15,
    
    # OOD Detection
    'use_ood_detection': True,
    'ood_threshold': 0.5,
    
    # Scheduled Sampling - MORE TEACHER FORCING
    'use_scheduled_sampling': True,
    'scheduled_sampling_start': 1.0,
    'scheduled_sampling_end': 0.8,             # INCREASED - more teacher forcing
    'scheduled_sampling_warmup': 15,           # INCREASED - longer warmup
    
    # Region regularization
    'use_region_regularization': True,
    'region_regularization_weight': 0.001,
    
    # EMA (Exponential Moving Average) - for stable training
    'use_ema': True,
    'ema_decay': 0.9999,
    
    # Data - A100 80GB OPTIMIZED
    'max_length': 300,
    'num_workers': 12,                         # Matches 12 vCPU
    'pin_memory': True,
    'prefetch_factor': 4,
    'persistent_workers': True,                # Faster epoch transitions
    
    # Device
    'use_amp': True,
    'gradient_checkpointing': False,           # NOT NEEDED with 80GB VRAM
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    
    # Experiment
    'experiment_name': 'xr2text_a100_80gb_' + datetime.now().strftime("%Y%m%d_%H%M%S"),
    'checkpoint_dir': '../checkpoints',
    'validate_every': 1,
    'save_every': 999,
    'patience': 30,                            # INCREASED
    'log_dir': '../logs',
    
    # Validation
    'val_fraction': 0.5,
    
    # Generation parameters
    'generation': {
        'num_beams': 5,
        'val_num_beams': 3,
        'min_length': 30,
        'max_length': 300,
        'length_penalty': 1.2,
        'repetition_penalty': 1.2,
        'no_repeat_ngram_size': 3,
        'early_stopping': True,
    },
    
    # Error Recovery
    'cublas_retry_enabled': True,
    'cublas_max_retries': 3,
    'cublas_retry_delay': 3,
    'clear_cache_every_steps': 500,
    'max_oom_retries': 3,
    'enable_temp_monitoring': False,
}

# Create directories
os.makedirs(config['checkpoint_dir'], exist_ok=True)
os.makedirs(config['log_dir'], exist_ok=True)
os.makedirs('../data/figures', exist_ok=True)
os.makedirs('../data/statistics', exist_ok=True)

print("=" * 70)
print("XR2Text Training Config - A100 PCIe 80GB VRAM (FIXED)")
print("=" * 70)
print("")
print("MEMORY SAFE SETTINGS:")
print(f"  Batch size: {config['batch_size']} (~50-55GB VRAM usage)")
print(f"  Buffer: ~25GB FREE for memory spikes!")
print(f"  Gradient checkpointing: OFF (not needed)")
print(f"  R-Drop: DISABLED (prevents OOM)")
print("")
print("SCHEDULER FIX:")
print(f"  Scheduler: {config['scheduler']} (no restarts - stable with curriculum)")
print(f"  min_lr_ratio: {config['min_lr_ratio']} (LR won't drop below 10%)")
print("")
print("CONVERGENCE FIXES:")
print("  1. Fixed curriculum criteria (removed severity:normal)")
print("  2. Plain cosine scheduler (no restarts during curriculum)")
print("  3. EMA enabled for stable weights")
print("")
print(">>> STABLE TRAINING - No divergence at curriculum transitions!")
print("=" * 70)

## 2. Load Model and Data

In [None]:
from src.models.xr2text import XR2TextModel, DEFAULT_CONFIG
from src.models.anatomical_attention import ANATOMICAL_REGIONS
from src.data.dataloader import get_dataloaders
from src.utils.device import setup_cuda_optimizations

# Setup CUDA optimizations for A100
setup_cuda_optimizations()

# Create model with HAQT-ARR + ALL ENHANCEMENT MODULES (10/10 Novelty)
print("Creating XR2Text model with HAQT-ARR + Enhancement Modules...")
model_config = {
    'image_size': config['image_size'],                            # 512
    'use_anatomical_attention': config['use_anatomical_attention'],  # Enable HAQT-ARR
    'gradient_checkpointing': False,                                 # NOT NEEDED on 80GB
    
    # Enhancement Modules (10/10 Novelty)
    'use_uncertainty': config.get('use_uncertainty', True),
    'use_grounding': config.get('use_grounding', True),
    'use_explainability': config.get('use_explainability', True),
    'use_multitask': config.get('use_multitask', True),
    
    'encoder': {
        'model_name': config['encoder_name'],
        'pretrained': True,
        'freeze_layers': 0,  # UNFREEZE ALL LAYERS - A100 has massive compute!
        'output_dim': 1024,
        'drop_rate': 0.1,
        'attn_drop_rate': 0.1,
    },
    'projection': {
        # HAQT-ARR parameters (Novel) - DOUBLED for better performance
        'language_dim': config['language_dim'],
        'num_regions': config['num_regions'],
        'num_global_queries': config['num_global_queries'],          # 16 (doubled)
        'num_region_queries': config['num_region_queries'],          # 8 (doubled)
        'use_spatial_priors': config['use_spatial_priors'],
        'use_adaptive_routing': config['use_adaptive_routing'],
        'use_cross_region': config['use_cross_region'],
        'num_cross_region_layers': 3,
        'feature_size': 16,                                          # 512/32 = 16x16
        'dropout': 0.1,
        'num_projection_layers': 3,
        'num_queries': 64,                                           # Doubled from 32
        'use_cross_attention': True,
        'use_residual': True,
    },
    'decoder': {
        'model_name': config['decoder_name'],
        'max_length': config['max_length'],                          # 300
        'freeze_embeddings': False,
        'freeze_layers': 0,
        'use_cache': True,
        'dropout': 0.1,
    }
}

model = XR2TextModel.from_config(model_config)
model = model.to(config['device'])

# Model summary
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n{'='*60}")
print("XR2Text Model - A100 PCIe 80GB FULL TRAINING MODE")
print(f"{'='*60}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Frozen parameters: {total_params - trainable_params:,}")
print(f"\nAnatomical regions: {model.get_anatomical_regions()}")
print(f"Total queries: {config['num_global_queries'] + config['num_regions'] * config['num_region_queries']}")
print(f"\nEnhancement Modules Enabled:")
print(f"  - Uncertainty Quantification: {config.get('use_uncertainty', True)}")
print(f"  - Factual Grounding: {config.get('use_grounding', True)}")
print(f"  - Explainability: {config.get('use_explainability', True)}")
print(f"  - Multi-Task Learning: {config.get('use_multitask', True)}")
print(f"\nA100 Optimizations:")
print(f"  - Image Size: 512x512")
print(f"  - All encoder layers unfrozen: YES")
print(f"  - Gradient checkpointing: OFF (not needed with 80GB)")
print(f"  - R-Drop regularization: {config.get('use_rdrop', True)}")
print(f"  - Num queries: 64 (doubled from 32)")

In [None]:
# Load data with A100 80GB SAFE settings
print("\nLoading datasets with A100 80GB VRAM SAFE settings...")
tokenizer = model.get_tokenizer()

train_loader, val_loader, test_loader = get_dataloaders(
    tokenizer=tokenizer,
    batch_size=config['batch_size'],          # 48 - SAFE with 25GB buffer
    num_workers=config['num_workers'],        # 12 workers (matches 12 vCPUs)
    image_size=config['image_size'],          # 512 for high resolution
    max_length=config['max_length'],          # 300 for longer reports
    train_subset=None,                        # Use full dataset
    pin_memory=config.get('pin_memory', True),
    prefetch_factor=config.get('prefetch_factor', 4),
    persistent_workers=config.get('persistent_workers', True),
)

print(f"\n{'='*70}")
print("DATALOADER CONFIGURATION - A100 PCIe 80GB VRAM (SAFE)")
print(f"{'='*70}")
print(f"Train samples: {len(train_loader.dataset)}")
print(f"Val samples: {len(val_loader.dataset)}")
print(f"Test samples: {len(test_loader.dataset)}")
print(f"\nMEMORY SAFE SETTINGS:")
print(f"  Batch size: {config['batch_size']} (~50-55GB VRAM)")
print(f"  Buffer: ~25GB FREE for spikes!")
print(f"  Train batches per epoch: {len(train_loader)}")
print(f"\nSPEED SETTINGS:")
print(f"  Num workers: {config['num_workers']} (matches 12 vCPU)")
print(f"  Persistent workers: {config.get('persistent_workers', True)}")
print(f"  Prefetch factor: {config.get('prefetch_factor', 4)}")
print(f"  Pin memory: {config.get('pin_memory', True)}")
print(f"\nImage size: {config['image_size']}x{config['image_size']}")
print(f"Max length: {config['max_length']}")
print(f"\nSteps per epoch: {len(train_loader)}")
print(f"Total steps (50 epochs): {len(train_loader) * 50}")
print(f"\n>>> NO OOM ERRORS EXPECTED - 25GB buffer maintained!")

## 3. Training Setup

In [None]:
from torch.optim import AdamW
from torch.cuda.amp import GradScaler, autocast
from src.training.scheduler import get_cosine_with_hard_restarts_schedule_with_warmup
from src.utils.metrics import compute_metrics

# NOVEL: Import novel training components
from src.training.losses import CombinedNovelLoss
from src.training.curriculum import AnatomicalCurriculumScheduler, create_curriculum_dataloader
from src.utils.clinical_validator import ClinicalValidator

# Optimizer with proper weight decay
no_decay = ['bias', 'LayerNorm.weight', 'layer_norm.weight']
optimizer_grouped_parameters = [
    {
        'params': [p for n, p in model.named_parameters() 
                   if p.requires_grad and not any(nd in n for nd in no_decay)],
        'weight_decay': config['weight_decay'],
    },
    {
        'params': [p for n, p in model.named_parameters() 
                   if p.requires_grad and any(nd in n for nd in no_decay)],
        'weight_decay': 0.0,
    },
]

optimizer = AdamW(optimizer_grouped_parameters, lr=config['learning_rate'])

# Scheduler - Cosine with Restarts (3 cycles for better convergence)
total_steps = len(train_loader) * config['epochs'] // config['gradient_accumulation_steps']
scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
    optimizer,
    num_warmup_steps=config['warmup_steps'],
    num_training_steps=total_steps,
    num_cycles=3,  # 3 restarts for better convergence
)

# Mixed precision scaler
scaler = GradScaler() if config['use_amp'] else None

# NOVEL: Initialize novel loss functions (with REDUCED weights)
if config.get('use_novel_losses', False):
    novel_loss = CombinedNovelLoss(
        use_anatomical_consistency=config.get('use_anatomical_consistency_loss', True),
        use_clinical_entity=config.get('use_clinical_entity_loss', False),  # DISABLED
        use_region_focal=config.get('use_region_focal_loss', True),
        use_cross_modal=config.get('use_cross_modal_loss', False),          # DISABLED
        anatomical_weight=config.get('anatomical_loss_weight', 0.01),       # Very small
        clinical_weight=config.get('clinical_loss_weight', 0.0),            # Disabled
        focal_weight=config.get('focal_loss_weight', 0.01),                 # Very small
        alignment_weight=config.get('alignment_loss_weight', 0.0),          # Disabled
    )
    print("Novel loss functions initialized (REDUCED weights for better main task focus)")
else:
    novel_loss = None

# NOVEL: Initialize curriculum learning scheduler
if config.get('use_curriculum_learning', False):
    curriculum_scheduler = AnatomicalCurriculumScheduler()
    print("Curriculum learning scheduler initialized")
else:
    curriculum_scheduler = None

# NOVEL: Initialize clinical validator
if config.get('use_clinical_validation', False):
    clinical_validator = ClinicalValidator()
    print("Clinical validator initialized")
else:
    clinical_validator = None

print(f"\n{'='*60}")
print("TRAINING SETUP - RTX 6000 Pro 96GB OPTIMIZED")
print(f"{'='*60}")
print(f"Total optimization steps: {total_steps}")
print(f"Warmup steps: {config['warmup_steps']} (increased for stability)")
print(f"Scheduler: Cosine with 3 restarts")
print(f"Mixed precision (AMP): {config['use_amp']}")
print(f"\nLearning Rates (stable for better convergence):")
print(f"  - Base: {config['learning_rate']} (reduced)")
print(f"  - Encoder: {config['encoder_lr']} (pretrained, lower LR)")
print(f"  - Decoder: {config['decoder_lr']}")
print(f"  - Projection: {config['projection_lr']} (new HAQT-ARR layers)")
print(f"\nNovel Components:")
print(f"  - Novel losses: {config.get('use_novel_losses', False)} (reduced weights)")
print(f"  - Curriculum learning: {config.get('use_curriculum_learning', False)}")
print(f"  - Clinical validation: {config.get('use_clinical_validation', False)}")
print(f"  - R-Drop regularization: {config.get('use_rdrop', True)} (alpha={config.get('rdrop_alpha', 0.3)})")

## 4. Training Loop

In [6]:
# Training history
history = {
    'train_loss': [],
    'val_loss': [],
    'bleu_1': [],
    'bleu_2': [],
    'bleu_3': [],
    'bleu_4': [],
    'rouge_1': [],
    'rouge_2': [],
    'rouge_l': [],
    'learning_rate': [],
}

best_metric = 0.0
patience_counter = 0
patience = 5

In [None]:
# =============================================================================
# MAIN TRAINING LOOP - A100 PCIe 80GB OPTIMIZED WITH AUTO-RESUME
# =============================================================================
from src.training.trainer import XR2TextTrainer
import torch
import gc
from pathlib import Path

# ============================================
# FINAL PERMISSION FIX (before training)
# ============================================
print("Ensuring all directories have write permissions...")
dirs_to_fix = [
    config['checkpoint_dir'],
    config['log_dir'],
    '../data',
    '../data/figures',
    '../data/statistics'
]
for d in dirs_to_fix:
    p = Path(d)
    p.mkdir(parents=True, exist_ok=True)
    try:
        os.chmod(p, 0o777)
    except:
        pass
print("Directories ready!")

# ============================================
# AUTO-RESUME FROM CHECKPOINT
# ============================================
checkpoint_dir = Path(config['checkpoint_dir'])

def find_best_checkpoint(checkpoint_dir):
    """Find the best checkpoint to resume from.
    
    PRIORITY ORDER:
    1. Latest checkpoint_epoch_*.pt (resume from most recent training progress)
    2. best_model.pt (fallback if no epoch checkpoints exist)
    """
    checkpoint_dir = Path(checkpoint_dir)
    if not checkpoint_dir.exists():
        return None, 0
    
    # PRIORITY 1: Find latest epoch checkpoint
    epoch_checkpoints = list(checkpoint_dir.glob("checkpoint_epoch_*.pt"))
    if epoch_checkpoints:
        def get_epoch(p):
            try:
                return int(p.stem.split('_')[-1])
            except:
                return 0
        latest = max(epoch_checkpoints, key=get_epoch)
        # Verify checkpoint is not corrupted (should be > 1GB)
        if latest.stat().st_size > 1e9:
            ckpt = torch.load(latest, map_location='cpu')
            latest_epoch = ckpt.get('epoch', get_epoch(latest))
            print(f"   Found epoch checkpoint: {latest.name} (epoch {latest_epoch})")
            return str(latest), latest_epoch + 1
        else:
            print(f"   Skipping corrupted checkpoint: {latest.name}")
            latest.unlink()  # Remove corrupted file
    
    # PRIORITY 2: Fallback to best_model.pt
    best_model = checkpoint_dir / "best_model.pt"
    if best_model.exists():
        # Verify checkpoint is not corrupted
        if best_model.stat().st_size > 1e9:
            ckpt = torch.load(best_model, map_location='cpu')
            best_epoch = ckpt.get('epoch', 0)
            print(f"   Found best_model.pt (from epoch {best_epoch})")
            return str(best_model), best_epoch + 1
        else:
            print(f"   Skipping corrupted best_model.pt")
            best_model.unlink()
    
    return None, 0

# Auto-detect checkpoint
checkpoint_path, resume_epoch = find_best_checkpoint(checkpoint_dir)

print("=" * 70)
print("XR2Text Training - RUNPOD A100 PCIe 80GB")
print("=" * 70)

# Memory cleanup before training
print("\nClearing GPU memory...")
gc.collect()
torch.cuda.empty_cache()
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory - Allocated: {torch.cuda.memory_allocated()/1024**3:.2f} GB")
    print(f"GPU Memory - Total: {torch.cuda.get_device_properties(0).total_memory/1024**3:.1f} GB")

# Create trainer with A100 optimized config
trainer = XR2TextTrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    config=config,
)

# AUTO-RESUME: Load checkpoint if found
if checkpoint_path:
    print(f"\n>>> CHECKPOINT FOUND: {checkpoint_path}")
    print(f">>> Resuming from epoch {resume_epoch}")
    trainer.load_checkpoint(checkpoint_path)
else:
    print("\n>>> No checkpoint found. Starting fresh training from epoch 1")

print("\n" + "=" * 70)
print("A100 PCIe 80GB TRAINING CONFIGURATION:")
print(f"  Batch size: {config['batch_size']} (28x RTX 4060!)")
print(f"  Gradient accumulation: {config['gradient_accumulation_steps']}")
print(f"  Effective batch: {config['batch_size'] * config['gradient_accumulation_steps']}")
print(f"  Image size: {config['image_size']}x{config['image_size']}")
print(f"  Learning rate: {config['learning_rate']} (stable)")
print(f"  R-Drop: {config.get('use_rdrop', True)} (alpha={config.get('rdrop_alpha', 0.3)})")
print(f"  Validate every: {config.get('validate_every', 1)} epochs")
print(f"  Val fraction: {config.get('val_fraction', 0.5)} (50% for accuracy)")
print("=" * 70 + "\n")

# Run training
final_metrics = trainer.train()

# Extract history from trainer for visualization
history = {
    'train_loss': trainer.metrics_tracker.get_history('train_loss'),
    'val_loss': trainer.metrics_tracker.get_history('val_loss'),
    'bleu_1': trainer.metrics_tracker.get_history('bleu_1'),
    'bleu_2': trainer.metrics_tracker.get_history('bleu_2'),
    'bleu_3': trainer.metrics_tracker.get_history('bleu_3'),
    'bleu_4': trainer.metrics_tracker.get_history('bleu_4'),
    'rouge_1': trainer.metrics_tracker.get_history('rouge_1'),
    'rouge_2': trainer.metrics_tracker.get_history('rouge_2'),
    'rouge_l': trainer.metrics_tracker.get_history('rouge_l'),
    'learning_rate': [trainer.scheduler.get_last_lr()[0]] * (trainer.current_epoch + 1),
}

# Add clinical validation metrics if enabled
if config.get('use_clinical_validation', False):
    history['clinical_accuracy'] = trainer.metrics_tracker.get_history('clinical_accuracy')
    history['clinical_f1'] = trainer.metrics_tracker.get_history('clinical_f1')
    history['critical_errors'] = trainer.metrics_tracker.get_history('critical_errors')

# Save training history
history_df = pd.DataFrame(history)
history_df['epoch'] = range(1, len(history_df) + 1)
os.makedirs('../data/statistics', exist_ok=True)
history_df.to_csv('../data/statistics/training_history.csv', index=False)

# Store predictions and references for sample display
predictions = []
references = []

print("\n" + "=" * 70)
print("TRAINING COMPLETE!")
print("=" * 70)
print(f"\nFinal Metrics:")
for key, value in final_metrics.items():
    if isinstance(value, float):
        print(f"  {key}: {value:.4f}")
    else:
        print(f"  {key}: {value}")

# Final memory cleanup
gc.collect()
torch.cuda.empty_cache()
print(f"\nFinal GPU Memory - Allocated: {torch.cuda.memory_allocated()/1024**3:.2f} GB")

In [None]:
# ROBUST FIX - handles any array length mismatch
import pandas as pd
import os

# Get all the metrics
val_loss = trainer.metrics_tracker.get_history('val_loss')
bleu_1 = trainer.metrics_tracker.get_history('bleu_1')
bleu_2 = trainer.metrics_tracker.get_history('bleu_2')
bleu_3 = trainer.metrics_tracker.get_history('bleu_3')
bleu_4 = trainer.metrics_tracker.get_history('bleu_4')
rouge_1 = trainer.metrics_tracker.get_history('rouge_1')
rouge_2 = trainer.metrics_tracker.get_history('rouge_2')
rouge_l = trainer.metrics_tracker.get_history('rouge_l')
train_loss = trainer.metrics_tracker.get_history('train_loss')

# Debug: Print lengths
print("Array lengths:")
print(f"  val_loss: {len(val_loss)}")
print(f"  bleu_4: {len(bleu_4)}")
print(f"  rouge_l: {len(rouge_l)}")
print(f"  train_loss: {len(train_loss)}")

# Find minimum length among validation metrics
min_len = min(len(val_loss), len(bleu_4), len(rouge_l))
print(f"\nUsing {min_len} epochs")

# Create DataFrame with matching lengths
history_df = pd.DataFrame({
    'epoch': list(range(2, 2 + min_len * 2, 2))[:min_len],
    'val_loss': val_loss[:min_len],
    'bleu_1': bleu_1[:min_len],
    'bleu_2': bleu_2[:min_len],
    'bleu_3': bleu_3[:min_len],
    'bleu_4': bleu_4[:min_len],
    'rouge_1': rouge_1[:min_len],
    'rouge_2': rouge_2[:min_len],
    'rouge_l': rouge_l[:min_len],
})

# Add train_loss if available (sample every 2nd)
if train_loss:
    sampled_train = train_loss[1::2][:min_len]
    if len(sampled_train) == min_len:
        history_df['train_loss'] = sampled_train

# Save
os.makedirs('../data/statistics', exist_ok=True)
history_df.to_csv('../data/statistics/training_history.csv', index=False)

print(f"\n‚úÖ Saved {len(history_df)} epochs!")
print("\nLast 5 rows:")
print(history_df.tail())
print(f"\nBest BLEU-4: {history_df['bleu_4'].max():.4f}")
print(f"Best ROUGE-L: {history_df['rouge_l'].max():.4f}")


## 5. Training Curves Visualization

## 4.5 NOVEL: Enhanced Curriculum Learning Analysis

This section provides detailed analysis of our curriculum learning strategy,
showing how it affects training dynamics and final performance.

In [None]:
# ============================================
# NOVEL: ENHANCED CURRICULUM LEARNING ANALYSIS (5 STAGES, 50 EPOCHS)
# RTX 6000 Pro 96GB - Gentler Progression for Better Metrics
# ============================================
from src.training.curriculum import AnatomicalCurriculumScheduler
import os
import pandas as pd
import matplotlib.pyplot as plt

print("=" * 80)
print("NOVEL: CURRICULUM LEARNING ANALYSIS (5 STAGES, 50 EPOCHS)")
print("RTX 6000 Pro 96GB - Optimized for Better BLEU/ROUGE")
print("=" * 80)

# Initialize curriculum scheduler
curriculum = AnatomicalCurriculumScheduler()

# Display curriculum stages (RTX 6000 Pro - Gentler Progression)
print("\n1. CURRICULUM STAGES (5-Stage Progressive Training)")
print("-" * 60)
print(f"\n{'Stage':<20} {'Epochs':<15} {'Description':<40}")
print("-" * 80)

# RTX 6000 Pro optimized stages
stage_descriptions = {
    'warmup': 'Warmup with easy cases only (epochs 0-5)',
    'easy': 'Normal X-rays, simple findings (epochs 5-12)',
    'medium': 'Single anatomical region findings (epochs 12-25)',
    'hard': 'Multiple regions, moderate complexity (epochs 25-40)',
    'finetune': 'Full dataset fine-tuning (epochs 40-50)',
}

rtx6000_stages = [
    {'name': 'warmup', 'epoch_start': 0, 'epoch_end': 5},
    {'name': 'easy', 'epoch_start': 5, 'epoch_end': 12},
    {'name': 'medium', 'epoch_start': 12, 'epoch_end': 25},
    {'name': 'hard', 'epoch_start': 25, 'epoch_end': 40},
    {'name': 'finetune', 'epoch_start': 40, 'epoch_end': 50},
]

for stage in rtx6000_stages:
    name = stage['name']
    epoch_range = f"{stage['epoch_start']}-{stage['epoch_end']}"
    desc = stage_descriptions.get(name, 'Full dataset')
    print(f"{name:<20} {epoch_range:<15} {desc:<40}")

# Sample difficulty scoring demo
print("\n2. SAMPLE DIFFICULTY SCORING")
print("-" * 60)

sample_reports = [
    "Lungs are clear. Heart size is normal. No acute cardiopulmonary process.",
    "Mild cardiomegaly. Lungs are clear bilaterally.",
    "Bilateral pleural effusions. Cardiomegaly. Pulmonary edema.",
    "Large right pneumothorax. Left lung consolidation. Cardiomegaly.",
]

print("\nSample Reports with Difficulty Scores:")
for i, report in enumerate(sample_reports):
    scores = curriculum.difficulty_scorer(report)
    total_difficulty = scores.get('num_findings', 0) + scores.get('severity_score', 0)
    print(f"\n[Sample {i+1}] Difficulty: {total_difficulty:.1f}")
    print(f"   Report: {report[:60]}...")
    print(f"   Findings: {scores.get('num_findings', 0)}, Regions: {scores.get('num_regions', 0)}")

# Load and analyze training history
print("\n3. CURRICULUM LEARNING IMPACT")
print("-" * 60)

history_path = '../data/statistics/training_history.csv'
if os.path.exists(history_path):
    print("\n" + "=" * 60)
    print("CURRICULUM LEARNING RESULTS (Real Data)")
    print("=" * 60)

    df = pd.read_csv(history_path)

    print("\nPerformance at Curriculum Stage Transitions:")
    print("-" * 60)

    # RTX 6000 Pro 5-stage curriculum: warmup(0-5), easy(5-12), medium(12-25), hard(25-40), finetune(40-50)
    stage_info = [
        (5, 'End of Stage 1 (Warmup)'),
        (12, 'End of Stage 2 (Easy Cases)'),
        (25, 'End of Stage 3 (Medium Cases)'),
        (40, 'End of Stage 4 (Hard Cases)'),
        (50, 'End of Stage 5 (Fine-tuning)'),
    ]

    for target_epoch, stage_name in stage_info:
        mask = df['epoch'] <= target_epoch
        if mask.any():
            row = df[mask].iloc[-1]
            print(f"\nEpoch {int(row['epoch'])} - {stage_name}:")
            print(f"  BLEU-4:  {row['bleu_4']:.4f}")
            print(f"  ROUGE-L: {row['rouge_l']:.4f}")
            print(f"  Val Loss: {row['val_loss']:.4f}")

    # Plot curriculum impact
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # RTX 6000 Pro stage transitions at epochs 5, 12, 25, 40
    stage_transitions = [5, 12, 25, 40]

    # BLEU-4 progression with stage markers
    axes[0].plot(df['epoch'], df['bleu_4'], linewidth=2, color='blue', marker='o', markersize=3)
    for trans in stage_transitions:
        axes[0].axvline(x=trans, color='red', linestyle='--', alpha=0.7)
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('BLEU-4')
    axes[0].set_title('BLEU-4 Progression with 5-Stage Curriculum (RTX 6000 Pro)')
    axes[0].grid(True, alpha=0.3)

    # Loss progression with stage markers
    axes[1].plot(df['epoch'], df['val_loss'], linewidth=2, color='orange', marker='o', markersize=3)
    for trans in stage_transitions:
        axes[1].axvline(x=trans, color='red', linestyle='--', alpha=0.7)
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Validation Loss')
    axes[1].set_title('Loss Progression with 5-Stage Curriculum (RTX 6000 Pro)')
    axes[1].grid(True, alpha=0.3)

    plt.tight_layout()
    os.makedirs('../data/figures', exist_ok=True)
    plt.savefig('../data/figures/curriculum_impact.png', dpi=300, bbox_inches='tight')
    plt.show()

    print("\n" + "=" * 60)
    print("Curriculum learning analysis complete!")
    print("Figure saved: ../data/figures/curriculum_impact.png")
    print("=" * 60)
else:
    print("\nTraining history not found yet.")
    print("Run this cell again after training completes.")

In [None]:
# ============================================
# FIXED: TRAINING CURVES VISUALIZATION
# ============================================
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Load history from CSV
history_path = "../data/statistics/training_history.csv"

if os.path.exists(history_path):
    print("Loading training history from CSV...")
    df = pd.read_csv(history_path)
    print(f"Loaded {len(df)} epochs of data")

    # Check if we have data
    if len(df) > 0 and 'bleu_4' in df.columns:
        # Create figure with subplots
        fig, axes = plt.subplots(2, 2, figsize=(14, 10))
        axes = axes.flatten()

        # Plot 1: Validation Loss
        axes[0].plot(df['epoch'], df['val_loss'], label='Val Loss', color='orange', linewidth=2, marker='o', markersize=4)
        axes[0].set_xlabel('Epoch')
        axes[0].set_ylabel('Loss')
        axes[0].set_title('Validation Loss')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)

        # Plot 2: BLEU Scores
        axes[1].plot(df['epoch'], df['bleu_1'], label='BLEU-1', linewidth=2)
        axes[1].plot(df['epoch'], df['bleu_2'], label='BLEU-2', linewidth=2)
        axes[1].plot(df['epoch'], df['bleu_3'], label='BLEU-3', linewidth=2)
        axes[1].plot(df['epoch'], df['bleu_4'], label='BLEU-4', linewidth=2, marker='o', markersize=4)
        axes[1].set_xlabel('Epoch')
        axes[1].set_ylabel('Score')
        axes[1].set_title('BLEU Scores')
        axes[1].legend()
        axes[1].grid(True, alpha=0.3)

        # Plot 3: ROUGE Scores
        axes[2].plot(df['epoch'], df['rouge_1'], label='ROUGE-1', linewidth=2)
        axes[2].plot(df['epoch'], df['rouge_2'], label='ROUGE-2', linewidth=2)
        axes[2].plot(df['epoch'], df['rouge_l'], label='ROUGE-L', linewidth=2, marker='o', markersize=4)
        axes[2].set_xlabel('Epoch')
        axes[2].set_ylabel('Score')
        axes[2].set_title('ROUGE Scores')
        axes[2].legend()
        axes[2].grid(True, alpha=0.3)

        # Plot 4: Combined BLEU-4 and ROUGE-L
        axes[3].plot(df['epoch'], df['bleu_4'], label='BLEU-4', linewidth=2, color='blue', marker='o', markersize=4)
        axes[3].plot(df['epoch'], df['rouge_l'], label='ROUGE-L', linewidth=2, color='green', marker='s', markersize=4)
        axes[3].set_xlabel('Epoch')
        axes[3].set_ylabel('Score')
        axes[3].set_title('BLEU-4 vs ROUGE-L Comparison')
        axes[3].legend()
        axes[3].grid(True, alpha=0.3)

        plt.tight_layout()
        os.makedirs('../data/figures', exist_ok=True)
        plt.savefig('../data/figures/training_curves.png', dpi=300, bbox_inches='tight')
        plt.show()

        print("\nTraining curves saved to ../data/figures/training_curves.png")

        # Print summary
        print("\n" + "=" * 50)
        print("TRAINING SUMMARY")
        print("=" * 50)
        print(f"Best BLEU-4:  {df['bleu_4'].max():.4f} (Epoch {df.loc[df['bleu_4'].idxmax(), 'epoch']:.0f})")
        print(f"Best ROUGE-L: {df['rouge_l'].max():.4f} (Epoch {df.loc[df['rouge_l'].idxmax(), 'epoch']:.0f})")
        print(f"Final Val Loss: {df['val_loss'].iloc[-1]:.4f}")
    else:
        print("No valid data in CSV file")
else:
    print("Training history CSV not found!")
    print("Expected at:", history_path)


## 6. Sample Predictions

In [None]:
# Show sample predictions vs ground truth
print("Sample Predictions vs Ground Truth:")
print("=" * 80)

# Check if predictions and references exist
if 'predictions' not in dir() or not predictions:
    predictions = []
if 'references' not in dir() or not references:
    references = []

if len(predictions) > 0 and len(references) > 0:
    for i in range(min(5, len(predictions))):
        print(f"\n--- Sample {i+1} ---")
        print(f"\nGround Truth:")
        print(references[i][:500] + "..." if len(references[i]) > 500 else references[i])
        print(f"\nGenerated:")
        print(predictions[i][:500] + "..." if len(predictions[i]) > 500 else predictions[i])
        print("-" * 80)
else:
    print("\n‚ö†Ô∏è No predictions available yet!")
    print("   Predictions will be available after training completes (cell 11).")
    print("   Or run evaluation on test set in notebook 03_evaluation.ipynb.")

## 7. Final Results Summary

In [None]:
# ============================================
# FIXED: FINAL RESULTS SUMMARY
# ============================================
import os
import pandas as pd
import numpy as np

history_path = "../data/statistics/training_history.csv"

if os.path.exists(history_path):
    df = pd.read_csv(history_path)

    # Find best epoch by combined BLEU-4 + ROUGE-L score
    df['combined_score'] = df['bleu_4'] + df['rouge_l']
    best_idx = df['combined_score'].idxmax()
    best_row = df.loc[best_idx]
    final_row = df.iloc[-1]

    print("=" * 60)
    print("TRAINING RESULTS SUMMARY")
    print("=" * 60)

    # FIXED: Use actual epoch value from the dataframe
    print(f"\nBest Epoch: {int(best_row['epoch'])} (by BLEU-4 + ROUGE-L)")

    print(f"\nBest Metrics (Epoch {int(best_row['epoch'])}):")
    print(f"  BLEU-1:  {best_row['bleu_1']:.4f}")
    print(f"  BLEU-2:  {best_row['bleu_2']:.4f}")
    print(f"  BLEU-3:  {best_row['bleu_3']:.4f}")
    print(f"  BLEU-4:  {best_row['bleu_4']:.4f}")
    print(f"  ROUGE-1: {best_row['rouge_1']:.4f}")
    print(f"  ROUGE-2: {best_row['rouge_2']:.4f}")
    print(f"  ROUGE-L: {best_row['rouge_l']:.4f}")

    print(f"\nFinal Metrics (Epoch {int(final_row['epoch'])}):")
    print(f"  Val Loss: {final_row['val_loss']:.4f}")
    print(f"  BLEU-4:   {final_row['bleu_4']:.4f}")
    print(f"  ROUGE-L:  {final_row['rouge_l']:.4f}")

    # Save best results to CSV
    results_table = pd.DataFrame({
        'Metric': ['BLEU-1', 'BLEU-2', 'BLEU-3', 'BLEU-4', 'ROUGE-1', 'ROUGE-2', 'ROUGE-L'],
        'Best Score': [
            best_row['bleu_1'],
            best_row['bleu_2'],
            best_row['bleu_3'],
            best_row['bleu_4'],
            best_row['rouge_1'],
            best_row['rouge_2'],
            best_row['rouge_l'],
        ],
        'Final Score': [
            final_row['bleu_1'],
            final_row['bleu_2'],
            final_row['bleu_3'],
            final_row['bleu_4'],
            final_row['rouge_1'],
            final_row['rouge_2'],
            final_row['rouge_l'],
        ]
    })

    os.makedirs('../data/statistics', exist_ok=True)
    results_table.to_csv('../data/statistics/best_results.csv', index=False)

    print("\n" + "=" * 60)
    print("Results saved to ../data/statistics/best_results.csv")
    print("=" * 60)

    # Display table
    print("\nResults Table:")
    print(results_table.to_string(index=False))

else:
    print("=" * 60)
    print("TRAINING RESULTS SUMMARY")
    print("=" * 60)
    print("\nNo training results available yet!")
    print("Run training first (cell 11) to see results.")


In [None]:
## 8. NOVEL: Enhanced Analysis with New Features

##This section demonstrates the new enhancement modules for comprehensive report analysis.

In [None]:
# ============================================
# NOVEL: Enhanced Analysis Demo
# ============================================
# This demonstrates the new enhancement modules

print("=" * 70)
print("NOVEL ENHANCEMENT MODULES DEMO")
print("=" * 70)

# Check if model has enhancement modules
if hasattr(model, 'generate_with_analysis'):
    print("\n‚úÖ Model has enhanced analysis capabilities!")
    print("\nAvailable analysis features:")
    print("  1. Uncertainty Quantification")
    print("     - Overall confidence score (0-1)")
    print("     - Per-finding confidence scores")
    print("     - Calibrated uncertainty estimates")
    print("\n  2. Factual Grounding")
    print("     - Detected medical findings")
    print("     - Potential hallucinations flagged")
    print("     - Knowledge graph validation")
    print("\n  3. Explainability")
    print("     - Evidence regions highlighted")
    print("     - Clinical reasoning chains")
    print("     - Attention visualizations")
    print("\n  4. Multi-Task Outputs")
    print("     - Region classification")
    print("     - Severity prediction")
    print("     - Finding detection")
    
    # Demo analysis on a sample if test data is available
    print("\n" + "-" * 50)
    print("Running Enhanced Analysis on Sample...")
    print("-" * 50)
    
    try:
        # Get a sample from test loader
        sample_batch = next(iter(test_loader))
        sample_image = sample_batch['images'][0:1].to(config['device'])
        
        # Run enhanced analysis
        with torch.no_grad():
            analysis = model.generate_with_analysis(
                sample_image,
                max_length=config['generation']['max_length'],
                num_beams=config['generation']['num_beams'],
            )
        
        print(f"\nüìù Generated Report:")
        print(f"   {analysis.get('report', 'N/A')[:200]}...")
        
        print(f"\nüìä Uncertainty Analysis:")
        print(f"   Overall Confidence: {analysis.get('confidence', 0):.2%}")
        if 'finding_confidences' in analysis:
            print(f"   Finding Confidences: {len(analysis['finding_confidences'])} findings analyzed")
        
        print(f"\nüîç Factual Grounding:")
        if 'detected_findings' in analysis:
            print(f"   Detected Findings: {analysis['detected_findings'][:5]}")
        if 'potential_hallucinations' in analysis:
            print(f"   Potential Hallucinations: {len(analysis.get('potential_hallucinations', []))}")
        
        print(f"\nüí° Explainability:")
        if 'evidence_regions' in analysis:
            print(f"   Evidence Regions: {len(analysis['evidence_regions'])} regions identified")
        if 'reasoning' in analysis:
            print(f"   Clinical Reasoning: Available")
            
    except Exception as e:
        print(f"   Demo skipped (requires trained model): {e}")
        
else:
    print("\n‚ö†Ô∏è  Enhancement modules not loaded in current model.")
    print("   Ensure use_uncertainty, use_grounding, use_explainability, use_multitask are True.")
    print("   Re-initialize model with updated config to enable these features.")

print("\n" + "=" * 70)
print("Enhanced Analysis Demo Complete")
print("=" * 70)