<a href="https://colab.research.google.com/github/amelft81/ASDEEG/blob/main/Complete_EEG_to_ASD_Prediction_Pipeline_(Simulated_Data).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import numpy as np
from collections import Counter
import os
import copy # For deep copying model state for early stopping
import matplotlib.pyplot as plt
from scipy.interpolate import griddata
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix

# --- 1. Configuration Parameters ---
# Data generation parameters (from eeg_dataset_generator)
NUM_ELECTRODES = 116 # Paper mentions 116 electrodes after exclusion
SAMPLING_RATE = 500  # Hz, as per paper
SEGMENT_LENGTH_SEC = 1 # seconds, as per paper
NUM_SAMPLES_PER_SEGMENT = SAMPLING_RATE * SEGMENT_LENGTH_SEC # 500 samples
IMAGE_SIZE = 224     # 224x224 pixels, as per ResNet-50 input
FREQ_BANDS = {
    'theta': (4, 7),   # Hz
    'alpha': (8, 13),  # Hz
    'beta': (13, 30)   # Hz
}
# Desired initial class distribution (approximate from paper: 81% NON-ASD, 19% ASD)
ASD_RATIO = 0.19
INITIAL_TOTAL_SAMPLES = 1000 # Total samples to generate before oversampling (for demonstration)

# Model and training parameters (from resnet_eeg_classifier)
NUM_CLASSES = 2 # ASD (1) or NON-ASD (0)
BATCH_SIZE = 100 # As per paper
LEARNING_RATE = 1e-3 # As per paper
NUM_EPOCHS = 100 # Maximum epochs, early stopping will likely stop sooner
PATIENCE = 10 # Number of epochs to wait for improvement before stopping (for Early Stopping)

# ImageNet normalization values for pre-trained models
NORM_MEAN = [0.485, 0.456, 0.406]
NORM_STD = [0.229, 0.224, 0.225]

# Set device to GPU if available, otherwise CPU
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# --- 2. Simulated EEG Data Generation and Image Transformation Functions ---

def generate_electrode_positions(num_electrodes, radius=0.5):
    """Generates simulated 2D electrode positions in a circular pattern."""
    angles = np.linspace(0, 2 * np.pi, num_electrodes, endpoint=False)
    x = radius * np.cos(angles) + np.random.normal(0, 0.05, num_electrodes) # Add some noise
    y = radius * np.sin(angles) + np.random.normal(0, 0.05, num_electrodes) # Add some noise
    # Add a central electrode
    x = np.append(x, 0)
    y = np.append(y, 0)
    return np.array([x, y]).T

def generate_eeg_signal(num_samples, sampling_rate, freq_bands, is_asd=False):
    """
    Generates a simulated EEG signal for one electrode.
    ASD signals might have slightly different characteristics (e.g., more noise or altered band power).
    """
    t = np.linspace(0, num_samples / sampling_rate, num_samples, endpoint=False)
    signal = np.zeros(num_samples)

    # Add base frequencies for each band
    for band_name, (low_f, high_f) in freq_bands.items():
        center_f = (low_f + high_f) / 2
        amplitude = np.random.uniform(0.5, 1.5)
        # Simplified simulation: ASD signals might have altered power in certain bands
        if is_asd:
            if band_name == 'theta': # Example: slightly higher theta in ASD
                amplitude *= 1.2
            elif band_name == 'beta': # Example: slightly lower beta in ASD
                amplitude *= 0.8
        signal += amplitude * np.sin(2 * np.pi * center_f * t + np.random.uniform(0, 2 * np.pi))

    # Add random noise
    noise_level = np.random.uniform(0.1, 0.5)
    if is_asd: # Simplified simulation: ASD signals might be noisier
        noise_level *= 1.2
    signal += noise_level * np.random.randn(num_samples)

    return signal

def get_band_power(signal, sampling_rate, freq_range):
    """Calculates the average power in a specific frequency range using FFT."""
    n = len(signal)
    yf = np.fft.fft(signal)
    xf = np.fft.fftfreq(n, 1 / sampling_rate)

    # Find indices corresponding to the frequency range
    min_freq, max_freq = freq_range
    indices = np.where((xf >= min_freq) & (xf <= max_freq))

    # Calculate power (magnitude squared) in the specified band
    power = np.mean(np.abs(yf[indices])**2)
    return power

def normalize_band_to_uint8(arr):
    """
    Normalizes a 2D array to the 0-255 range and converts it to uint8.
    Handles cases where min/max might be the same (flat array).
    """
    min_val = np.min(arr)
    max_val = np.max(arr)
    if max_val == min_val:
        return np.full(arr.shape, 128, dtype=np.uint8) # Default grey if flat
    return ((arr - min_val) / (max_val - min_val) * 255).astype(np.uint8)

def create_eeg_image(electrode_positions, band_powers, image_size):
    """
    Interpolates band powers onto a 2D grid and creates an RGB image.
    Each channel (R, G, B) corresponds to a frequency band.
    """
    grid_x, grid_y = np.mgrid[-1:1:complex(0, image_size), -1:1:complex(0, image_size)]

    # Interpolate for each band
    # Method='cubic' for smoother interpolation, fill_value=0 for outside points
    interp_theta = griddata(electrode_positions, band_powers['theta'], (grid_x, grid_y), method='cubic', fill_value=0)
    interp_alpha = griddata(electrode_positions, band_powers['alpha'], (grid_x, grid_y), method='cubic', fill_value=0)
    interp_beta = griddata(electrode_positions, band_powers['beta'], (grid_x, grid_y), method='cubic', fill_value=0)

    # Assign bands to RGB channels and normalize
    # Theta -> Red, Alpha -> Green, Beta -> Blue (common choice, can be varied)
    img_r = normalize_band_to_uint8(interp_theta)
    img_g = normalize_band_to_uint8(interp_alpha)
    img_b = normalize_band_to_uint8(interp_beta)

    # Stack into an RGB image (H, W, C)
    eeg_image = np.stack([img_r, img_g, img_b], axis=-1)
    return eeg_image

def generate_and_oversample_dataset(num_samples_total, asd_ratio, image_size, num_electrodes, sampling_rate, freq_bands):
    """
    Generates a synthetic EEG image dataset and then oversamples the minority class.
    Returns X (images) and y (labels) as numpy arrays.
    """
    dataset = []
    electrode_positions = generate_electrode_positions(num_electrodes)

    print(f"Generating {num_samples_total} initial synthetic EEG images...")
    for i in range(num_samples_total):
        is_asd = np.random.rand() < asd_ratio
        label = 1 if is_asd else 0 # 1 for ASD, 0 for NON-ASD

        # Generate EEG signals for all electrodes
        # +1 because generate_electrode_positions adds a central electrode
        electrode_signals = [generate_eeg_signal(NUM_SAMPLES_PER_SEGMENT, sampling_rate, freq_bands, is_asd)
                             for _ in range(num_electrodes + 1)]

        # Calculate band powers for each electrode
        band_powers_per_electrode = {band: [] for band in freq_bands}
        for signal in electrode_signals:
            for band_name, freq_range in freq_bands.items():
                power = get_band_power(signal, sampling_rate, freq_range)
                band_powers_per_electrode[band_name].append(power)

        # Create the EEG image
        eeg_image = create_eeg_image(electrode_positions, band_powers_per_electrode, image_size)
        dataset.append((eeg_image, label))

        if (i + 1) % (num_samples_total // 10) == 0:
            print(f"  Generated {i + 1}/{num_samples_total} samples...")

    print("Initial dataset generation complete.")

    X = np.array([item[0] for item in dataset])
    y = np.array([item[1] for item in dataset])

    print(f"Original dataset shape: {Counter(y)}")

    # Apply RandomOverSampler to balance the dataset
    print("Applying RandomOverSampler to balance the dataset...")
    n_samples, h, w, c = X.shape
    X_flat = X.reshape(n_samples, -1) # Flatten images for the sampler

    from imblearn.over_sampling import RandomOverSampler
    ros = RandomOverSampler(random_state=42)
    X_resampled_flat, y_resampled = ros.fit_resample(X_flat, y)

    X_resampled = X_resampled_flat.reshape(-1, h, w, c) # Reshape images back
    print(f"Resampled dataset shape: {Counter(y_resampled)}")
    print("Oversampling complete.")

    return X_resampled, y_resampled

# --- 3. Custom PyTorch Dataset Class ---
class EEGImageDataset(Dataset):
    """
    Custom PyTorch Dataset for loading EEG images and their labels.
    """
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

# --- 4. Model Definition: Pre-trained ResNet-50 ---
def get_resnet_model(num_classes, freeze_features=True):
    """
    Loads a pre-trained ResNet-50 model and modifies its final layer
    for binary classification.
    """
    model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
    print("Loaded pre-trained ResNet-50 model.")

    if freeze_features:
        for param in model.parameters():
            param.requires_grad = False
        print("Frozen all feature extractor layers.")

    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)
    print(f"Modified final fully connected layer to output {num_classes} classes.")

    return model

# --- 5. Data Loading and Preprocessing (using generated data) ---
def prepare_data_loaders(images, labels, batch_size, norm_mean, norm_std):
    """
    Prepares PyTorch DataLoaders from the generated and oversampled dataset.
    """
    data_transforms = transforms.Compose([
        transforms.ToPILImage(), # Convert numpy array to PIL Image for torchvision transforms
        transforms.ToTensor(),   # Converts PIL Image to FloatTensor (0-1) and (C, H, W)
        transforms.Normalize(mean=norm_mean, std=norm_std) # Normalize with ImageNet stats
    ])

    # Split data into training and validation sets (80/20 split as in paper's experiments)
    train_images, val_images, train_labels, val_labels = train_test_split(
        images, labels, test_size=0.2, random_state=42, stratify=labels
    )

    print(f"\nTraining set distribution after split: {Counter(train_labels)}")
    print(f"Validation set distribution after split: {Counter(val_labels)}")

    # Create datasets
    train_dataset = EEGImageDataset(train_images, train_labels, transform=data_transforms)
    val_dataset = EEGImageDataset(val_images, val_labels, transform=data_transforms)

    # Weighted Random Sampler for training data (already oversampled, but sampler ensures balanced batches)
    # The oversampling already balanced the dataset, so weights here will be uniform if `num_samples` is the new total.
    # However, the paper explicitly mentions WRS with replacement for balanced mini-batches,
    # so we'll re-calculate weights based on the (now balanced) train_labels.
    class_counts = Counter(train_labels)
    num_samples_train = sum(class_counts.values())
    class_weights = {cls: num_samples_train / count for cls, count in class_counts.items()}
    sample_weights = [class_weights[label] for label in train_labels]
    sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=num_samples_train, # Draw 'num_samples_train' times (with replacement)
        replacement=True
    )
    print("WeightedRandomSampler initialized for training data loaders.")

    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

    return train_loader, val_loader

# --- 6. Training Function with Early Stopping ---
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, patience, device):
    """
    Trains the deep learning model with early stopping.
    """
    best_val_loss = float('inf')
    epochs_no_improve = 0
    best_model_wts = copy.deepcopy(model.state_dict())

    model.to(device)

    print("\nStarting model training...")
    for epoch in range(num_epochs):
        model.train()
        running_train_loss = 0.0
        correct_train_predictions = 0
        total_train_samples = 0

        for inputs, labels in train_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_train_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total_train_samples += labels.size(0)
            correct_train_predictions += (predicted == labels).sum().item()

        epoch_train_loss = running_train_loss / total_train_samples
        epoch_train_accuracy = correct_train_predictions / total_train_samples * 100

        # --- Validation Phase ---
        model.eval()
        running_val_loss = 0.0
        correct_val_predictions = 0
        total_val_samples = 0

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs = inputs.to(device)
                labels = labels.to(device)

                outputs = model(inputs)
                loss = criterion(outputs, labels)

                running_val_loss += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs.data, 1)
                total_val_samples += labels.size(0)
                correct_val_predictions += (predicted == labels).sum().item()

        epoch_val_loss = running_val_loss / total_val_samples
        epoch_val_accuracy = correct_val_predictions / total_val_samples * 100

        print(f"Epoch {epoch+1}/{num_epochs}: "
              f"Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_accuracy:.2f}% | "
              f"Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_accuracy:.2f}%")

        # --- Early Stopping Logic ---
        if epoch_val_loss < best_val_loss:
            best_val_loss = epoch_val_loss
            best_model_wts = copy.deepcopy(model.state_dict())
            epochs_no_improve = 0
            # print(f"  Validation loss improved. Saving model state. Best Loss: {best_val_loss:.4f}") # Uncomment for more verbose output
        else:
            epochs_no_improve += 1
            # print(f"  Validation loss did not improve. Patience: {epochs_no_improve}/{patience}") # Uncomment for more verbose output
            if epochs_no_improve >= patience:
                print(f"Early stopping triggered after {epoch+1} epochs due to no improvement in validation loss.")
                model.load_state_dict(best_model_wts)
                return model

    print("Training finished (max epochs reached).")
    model.load_state_dict(best_model_wts)
    return model

# --- 7. Evaluation Function ---
def evaluate_model(model, data_loader, device, dataset_name="Test"):
    """
    Evaluates the model's performance on a given DataLoader.
    """
    model.eval() # Set model to evaluation mode
    all_labels = []
    all_predictions = []

    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_predictions)
    precision = precision_score(all_labels, all_predictions, average='binary') # 'binary' for 2 classes
    recall = recall_score(all_labels, all_predictions, average='binary')
    f1 = f1_score(all_labels, all_predictions, average='binary')
    cm = confusion_matrix(all_labels, all_predictions)

    print(f"\n--- {dataset_name} Set Evaluation ---")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1-Score: {f1:.4f}")
    print("Confusion Matrix:")
    print(cm)
    print(f"  (Rows: True Labels, Columns: Predicted Labels)")
    print(f"  [[True Negative (NON-ASD predicted NON-ASD), False Positive (NON-ASD predicted ASD)]")
    print(f"   [False Negative (ASD predicted NON-ASD), True Positive (ASD predicted ASD)]]")


# --- Main Execution ---
if __name__ == "__main__":
    # 1. Generate and Oversample the Dataset
    X_resampled, y_resampled = generate_and_oversample_dataset(
        num_samples_total=INITIAL_TOTAL_SAMPLES,
        asd_ratio=ASD_RATIO,
        image_size=IMAGE_SIZE,
        num_electrodes=NUM_ELECTRODES,
        sampling_rate=SAMPLING_RATE,
        freq_bands=FREQ_BANDS
    )

    # 2. Prepare DataLoaders for training and validation
    train_loader, val_loader = prepare_data_loaders(
        X_resampled, y_resampled, BATCH_SIZE, NORM_MEAN, NORM_STD
    )

    # 3. Get the ResNet-50 model
    model = get_resnet_model(NUM_CLASSES, freeze_features=True)

    # 4. Define Loss Function and Optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    # 5. Train the model
    trained_model = train_model(
        model, train_loader, val_loader, criterion, optimizer, NUM_EPOCHS, PATIENCE, DEVICE
    )

    print("\n--- Training Complete ---")

    # 6. Evaluate the trained model on the validation set
    evaluate_model(trained_model, val_loader, DEVICE, dataset_name="Validation")

    # Optional: Visualize a few generated images before training
    # This part is included in the eeg_dataset_generator artifact for initial checks.
    # You can uncomment and run it here if you want to see samples again.
    # print("\nDisplaying a few generated EEG images from the resampled dataset...")
    # fig, axes = plt.subplots(2, 4, figsize=(12, 6))
    # axes = axes.flatten()
    # for i in range(min(8, len(X_resampled))):
    #     ax = axes[i]
    #     ax.imshow(X_resampled[i])
    #     ax.set_title(f"Label: {'ASD' if y_resampled[i] == 1 else 'NON-ASD'}")
    #     ax.axis('off')
    # plt.tight_layout()
    # plt.show()