## 1. Setup & Configuration

In [None]:
import os
import sys
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Add project root to path
project_root = Path("/storage2/CV_Irradiance/VMamba/BRTM")
sys.path.insert(0, str(project_root))

print(f"Project Root: {project_root}")
print(f"Working Directory: {os.getcwd()}")

In [None]:
# Import all necessary modules
import torch
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime

# Custom modules
from config import Config
from models import SegMamba
from data import create_dataloaders
from utils import ExperimentManager, visualize_batch
from train import SegMambaTrainer

print("‚úì All imports successful")
print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA Device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

### üéØ Configuration

**IMPORTANT**: Change `RUN_NAME` for each training run to avoid overwriting results!

In [None]:
# =============================================================================
# CONFIGURATION - MODIFY THIS SECTION FOR EACH RUN
# =============================================================================

# Experiment name (CHANGE THIS FOR EACH NEW RUN!)
Config.RUN_NAME = "SegMamba_Run01"
Config.DESCRIPTION = "Pure Mamba SegMamba baseline"

# Model architecture
Config.PATCH_SIZE = (128, 128, 64)  # Adjust based on GPU VRAM
Config.BASE_CHANNELS = 32

# Training hyperparameters
Config.BATCH_SIZE = 2
Config.ACCUMULATION_STEPS = 2
Config.NUM_EPOCHS = 300
Config.INITIAL_LR = 1e-4

# Data paths (UPDATE TO YOUR DATASET LOCATION)
Config.DATA_ROOT = Path("/storage2/CV_Irradiance/datasets/CVMD/BraTS")
Config.TRAIN_DATA_PATH = Config.DATA_ROOT / "train"
Config.VAL_DATA_PATH = Config.DATA_ROOT / "val"

# Results path
Config.RESULTS_BASE_PATH = project_root / "results"

# =============================================================================

# Print configuration
Config.print_config()

## 2. Data Verification

Verify dataset structure and accessibility before training.

In [None]:
# Check if data paths exist
print("=" * 70)
print("Data Path Verification")
print("=" * 70)

print(f"\nDATA_ROOT: {Config.DATA_ROOT}")
print(f"Exists: {Config.DATA_ROOT.exists()}")

print(f"\nTRAIN_DATA_PATH: {Config.TRAIN_DATA_PATH}")
print(f"Exists: {Config.TRAIN_DATA_PATH.exists()}")

print(f"\nVAL_DATA_PATH: {Config.VAL_DATA_PATH}")
print(f"Exists: {Config.VAL_DATA_PATH.exists()}")

# List sample patients
if Config.TRAIN_DATA_PATH.exists():
    train_patients = [d for d in Config.TRAIN_DATA_PATH.iterdir() if d.is_dir()]
    print(f"\nNumber of training patients: {len(train_patients)}")
    if len(train_patients) > 0:
        print(f"Sample patients: {[p.name for p in train_patients[:5]]}")
else:
    print("\n‚ö†Ô∏è WARNING: Training data path does not exist!")
    print("Please update Config.DATA_ROOT in the configuration cell above.")

## 3. Initialize Experiment Manager

This ensures no overwriting between training runs.

In [None]:
# Initialize experiment manager
try:
    exp_manager = ExperimentManager(
        run_name=Config.RUN_NAME,
        base_path=str(Config.RESULTS_BASE_PATH),
        overwrite=False  # Set True to overwrite existing run
    )
    
    # Save configuration
    exp_manager.save_config(Config.get_config_dict())
    
    print(f"‚úì Experiment initialized: {exp_manager.run_path}")
    
except ValueError as e:
    print(f"‚ùå Error: {e}")
    print("\nOptions:")
    print("1. Change Config.RUN_NAME to a new name")
    print("2. Set overwrite=True (WARNING: deletes existing results)")

## 4. Create Data Loaders

Load BraTS dataset with MONAI transforms.

In [None]:
print("=" * 70)
print("Creating Data Loaders")
print("=" * 70)

try:
    train_loader, val_loader = create_dataloaders(
        train_data_path=str(Config.TRAIN_DATA_PATH),
        val_data_path=str(Config.VAL_DATA_PATH),
        batch_size=Config.BATCH_SIZE,
        patch_size=Config.PATCH_SIZE,
        num_workers=Config.NUM_WORKERS,
        pin_memory=Config.PIN_MEMORY
    )
    
    print(f"\n‚úì Data loaders created successfully")
    print(f"Training batches: {len(train_loader)}")
    print(f"Validation batches: {len(val_loader)}")
    
except Exception as e:
    print(f"\n‚ùå Error creating data loaders: {e}")
    print("\nPossible solutions:")
    print("1. Verify dataset structure matches BraTS format")
    print("2. Check file permissions")
    print("3. Install MONAI: pip install monai")

## 5. Visualize Sample Batch

Sanity check: visualize a sample batch before training.

In [None]:
# Get a sample batch
sample_batch = next(iter(train_loader))

print("Sample Batch Information:")
print(f"Image shape: {sample_batch['image'].shape}")
print(f"Label shape: {sample_batch['label'].shape}")
print(f"Image dtype: {sample_batch['image'].dtype}")
print(f"Label dtype: {sample_batch['label'].dtype}")
print(f"Image range: [{sample_batch['image'].min():.3f}, {sample_batch['image'].max():.3f}]")
print(f"Label unique values: {torch.unique(sample_batch['label']).tolist()}")

# Visualize
save_path = exp_manager.get_plot_path("notebook_sample_batch.png")
visualize_batch(
    images=sample_batch['image'],
    labels=sample_batch['label'],
    save_path=save_path,
    title="Sample Training Batch"
)

# Display in notebook
from IPython.display import Image, display
display(Image(filename=str(save_path)))

## 6. Initialize Model

Create SegMamba architecture and verify forward pass.

In [None]:
print("=" * 70)
print("Initializing SegMamba Model")
print("=" * 70)

model = SegMamba(
    in_channels=Config.IN_CHANNELS,
    num_classes=Config.NUM_CLASSES,
    base_channels=Config.BASE_CHANNELS,
    encoder_depths=Config.ENCODER_DEPTHS,
    use_checkpoint=Config.USE_CHECKPOINT
).to(Config.DEVICE)

print(f"\n‚úì Model initialized")
print(f"Total parameters: {model.count_parameters():,}")
print(f"Model on device: {next(model.parameters()).device}")

# Test forward pass
print("\nTesting forward pass...")
with torch.no_grad():
    test_input = torch.randn(1, Config.IN_CHANNELS, *Config.PATCH_SIZE).to(Config.DEVICE)
    test_output = model(test_input)
    print(f"‚úì Forward pass successful")
    print(f"  Input shape: {test_input.shape}")
    print(f"  Output shape: {test_output.shape}")
    
    # Memory usage
    if torch.cuda.is_available():
        memory_allocated = torch.cuda.memory_allocated(0) / 1e9
        memory_reserved = torch.cuda.memory_reserved(0) / 1e9
        print(f"  GPU Memory Allocated: {memory_allocated:.2f} GB")
        print(f"  GPU Memory Reserved: {memory_reserved:.2f} GB")

## 7. Initialize Trainer

Create trainer with all components: optimizer, scheduler, loss, metrics.

In [None]:
print("=" * 70)
print("Initializing Trainer")
print("=" * 70)

trainer = SegMambaTrainer(
    config=Config,
    experiment_manager=exp_manager
)

print("\n‚úì Trainer initialized")
print(f"Optimizer: {Config.OPTIMIZER}")
print(f"Learning rate: {Config.INITIAL_LR}")
print(f"Scheduler: {Config.LR_SCHEDULER}")
print(f"Loss function: Dice + Cross Entropy")
print(f"AMP enabled: {Config.USE_AMP}")

## 8. Start Training üöÄ

**This will take several hours to days depending on your GPU.**

Monitor progress via:
- Progress bars in this notebook
- Training curves: `results/{RUN_NAME}/plots/training_curves.png`
- Validation predictions: `results/{RUN_NAME}/plots/val_predictions_epoch_*.png`
- Metrics: `results/{RUN_NAME}/metrics/final_metrics.json`

In [None]:
# Record start time
training_start_time = datetime.now()
print(f"Training started at: {training_start_time.strftime('%Y-%m-%d %H:%M:%S')}")
print("=" * 70)

# Start training
try:
    trainer.train(train_loader, val_loader)
    
    # Record end time
    training_end_time = datetime.now()
    training_duration = training_end_time - training_start_time
    
    print("\n" + "=" * 70)
    print("‚úì Training Completed Successfully")
    print("=" * 70)
    print(f"Started: {training_start_time.strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"Ended: {training_end_time.strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"Duration: {training_duration}")
    print(f"\nBest {Config.METRIC_NAME}: {trainer.best_metric:.4f} at epoch {trainer.best_epoch}")
    print(f"\nResults saved to: {exp_manager.run_path}")
    
except KeyboardInterrupt:
    print("\n‚ö†Ô∏è Training interrupted by user")
    print(f"Results saved to: {exp_manager.run_path}")
    
except Exception as e:
    print(f"\n‚ùå Training failed with error: {e}")
    import traceback
    traceback.print_exc()

## 9. Visualize Results

Display training curves and final metrics.

In [None]:
# Display training curves
from IPython.display import Image, display

curves_path = exp_manager.get_plot_path("training_curves.png")
if curves_path.exists():
    print("Training Curves:")
    display(Image(filename=str(curves_path)))
else:
    print("Training curves not available yet.")

In [None]:
# Display final metrics
import json

metrics_path = exp_manager.get_metrics_path("final_metrics.json")
if metrics_path.exists():
    with open(metrics_path, 'r') as f:
        final_metrics = json.load(f)
    
    print("=" * 70)
    print("Final Metrics")
    print("=" * 70)
    print(f"Best Metric: {final_metrics['best_metric']:.4f}")
    print(f"Best Epoch: {final_metrics['best_epoch']}")
    print(f"Total Epochs: {final_metrics['total_epochs']}")
    print(f"Training Time: {final_metrics['training_time_hours']:.2f} hours")
    
    # Plot metrics
    import matplotlib.pyplot as plt
    
    fig, ax = plt.subplots(1, 1, figsize=(10, 6))
    epochs = range(1, len(final_metrics['val_metrics']) + 1)
    ax.plot(epochs, final_metrics['val_metrics'], 'o-', linewidth=2, markersize=6)
    ax.set_xlabel('Epoch', fontsize=12)
    ax.set_ylabel(Config.METRIC_NAME, fontsize=12)
    ax.set_title('Validation Metric Progress', fontsize=14, fontweight='bold')
    ax.grid(True, alpha=0.3)
    ax.axhline(y=final_metrics['best_metric'], color='r', linestyle='--', alpha=0.5, label=f'Best: {final_metrics["best_metric"]:.4f}')
    ax.legend()
    plt.tight_layout()
    plt.show()
    
else:
    print("Final metrics not available yet.")

## 10. Load Best Model

Load the best checkpoint for inference or further evaluation.

In [None]:
# Load best model checkpoint
best_checkpoint_path = exp_manager.get_checkpoint_path("best_metric_model.pth")

if best_checkpoint_path.exists():
    checkpoint = torch.load(best_checkpoint_path)
    
    print("=" * 70)
    print("Best Model Checkpoint")
    print("=" * 70)
    print(f"Epoch: {checkpoint['epoch']}")
    print(f"Metrics: {checkpoint['metrics']}")
    print(f"Timestamp: {checkpoint['timestamp']}")
    
    # Load into model
    model.load_state_dict(checkpoint['model_state_dict'])
    print("\n‚úì Model weights loaded successfully")
    
    # Model is now ready for inference
    model.eval()
    print("‚úì Model set to evaluation mode")
    
else:
    print("Best model checkpoint not found.")

## 11. Next Steps

### For Inference:
```python
# Load test data
test_image = ...  # Load your test NIfTI file

# Preprocess (same as training)
test_image = preprocess(test_image)

# Inference
with torch.no_grad():
    prediction = model(test_image.unsqueeze(0).to(Config.DEVICE))
    prediction = torch.argmax(prediction, dim=1)

# Save prediction
save_nifti(prediction, "output_segmentation.nii.gz")
```

### For Ensemble:
1. Train multiple models with different seeds/configurations
2. Average predictions for better performance
3. See documentation for details

### For Competition Submission:
1. Test on validation set
2. Compute final metrics
3. Create submission file
4. Document approach in `docs/`

---

## üìö Documentation

For detailed explanations of architecture, mathematics, and methodology, see:
- `docs/SegMamba_Documentation.md`

---

**Training Complete! üéâ**