# Training SE3GCNN on NIfTI Medical Data (No WandB)
This notebook demonstrates how to train the SE3GCNN model on NIfTI medical imaging data without wandb dependency. The model is designed to perform brain tumor classification while being equivariant to rotations and translations.

## 1. Import Required Libraries
First, let's import all the necessary libraries and modules.

In [None]:
import os
import numpy as np
import nibabel as nib
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
from focal_loss import focal_loss
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
from sklearn.model_selection import train_test_split
from torch.nn.functional import interpolate
import time
from datetime import datetime

# Import local modules
import model_util
import mesh_util
from datagen import get_spatial_blocks
from train_drivedata import run_steerable_gcnn

## 2. Configuration
Define the training parameters and model configuration.

In [None]:
CLASSES = ['Schwannoma', 'Pituitary', 'Metastases', 'Meningioma', 'AVM']

class Config:
    def __init__(self):
        # Data parameters
        self.path = "data/"  # Path to your NIfTI data
        self.grid_size = 7
        self.num_classes = len(CLASSES)
        self.num_shells = 1  # Single channel MRI data
        
        # Model parameters
        self.interpolate = True
        self.num_rays = 5
        self.samples_per_ray = 2
        self.ray_len = None  # Radius of the spherical kernel, None uses default arc length
        self.watson_param = 10
        self.model_capacity = "small"  # or "big"
        
        # Training parameters
        self.b_size = 16
        self.iter = 200
        self.lr = 0.0001
        self.alpha = 0.25  # Focal loss parameter
        self.gamma = 2.0   # Focal loss parameter
        self.cuda = 0      # GPU device index
        self.train_split = 0.8  # 80% training, 20% validation
        
        # Other parameters
        self.bias = True
        self.lin_bias = True
        self.spatial_bias = True
        self.lin_bn = True
        self.pooling = 'max'
        self.run_path = 'results'
        self.data_aug = True  # Enable data augmentation for medical images
        self.spatial_kernel_size = (7, 7, 7)  # Correct - tuple of integers for 3D convolution
        
        # Create timestamp for unique run folder
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.run_path = os.path.join(self.run_path, f"run_{timestamp}")

args = Config()

## 3. Data Loading and Preprocessing
Create a custom Dataset class to load and preprocess NIfTI data.

In [None]:
class BrainTumorDataset(Dataset):
    def __init__(self, base_dir, classes=CLASSES, transform=None, train=True, train_ratio=0.8, random_state=42):
        self.transform = transform
        self.classes = classes
        self.train = train
        
        # Collect all file paths and labels
        self.data = []
        for class_idx, class_name in enumerate(classes):
            class_dir = os.path.join(base_dir, class_name)
            if not os.path.exists(class_dir):
                print(f"Warning: Directory {class_dir} does not exist!")
                continue
                
            files = [f for f in os.listdir(class_dir) if f.endswith('.nii') or f.endswith('.nii.gz')]
            print(f"Found {len(files)} files in class {class_name}")
            
            for file in files:
                self.data.append({
                    'path': os.path.join(class_dir, file),
                    'label': class_idx,
                    'class': class_name
                })
        
        # Split into train/test
        if len(self.data) > 0:
            train_data, test_data = train_test_split(
                self.data, 
                train_size=train_ratio,
                random_state=random_state,
                stratify=[d['label'] for d in self.data]
            )
            
            self.data = train_data if train else test_data
            
            # Print class distribution
            class_dist = {}
            for d in self.data:
                if d['class'] not in class_dist:
                    class_dist[d['class']] = 0
                class_dist[d['class']] += 1
            
            print(f"{'Training' if train else 'Validation'} set class distribution:")
            for cls, count in class_dist.items():
                print(f"  {cls}: {count} images")
        else:
            print("No data found! Please check the data directory.")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Load NIfTI file
        nifti_img = nib.load(item['path'])
        image_data = nifti_img.get_fdata()
        
        # Preprocess data
        processed_data = self.preprocess_data(image_data)
        
        # Convert label to tensor
        label = torch.tensor(item['label'], dtype=torch.long)
        
        return processed_data, label
    
    def preprocess_data(self, image_data):
        # Normalize to [0, 1]
        min_val = image_data.min()
        max_val = image_data.max()
        if max_val > min_val:  # Avoid division by zero
            image_data = (image_data - min_val) / (max_val - min_val + 1e-8)
        
        # Standardize
        mean_val = image_data.mean()
        std_val = image_data.std()
        if std_val > 0:  # Avoid division by zero
            image_data = (image_data - mean_val) / (std_val + 1e-8)
        
        # Convert to torch tensor
        image_tensor = torch.from_numpy(image_data).float()
        
        # Add channel dimension if needed
        if len(image_tensor.shape) == 3:
            image_tensor = image_tensor.unsqueeze(0)
        
        # Apply transform if provided
        if self.transform:
            image_tensor = self.transform(image_tensor)
        
        return image_tensor

class Transform3D:
    def __init__(self, output_size=(128, 128, 128), data_aug=False):
        self.output_size = output_size
        self.data_aug = data_aug
    
    def __call__(self, x):
        # Resize to standard size
        x = interpolate(x.unsqueeze(0), size=self.output_size, mode='trilinear', align_corners=True).squeeze(0)
        
        if self.data_aug and torch.rand(1).item() > 0.5:
            # Random rotation (90 degree increments)
            k = torch.randint(4, (1,)).item()
            x = torch.rot90(x, k, dims=[1, 2])
            
            # Random flips
            if torch.rand(1).item() > 0.5:
                x = x.flip(1)
            if torch.rand(1).item() > 0.5:
                x = x.flip(2)
        
        return x

## 4. Initialize Model and Training
Set up the model, optimizer, and training loop.

In [None]:
def initialize_training():
    # Create output directory
    os.makedirs(args.run_path, exist_ok=True)
    
    # Save configuration
    with open(os.path.join(args.run_path, 'config.txt'), 'w') as f:
        for key, value in vars(args).items():
            f.write(f"{key}: {value}\n")
    
    # Setup transform
    transform = Transform3D(output_size=(128, 128, 128), data_aug=args.data_aug)
    
    # Create datasets
    train_dataset = BrainTumorDataset(
        base_dir=args.path,
        transform=transform,
        train=True,
        train_ratio=args.train_split
    )
    
    test_dataset = BrainTumorDataset(
        base_dir=args.path,
        transform=transform,
        train=False,
        train_ratio=args.train_split
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=args.b_size, 
        shuffle=True, 
        num_workers=4,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset, 
        batch_size=args.b_size, 
        shuffle=False, 
        num_workers=4,
        pin_memory=True
    )
    
    # Initialize model and move to device
    device = f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    model, model_name = run_steerable_gcnn(args, device, True)
    model = model.to(device)
    
    print(f"Model: {model_name}")
    print(f"Number of parameters: {sum(p.numel() for p in model.parameters())}")
    
    # Set up training
    # Calculate class weights based on dataset distribution
    if len(train_dataset.data) > 0:
        class_counts = torch.bincount(torch.tensor([data['label'] for data in train_dataset.data]))
        class_weights = 1. / class_counts.float()
        class_weights = class_weights / class_weights.sum()
        print(f"Class weights: {class_weights}")
    else:
        class_weights = torch.ones(len(CLASSES)) / len(CLASSES)
    
    criterion = focal_loss(
        alpha=class_weights.to(device),
        gamma=args.gamma,
        device=device
    )
    
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=args.lr,
        weight_decay=0.01
    )
    
    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='max',
        factor=0.1,
        patience=10,
        verbose=True
    )
    
    return model, train_loader, test_loader, criterion, optimizer, scheduler, device

## 5. Training Loop
Run the training loop with validation.

In [None]:
def train():
    model, train_loader, test_loader, criterion, optimizer, scheduler, device = initialize_training()
    
    # Lists to store metrics for plotting
    train_losses = []
    train_accuracies = []
    val_losses = []
    val_accuracies = []
    epochs = []
    
    best_val_acc = 0.0
    start_time = time.time()
    
    # Training loop
    for epoch in range(args.iter):
        # Training phase
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0
        
        batch_times = []
        epoch_start = time.time()
        
        for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader, desc=f'Epoch {epoch+1}/{args.iter}')):
            batch_start = time.time()
            inputs, targets = inputs.to(device), targets.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            train_total += targets.size(0)
            train_correct += predicted.eq(targets).sum().item()
            
            batch_end = time.time()
            batch_times.append(batch_end - batch_start)
            
            # Print intermediate results every 10 batches
            if (batch_idx + 1) % 10 == 0:
                print(f'Batch {batch_idx+1}/{len(train_loader)}, Loss: {loss.item():.4f}, '
                      f'Acc: {100.*predicted.eq(targets).sum().item()/targets.size(0):.2f}%, '
                      f'Time: {batch_end-batch_start:.2f}s')
        
        # Validation phase
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        all_preds = []
        all_targets = []
        
        with torch.no_grad():
            for inputs, targets in tqdm(test_loader, desc='Validation'):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += targets.size(0)
                val_correct += predicted.eq(targets).sum().item()
                
                all_preds.extend(predicted.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())
        
        # Calculate metrics
        train_loss = train_loss / len(train_loader)
        train_acc = 100. * train_correct / (train_total or 1)  # Avoid division by zero
        val_loss = val_loss / len(test_loader)
        val_acc = 100. * val_correct / (val_total or 1)  # Avoid division by zero
        
        # Store metrics for plotting
        train_losses.append(train_loss)
        train_accuracies.append(train_acc)
        val_losses.append(val_loss)
        val_accuracies.append(val_acc)
        epochs.append(epoch + 1)
        
        # Print epoch results
        epoch_end = time.time()
        print(f"Epoch {epoch+1}/{args.iter} completed in {(epoch_end - epoch_start)/60:.2f} minutes")
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
        print(f"Avg batch time: {np.mean(batch_times):.4f}s")
        
        # Update learning rate
        scheduler.step(val_acc)
        
        # Save metrics to CSV
        with open(os.path.join(args.run_path, 'metrics.csv'), 'a') as f:
            if epoch == 0:
                f.write('epoch,train_loss,train_acc,val_loss,val_acc\n')
            f.write(f'{epoch+1},{train_loss:.6f},{train_acc:.6f},{val_loss:.6f},{val_acc:.6f}\n')
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
            }, os.path.join(args.run_path, 'best_model.pth'))
            
            # Create and save confusion matrix
            cm = confusion_matrix(all_targets, all_preds)
            plt.figure(figsize=(10, 8))
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                       xticklabels=CLASSES,
                       yticklabels=CLASSES)
            plt.title(f'Confusion Matrix - Epoch {epoch+1} - Accuracy: {val_acc:.2f}%')
            plt.ylabel('True Label')
            plt.xlabel('Predicted Label')
            plt.tight_layout()
            plt.savefig(os.path.join(args.run_path, f'confusion_matrix_epoch_{epoch+1}.png'))
            plt.close()
            
            # Save classification report
            report = classification_report(all_targets, all_preds, 
                                          target_names=CLASSES, 
                                          output_dict=False)
            with open(os.path.join(args.run_path, f'classification_report_epoch_{epoch+1}.txt'), 'w') as f:
                f.write(report)
            
            print(f"New best model saved! Validation accuracy: {val_acc:.2f}%")
    
    # Plot and save learning curves
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, label='Train Loss')
    plt.plot(epochs, val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Loss Curves')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accuracies, label='Train Accuracy')
    plt.plot(epochs, val_accuracies, label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.title('Accuracy Curves')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig(os.path.join(args.run_path, 'learning_curves.png'))
    plt.close()
    
    # Calculate and print total training time
    total_time = time.time() - start_time
    print(f"Training completed in {total_time/3600:.2f} hours")
    print(f"Best validation accuracy: {best_val_acc:.2f}%")
    print(f"Results saved to {args.run_path}")
    
    # Return the trained model and metrics
    return model, {
        'train_losses': train_losses,
        'train_accuracies': train_accuracies,
        'val_losses': val_losses,
        'val_accuracies': val_accuracies
    }

## 6. Run Training
Execute the training process.

In [None]:
if __name__ == "__main__":
    train()

## 7. Visualization and Evaluation
After training, load the best model and evaluate its performance.

In [None]:
def evaluate_best_model():
    # Load configuration and datasets
    transform = Transform3D(output_size=(128, 128, 128), data_aug=False)  # No augmentation for evaluation
    
    test_dataset = BrainTumorDataset(
        base_dir=args.path,
        transform=transform,
        train=False,
        train_ratio=args.train_split
    )
    
    test_loader = DataLoader(
        test_dataset, 
        batch_size=args.b_size, 
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    
    # Initialize model
    device = f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu"
    model, _ = run_steerable_gcnn(args, device, True)
    
    # Load best model weights
    checkpoint = torch.load(os.path.join(args.run_path, 'best_model.pth'))
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    model.eval()
    
    # Evaluate model
    all_preds = []
    all_targets = []
    all_probs = []
    
    with torch.no_grad():
        for inputs, targets in tqdm(test_loader, desc='Evaluating'):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            probs = torch.nn.functional.softmax(outputs, dim=1)
            
            _, predicted = outputs.max(1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    
    # Calculate metrics
    accuracy = accuracy_score(all_targets, all_preds)
    cm = confusion_matrix(all_targets, all_preds)
    report = classification_report(all_targets, all_preds, target_names=CLASSES)
    
    # Print results
    print(f"Evaluation Results (Epoch {checkpoint['epoch']+1}):")
    print(f"Accuracy: {accuracy*100:.2f}%")
    print("\nClassification Report:")
    print(report)
    
    # Plot confusion matrix
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
               xticklabels=CLASSES,
               yticklabels=CLASSES)
    plt.title(f'Confusion Matrix (Final) - Accuracy: {accuracy*100:.2f}%')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.savefig(os.path.join(args.run_path, 'final_confusion_matrix.png'))
    plt.show()
    
    return accuracy, report, cm