In [None]:
!pip install torch torchvision opencv-python numpy pandas matplotlib scikit-learn torchsummary kagglehub

In [1]:
import os
import h5py
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import time
import kagglehub
import glob
import random
from sklearn.metrics import accuracy_score, confusion_matrix, roc_curve, auc, classification_report
from sklearn.preprocessing import label_binarize
from itertools import cycle

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Configuration parameters
CONFIG = {
    'batch_size': 16,
    'num_epochs': 15,  # Increased from 10 to 15
    'base_lr': 5e-4,   # Changed from 1e-4
    'weight_decay': 2e-5,  # Changed from 1e-4
    'scheduler': 'onecycle',  # Options: 'cosine', 'onecycle'
    'mixup_alpha': 0.2,  # Added mixup augmentation
    'label_smoothing': 0.1,  # Added label smoothing
    'dropout_rate': 0.4,  # Increased dropout
    'model_variant': 'efficient_cbam',  # Options: 'efficient_basic', 'efficient_cbam', 'efficient_dual'
    'num_classes': 3
}

print("Training configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

# Download the dataset using kagglehub
print("Downloading dataset...")
dataset_path = kagglehub.dataset_download("ashkhagan/figshare-brain-tumor-dataset")
print("Path to dataset files:", dataset_path)

# Data loading functions
def load_data():
    """Load and preprocess data from the downloaded directory"""
    # First find all .mat files in the dataset directory and its subdirectories
    mat_files = []
    for root, dirs, files in os.walk(dataset_path):
        for file in files:
            if file.endswith('.mat') and file != 'cvind.mat':
                mat_files.append(os.path.join(root, file))
   
    print(f"Found {len(mat_files)} .mat files")
   
    # If we don't have enough files, exit
    if len(mat_files) < 3000:
        raise ValueError(f"Expected ~3064 .mat files but found only {len(mat_files)}")
   
    # Sort the files to ensure consistent order
    mat_files.sort()
   
    # Prepare arrays for images and labels
    img = np.zeros((len(mat_files), 224, 224))
    lbl = []
   
    # Load each file
    for i, file_path in enumerate(mat_files):
        try:
            with h5py.File(file_path, 'r') as f:
                images = f['cjdata']
                resized = cv2.resize(images['image'][:,:], (224, 224), interpolation=cv2.INTER_CUBIC)
                x = np.asarray(resized)
                x = (x - np.min(x)) / (np.max(x) - np.min(x))  # Normalization
                x = x.reshape((1, 224, 224))
                img[i] = x
                lbl.append(int(images['label'][0]))
               
                if i % 500 == 0:
                    print(f"Processed {i} images")
        except Exception as e:
            print(f"Failed to load image at {file_path}: {e}")
   
    # Find cvind.mat file
    cvind_files = []
    for root, dirs, files in os.walk(dataset_path):
        for file in files:
            if file == 'cvind.mat':
                cvind_files.append(os.path.join(root, file))
   
    if not cvind_files:
        raise ValueError("Could not find cvind.mat file")
   
    cvind_path = cvind_files[0]
    print(f"Found cvind.mat at: {cvind_path}")
   
    # Load fold indices
    with h5py.File(cvind_path, 'r') as f:
        idx = np.array(f['cvind']).astype(np.int16).squeeze()
   
    return img, np.array(lbl), idx

# Custom Dataset
class BrainTumorDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform
       
    def __len__(self):
        return len(self.labels)
   
    def __getitem__(self, idx):
        # Convert grayscale to RGB by repeating channel
        image = self.images[idx]
        image = np.repeat(image.reshape(224, 224, 1), 3, axis=2)
        label = self.labels[idx] - 1  # Convert to 0-indexed
       
        if self.transform:
            image = self.transform(image)
        else:
            image = torch.from_numpy(image.transpose(2, 0, 1)).float()
           
        return image, label

# Define transforms for training and validation with more aggressive augmentation
def get_transforms():
    train_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.RandomRotation(30),  # Increased from 15 to 30
        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),  # Added scale
        transforms.ColorJitter(brightness=0.2, contrast=0.2),  # Added color jitter
        transforms.RandomPerspective(distortion_scale=0.2, p=0.5),  # Added perspective transform
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        transforms.RandomErasing(p=0.2)  # Added random erasing
    ])
   
    val_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
   
    return train_transform, val_transform

# Get train and test splits
def get_train_test_data(images, labels, fold_indices, test_fold):
    train_mask = fold_indices != test_fold
    test_mask = fold_indices == test_fold
   
    train_images = images[train_mask]
    train_labels = labels[train_mask]
    test_images = images[test_mask]
    test_labels = labels[test_mask]
   
    return (train_images, train_labels), (test_images, test_labels)

# MixUp augmentation
def mixup_data(x, y, alpha=1.0):
    '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

# Squeeze and Excitation Block
class SEBlock(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

# CBAM: Convolutional Block Attention Module
class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
       
        self.fc = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // reduction_ratio, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // reduction_ratio, in_channels, 1, bias=False)
        )
       
    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return torch.sigmoid(out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1
       
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
       
    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avg_out, max_out], dim=1)
        out = self.conv(out)
        return torch.sigmoid(out)

class CBAM(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttention(in_channels, reduction_ratio)
        self.spatial_attention = SpatialAttention(kernel_size)
       
    def forward(self, x):
        x = x * self.channel_attention(x)
        x = x * self.spatial_attention(x)
        return x

# Self-Attention Block
class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.query = nn.Conv2d(in_channels, in_channels//8, kernel_size=1)
        self.key = nn.Conv2d(in_channels, in_channels//8, kernel_size=1)
        self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)
       
    def forward(self, x):
        batch_size, C, width, height = x.size()
       
        # Reshape for matrix multiplication
        proj_query = self.query(x).view(batch_size, -1, width*height).permute(0, 2, 1)  # B X (W*H) X C
        proj_key = self.key(x).view(batch_size, -1, width*height)  # B X C X (W*H)
       
        # Calculate attention map
        attention = torch.bmm(proj_query, proj_key)  # B X (W*H) X (W*H)
        attention = self.softmax(attention)
       
        # Apply attention to values
        proj_value = self.value(x).view(batch_size, -1, width*height)  # B X C X (W*H)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))  # B X C X (W*H)
        out = out.view(batch_size, C, width, height)  # B X C X W X H
       
        # Add residual connection with learnable parameter gamma
        out = self.gamma * out + x
       
        return out

# Dual Path Block - combines features from multiple paths
class DualPathBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(DualPathBlock, self).__init__()
       
        # First path - standard convolution
        self.conv_path = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
       
        # Second path - depthwise separable convolution
        self.dw_path = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=stride, padding=1, groups=in_channels, bias=False),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
       
        # Attention module
        self.cbam = CBAM(out_channels)
       
        # Residual connection if dimensions change
        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
   
    def forward(self, x):
        residual = x
       
        # Process through both paths
        out1 = self.conv_path(x)
        out2 = self.dw_path(x)
       
        # Combine paths with element-wise addition
        out = out1 + out2
       
        # Apply attention
        out = self.cbam(out)
       
        # Apply residual connection
        if self.downsample is not None:
            residual = self.downsample(x)
        out += residual
       
        return out

# Feature Pyramid Network (FPN) module
class FPN(nn.Module):
    def __init__(self, channels):
        super(FPN, self).__init__()
        self.lateral_conv = nn.Conv2d(channels, 256, kernel_size=1)
        self.output_conv = nn.Conv2d(256, 256, kernel_size=3, padding=1)
       
    def forward(self, x):
        lateral = self.lateral_conv(x)
        output = self.output_conv(lateral)
        return output

# Main model with EfficientNet backbone and attention - Modified to support different variants
class EnhancedEfficientNetClassifier(nn.Module):
    def __init__(self, model_variant='efficient_basic', num_classes=3, dropout_rate=0.3):
        super(EnhancedEfficientNetClassifier, self).__init__()
       
        # Load pretrained EfficientNet and remove final classifier
        # Using B4 instead of B3 for higher capacity
        efficient_net = models.efficientnet_b4(weights="IMAGENET1K_V1")
        self.features = nn.Sequential(*list(efficient_net.children())[:-1])
       
        # Get the output feature dimension
        feature_dim = efficient_net._modules['classifier'][1].in_features
        split_dim = feature_dim // 2
       
        self.model_variant = model_variant
       
        if model_variant == 'efficient_basic':
            # Basic variant with Self-Attention and SE blocks
            self.attention1 = SelfAttention(split_dim)
            self.attention2 = SelfAttention(split_dim)
            self.se1 = SEBlock(split_dim)
            self.se2 = SEBlock(split_dim)
           
        elif model_variant == 'efficient_cbam':
            # Enhanced variant with CBAM
            self.cbam1 = CBAM(split_dim)
            self.cbam2 = CBAM(split_dim)
           
        elif model_variant == 'efficient_dual':
            # Dual path variant
            self.dual_path1 = DualPathBlock(split_dim, split_dim)
            self.dual_path2 = DualPathBlock(split_dim, split_dim)
           
        # FPN modules for multi-scale feature extraction
        self.fpn1 = FPN(split_dim)
        self.fpn2 = FPN(split_dim)
       
        # Global pooling
        self.global_pool = nn.AdaptiveAvgPool2d(1)
       
        # Additional features: Global Context Block
        self.gcb = nn.Sequential(
            nn.Conv2d(512, 64, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(64, 512, kernel_size=1),
            nn.Sigmoid()
        )
       
        # Fully connected layers with improved regularization
        self.classifier = nn.Sequential(
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),  # Increased dropout
            nn.Linear(512, 256),  # Added an extra FC layer
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate/2),  # Lower dropout in the last layer
            nn.Linear(256, num_classes)
        )
   
    def forward(self, x):
        # Extract features using EfficientNet backbone
        features = self.features(x)
       
        # Split features into two parts
        features1, features2 = torch.split(features, features.size(1)//2, dim=1)
       
        # Apply attention based on model variant
        if self.model_variant == 'efficient_basic':
            features1 = self.attention1(features1)
            features1 = self.se1(features1)
           
            features2 = self.attention2(features2)
            features2 = self.se2(features2)
           
        elif self.model_variant == 'efficient_cbam':
            features1 = self.cbam1(features1)
            features2 = self.cbam2(features2)
           
        elif self.model_variant == 'efficient_dual':
            features1 = self.dual_path1(features1)
            features2 = self.dual_path2(features2)
       
        # Apply FPN for multi-scale feature enhancement
        features1 = self.fpn1(features1)
        features2 = self.fpn2(features2)
       
        # Concatenate features
        combined_features = torch.cat([features1, features2], dim=1)
       
        # Apply global context
        context = self.gcb(combined_features)
        combined_features = combined_features * context
       
        # Global pooling
        pooled = self.global_pool(combined_features)
        pooled = pooled.view(pooled.size(0), -1)
       
        # Classification
        output = self.classifier(pooled)
       
        return output

# Training function with mixup and label smoothing
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=10, mixup_alpha=0.0):
    best_val_acc = 0.0
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
   
    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)
       
        # Training phase
        model.train()
        running_loss = 0.0
        running_corrects = 0
        total_samples = 0
       
        for inputs, labels in train_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
           
            # Apply mixup if alpha > 0
            if mixup_alpha > 0:
                inputs, labels_a, labels_b, lam = mixup_data(inputs, labels, mixup_alpha)
                use_mixup = True
            else:
                use_mixup = False
           
            # Zero the parameter gradients
            optimizer.zero_grad()
           
            # Forward pass
            outputs = model(inputs)
           
            # Calculate loss with or without mixup
            if use_mixup:
                loss = mixup_criterion(criterion, outputs, labels_a, labels_b, lam)
                # For accuracy calculation with mixup, we use the dominant label
                _, preds = torch.max(outputs, 1)
                running_corrects += (lam * torch.sum(preds == labels_a) +
                                   (1 - lam) * torch.sum(preds == labels_b)).float()
            else:
                loss = criterion(outputs, labels)
                _, preds = torch.max(outputs, 1)
                running_corrects += torch.sum(preds == labels.data)
           
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
           
            # Statistics
            batch_size = inputs.size(0)
            running_loss += loss.item() * batch_size
            total_samples += batch_size
       
        # Step scheduler if it's not OneCycleLR (which steps per batch)
        if CONFIG['scheduler'] != 'onecycle':
            scheduler.step()
       
        epoch_loss = running_loss / total_samples
        epoch_acc = running_corrects.double() / total_samples
       
        history['train_loss'].append(epoch_loss)
        history['train_acc'].append(epoch_acc.cpu().numpy())
       
        print(f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
       
        # Validation phase
        model.eval()
        val_running_loss = 0.0
        val_running_corrects = 0
        val_total_samples = 0
       
        # No gradient calculation needed for validation
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs = inputs.to(device)
                labels = labels.to(device)
               
                # Forward pass
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)
               
                # Statistics
                batch_size = inputs.size(0)
                val_running_loss += loss.item() * batch_size
                val_running_corrects += torch.sum(preds == labels.data)
                val_total_samples += batch_size
       
        val_epoch_loss = val_running_loss / val_total_samples
        val_epoch_acc = val_running_corrects.double() / val_total_samples
       
        history['val_loss'].append(val_epoch_loss)
        history['val_acc'].append(val_epoch_acc.cpu().numpy())
       
        print(f'Val Loss: {val_epoch_loss:.4f} Acc: {val_epoch_acc:.4f}')
       
        # Save best model
        if val_epoch_acc > best_val_acc:
            best_val_acc = val_epoch_acc
            torch.save(model.state_dict(), f'best_{CONFIG["model_variant"]}_model.pth')
            print(f'New best model saved with accuracy: {val_epoch_acc:.4f}')
       
        print()
   
    # Load best model weights
    model.load_state_dict(torch.load(f'best_{CONFIG["model_variant"]}_model.pth'))
    return model, history

# Evaluation function with advanced metrics
def evaluate_model(model, test_loader, num_classes=3):
    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []
   
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            probs = F.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)
           
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())
            all_probs.extend(probs.cpu().numpy())
   
    # Convert to numpy arrays
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)
   
    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
   
    # Confusion Matrix
    cm = confusion_matrix(all_labels, all_preds)
   
    # Classification Report
    report = classification_report(all_labels, all_preds)
   
    # ROC Curve and AUC
    # Binarize the labels for ROC calculation
    labels_bin = label_binarize(all_labels, classes=range(num_classes))
   
    # Compute ROC curve and ROC area for each class
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    for i in range(num_classes):
        fpr[i], tpr[i], _ = roc_curve(labels_bin[:, i], all_probs[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])
   
    # Compute micro-average ROC curve and ROC area
    fpr["micro"], tpr["micro"], _ = roc_curve(labels_bin.ravel(), all_probs.ravel())
    roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
   
    # Compute macro-average ROC curve and ROC area
    # First aggregate all false positive rates
    all_fpr = np.unique(np.concatenate([fpr[i] for i in range(num_classes)]))
   
    # Then interpolate all ROC curves at these points
    mean_tpr = np.zeros_like(all_fpr)
    for i in range(num_classes):
        mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
   
    # Finally average it and compute AUC
    mean_tpr /= num_classes
   
    fpr["macro"] = all_fpr
    tpr["macro"] = mean_tpr
    roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])
   
    # Plot ROC Curves
    plt.figure(figsize=(12, 8))
   
    # Plot micro-average ROC curve
    plt.plot(fpr["micro"], tpr["micro"],
             label=f'Micro-average ROC curve (area = {roc_auc["micro"]:.2f})',
             color='deeppink', linestyle=':', linewidth=4)
   
    # Plot macro-average ROC curve
    plt.plot(fpr["macro"], tpr["macro"],
             label=f'Macro-average ROC curve (area = {roc_auc["macro"]:.2f})',
             color='navy', linestyle=':', linewidth=4)
   
    # Plot ROC curves for all classes
    colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
    class_names = ['Meningioma', 'Glioma', 'Pituitary']
   
    for i, color, name in zip(range(num_classes), colors, class_names):
        plt.plot(fpr[i], tpr[i], color=color, lw=2,
                 label=f'ROC curve of {name} (area = {roc_auc[i]:.2f})')
   
    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f'ROC Curve - {CONFIG["model_variant"]}')
    plt.legend(loc="lower right")
    plt.savefig(f'{CONFIG["model_variant"]}_roc_curve.png')
    plt.close()
   
    # Plot Confusion Matrix
    plt.figure(figsize=(10, 8))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title(f'Confusion Matrix - {CONFIG["model_variant"]}')
    plt.colorbar()
   
    tick_marks = np.arange(num_classes)
    plt.xticks(tick_marks, class_names, rotation=45)
    plt.yticks(tick_marks, class_names)
   
    # Add text annotations to the confusion matrix
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, format(cm[i, j], 'd'),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")
   
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.savefig(f'{CONFIG["model_variant"]}_confusion_matrix.png')
    plt.close()
   
    # Return metrics and predictions
    return {
        'accuracy': accuracy,
        'confusion_matrix': cm,
        'classification_report': report,
        'roc_auc': roc_auc,
        'predictions': all_preds,
        'labels': all_labels,
        'probabilities': all_probs
    }

# Function to test with a random image
def test_random_image(model, images, labels, transform=None):
    """
    Test the model with a random image from the dataset
   
    Args:
        model: Trained model
        images: Image data
        labels: Labels
        transform: Image transform function
       
    Returns:
        prediction: The model's prediction
    """
    # Choose a random image
    idx = random.randint(0, len(images) - 1)
    img = images[idx]
    true_label = labels[idx] - 1  # Convert totrue_label = labels[idx] - 1  # Convert to 0-indexed
   
    # Preprocess the image
    img = np.repeat(img.reshape(224, 224, 1), 3, axis=2)
   
    # Apply transform if provided
    if transform:
        img_tensor = transform(img).unsqueeze(0).to(device)  # Add batch dimension
    else:
        img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).float().unsqueeze(0).to(device)
   
    # Get model prediction
    model.eval()
    with torch.no_grad():
        output = model(img_tensor)
        probs = F.softmax(output, dim=1)
        _, pred = torch.max(output, 1)
   
    # Display the image and prediction
    class_names = ['Meningioma', 'Glioma', 'Pituitary']
    plt.figure(figsize=(6, 6))
    plt.imshow(img, cmap='gray')
    plt.title(f'True: {class_names[true_label]}, Pred: {class_names[pred.item()]}\nConfidence: {probs[0][pred.item()]:.4f}')
    plt.axis('off')
    plt.savefig(f'{CONFIG["model_variant"]}_random_test.png')
    plt.close()
   
    print(f"Random image test - True: {class_names[true_label]}, Predicted: {class_names[pred.item()]}")
    print(f"Confidence scores: {probs[0].cpu().numpy()}")
   
    return {
        'image_idx': idx,
        'true_label': true_label,
        'predicted_label': pred.item(),
        'confidence': probs[0][pred.item()].item(),
        'all_probs': probs[0].cpu().numpy()
    }

# Main execution function
def run_experiment(fold_index):
    print(f"\n{'='*20} RUNNING FOLD {fold_index} {'='*20}\n")
   
    # Load data
    images, labels, fold_indices = load_data()
   
    # Get train and test data for this fold
    (train_images, train_labels), (test_images, test_labels) = get_train_test_data(
        images, labels, fold_indices, fold_index)
   
    # Create datasets with transforms
    train_transform, val_transform = get_transforms()
   
    train_dataset = BrainTumorDataset(train_images, train_labels, transform=train_transform)
    test_dataset = BrainTumorDataset(test_images, test_labels, transform=val_transform)
   
    # Create data loaders with larger batch size
    train_loader = DataLoader(
        train_dataset,
        batch_size=CONFIG['batch_size'],
        shuffle=True,
        num_workers=4,
        pin_memory=True  # Add pin_memory for faster data transfer to GPU
    )
   
    test_loader = DataLoader(
        test_dataset,
        batch_size=CONFIG['batch_size'],
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
   
    # Initialize model based on selected variant
    model = EnhancedEfficientNetClassifier(
        model_variant=CONFIG['model_variant'],
        num_classes=CONFIG['num_classes'],
        dropout_rate=CONFIG['dropout_rate']
    )
    model = model.to(device)
   
    # Loss function with label smoothing
    criterion = nn.CrossEntropyLoss(label_smoothing=CONFIG['label_smoothing'])
   
    # Optimizer with weight decay
    optimizer = optim.AdamW(
        model.parameters(),
        lr=CONFIG['base_lr'],
        weight_decay=CONFIG['weight_decay']
    )
   
    # Set scheduler based on configuration
    if CONFIG['scheduler'] == 'cosine':
        scheduler = CosineAnnealingLR(
            optimizer,
            T_max=CONFIG['num_epochs'],
            eta_min=CONFIG['base_lr'] / 100
        )
    elif CONFIG['scheduler'] == 'onecycle':
        # OneCycle learning rate scheduler
        scheduler = OneCycleLR(
            optimizer,
            max_lr=CONFIG['base_lr'] * 10,  # Peak LR is 10x the base LR
            steps_per_epoch=len(train_loader),
            epochs=CONFIG['num_epochs'],
            pct_start=0.3,  # Spend 30% of time increasing LR
            div_factor=25.0,  # Initial LR = max_lr/25
            final_div_factor=10000.0  # Final LR = initial_lr/10000
        )
   
    # Train model with specified epochs
    start_time = time.time()
    model, history = train_model(
        model,
        train_loader,
        test_loader,
        criterion,
        optimizer,
        scheduler,
        num_epochs=CONFIG['num_epochs'],
        mixup_alpha=CONFIG['mixup_alpha']
    )
    end_time = time.time()
   
    # Evaluate model with advanced metrics
    metrics = evaluate_model(model, test_loader, num_classes=CONFIG['num_classes'])
    accuracy = metrics['accuracy']
   
    print(f"Fold {fold_index} accuracy: {accuracy:.4f}")
    print(f"Training time: {end_time - start_time:.2f} seconds")
    print("\nClassification Report:")
    print(metrics['classification_report'])
   
    # Get ROC AUC scores
    print("\nROC AUC Scores:")
    for i in range(CONFIG['num_classes']):
        print(f"Class {i}: {metrics['roc_auc'][i]:.4f}")
    print(f"Micro-average: {metrics['roc_auc']['micro']:.4f}")
    print(f"Macro-average: {metrics['roc_auc']['macro']:.4f}")
   
    # Plot loss curve
    plt.figure(figsize=(10, 5))
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title(f'Loss Curves for Fold {fold_index} - {CONFIG["model_variant"]}')
    plt.savefig(f'{CONFIG["model_variant"]}_loss_curve_fold_{fold_index}.png')
    plt.close()
   
    # Plot accuracy curve
    plt.figure(figsize=(10, 5))
    plt.plot(history['train_acc'], label='Train Acc')
    plt.plot(history['val_acc'], label='Val Acc')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title(f'Accuracy Curves for Fold {fold_index} - {CONFIG["model_variant"]}')
    plt.savefig(f'{CONFIG["model_variant"]}_acc_curve_fold_{fold_index}.png')
    plt.close()
   
    # Test with a random image
    random_test_results = test_random_image(model, images, labels, transform=val_transform)
   
    # Save model and results
    torch.save(model.state_dict(), f'{CONFIG["model_variant"]}_model_fold_{fold_index}.pth')
    np.save(f'{CONFIG["model_variant"]}_results_fold_{fold_index}.npy', metrics)
   
    return metrics

# Run for all folds
if __name__ == "__main__":
    results_all = {}
   
    # Run for all 5 folds
    for fold in range(1, 6):
        try:
            metrics = run_experiment(fold)
            results_all[fold] = metrics
        except Exception as e:
            print(f"Error in fold {fold}: {e}")
   
    # Calculate and print average accuracy across all folds
    accuracies = [results_all[fold]['accuracy'] for fold in results_all if fold in results_all]
    if accuracies:
        avg_accuracy = np.mean(accuracies)
        print(f"\nAverage accuracy across all folds: {avg_accuracy:.4f}")
       
        # Calculate average AUC across all folds
        avg_auc_micro = np.mean([results_all[fold]['roc_auc']['micro'] for fold in results_all if fold in results_all])
        avg_auc_macro = np.mean([results_all[fold]['roc_auc']['macro'] for fold in results_all if fold in results_all])
        print(f"Average micro-average AUC: {avg_auc_micro:.4f}")
        print(f"Average macro-average AUC: {avg_auc_macro:.4f}")
       
        # Model and configuration summary
        print("\nModel Configuration Summary:")
        print(f"  Model variant: {CONFIG['model_variant']}")
        print(f"  Batch size: {CONFIG['batch_size']}")
        print(f"  Epochs: {CONFIG['num_epochs']}")
        print(f"  Learning rate: {CONFIG['base_lr']}")
        print(f"  Weight decay: {CONFIG['weight_decay']}")
        print(f"  Scheduler: {CONFIG['scheduler']}")
        print(f"  MixUp alpha: {CONFIG['mixup_alpha']}")
        print(f"  Label smoothing: {CONFIG['label_smoothing']}")
        print(f"  Dropout rate: {CONFIG['dropout_rate']}")
       
    # Visualize learning curves across all folds
    if results_all:
        # Plot average accuracy across folds
        plt.figure(figsize=(12, 6))
        for fold in results_all:
            plt.plot(results_all[fold]['val_acc'], label=f'Fold {fold}')
        plt.xlabel('Epoch')
        plt.ylabel('Validation Accuracy')
        plt.title(f'Validation Accuracy Across Folds - {CONFIG["model_variant"]}')
        plt.legend()
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.savefig(f'{CONFIG["model_variant"]}_all_folds_accuracy.png')
        plt.close()
       
        # Generate final summary visualization of class-wise metrics
        class_names = ['Meningioma', 'Glioma', 'Pituitary']
        class_aucs = []
        for i in range(CONFIG['num_classes']):
            class_auc = np.mean([results_all[fold]['roc_auc'][i] for fold in results_all if fold in results_all])
            class_aucs.append(class_auc)
       
        plt.figure(figsize=(10, 6))
        plt.bar(class_names, class_aucs, color=['skyblue', 'lightgreen', 'salmon'])
        plt.ylabel('Average AUC')
        plt.title(f'Average AUC by Class - {CONFIG["model_variant"]}')
        plt.ylim(0.8, 1.0)  # Adjust as needed
        for i, v in enumerate(class_aucs):
            plt.text(i, v + 0.01, f"{v:.4f}", ha='center')
        plt.savefig(f'{CONFIG["model_variant"]}_class_aucs.png')
        plt.close()

Using device: cuda
Training configuration:
  batch_size: 16
  num_epochs: 15
  base_lr: 0.0005
  weight_decay: 2e-05
  scheduler: onecycle
  mixup_alpha: 0.2
  label_smoothing: 0.1
  dropout_rate: 0.4
  model_variant: efficient_cbam
  num_classes: 3
Downloading dataset...
Path to dataset files: /kaggle/input/figshare-brain-tumor-dataset


Found 3064 .mat files
Processed 0 images


  lbl.append(int(images['label'][0]))


Processed 500 images
Processed 1000 images
Processed 1500 images
Processed 2000 images
Processed 2500 images
Processed 3000 images
Found cvind.mat at: /kaggle/input/figshare-brain-tumor-dataset/dataset/cvind.mat


Downloading: "https://download.pytorch.org/models/efficientnet_b4_rwightman-23ab8bcd.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b4_rwightman-23ab8bcd.pth
100%|██████████| 74.5M/74.5M [00:00<00:00, 187MB/s]


Epoch 1/15
----------
Train Loss: 0.8524 Acc: 0.6502
Val Loss: 0.5665 Acc: 0.8450
New best model saved with accuracy: 0.8450

Epoch 2/15
----------
Train Loss: 0.7178 Acc: 0.7592
Val Loss: 0.4905 Acc: 0.8985
New best model saved with accuracy: 0.8985

Epoch 3/15
----------
Train Loss: 0.6517 Acc: 0.8084
Val Loss: 0.4429 Acc: 0.9207
New best model saved with accuracy: 0.9207

Epoch 4/15
----------
Train Loss: 0.6410 Acc: 0.8180
Val Loss: 0.4472 Acc: 0.9225
New best model saved with accuracy: 0.9225

Epoch 5/15
----------
Train Loss: 0.6327 Acc: 0.8201
Val Loss: 0.4410 Acc: 0.9207

Epoch 6/15
----------
Train Loss: 0.6011 Acc: 0.8440
Val Loss: 0.4235 Acc: 0.9207

Epoch 7/15
----------
Train Loss: 0.5898 Acc: 0.8516
Val Loss: 0.4014 Acc: 0.9483
New best model saved with accuracy: 0.9483

Epoch 8/15
----------
Train Loss: 0.6075 Acc: 0.8425
Val Loss: 0.3908 Acc: 0.9465

Epoch 9/15
----------
Train Loss: 0.5618 Acc: 0.8664
Val Loss: 0.4145 Acc: 0.9373

Epoch 10/15
----------
Train Loss: 0.5

  model.load_state_dict(torch.load(f'best_{CONFIG["model_variant"]}_model.pth'))


Fold 1 accuracy: 0.9686
Training time: 610.73 seconds

Classification Report:
              precision    recall  f1-score   support

           0       0.91      0.94      0.93       113
           1       0.99      0.97      0.98       288
           2       0.97      1.00      0.99       141

    accuracy                           0.97       542
   macro avg       0.96      0.97      0.96       542
weighted avg       0.97      0.97      0.97       542


ROC AUC Scores:
Class 0: 0.9946
Class 1: 0.9986
Class 2: 0.9994
Micro-average: 0.9979
Macro-average: 0.9979
Random image test - True: Meningioma, Predicted: Meningioma
Confidence scores: [0.72648036 0.2397099  0.0338097 ]


Found 3064 .mat files
Processed 0 images


  lbl.append(int(images['label'][0]))


Processed 500 images
Processed 1000 images
Processed 1500 images
Processed 2000 images
Processed 2500 images
Processed 3000 images
Found cvind.mat at: /kaggle/input/figshare-brain-tumor-dataset/dataset/cvind.mat
Epoch 1/15
----------
Error in fold 2: Expected more than 1 value per channel when training, got input size torch.Size([1, 512])


Found 3064 .mat files
Processed 0 images
Processed 500 images
Processed 1000 images
Processed 1500 images
Processed 2000 images
Processed 2500 images
Processed 3000 images
Found cvind.mat at: /kaggle/input/figshare-brain-tumor-dataset/dataset/cvind.mat
Epoch 1/15
----------
Train Loss: 0.8472 Acc: 0.6590
Val Loss: 0.5645 Acc: 0.8671
New best model saved with accuracy: 0.8671

Epoch 2/15
----------
Train Loss: 0.7127 Acc: 0.7597
Val Loss: 0.4813 Acc: 0.9091
New best model saved with accuracy: 0.9091

Epoch 3/15
----------
Train Loss: 0.6680 Acc: 0.7980
Val Loss: 0.4512 Acc: 0.9406
New best model saved with accuracy: 0.9406

Epoch 4/15
----------
Trai

  model.load_state_dict(torch.load(f'best_{CONFIG["model_variant"]}_model.pth'))


Fold 3 accuracy: 0.9615
Training time: 611.12 seconds

Classification Report:
              precision    recall  f1-score   support

           0       0.90      0.91      0.90        99
           1       0.97      0.98      0.97       267
           2       0.99      0.96      0.97       206

    accuracy                           0.96       572
   macro avg       0.95      0.95      0.95       572
weighted avg       0.96      0.96      0.96       572


ROC AUC Scores:
Class 0: 0.9913
Class 1: 0.9964
Class 2: 0.9990
Micro-average: 0.9964
Macro-average: 0.9960
Random image test - True: Pituitary, Predicted: Pituitary
Confidence scores: [0.08611514 0.05901101 0.8548739 ]


Found 3064 .mat files
Processed 0 images


  lbl.append(int(images['label'][0]))


Processed 500 images
Processed 1000 images
Processed 1500 images
Processed 2000 images
Processed 2500 images
Processed 3000 images
Found cvind.mat at: /kaggle/input/figshare-brain-tumor-dataset/dataset/cvind.mat
Epoch 1/15
----------
Train Loss: 0.8509 Acc: 0.6553
Val Loss: 0.6091 Acc: 0.8535
New best model saved with accuracy: 0.8535

Epoch 2/15
----------
Train Loss: 0.7163 Acc: 0.7583
Val Loss: 0.5229 Acc: 0.8822
New best model saved with accuracy: 0.8822

Epoch 3/15
----------
Train Loss: 0.6532 Acc: 0.8063
Val Loss: 0.4652 Acc: 0.9204
New best model saved with accuracy: 0.9204

Epoch 4/15
----------
Train Loss: 0.6149 Acc: 0.8352
Val Loss: 0.4361 Acc: 0.9299
New best model saved with accuracy: 0.9299

Epoch 5/15
----------
Train Loss: 0.6193 Acc: 0.8314
Val Loss: 0.4259 Acc: 0.9363
New best model saved with accuracy: 0.9363

Epoch 6/15
----------
Train Loss: 0.6026 Acc: 0.8459
Val Loss: 0.4210 Acc: 0.9411
New best model saved with accuracy: 0.9411

Epoch 7/15
----------
Train Loss

  model.load_state_dict(torch.load(f'best_{CONFIG["model_variant"]}_model.pth'))


Fold 4 accuracy: 0.9729
Training time: 607.98 seconds

Classification Report:
              precision    recall  f1-score   support

           0       0.96      0.96      0.96       168
           1       0.98      0.98      0.98       287
           2       0.97      0.97      0.97       173

    accuracy                           0.97       628
   macro avg       0.97      0.97      0.97       628
weighted avg       0.97      0.97      0.97       628


ROC AUC Scores:
Class 0: 0.9975
Class 1: 0.9988
Class 2: 0.9988
Micro-average: 0.9984
Macro-average: 0.9985
Random image test - True: Pituitary, Predicted: Pituitary
Confidence scores: [0.07824273 0.04280282 0.8789544 ]


Found 3064 .mat files
Processed 0 images


  lbl.append(int(images['label'][0]))


Processed 500 images
Processed 1000 images
Processed 1500 images
Processed 2000 images
Processed 2500 images
Processed 3000 images
Found cvind.mat at: /kaggle/input/figshare-brain-tumor-dataset/dataset/cvind.mat
Epoch 1/15
----------
Train Loss: 0.8281 Acc: 0.6729
Val Loss: 0.5819 Acc: 0.8554
New best model saved with accuracy: 0.8554

Epoch 2/15
----------
Train Loss: 0.7144 Acc: 0.7590
Val Loss: 0.5155 Acc: 0.8787
New best model saved with accuracy: 0.8787

Epoch 3/15
----------
Train Loss: 0.6549 Acc: 0.8035
Val Loss: 0.4451 Acc: 0.9160
New best model saved with accuracy: 0.9160

Epoch 4/15
----------
Train Loss: 0.6380 Acc: 0.8211
Val Loss: 0.4333 Acc: 0.9300
New best model saved with accuracy: 0.9300

Epoch 5/15
----------
Train Loss: 0.6243 Acc: 0.8340
Val Loss: 0.4299 Acc: 0.9300

Epoch 6/15
----------
Train Loss: 0.6443 Acc: 0.8251
Val Loss: 0.4390 Acc: 0.9347
New best model saved with accuracy: 0.9347

Epoch 7/15
----------
Train Loss: 0.6060 Acc: 0.8417
Val Loss: 0.4356 Acc: 

  model.load_state_dict(torch.load(f'best_{CONFIG["model_variant"]}_model.pth'))


Fold 5 accuracy: 0.9596
Training time: 603.78 seconds

Classification Report:
              precision    recall  f1-score   support

           0       0.89      0.94      0.91       144
           1       0.99      0.96      0.97       296
           2       0.97      0.98      0.98       203

    accuracy                           0.96       643
   macro avg       0.95      0.96      0.95       643
weighted avg       0.96      0.96      0.96       643


ROC AUC Scores:
Class 0: 0.9888
Class 1: 0.9985
Class 2: 0.9942
Micro-average: 0.9945
Macro-average: 0.9945
Random image test - True: Pituitary, Predicted: Pituitary
Confidence scores: [0.03249943 0.03405317 0.9334475 ]

Average accuracy across all folds: 0.9657
Average micro-average AUC: 0.9968
Average macro-average AUC: 0.9967

Model Configuration Summary:
  Model variant: efficient_cbam
  Batch size: 16
  Epochs: 15
  Learning rate: 0.0005
  Weight decay: 2e-05
  Scheduler: onecycle
  MixUp alpha: 0.2
  Label smoothing: 0.1
  Dropo

KeyError: 'val_acc'

<Figure size 1200x600 with 0 Axes>