<a href="https://colab.research.google.com/github/amelft81/ASDEEG/blob/main/ResNet_50_EEG_Image_Classifier_with_Oversampling_and_Early_Stopping.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import numpy as np
from collections import Counter
import os
import copy # For deep copying model state for early stopping
from sklearn.model_selection import train_test_split # Import train_test_split

# --- 1. Configuration Parameters ---
# Paths to your generated synthetic dataset
DATA_DIR = "synthetic_eeg_dataset"
X_RESAMPLED_PATH = os.path.join(DATA_DIR, 'X_resampled.npy')
Y_RESAMPLED_PATH = os.path.join(DATA_DIR, 'y_resampled.npy')

# Model and training parameters
NUM_CLASSES = 2 # ASD or NON-ASD
BATCH_SIZE = 100 # As per paper
LEARNING_RATE = 1e-3 # As per paper
NUM_EPOCHS = 100 # Maximum epochs, early stopping will likely stop sooner
PATIENCE = 10 # Number of epochs to wait for improvement before stopping (for Early Stopping)

# ImageNet normalization values for pre-trained models
# These are standard mean and std deviation for images trained on ImageNet
# ResNet-50 expects inputs normalized with these values.
NORM_MEAN = [0.485, 0.456, 0.406]
NORM_STD = [0.229, 0.224, 0.225]

# Set device to GPU if available, otherwise CPU
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# --- 2. Custom Dataset Class ---
class EEGImageDataset(Dataset):
    """
    Custom PyTorch Dataset for loading EEG images and their labels.
    """
    def __init__(self, images, labels, transform=None):
        """
        Args:
            images (np.ndarray): NumPy array of EEG images (H, W, C).
            labels (np.ndarray): NumPy array of corresponding labels.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        """Returns the total number of samples."""
        return len(self.images)

    def __getitem__(self, idx):
        """
        Retrieves an image and its label at the given index.
        Applies transformations if provided.
        """
        image = self.images[idx]
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

# --- 3. Model Definition: Pre-trained ResNet-50 ---
def get_resnet_model(num_classes, freeze_features=True):
    """
    Loads a pre-trained ResNet-50 model and modifies its final layer
    for binary classification.

    Args:
        num_classes (int): The number of output classes (2 for ASD/NON-ASD).
        freeze_features (bool): If True, freezes all layers except the final
                                classification layer. This is common for transfer learning.

    Returns:
        torch.nn.Module: The modified ResNet-50 model.
    """
    # Load pre-trained ResNet-50
    # The paper mentions using PyTorch model zoo and ImageNet pre-training.
    model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
    print("Loaded pre-trained ResNet-50 model.")

    if freeze_features:
        # Freeze all parameters in the network
        for param in model.parameters():
            param.requires_grad = False
        print("Frozen all feature extractor layers.")

    # Get the number of features from the original fully connected layer
    num_ftrs = model.fc.in_features
    # Replace the final fully connected layer with a new one for our number of classes
    model.fc = nn.Linear(num_ftrs, num_classes)
    print(f"Modified final fully connected layer to output {num_classes} classes.")

    return model

# --- 4. Data Loading and Preprocessing ---
def load_and_prepare_data(x_path, y_path, batch_size, norm_mean, norm_std):
    """
    Loads the dataset, applies transformations, creates DataLoaders,
    and sets up WeightedRandomSampler for class imbalance.
    """
    try:
        images = np.load(x_path)
        labels = np.load(y_path)

        # Convert images from uint8 (0-255) to float32 (0-1) and then normalize
        # PyTorch expects (C, H, W) format, so we need to transpose.
        data_transforms = transforms.Compose([
            transforms.ToPILImage(), # Convert numpy array to PIL Image for torchvision transforms
            transforms.ToTensor(),   # Converts PIL Image to FloatTensor (0-1) and (C, H, W)
            transforms.Normalize(mean=norm_mean, std=norm_std) # Normalize with ImageNet stats
        ])

        # Split data into training and validation sets (e.g., 80/20 split as in paper's experiments)
        # The paper also mentions leave-one-participant-out, which is more complex to simulate here.
        # We'll use a standard train/val split for demonstration.
        train_images, val_images, train_labels, val_labels = train_test_split(
            images, labels, test_size=0.2, random_state=42, stratify=labels
        )

        print(f"\nTraining set distribution: {Counter(train_labels)}")
        print(f"Validation set distribution: {Counter(val_labels)}")

        # Create datasets
        train_dataset = EEGImageDataset(train_images, train_labels, transform=data_transforms)
        val_dataset = EEGImageDataset(val_images, val_labels, transform=data_transforms)

        # Implement Weighted Random Sampler for training data (as per paper)
        # This addresses the class imbalance by giving higher probability to minority class samples.
        class_counts = Counter(train_labels)
        num_samples = sum(class_counts.values())
        class_weights = {cls: num_samples / count for cls, count in class_counts.items()}
        sample_weights = [class_weights[label] for label in train_labels]
        sampler = WeightedRandomSampler(
            weights=sample_weights,
            num_samples=num_samples, # Draw 'num_samples' times (with replacement)
            replacement=True
        )
        print("WeightedRandomSampler initialized for training data.")

        # Create DataLoaders
        train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler, num_workers=0)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0) # No sampler for validation

        return train_loader, val_loader

    except FileNotFoundError:
        print(f"Error: Dataset files not found at {x_path} and {y_path}.")
        print("Please ensure you have run the 'eeg_dataset_generator' code first.")
        return None, None # Return None for both loaders if files are not found


# --- 5. Training Function with Early Stopping ---
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, patience, device):
    """
    Trains the deep learning model with early stopping.

    Args:
        model (torch.nn.Module): The neural network model.
        train_loader (DataLoader): DataLoader for the training set.
        val_loader (DataLoader): DataLoader for the validation set.
        criterion (torch.nn.Module): Loss function.
        optimizer (torch.optim.Optimizer): Optimizer.
        num_epochs (int): Maximum number of training epochs.
        patience (int): Number of epochs to wait for validation loss improvement.
        device (torch.device): Device to train on (CPU or GPU).

    Returns:
        torch.nn.Module: The trained model (best version based on validation loss).
    """
    if train_loader is None or val_loader is None:
        print("Skipping training due to data loading error.")
        return None

    best_val_loss = float('inf')
    epochs_no_improve = 0
    best_model_wts = copy.deepcopy(model.state_dict())

    model.to(device) # Move model to the specified device

    print("\nStarting model training...")
    for epoch in range(num_epochs):
        model.train() # Set model to training mode
        running_train_loss = 0.0
        correct_train_predictions = 0
        total_train_samples = 0

        for inputs, labels in train_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad() # Zero the parameter gradients

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward() # Backpropagation
            optimizer.step() # Update weights

            running_train_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total_train_samples += labels.size(0)
            correct_train_predictions += (predicted == labels).sum().item()

        epoch_train_loss = running_train_loss / total_train_samples
        epoch_train_accuracy = correct_train_predictions / total_train_samples * 100

        # --- Validation Phase ---
        model.eval() # Set model to evaluation mode
        running_val_loss = 0.0
        correct_val_predictions = 0
        total_val_samples = 0

        with torch.no_grad(): # Disable gradient calculation for validation
            for inputs, labels in val_loader:
                inputs = inputs.to(device)
                labels = labels.to(device)

                outputs = model(inputs)
                loss = criterion(outputs, labels)

                running_val_loss += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs.data, 1)
                total_val_samples += labels.size(0)
                correct_val_predictions += (predicted == labels).sum().item()

        epoch_val_loss = running_val_loss / total_val_samples
        epoch_val_accuracy = correct_val_predictions / total_val_samples * 100

        print(f"Epoch {epoch+1}/{num_epochs}: "
              f"Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_accuracy:.2f}% | "
              f"Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_accuracy:.2f}%")

        # --- Early Stopping Logic ---
        if epoch_val_loss < best_val_loss:
            best_val_loss = epoch_val_loss
            best_model_wts = copy.deepcopy(model.state_dict()) # Save the best model state
            epochs_no_improve = 0
            print(f"  Validation loss improved. Saving model state. Best Loss: {best_val_loss:.4f}")
        else:
            epochs_no_improve += 1
            print(f"  Validation loss did not improve. Patience: {epochs_no_improve}/{patience}")
            if epochs_no_improve >= patience:
                print(f"Early stopping triggered after {epoch+1} epochs due to no improvement in validation loss.")
                model.load_state_dict(best_model_wts) # Load the best model weights
                return model

    print("Training finished (max epochs reached).")
    model.load_state_dict(best_model_wts) # Load the best model weights even if max epochs reached
    return model

# --- Main Execution ---
if __name__ == "__main__":
    # 1. Load and prepare data
    train_loader, val_loader = load_and_prepare_data(
        X_RESAMPLED_PATH, Y_RESAMPLED_PATH, BATCH_SIZE, NORM_MEAN, NORM_STD
    )

    # 2. Get the ResNet-50 model
    model = get_resnet_model(NUM_CLASSES, freeze_features=True) # Freeze features for faster initial training

    # 3. Define Loss Function and Optimizer
    # CrossEntropyLoss is suitable for multi-class classification (even binary)
    # and implicitly applies softmax.
    criterion = nn.CrossEntropyLoss()
    # Adam optimizer as specified in the paper
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    # 4. Train the model
    trained_model = train_model(
        model, train_loader, val_loader, criterion, optimizer, NUM_EPOCHS, PATIENCE, DEVICE
    )

    print("\nModel training complete. The 'trained_model' variable holds the best model weights.")
    # You can now save the trained model for future inference:
    # torch.save(trained_model.state_dict(), 'resnet50_eeg_classifier.pth')
    # To load:
    # model = get_resnet_model(NUM_CLASSES)
    # model.load_state_dict(torch.load('resnet50_eeg_classifier.pth'))
    # model.eval() # Set to evaluation mode before inference

Using device: cpu
Error: Dataset files not found at synthetic_eeg_dataset/X_resampled.npy and synthetic_eeg_dataset/y_resampled.npy.
Please ensure you have run the 'eeg_dataset_generator' code first.


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 160MB/s]


Loaded pre-trained ResNet-50 model.
Frozen all feature extractor layers.
Modified final fully connected layer to output 2 classes.
Skipping training due to data loading error.

Model training complete. The 'trained_model' variable holds the best model weights.
