# 03_model_training.ipynb
# Interactive model training with monitoring

## CELL 1: Setup and Imports


In [1]:
import os
import sys
import yaml
import torch
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
import json

# Add src to path
sys.path.append('../src')

from data_loader import LUNA16Dataset, create_data_loaders
from preprocessing import create_augmentation
from model import create_model, MultiTaskLoss
from train import Trainer
from utils import set_seed, get_device, print_model_summary, log_system_info

print("‚úì Imports successful")

# Set random seed for reproducibility
set_seed(42)
print("‚úì Random seed set to 42")



ImportError: cannot import name 'create_model' from 'model' (C:\Users\Administrator\Downloads\lung_cancer_detection\notebooks\../src\model.py)

## System and Configuration Setup

In [None]:
print("\n" + "="*60)
print("LUNG CANCER DETECTION - MODEL TRAINING")
print("="*60)

# Log system info
log_system_info()

# Load configuration
with open('../configs/config.yaml', 'r') as f:
    config = yaml.safe_load(f)

print("\nüìã Configuration:")
print(f"  Batch size: {config['training']['batch_size']}")
print(f"  Epochs: {config['training']['num_epochs']}")
print(f"  Learning rate: {config['training']['learning_rate']}")
print(f"  ROI size: {config['preprocessing']['roi_size']}")
print(f"  Mixed precision: {config['training']['mixed_precision']}")

# Get device
device = get_device()

## Prepare Datasets

In [None]:
print("\n" + "="*60)
print("PREPARING DATASETS")
print("="*60)

# Create augmentation
augmentation = create_augmentation(config)

# Create datasets
train_dataset = LUNA16Dataset(
    data_dir=config['data']['processed_dir'],
    annotations_file=config['data']['annotations_file'],
    roi_size=tuple(config['preprocessing']['roi_size']),
    transform=augmentation,
    mode='train'
)

val_dataset = LUNA16Dataset(
    data_dir=config['data']['processed_dir'],
    annotations_file=config['data']['annotations_file'],
    roi_size=tuple(config['preprocessing']['roi_size']),
    transform=None,
    mode='val'
)

# Split datasets
from sklearn.model_selection import train_test_split

all_indices = np.arange(len(train_dataset))
train_indices, val_indices = train_test_split(
    all_indices, test_size=0.2, random_state=42
)

train_dataset.samples = [train_dataset.samples[i] for i in train_indices]
val_dataset.samples = [val_dataset.samples[i] for i in val_indices]

print(f"\n‚úì Dataset split:")
print(f"  Training samples: {len(train_dataset)}")
print(f"  Validation samples: {len(val_dataset)}")
print(f"  Split ratio: {len(train_dataset)/len(val_dataset):.1f}:1")

# Create data loaders
train_loader, val_loader = create_data_loaders(config, train_dataset, val_dataset)

print(f"\n‚úì Data loaders created:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")
print(f"  Samples per epoch: {len(train_dataset)}")

## Inspect Sample Batch

In [None]:
print("\n" + "="*60)
print("SAMPLE BATCH INSPECTION")
print("="*60)

# Get a sample batch
sample_batch = next(iter(train_loader))

print(f"\nBatch contents:")
print(f"  Images: {sample_batch['image'].shape} - {sample_batch['image'].dtype}")
print(f"  Labels: {sample_batch['label'].shape} - {sample_batch['label'].dtype}")
print(f"  Malignancy: {sample_batch['malignancy'].shape} - {sample_batch['malignancy'].dtype}")
print(f"  BBox: {sample_batch['bbox'].shape} - {sample_batch['bbox'].dtype}")

print(f"\nValue ranges:")
print(f"  Images: [{sample_batch['image'].min():.4f}, {sample_batch['image'].max():.4f}]")
print(f"  Labels: {sample_batch['label'].unique().tolist()}")
print(f"  Malignancy: [{sample_batch['malignancy'].min():.2f}, {sample_batch['malignancy'].max():.2f}]")

# Visualize batch
fig, axes = plt.subplots(2, 4, figsize=(20, 10))

batch_size = min(4, sample_batch['image'].shape[0])

for i in range(batch_size):
    volume = sample_batch['image'][i, 0].numpy()
    mid_slice = volume.shape[0] // 2
    
    # CT slice
    axes[0, i].imshow(volume[mid_slice], cmap='gray')
    axes[0, i].set_title(f'Sample {i+1}\nLabel: {sample_batch["label"][i].item()}', 
                        fontsize=11, fontweight='bold')
    axes[0, i].axis('off')
    
    # 3 orthogonal views
    axes[1, i].imshow(volume[:, volume.shape[1]//2, :], cmap='gray')
    axes[1, i].set_title(f'Malignancy: {sample_batch["malignancy"][i].item():.0f}', 
                        fontsize=11)
    axes[1, i].axis('off')

plt.suptitle('Training Batch Samples', fontsize=16, fontweight='bold', y=0.98)
plt.tight_layout()
plt.savefig('../results/training_batch_samples.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n‚úì Batch visualization complete")


## Create Model

In [None]:
print("\n" + "="*60)
print("MODEL ARCHITECTURE")
print("="*60)

# Create model
model = create_model(config)
model = model.to(device)

# Print model summary
print_model_summary(model)

# Test forward pass
print("\nTesting forward pass...")
with torch.no_grad():
    sample_input = sample_batch['image'][:2].to(device)
    output = model(sample_input)
    
    print(f"‚úì Forward pass successful")
    print(f"  Detection logits: {output['detection']['class_logits'].shape}")
    print(f"  Detection bbox: {output['detection']['bbox'].shape}")
    print(f"  Detection confidence: {output['detection']['confidence'].shape}")
    print(f"  Malignancy score: {output['malignancy'].shape}")

## Setup Loss Function

In [None]:
print("\n" + "="*60)
print("LOSS FUNCTION")
print("="*60)

criterion = MultiTaskLoss(
    detection_weight=1.0,
    malignancy_weight=1.0,
    bbox_weight=0.5
)

print("Multi-task loss initialized:")
print(f"  Detection weight: 1.0")
print(f"  Malignancy weight: 1.0")
print(f"  BBox weight: 0.5")

# Test loss computation
print("\nTesting loss computation...")
with torch.no_grad():
    sample_targets = {
        'label': sample_batch['label'][:2].to(device),
        'malignancy': sample_batch['malignancy'][:2].to(device),
        'bbox': sample_batch['bbox'][:2].to(device)
    }
    
    loss, loss_dict = criterion(output, sample_targets)
    
    print(f"‚úì Loss computation successful")
    print(f"  Total loss: {loss.item():.4f}")
    print(f"  Detection loss: {loss_dict['detection_loss']:.4f}")
    print(f"  BBox loss: {loss_dict['bbox_loss']:.4f}")
    print(f"  Malignancy loss: {loss_dict['malignancy_loss']:.4f}")


## Initialize Trainer

In [None]:
print("\n" + "="*60)
print("TRAINER INITIALIZATION")
print("="*60)

trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    config=config,
    device=device
)

print("‚úì Trainer initialized")
print(f"  Optimizer: {type(trainer.optimizer).__name__}")
print(f"  Scheduler: {type(trainer.scheduler).__name__}")
print(f"  Mixed precision: {trainer.use_amp}")
print(f"  Gradient clipping: {trainer.grad_clip}")
print(f"  Early stopping patience: {trainer.early_stopping_patience}")

## Training Loop (Main Training)

In [None]:
print("\n" + "="*60)
print("STARTING TRAINING")
print("="*60)
print(f"Training for {config['training']['num_epochs']} epochs")
print(f"Press Ctrl+C to interrupt\n")

# Train the model
try:
    history = trainer.train()
    print("\n‚úÖ Training completed successfully!")
except KeyboardInterrupt:
    print("\n‚ö† Training interrupted by user")
    history = trainer.history
except Exception as e:
    print(f"\n‚ùå Training error: {str(e)}")
    import traceback
    traceback.print_exc()
    history = trainer.history

## Training History Analysis

In [None]:
print("\n" + "="*60)
print("TRAINING HISTORY ANALYSIS")
print("="*60)

# Print best results
print(f"\nüèÜ Best Results:")
print(f"  Best validation loss: {trainer.best_val_loss:.4f}")
print(f"  Achieved at epoch: {np.argmin(history['val_loss']) + 1}")

# Loss statistics
print(f"\nüìä Loss Statistics:")
print(f"  Final train loss: {history['train_loss'][-1]:.4f}")
print(f"  Final val loss: {history['val_loss'][-1]:.4f}")
print(f"  Min train loss: {min(history['train_loss']):.4f}")
print(f"  Min val loss: {min(history['val_loss']):.4f}")

# Check for overfitting
train_val_gap = history['train_loss'][-1] - history['val_loss'][-1]
if abs(train_val_gap) > 0.5:
    print(f"\n‚ö† Warning: Large train-val gap ({train_val_gap:.4f}) - possible overfitting")

## Plot Training Curves

In [None]:
print("\n" + "="*60)
print("TRAINING VISUALIZATION")
print("="*60)

fig, axes = plt.subplots(2, 2, figsize=(16, 12))

epochs = range(1, len(history['train_loss']) + 1)

# Overall loss
axes[0, 0].plot(epochs, history['train_loss'], 'b-', linewidth=2, label='Train', marker='o')
axes[0, 0].plot(epochs, history['val_loss'], 'r-', linewidth=2, label='Validation', marker='s')
axes[0, 0].axhline(y=trainer.best_val_loss, color='g', linestyle='--', 
                   label=f'Best Val: {trainer.best_val_loss:.4f}')
axes[0, 0].set_xlabel('Epoch', fontsize=12)
axes[0, 0].set_ylabel('Loss', fontsize=12)
axes[0, 0].set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Detection loss
det_train = [m['detection_loss'] for m in history['train_metrics']]
det_val = [m['detection_loss'] for m in history['val_metrics']]
axes[0, 1].plot(epochs, det_train, 'b-', linewidth=2, label='Train', marker='o')
axes[0, 1].plot(epochs, det_val, 'r-', linewidth=2, label='Validation', marker='s')
axes[0, 1].set_xlabel('Epoch', fontsize=12)
axes[0, 1].set_ylabel('Detection Loss', fontsize=12)
axes[0, 1].set_title('Detection Loss', fontsize=14, fontweight='bold')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# BBox loss
bbox_train = [m['bbox_loss'] for m in history['train_metrics']]
bbox_val = [m['bbox_loss'] for m in history['val_metrics']]
axes[1, 0].plot(epochs, bbox_train, 'b-', linewidth=2, label='Train', marker='o')
axes[1, 0].plot(epochs, bbox_val, 'r-', linewidth=2, label='Validation', marker='s')
axes[1, 0].set_xlabel('Epoch', fontsize=12)
axes[1, 0].set_ylabel('BBox Loss', fontsize=12)
axes[1, 0].set_title('Bounding Box Loss', fontsize=14, fontweight='bold')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Malignancy loss
mal_train = [m['malignancy_loss'] for m in history['train_metrics']]
mal_val = [m['malignancy_loss'] for m in history['val_metrics']]
axes[1, 1].plot(epochs, mal_train, 'b-', linewidth=2, label='Train', marker='o')
axes[1, 1].plot(epochs, mal_val, 'r-', linewidth=2, label='Validation', marker='s')
axes[1, 1].set_xlabel('Epoch', fontsize=12)
axes[1, 1].set_ylabel('Malignancy Loss', fontsize=12)
axes[1, 1].set_title('Malignancy Loss', fontsize=14, fontweight='bold')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('../results/training_curves_detailed.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n‚úì Training curves saved")

## Loss Component Analysis

In [None]:
print("\n" + "="*60)
print("LOSS COMPONENT ANALYSIS")
print("="*60)

# Final epoch loss breakdown
final_train = history['train_metrics'][-1]
final_val = history['val_metrics'][-1]

print("\nFinal Epoch Loss Components:")
print("\nTrain:")
for key, value in final_train.items():
    print(f"  {key}: {value:.4f}")

print("\nValidation:")
for key, value in final_val.items():
    print(f"  {key}: {value:.4f}")

# Visualize loss components over time
fig, ax = plt.subplots(figsize=(14, 6))

ax.plot(epochs, det_val, label='Detection', linewidth=2, marker='o')
ax.plot(epochs, bbox_val, label='BBox', linewidth=2, marker='s')
ax.plot(epochs, mal_val, label='Malignancy', linewidth=2, marker='^')

ax.set_xlabel('Epoch', fontsize=13)
ax.set_ylabel('Validation Loss', fontsize=13)
ax.set_title('Loss Components Evolution (Validation)', fontsize=15, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('../results/loss_components.png', dpi=150, bbox_inches='tight')
plt.show()

## Learning Rate Schedule

In [None]:
print("\n" + "="*60)
print("LEARNING RATE SCHEDULE")
print("="*60)

# Extract learning rates (if logged)
lrs = []
for epoch in range(len(history['train_loss'])):
    # Reconstruct LR from cosine schedule
    import math
    lr = config['training']['learning_rate'] * 0.5 * (
        1 + math.cos(math.pi * epoch / config['training']['num_epochs'])
    )
    lrs.append(lr)

plt.figure(figsize=(12, 5))
plt.plot(epochs, lrs, linewidth=2, color='purple', marker='o')
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Learning Rate', fontsize=12)
plt.title('Learning Rate Schedule (Cosine Annealing)', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)
plt.yscale('log')
plt.tight_layout()
plt.savefig('../results/learning_rate_schedule.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nLearning rate:")
print(f"  Initial: {lrs[0]:.6f}")
print(f"  Final: {lrs[-1]:.6f}")
print(f"  Min: {min(lrs):.6f}")


## Save Training Report

In [None]:
print("\n" + "="*60)
print("GENERATING TRAINING REPORT")
print("="*60)

# Create comprehensive report
report = {
    'experiment_info': {
        'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
        'config': config,
        'device': str(device),
        'total_epochs': len(history['train_loss'])
    },
    'dataset_info': {
        'train_samples': len(train_dataset),
        'val_samples': len(val_dataset),
        'train_batches': len(train_loader),
        'val_batches': len(val_loader)
    },
    'model_info': {
        'architecture': 'SwinTransformer3D',
        'total_parameters': sum(p.numel() for p in model.parameters()),
        'trainable_parameters': sum(p.numel() for p in model.parameters() if p.requires_grad)
    },
    'training_results': {
        'best_val_loss': float(trainer.best_val_loss),
        'best_epoch': int(np.argmin(history['val_loss']) + 1),
        'final_train_loss': float(history['train_loss'][-1]),
        'final_val_loss': float(history['val_loss'][-1]),
        'min_train_loss': float(min(history['train_loss'])),
        'min_val_loss': float(min(history['val_loss']))
    },
    'final_metrics': {
        'train': final_train,
        'validation': final_val
    }
}

# Save report
report_path = os.path.join(config['logging']['checkpoint_dir'], 'training_report.json')
with open(report_path, 'w') as f:
    json.dump(report, f, indent=2)

print(f"‚úì Training report saved to {report_path}")

# Print summary
summary_text = f"""
{'='*60}
TRAINING SUMMARY
{'='*60}

üìÖ Date: {report['experiment_info']['timestamp']}
‚è±Ô∏è Total Epochs: {report['experiment_info']['total_epochs']}

üìä Dataset:
  - Training samples: {report['dataset_info']['train_samples']}
  - Validation samples: {report['dataset_info']['val_samples']}

üèóÔ∏è Model:
  - Architecture: {report['model_info']['architecture']}
  - Total parameters: {report['model_info']['total_parameters']:,}
  - Trainable parameters: {report['model_info']['trainable_parameters']:,}

üéØ Best Results:
  - Best validation loss: {report['training_results']['best_val_loss']:.4f}
  - Achieved at epoch: {report['training_results']['best_epoch']}

üìà Final Loss:
  - Train: {report['training_results']['final_train_loss']:.4f}
  - Validation: {report['training_results']['final_val_loss']:.4f}

üíæ Saved Artifacts:
  ‚úì Model checkpoint: {config['logging']['checkpoint_dir']}/best_model.pth
  ‚úì Training history: {config['logging']['checkpoint_dir']}/training_history.json
  ‚úì Training report: {report_path}
  ‚úì Visualizations: ../results/

{'='*60}
"""

print(summary_text)

# Save summary as text
summary_path = '../results/training_summary.txt'
with open(summary_path, 'w') as f:
    f.write(summary_text)

print(f"‚úì Summary saved to {summary_path}")

## Model Checkpoint Information

In [None]:
print("\n" + "="*60)
print("MODEL CHECKPOINTS")
print("="*60)

checkpoint_dir = config['logging']['checkpoint_dir']

# List all checkpoints
if os.path.exists(checkpoint_dir):
    checkpoints = sorted([f for f in os.listdir(checkpoint_dir) if f.endswith('.pth')])
    
    print(f"\nSaved checkpoints ({len(checkpoints)}):")
    for ckpt in checkpoints:
        ckpt_path = os.path.join(checkpoint_dir, ckpt)
        size_mb = os.path.getsize(ckpt_path) / (1024 * 1024)
        print(f"  ‚Ä¢ {ckpt} ({size_mb:.2f} MB)")
    
    # Load best model info
    best_path = os.path.join(checkpoint_dir, 'best_model.pth')
    if os.path.exists(best_path):
        checkpoint = torch.load(best_path, map_location='cpu')
        print(f"\nüèÜ Best Model:")
        print(f"  Epoch: {checkpoint['epoch']}")
        print(f"  Validation loss: {checkpoint['best_val_loss']:.4f}")
        print(f"  File size: {os.path.getsize(best_path) / (1024 * 1024):.2f} MB")
else:
    print("\n‚ö† No checkpoint directory found")

## Quick Model Test

In [None]:
print("\n" + "="*60)
print("QUICK MODEL TEST")
print("="*60)

# Load best model
best_model_path = os.path.join(config['logging']['checkpoint_dir'], 'best_model.pth')

if os.path.exists(best_model_path):
    checkpoint = torch.load(best_model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    print(f"‚úì Loaded best model from epoch {checkpoint['epoch']}")
    
    # Test on a validation batch
    val_batch = next(iter(val_loader))
    
    with torch.no_grad():
        images = val_batch['image'].to(device)
        outputs = model(images)
        
        # Get predictions
        det_probs = torch.softmax(outputs['detection']['class_logits'], dim=1)[:, 1]
        mal_probs = outputs['malignancy']
        
        print(f"\nPredictions on validation batch:")
        print(f"  Batch size: {images.shape[0]}")
        print(f"  Detection probabilities: {det_probs.cpu().numpy()}")
        print(f"  Malignancy probabilities: {mal_probs.cpu().numpy().flatten()}")
        
        # Compare with ground truth
        print(f"\nGround truth:")
        print(f"  Labels: {val_batch['label'].numpy()}")
        print(f"  Malignancy: {val_batch['malignancy'].numpy()}")
    
    # Visualize predictions
    fig, axes = plt.subplots(2, 4, figsize=(20, 10))
    
    batch_size = min(4, images.shape[0])
    
    for i in range(batch_size):
        volume = images[i, 0].cpu().numpy()
        mid_slice = volume.shape[0] // 2
        
        det_prob = det_probs[i].item()
        mal_prob = mal_probs[i].item()
        true_label = val_batch['label'][i].item()
        true_mal = val_batch['malignancy'][i].item()
        
        # Top row: images
        axes[0, i].imshow(volume[mid_slice], cmap='gray')
        axes[0, i].set_title(f'Sample {i+1}', fontsize=12, fontweight='bold')
        axes[0, i].axis('off')
        
        # Bottom row: predictions
        pred_text = (
            f"Detection: {det_prob:.3f}\n"
            f"True: {true_label}\n"
            f"Malignancy: {mal_prob:.3f}\n"
            f"True: {true_mal:.0f}"
        )
        axes[1, i].text(0.5, 0.5, pred_text, 
                       transform=axes[1, i].transAxes,
                       fontsize=11, ha='center', va='center',
                       bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))
        axes[1, i].axis('off')
    
    plt.suptitle('Model Predictions on Validation Set', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig('../results/quick_model_test.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print("\n‚úì Model test complete")
else:
    print("\n‚ö† Best model checkpoint not found")

## Training Completion Summary

In [None]:
print("\n" + "="*60)
print("‚úÖ MODEL TRAINING COMPLETE!")
print("="*60)

completion_summary = f"""
üéâ Training successfully completed!

üìÅ Generated Artifacts:
  ‚úì Model checkpoints in {checkpoint_dir}
  ‚úì Training curves and visualizations in ../results/
  ‚úì Training history JSON
  ‚úì Comprehensive training report

üéØ Next Steps:
  1. Run evaluation notebook (04_evaluation_visualization.ipynb)
  2. Generate Grad-CAM explainability visualizations
  3. Test model on new CT scans using inference.py
  4. Fine-tune hyperparameters if needed

üìä Key Metrics:
  ‚Ä¢ Best validation loss: {trainer.best_val_loss:.4f}
  ‚Ä¢ Training epochs: {len(history['train_loss'])}
  ‚Ä¢ Model parameters: {sum(p.numel() for p in model.parameters()):,}

üí° Tips:
  - Check training curves for signs of overfitting
  - Compare train/val losses for generalization
  - Review loss components to identify bottlenecks
  - Use the best model checkpoint for inference
"""

print(completion_summary)

# Save completion summary
with open('../results/training_completion_summary.txt', 'w') as f:
    f.write(completion_summary)

print("\nüìÑ All summaries and reports saved!")
print("üöÄ Ready for evaluation and deployment!\n")