# üß† SymPanICH-Net v2 ‚Äî Google Colab Training

**Text-Guided Symmetry-Aware Panoptic Segmentation for ICH Detection**

This notebook trains SymPanICH-Net v2 on Google Colab with GPU acceleration.

## Prerequisites
- Google Colab Pro (for A100/V100 GPUs)
- Dataset uploaded to Google Drive

## Steps
1. Mount Google Drive
2. Clone repository
3. Install dependencies
4. Configure data path
5. Train the model
6. Evaluate & download results

## 1Ô∏è‚É£ GPU Check & Setup

In [None]:
# Check GPU availability
!nvidia-smi
import torch
print(f"\n‚úÖ PyTorch: {torch.__version__}")
print(f"‚úÖ CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"‚úÖ GPU: {torch.cuda.get_device_name(0)}")
    print(f"‚úÖ VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")

## 2Ô∏è‚É£ Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Set your data path (MODIFY THIS to your dataset location in Drive)
DATA_DIR = '/content/drive/MyDrive/FYP26/data'

import os
if os.path.exists(DATA_DIR):
    print(f'‚úÖ Data directory found: {DATA_DIR}')
    print(f'   Contents: {os.listdir(DATA_DIR)[:10]}')
else:
    print(f'‚ùå Data directory not found: {DATA_DIR}')
    print('   Please update DATA_DIR to point to your dataset location in Google Drive')

## 3Ô∏è‚É£ Clone Repository & Install Dependencies

In [None]:
# Clone the repository
!git clone https://github.com/Babu2107/FYP26.git /content/FYP26
%cd /content/FYP26

# Install dependencies
!pip install -q -r requirements.txt

print('\n‚úÖ Repository cloned and dependencies installed!')

## 4Ô∏è‚É£ Quick Smoke Test

Test that the model builds and can do a forward pass.

In [None]:
import sys
sys.path.insert(0, '/content/FYP26')

import torch
from src.models.sympanich_net import SymPanICHNetV2

# Build model
model = SymPanICHNetV2(
    backbone_name='swinv2_tiny_window8_256',
    pretrained=True,
    use_context=False,  # Use 3ch for quick test
    num_queries=50,
    num_classes=7,
    num_decoder_layers=9,
)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Total parameters: {total_params / 1e6:.1f}M')
print(f'Trainable parameters: {trainable_params / 1e6:.1f}M')

# Quick forward pass test
model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

with torch.no_grad():
    dummy_img = torch.randn(1, 3, 256, 256).to(device)
    dummy_flip = torch.flip(dummy_img, dims=[3])
    outputs = model(dummy_img, dummy_flip)

print(f'\n‚úÖ Forward pass successful!')
print(f'   pred_logits: {outputs["pred_logits"].shape}')
print(f'   pred_masks:  {outputs["pred_masks"].shape}')
print(f'   hv_maps:     {outputs["hv_maps"].shape}')
print(f'   text_emb:    {outputs["text_embeddings"].shape}')

del model
torch.cuda.empty_cache()

## 5Ô∏è‚É£ Train the Model

Configure training parameters and launch training.
Checkpoints are saved to `/content/FYP26/checkpoints/` and
also backed up to Google Drive.

In [None]:
# ==========================================
#  TRAINING CONFIGURATION ‚Äî MODIFY AS NEEDED
# ==========================================

CONFIG = {
    # Data
    'data_dir': DATA_DIR,
    'image_size': 256,
    'context_slices': 2,  # Set to 0 for standard 3-channel input
    
    # Training
    'max_epochs': 100,
    'batch_size': 4,
    'gradient_accumulation': 4,  # Effective batch = 16
    'num_workers': 2,
    'precision': '16-mixed',     # FP16 for faster training
    
    # Model
    'backbone': 'swinv2_tiny_window8_256',
    'num_queries': 50,
    'num_classes': 7,
    'num_decoder_layers': 9,
    
    # Optimizer
    'lr': 1e-4,
    'weight_decay': 0.05,
    
    # Checkpoints
    'checkpoint_dir': '/content/FYP26/checkpoints',
    'drive_backup_dir': '/content/drive/MyDrive/FYP26/checkpoints',
}

print('üìã Training Configuration:')
for k, v in CONFIG.items():
    print(f'   {k}: {v}')

In [None]:
import sys
sys.path.insert(0, '/content/FYP26')

import torch
try:
    import pytorch_lightning as pl
    from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
except ImportError:
    import lightning as pl
    from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor

from src.training.trainer import SymPanICHNetModule
from src.data.datamodule import ICHDataModule

# DataModule
datamodule = ICHDataModule(
    data_dir=CONFIG['data_dir'],
    image_size=CONFIG['image_size'],
    batch_size=CONFIG['batch_size'],
    num_workers=CONFIG['num_workers'],
    context_slices=CONFIG['context_slices'],
)

# Model
model = SymPanICHNetModule(
    backbone_name=CONFIG['backbone'],
    pretrained=True,
    num_queries=CONFIG['num_queries'],
    num_classes=CONFIG['num_classes'],
    num_decoder_layers=CONFIG['num_decoder_layers'],
    use_context=CONFIG['context_slices'] > 0,
    base_lr=CONFIG['lr'],
    weight_decay=CONFIG['weight_decay'],
    max_epochs=CONFIG['max_epochs'],
)

# Callbacks
callbacks = [
    ModelCheckpoint(
        dirpath=CONFIG['checkpoint_dir'],
        filename='sympanich-{epoch:02d}-{val/dice:.4f}',
        monitor='val/dice', mode='max', save_top_k=3, save_last=True,
    ),
    EarlyStopping(monitor='val/dice', patience=15, mode='max'),
    LearningRateMonitor(logging_interval='step'),
]

# Trainer
trainer = pl.Trainer(
    max_epochs=CONFIG['max_epochs'],
    accelerator='gpu',
    devices=1,
    precision=CONFIG['precision'],
    callbacks=callbacks,
    accumulate_grad_batches=CONFIG['gradient_accumulation'],
    gradient_clip_val=1.0,
    log_every_n_steps=10,
)

print('üöÄ Starting training...')
trainer.fit(model, datamodule=datamodule)
print('\n‚úÖ Training complete!')
print(f'Best model: {callbacks[0].best_model_path}')

## 6Ô∏è‚É£ Backup Checkpoints to Google Drive

In [None]:
import shutil
import os

# Backup checkpoints to Drive
os.makedirs(CONFIG['drive_backup_dir'], exist_ok=True)

checkpoint_dir = CONFIG['checkpoint_dir']
if os.path.exists(checkpoint_dir):
    for f in os.listdir(checkpoint_dir):
        src = os.path.join(checkpoint_dir, f)
        dst = os.path.join(CONFIG['drive_backup_dir'], f)
        shutil.copy2(src, dst)
        print(f'  Backed up: {f}')
    print(f'\n‚úÖ Checkpoints backed up to: {CONFIG["drive_backup_dir"]}')
else:
    print('‚ùå No checkpoints found')

## 7Ô∏è‚É£ Evaluate & Visualize Results

In [None]:
# Test the best model
trainer.test(model, datamodule=datamodule, ckpt_path='best')

In [None]:
# Visualize sample predictions
from src.utils.visualization import plot_prediction, overlay_mask
from src.utils.panoptic_fusion import panoptic_fusion
import numpy as np

model.eval()
model = model.to('cuda')

# Get a batch from validation
datamodule.setup('test')
batch = next(iter(datamodule.test_dataloader()))

with torch.no_grad():
    images = batch['image'].cuda()
    images_flip = batch['image_flipped'].cuda()
    outputs = model(images, images_flip)

# Visualize first sample
img = images[0].cpu().numpy()[:3].transpose(1, 2, 0)  # First 3 channels
gt = batch['mask'][0].numpy()

# Run panoptic fusion
fusion = panoptic_fusion(
    outputs['pred_logits'][0],
    outputs['pred_masks'][0],
)
pred = fusion['semantic_map']

plot_prediction(img, gt_mask=gt, pred_mask=pred, title='SymPanICH-Net v2 ‚Äî Sample Prediction')

## 8Ô∏è‚É£ Generate AI Clinical Report

In [None]:
from src.models.report_generator import ReportGenerator

reporter = ReportGenerator(image_size=256)

# Generate report from the first prediction
segments = fusion['segments']
if segments:
    classes = np.array([s['class_id'] for s in segments])
    masks = np.stack([s['mask'] for s in segments])
    scores = np.array([s['score'] for s in segments])
    
    report = reporter.generate(
        pred_classes=classes,
        pred_masks=masks,
        pred_scores=scores,
        patient_id='Sample_001',
    )
    print(report)
else:
    print('No hemorrhage detected in this slice.')

---
## üìù Notes

- **Colab Pro** recommended for A100 GPU access (40GB VRAM)
- **Training time**: ~4-6 hours for 100 epochs on A100
- **Checkpoints** are auto-saved to Drive on backup
- **For inference on your laptop**: download the best checkpoint and run:
  ```python
  python scripts/predict.py --checkpoint best_model.ckpt --input scan.nii
  ```