# Training SE3GCNN on NIfTI Medical Data
This notebook demonstrates how to train the SE3GCNN model on NIfTI medical imaging data. The model is designed to perform segmentation while being equivariant to rotations and translations.

## 1. Import Required Libraries
First, let's import all the necessary libraries and modules.

In [None]:
import os
import numpy as np
import nibabel as nib
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
import wandb
from focal_loss import focal_loss

# Import local modules
import model_util
import mesh_util
from datagen import get_spatial_blocks
from train_drivedata import run_steerable_gcnn, run_training

## 2. Configuration
Define the training parameters and model configuration.

In [None]:
CLASSES = ['Schwannoma', 'Pituitary', 'Metastases', 'Meningioma', 'AVM']

class Config:
    def __init__(self):
        # Data parameters
        self.path = "data/"  # Path to your NIfTI data
        self.grid_size = 7
        self.num_classes = len(CLASSES)
        self.num_shells = 1  # Single channel MRI data
        
        # Model parameters
        self.interpolate = True
        self.num_rays = 5
        self.samples_per_ray = 2
        self.ray_len = None  # Radius of the spherical kernel, None uses default arc length
        self.watson_param = 10
        self.model_capacity = "small"  # or "big"
        
        # Training parameters
        self.b_size = 16
        self.iter = 200
        self.lr = 0.0001
        self.alpha = 0.25  # Focal loss parameter
        self.gamma = 2.0   # Focal loss parameter
        self.cuda = 0      # GPU device index
        self.train_split = 0.8  # 80% training, 20% validation
        
        # Other parameters
        self.bias = True
        self.lin_bias = True
        self.spatial_bias = True
        self.lin_bn = True
        self.pooling = 'max'
        self.exp_name = 'brain_tumor_classification'
        self.run_path = 'results'
        self.data_aug = True  # Enable data augmentation for medical images
        self.spatial_kernel_size = (7, 7, 7)  # Correct - tuple of integers for 3D convolution

args = Config()

## 3. Data Loading and Preprocessing
Create a custom Dataset class to load and preprocess NIfTI data.

In [None]:
from torch.nn.functional import interpolate
from torchvision import transforms
from sklearn.model_selection import train_test_split

class BrainTumorDataset(Dataset):
    def __init__(self, base_dir, classes=CLASSES, transform=None, train=True, train_ratio=0.8, random_state=42):
        self.transform = transform
        self.classes = classes
        self.train = train
        
        # Collect all file paths and labels
        self.data = []
        for class_idx, class_name in enumerate(classes):
            class_dir = os.path.join(base_dir, class_name)
            for file in os.listdir(class_dir):
                if file.endswith('.nii') or file.endswith('.nii.gz'):
                    self.data.append({
                        'path': os.path.join(class_dir, file),
                        'label': class_idx,
                        'class': class_name
                    })
        
        # Split into train/test
        train_data, test_data = train_test_split(
            self.data, 
            train_size=train_ratio,
            random_state=random_state,
            stratify=[d['label'] for d in self.data]
        )
        
        self.data = train_data if train else test_data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Load NIfTI file
        nifti_img = nib.load(item['path'])
        image_data = nifti_img.get_fdata()
        
        # Preprocess data
        processed_data = self.preprocess_data(image_data)
        
        # Convert label to tensor
        label = torch.tensor(item['label'], dtype=torch.long)
        
        return processed_data, label
    
    def preprocess_data(self, image_data):
        # Normalize to [0, 1]
        image_data = (image_data - image_data.min()) / (image_data.max() - image_data.min() + 1e-8)
        
        # Standardize
        image_data = (image_data - image_data.mean()) / (image_data.std() + 1e-8)
        
        # Convert to torch tensor
        image_tensor = torch.from_numpy(image_data).float()
        
        # Add channel dimension if needed
        if len(image_tensor.shape) == 3:
            image_tensor = image_tensor.unsqueeze(0)
        
        # Resize to a standard size if needed (e.g., 128x128x128)
        if self.transform:
            image_tensor = self.transform(image_tensor)
        
        return image_tensor

class Transform3D:
    def __init__(self, output_size=(128, 128, 128), data_aug=False):
        self.output_size = output_size
        self.data_aug = data_aug
    
    def __call__(self, x):
        # Resize to standard size
        x = interpolate(x.unsqueeze(0), size=self.output_size, mode='trilinear', align_corners=True).squeeze(0)
        
        if self.data_aug and torch.rand(1).item() > 0.5:
            # Random rotation (90 degree increments)
            k = torch.randint(4, (1,)).item()
            x = torch.rot90(x, k, dims=[1, 2])
            
            # Random flips
            if torch.rand(1).item() > 0.5:
                x = x.flip(1)
            if torch.rand(1).item() > 0.5:
                x = x.flip(2)
        
        return x

## 4. Initialize Model and Training
Set up the model, optimizer, and training loop.

In [None]:
def initialize_training():
    # Initialize wandb with config
    wandb.init(
        project=args.exp_name,
        config={
            "model_capacity": args.model_capacity,
            "learning_rate": args.lr,
            "batch_size": args.b_size,
            "epochs": args.iter,
            "data_augmentation": args.data_aug,
            "num_classes": args.num_classes
        }
    )
    
    # Setup transform
    transform = Transform3D(output_size=(128, 128, 128), data_aug=args.data_aug)
    
    # Create datasets
    train_dataset = BrainTumorDataset(
        base_dir=args.path,
        transform=transform,
        train=True,
        train_ratio=args.train_split
    )
    
    test_dataset = BrainTumorDataset(
        base_dir=args.path,
        transform=transform,
        train=False,
        train_ratio=args.train_split
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=args.b_size, 
        shuffle=True, 
        num_workers=4,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset, 
        batch_size=args.b_size, 
        shuffle=False, 
        num_workers=4,
        pin_memory=True
    )
    
    # Initialize model and move to device
    device = f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu"
    model, wandb_name = run_steerable_gcnn(args, device, True)
    model = model.to(device)
    
    # Set up training
    # Calculate class weights based on dataset distribution
    class_counts = torch.bincount(torch.tensor([data['label'] for data in train_dataset.data]))
    class_weights = 1. / class_counts.float()
    class_weights = class_weights / class_weights.sum()
    
    criterion = focal_loss(
        alpha=class_weights.to(device),
        gamma=args.gamma,
        device=device
    )
    
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=args.lr,
        weight_decay=0.01
    )
    
    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='max',
        factor=0.1,
        patience=10,
        verbose=True
    )
    
    return model, train_loader, test_loader, criterion, optimizer, scheduler, device

## 5. Training Loop
Run the training loop with validation.

In [None]:
from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns

def train():
    model, train_loader, test_loader, criterion, optimizer, scheduler, device = initialize_training()
    
    # Create output directory
    os.makedirs(args.run_path, exist_ok=True)
    
    best_val_acc = 0.0
    for epoch in range(args.iter):
        # Training phase
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0
        
        for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader, desc=f'Epoch {epoch+1}')):
            inputs, targets = inputs.to(device), targets.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            train_total += targets.size(0)
            train_correct += predicted.eq(targets).sum().item()
            
            if batch_idx % 10 == 0:
                wandb.log({
                    'train_batch_loss': loss.item(),
                    'train_batch_acc': 100. * predicted.eq(targets).sum().item() / targets.size(0)
                })
        
        # Validation phase
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        all_preds = []
        all_targets = []
        
        with torch.no_grad():
            for inputs, targets in test_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += targets.size(0)
                val_correct += predicted.eq(targets).sum().item()
                
                all_preds.extend(predicted.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())
        
        train_acc = 100. * train_correct / train_total
        val_acc = 100. * val_correct / val_total
        
        # Log metrics
        wandb.log({
            'epoch': epoch,
            'train_loss': train_loss / len(train_loader),
            'train_acc': train_acc,
            'val_loss': val_loss / len(test_loader),
            'val_acc': val_acc
        })
        
        # Update learning rate
        scheduler.step(val_acc)
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
            }, f'{args.run_path}/best_model.pth')
            
            # Create and log confusion matrix
            cm = confusion_matrix(all_targets, all_preds)
            plt.figure(figsize=(10, 8))
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                       xticklabels=CLASSES,
                       yticklabels=CLASSES)
            plt.title(f'Confusion Matrix - Epoch {epoch+1}')
            plt.ylabel('True Label')
            plt.xlabel('Predicted Label')
            plt.tight_layout()
            wandb.log({'confusion_matrix': wandb.Image(plt)})
            plt.close()
            
            # Log classification report
            report = classification_report(all_targets, all_preds, 
                                        target_names=CLASSES, 
                                        output_dict=True)
            wandb.log({'classification_report': report})

## 6. Run Training
Execute the training process and monitor with wandb.

In [None]:
if __name__ == "__main__":
    train()

## 7. Visualization and Evaluation
After training, visualize results and evaluate model performance.