<a href="https://colab.research.google.com/github/Ajinkya-18/NeuroVision/blob/main/neurovision.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# NeuroVision

## Dataset Preparation

In [None]:
# Make sure your Google Drive is mounted
from google.colab import drive
drive.mount('/content/drive')

# --- This is the key step ---
# Use 'cp -r' to recursively copy the entire folder.
# This will be SLOW, as it's copying thousands of individual files over the network.
# Let it run until it's finished.

print("Starting to copy dataset folder from Drive to local storage...")
print("This may take a significant amount of time, please be patient.")

# !cp -r "/content/drive/MyDrive/NeuroVision/Segregated_Dataset" "/content/eeg_dataset"

print("Copying complete!")

Mounted at /content/drive
Starting to copy dataset folder from Drive to local storage...
This may take a significant amount of time, please be patient.
Copying complete!


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import shutil

In [None]:
# function to read the dir contents of dataset folder and segregate them
# into n separate classes.
def create_dataset_folders(metadata_file:str, csv_dir:str, output_dir:str):
    class_id_to_folder = {}

    with open(metadata_file, 'r', encoding='utf-8') as f:
        for line in f:
            parts = line.strip().split('\t')

            if len(parts) < 3:
                continue

            label_str, _, class_id = parts
            # print(label_str, class_id)
            first_label = label_str.split(',')[0].strip()
            # print(first_label)
            class_id_to_folder[class_id] = first_label

        count = 0
        for filename in os.listdir(csv_dir):
            if not filename.endswith('.csv'):
                continue

            class_id = filename.split('_')[3]

            folder_name = class_id_to_folder.get(class_id)
            print(folder_name)

            if not folder_name:
                print(f'Unknown class id: {class_id}')
                continue

            safe_folder = folder_name.replace('/', '_').replace('\\', '_').strip()

            dest_folder = os.path.join(output_dir, safe_folder)
            os.makedirs(dest_folder, exist_ok=True)

            src_path = os.path.join(csv_dir, filename)
            dst_path = os.path.join(dest_folder, filename)

            # print(f"Move: {src_path} to {dst_path}")
            count+=1
            print(count)
            shutil.copy(src_path, dst_path)


In [None]:
# create_dataset_folders('../data/WordReport-v1.04.txt',
#                        '../data/MindBigData-Imagenet',
#                        '../data/Segregated_Dataset')

In [None]:
import shutil
import json
import os

def reorganize_dataset(mapping_file, src_root, dst_root, move=False):
    with open(mapping_file, 'r') as f:
        mapping = json.load(f)

    os.makedirs(os.path.dirname(dst_root), exist_ok=True)
    # src_root = os.path.dirname(src_root)

    for super_class, sub_classes in mapping.items():
        super_cls_dir = os.path.join(dst_root, super_class)
        os.makedirs(super_cls_dir, exist_ok=True)

        for sub_class in sub_classes:
            sub_cls_dir = os.path.join(src_root, sub_class)
            if not os.path.exists(sub_cls_dir):
                print(f"[Warning] Sub-class folder not found: {sub_cls_dir}")
                continue

            for file_name in os.listdir(sub_cls_dir):
                src_file = os.path.join(sub_cls_dir, file_name)
                dst_file = os.path.join(super_cls_dir, file_name)

                if move:
                    shutil.move(src_file, dst_file)

                else:
                    shutil.copy2(src_file, dst_file)

            print(f"[OK] {'Moved' if move else 'Copied'} {sub_class} -> {super_class}")
    print("Dataset reorganization complete!")

In [None]:
reorganize_dataset(mapping_file='/content/drive/MyDrive/NeuroVision/meta-learner-class-mapping-v2.json',
                   src_root='/content/drive/MyDrive/NeuroVision/Segregated_Dataset',
                   dst_root='/content/meta_learner_dataset',
                   move=False)

[OK] Copied affenpinscher -> dogs_n_cats
[OK] Copied Afghan hound -> dogs_n_cats
[OK] Copied Airedale -> dogs_n_cats
[OK] Copied American Staffordshire terrier -> dogs_n_cats
[OK] Copied Appenzeller -> dogs_n_cats
[OK] Copied Australian terrier -> dogs_n_cats
[OK] Copied basenji -> dogs_n_cats
[OK] Copied basset -> dogs_n_cats
[OK] Copied beagle -> dogs_n_cats
[OK] Copied Bedlington terrier -> dogs_n_cats
[OK] Copied Bernese mountain dog -> dogs_n_cats
[OK] Copied black-and-tan coonhound -> dogs_n_cats
[OK] Copied Blenheim spaniel -> dogs_n_cats
[OK] Copied bloodhound -> dogs_n_cats
[OK] Copied bluetick -> dogs_n_cats
[OK] Copied Border collie -> dogs_n_cats
[OK] Copied Border terrier -> dogs_n_cats
[OK] Copied borzoi -> dogs_n_cats
[OK] Copied Boston bull -> dogs_n_cats
[OK] Copied Bouvier des Flandres -> dogs_n_cats
[OK] Copied boxer -> dogs_n_cats
[OK] Copied Brabancon griffon -> dogs_n_cats
[OK] Copied briard -> dogs_n_cats
[OK] Copied Brittany spaniel -> dogs_n_cats
[OK] Copied bu

## Dataset Processing for PyTorch

In [None]:
import torch
import os
import pandas as pd
from torch.utils.data import DataLoader, Dataset, Subset
from sklearn.model_selection import train_test_split

In [None]:
def simple_collate_fn(batch):
  sequences, _ = zip(*batch)
  padded_seqs = pad_sequence(sequences, batch_first=True, padding_value=0.0)

  return padded_seqs, None

In [None]:
def get_dataset_stats(dataset):
  from tqdm import tqdm

  num_channels = dataset.num_channels
  sum_ = torch.zeros(num_channels)
  sum_sq = torch.zeros(num_channels)
  count = torch.zeros(num_channels)

  # loader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=2, collate_fn=collate_fn)
  # all_channels_data = [[] for _ in range(5)]

  print("Calculating dataset stats...")

  for spectrogram, _ in tqdm(dataset):
    sum_ += torch.sum(spectrogram, dim=[1,2])
    sum_sq += torch.sum(spectrogram **2, dim=[1,2])

    count += spectrogram.shape[1] * spectrogram.shape[2]

  mean = sum_/count
  std = torch.sqrt((sum_sq/count) - (mean**2))
  #   for i in range(5):
  #     all_channels_data[i].append(data[:, :, i].flatten())

  #   channel_means = [torch.cat(ch_data).mean() for ch_data in all_channels_data]
  #   channel_stds = [torch.cat(ch_data).std() for ch_data in all_channels_data]

  # return torch.tensor(channel_means), torch.tensor(channel_stds)
  return mean, std


In [None]:
class EEGDataset(Dataset):
    def __init__(self, root_dir, samples, num_channels=5, mean=None, std=None, transform=None):
        self.root_dir = root_dir
        self.samples = samples
        self.transform = transform
        self.num_channels = num_channels
        self.class_to_idx = list(set([label for _, label in self.samples]))
        self.mean = mean
        self.std = std

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        file_path, label = self.samples[idx]

        df = pd.read_csv(file_path, header=None, index_col=0)
        eeg_data = torch.tensor(df.values, dtype=torch.float32)

        if eeg_data.shape[1] != self.num_channels:
          if eeg_data.shape[0] == self.num_channels:
            eeg_data = eeg_data.T

          else:
            raise ValueError(f"File {file_path} has invalid shape: {eeg_data.shape}")

        if self.mean is not None and self.std is not None:
          eeg_data = (eeg_data - self.mean) / self.std

        if self.transform:
            eeg_data = self.transform(eeg_data)

        return eeg_data, label


In [None]:
def make_datasets(root_dir, val_ratio=0.3, random_state=42):
    class_names = os.listdir(root_dir)
    class_to_idx = {cls:idx for idx, cls in enumerate(class_names)}

    all_samples = []
    all_labels = []

    for cls in class_names:
        cls_dir = os.path.join(root_dir, cls)

        for fname in os.listdir(cls_dir):
            if fname.endswith('.csv'):
                path = os.path.join(cls_dir, fname)
                all_samples.append((path, class_to_idx[cls]))
                all_labels.append(class_to_idx[cls])

    train_idx, val_idx = train_test_split(
        list(range(len(all_samples))),
        test_size=val_ratio,
        random_state=random_state,
        stratify=all_labels
    )

    train_samples = [all_samples[i] for i in train_idx]
    val_samples = [all_samples[i] for i in val_idx]

    # train_dataset = EEGDataset(root_dir, train_samples)
    # val_dataset = EEGDataset(root_dir, val_samples)

    return train_samples, val_samples

In [None]:
import torch
import pandas as pd
import numpy as np
from torch.utils.data import Dataset
from scipy import signal

In [None]:
class EEGSpectrogramDataset(Dataset):
  def __init__(self, root_dir, samples, is_train=False, fs=128, nperseg=64, noverlap=32, mean=None, std=None, transforms=None):
    self.root_dir = root_dir
    self.samples = samples
    self.is_train = is_train
    self.class_to_idx = list(set([label for _, label in self.samples]))
    self.transform_params = {'fs': fs, 'nperseg': nperseg, 'noverlap': noverlap}
    self.mean = mean
    self.std = std
    self.num_channels = 5
    self.target_freq_bins = 33
    self.target_time_bins = 10
    self.transforms=transforms

  def __len__(self):
    return len(self.samples)

  def __getitem__(self, idx):
    file_path, label = self.samples[idx]

    eeg_data_1d = pd.read_csv(file_path, header=None, index_col=0).values

    if eeg_data_1d.shape[1] != self.num_channels:
      if eeg_data_1d.shape[0] == self.num_channels:
        eeg_data_1d = eeg_data_1d.T
      else:
        raise ValueError(f"File {file_path} has an invalid shape: {eeg_data_1d.shape}")


    channel_spectrograms = []

    for i in range(self.num_channels):
      channel_signal = eeg_data_1d[:, i]

      f, t, Sxx = signal.spectrogram(channel_signal, **self.transform_params)

      Sxx = np.log1p(Sxx)

      channel_spectrograms.append(Sxx)

    spectrogram = torch.tensor(np.array(channel_spectrograms), dtype=torch.float32)

    # _, current_freqs, current_times = spectrogram.shape
    # resized_spectrogram = torch.zeros((self.num_channels, self.target_freq_bins, self.target_time_bins))

    # copy_freqs = min(current_freqs, self.target_freq_bins)
    # copy_times = min(current_times, self.target_time_bins)

    # resized_spectrogram[:, :copy_freqs, :copy_times] = spectrogram[:, :copy_freqs, :copy_times]

    # spectrogram = resized_spectrogram

    if self.mean is not None and self.std is not None:
      mean = self.mean.view(self.num_channels, 1, 1)
      std = self.std.view(self.num_channels, 1, 1)

      spectrogram = (spectrogram - mean) / (std + 1e-9)

    if self.is_train and self.transforms:
      transformed_channels = []

      for i in range(self.num_channels):
        channel_spectrogram = spectrogram[i, :, :]
        transformed_channel = self.transforms(channel_spectrogram)
        transformed_channels.append(transformed_channel)

      spectrogram = torch.cat(transformed_channels, dim=0)

    elif not self.is_train and self.transforms:
      transformed_channels = []

      for i in range(self.num_channels):
        channel_spectrogram = spectrogram[i, :, :]
        transformed_channel = self.transforms(channel_spectrogram)
        transformed_channels.append(transformed_channel)

      spectrogram = torch.cat(transformed_channels, dim=0)


    return spectrogram, label


In [None]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    MIN_LENGTH = 4
    valid_batch = [item for item in batch if item[0].shape[0] >= MIN_LENGTH]

    if not valid_batch:
      return None, None, None, None

    sequences, labels = zip(*valid_batch)
    lengths = torch.tensor([seq.shape[0] for seq in sequences], dtype=torch.long)

    padded_seqs = pad_sequence(sequences, batch_first=True, padding_value=0.0)

    mask = torch.arange(padded_seqs.shape[1])[None, :] < lengths[:, None]

    return padded_seqs, torch.tensor(labels, dtype=torch.long), lengths, mask

In [None]:
def create_sampler(dataset):
  from collections import Counter
  from torch.utils.data import WeightedRandomSampler

  all_labels = [label for _, label in dataset.samples]

  class_counts = Counter(all_labels)

  num_classes = len(dataset.class_to_idx)
  class_weights = torch.zeros(num_classes)

  for class_idx, count in class_counts.items():
    if count > 0:
      class_weights[class_idx] = 1.0 / count

  sample_weights = [class_weights[label] for label in all_labels]

  sampler = WeightedRandomSampler(
      weights=sample_weights,
      num_samples=len(dataset.samples),
      replacement=True
  )

  return sampler

In [None]:
# root_dir = '../content/drive/MyDrive/NeuroVision/data/Classes_Regrouped_Dataset'
root_dir = '/content/meta_learner_dataset'

In [None]:
len(os.listdir(root_dir))

10

In [None]:
train_samples, val_samples = make_datasets(root_dir)

In [None]:
train_dataset = EEGSpectrogramDataset(root_dir, train_samples, 5)
val_dataset = EEGSpectrogramDataset(root_dir, val_samples, 5)

In [None]:
len(train_dataset), len(val_dataset)

(9424, 4039)

In [None]:
from torchvision import transforms
train_means, train_stds = get_dataset_stats(train_dataset)

Calculating dataset stats...


100%|██████████| 9424/9424 [01:32<00:00, 101.80it/s]


In [None]:
train_sampler = create_sampler(train_dataset)

In [None]:
from torchvision.transforms.v2 import GaussianNoise

In [None]:
train_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((128, 128), antialias=True),
    transforms.ToTensor(),
    GaussianNoise(),
    transforms.RandomErasing(p=0.5, scale=(0.02, 0.1), ratio=(0.1, 10.0), value=0)
])

val_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((128, 128), antialias=True),
    transforms.ToTensor()
])

In [None]:
train_dataset_scaled = EEGSpectrogramDataset(root_dir, train_samples, is_train=True, mean=train_means,
                                             std=train_stds, transforms=train_transforms)

val_dataset_scaled = EEGSpectrogramDataset(root_dir, val_samples, is_train=False, mean=train_means,
                                             std=train_stds, transforms=val_transforms)

In [None]:
train_loader = DataLoader(train_dataset_scaled, batch_size=128, shuffle=False, sampler=train_sampler,
                          num_workers=2, pin_memory=True, persistent_workers=True, prefetch_factor=3)

val_loader = DataLoader(val_dataset_scaled, batch_size=128, shuffle=False, num_workers=2, pin_memory=True,
                        persistent_workers=True, prefetch_factor=3)

## Past Model Architectures

In [None]:
# class EegLstm(nn.Module):
#     def __init__(self, input_dims=5, hidden_dims=128, num_layers=3, dropout=0.3 , num_classes=len(os.listdir(root_dir))):
#         super(EegLstm, self).__init__()

#         self.lstm = nn.LSTM(
#             input_size=input_dims,
#             hidden_size=hidden_dims,
#             num_layers=num_layers,
#             batch_first=True,
#             dropout=dropout if num_layers >= 2 else 0,
#             bidirectional=True
#         )

#         self.conv_stack = nn.Sequential(
#             nn.Conv2d(hidden_dims*2)
#         )

#         self.fc = nn.Sequential(
#             nn.Linear(hidden_dims*2, hidden_dims),
#             nn.BatchNorm1d(hidden_dims),
#             nn.SELU(),
#             nn.Dropout(dropout),
#             nn.Linear(hidden_dims, hidden_dims),
#             nn.BatchNorm1d(hidden_dims),
#             nn.SELU(),
#             nn.Dropout(dropout),
#             nn.Linear(hidden_dims, hidden_dims//2),
#             nn.BatchNorm1d(hidden_dims//2),
#             nn.SELU(),
#             nn.Linear(hidden_dims, num_classes)
#         )

#     def forward(self, x, lengths=None):
#         if lengths is not None:
#             packed = nn.utils.rnn.pack_padded_sequence(
#                 x, lengths.cpu(), batch_first=True, enforce_sorted=False
#             )

#             packed_out, (h_n, c_n) = self.lstm(packed)

#         else:
#             out, (h_n, c_n) = self.lstm(x)

#         last_hidden_backward, last_hidden_forward = h_n[-1], h_n[-2]
#         logits=self.fc(torch.cat((last_hidden_backward, last_hidden_forward), dim=1))

#         return logits

In [None]:
# class HybridExtractor(nn.Module):
#     def __init__(self, input_dims=5, cnn_out_channels=64, kernel_size=50, lstm_hidden_dims=128, num_layers=3, dropout=0.3 , num_classes=len(os.listdir(root_dir))):
#         super(HybridExtractor, self).__init__()

#         # CNN block
#         self.cnn_stack = nn.Sequential(
#             nn.Conv1d(input_dims, 32, kernel_size=kernel_size, stride=1, padding='same'),
#             nn.BatchNorm1d(32),
#             nn.ELU(),
#             nn.AvgPool1d(kernel_size=2, stride=2),

#             nn.Conv1d(32, cnn_out_channels, kernel_size=kernel_size//2, stride=1, padding='same'),
#             nn.BatchNorm1d(cnn_out_channels),
#             nn.ELU(),
#             nn.AvgPool1d(kernel_size=2, stride=2)
#         )

#         self.lstm = nn.LSTM(
#             input_size=cnn_out_channels,
#             hidden_size=lstm_hidden_dims,
#             num_layers=num_layers,
#             batch_first=True,
#             dropout=dropout if num_layers > 1 else 0,
#             bidirectional=True
#         )

#         self.fc = nn.Sequential(
#             nn.Linear(lstm_hidden_dims*2, lstm_hidden_dims),
#             nn.BatchNorm1d(lstm_hidden_dims),
#             nn.SELU(),
#             nn.Dropout(dropout),
#             nn.Linear(lstm_hidden_dims, lstm_hidden_dims),
#             nn.BatchNorm1d(lstm_hidden_dims),
#             nn.SELU(),
#             nn.Dropout(dropout),
#             nn.Linear(lstm_hidden_dims, num_classes)
#         )

#     def forward(self, x, lengths=None):
#       x = x.permute(0, 2, 1)

#       cnn_out = self.cnn_stack(x)

#       lstm_input = cnn_out.permute(0, 2, 1)


#       if lengths is not None:
#         new_lengths = (lengths//4).long()

#         packed = nn.utils.rnn.pack_padded_sequence(
#                 lstm_input, new_lengths.cpu(), batch_first=True, enforce_sorted=False
#             )

#         packed_out, (h_n, c_n) = self.lstm(packed)

#       else:
#           out, (h_n, c_n) = self.lstm(lstm_input)

#       last_hidden_backward, last_hidden_forward = h_n[-1, :, :], h_n[-2, :, :]
#       logits=self.fc(torch.cat((last_hidden_backward, last_hidden_forward), dim=1))

#       return logits

## Newer Architectures

In [None]:
import torch
import torch.nn as nn

In [None]:
class DepthwiseSeparableConv(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size, padding):
    super(DepthwiseSeparableConv, self).__init__()
    self.depthwise = nn.Conv1d(in_channels, in_channels, kernel_size, padding=padding, groups=in_channels, bias=False)
    self.pointwise = nn.Conv1d(in_channels, out_channels, 1, bias=False)
    self.bn = nn.BatchNorm1d(out_channels)
    self.elu = nn.ELU()

  def forward(self, x):
    x = self.depthwise(x)
    x = self.pointwise(x)
    x = self.bn(x)

    return self.elu(x)

class MesoHybridNet(nn.Module):
  def __init__(self, input_dims=5, num_classes=118, gru_hidden_dims=128, gru_num_layers=2, dropout=0.4):
    super(MesoHybridNet, self).__init__()

    channels=24
    self.branch_fine = DepthwiseSeparableConv(input_dims, channels, kernel_size=10, padding='same')
    self.branch_medium = DepthwiseSeparableConv(input_dims, channels, kernel_size=50, padding='same')
    self.branch_coarse = DepthwiseSeparableConv(input_dims, channels, kernel_size=150, padding='same')

    combined_channels = 3 * channels
    self.pool = nn.AvgPool1d(4)

    self.gru = nn.GRU(
        input_size=combined_channels,
        hidden_size=gru_hidden_dims,
        num_layers=gru_num_layers,
        batch_first=True,
        bidirectional=True,
        dropout=dropout if gru_num_layers > 1 else 0
    )

    gru_output_dim = gru_hidden_dims * 2
    self.fc = nn.Sequential(
        nn.Linear(gru_output_dim, 256),
        nn.BatchNorm1d(256),
        nn.ELU(),
        nn.Linear(256, 256),
        nn.BatchNorm1d(256),
        nn.ELU(),
        nn.Dropout(dropout),
        nn.Linear(256, num_classes)
    )

  def forward(self, x, lengths=None, mask=None):
    x = x.permute(0,2,1)

    out_fine = self.branch_fine(x)
    out_medium = self.branch_medium(x)
    out_coarse = self.branch_coarse(x)

    combined_features = torch.cat([out_fine, out_medium, out_coarse], dim=1)
    pooled_features = self.pool(combined_features)

    gru_input = pooled_features.permute(0,2,1)

    if lengths is not None:
      new_lengths = (lengths//4).clamp(min=1).long()
      packed = nn.utils.rnn.pack_padded_sequence(
          gru_input, new_lengths.cpu(), batch_first=True, enforce_sorted=False
      )
      _, h_n = self.gru(packed)

    else:
      _, h_n = self.gru(gru_input)

    features = torch.cat((h_n[-2, :, :], h_n[-1, :, :]), dim=1)

    # if mask is not None:
    #   mask = mask[:, ::4]
    #   mask = mask.unsqueeze(1)

    #   combined_features = combined_features * mask
    #   summed_features = torch.sum(combined_features, dim=2)

    #   true_lengths = torch.sum(mask, dim=2) + 1e-9
    #   pooled_features = summed_features / true_lengths

    # else:
    #   pooled_features = torch.mean(combined_features, dim=2)

    logits = self.fc(features)

    return logits


In [None]:
class ConvBranch(nn.Module):
  def __init__(self, input_dims, output_channels, kernel_size, stride, padding, pooling_type, pool_size=4):
    super(ConvBranch, self).__init__()

    self.conv = nn.Conv2d(input_dims, output_channels, kernel_size, stride, padding, bias=False)
    # self.bn = nn.BatchNorm2d(output_channels)

    # if pooling_type.lower() == 'max':
    #   self.pool = nn.MaxPool2d(pool_size)
    # else:
    #   self.pool = nn.AvgPool2d(pool_size)

    # self.elu = nn.ELU()

  def forward(self, x):
    return self.conv(x)

class EEGMesoNet(nn.Module):
  def __init__(self, input_dims=5, num_classes=10, dropout=0.5):
    super(EEGMesoNet, self).__init__()

    channels = 8
    self.branch_fine = ConvBranch(input_dims, channels*3, kernel_size=4, stride=1, padding=2, pooling_type='max')
    self.branch_medium = ConvBranch(input_dims, channels*2, kernel_size=16, stride=4, padding=4, pooling_type='avg')
    self.branch_coarse = ConvBranch(input_dims, channels, kernel_size=64, stride=16, padding=8, pooling_type='avg')

    combined_channels = (channels*2) + channels + (channels*3)

    self.fc = nn.Sequential(
        nn.Linear(combined_channels, 128),
        nn.BatchNorm1d(128),
        nn.LeakyReLU(),
        nn.Linear(128, 64),
        nn.BatchNorm1d(64),
        nn.LeakyReLU(),
        nn.Dropout(dropout),
        nn.Linear(64, num_classes)
    )

  def forward(self, x):
    # x = x.permute(0, 2, 1)

    out_fine = self.branch_fine(x)
    out_medium = self.branch_medium(x)
    out_coarse = self.branch_coarse(x)

    out_fine_pooled_flattened = torch.flatten(nn.functional.adaptive_max_pool2d(out_fine, (1,1)), 1)
    out_medium_pooled_flattened = torch.flatten(nn.functional.adaptive_avg_pool2d(out_medium, (1,1)), 1)
    out_coarse_pooled_flattened = torch.flatten(nn.functional.adaptive_avg_pool2d(out_coarse, (1,1)), 1)

    combined_features = torch.cat([out_fine_pooled_flattened, out_medium_pooled_flattened, out_coarse_pooled_flattened], dim=1)

    # if mask is not None:
    #   mask = mask[:, ::4]
    #   mask = mask.unsqueeze(1)

    #   combined_features = combined_features * mask
    #   summed_features = torch.sum(combined_features, dim=2)
    #   true_lengths = torch.sum(mask, dim=2) + 1e-9
    #   pooled_features = summed_features / true_lengths

    # else:
      # pooled_features = torch.mean(combined_features, dim=2)

    logits = self.fc(combined_features)

    return logits

## Model Training

In [None]:
def weights_init(m):
  if isinstance(m, (nn.Conv2d, nn.Linear)):
    nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')

    if m.bias is not None:
      nn.init.constant_(m.bias, 0)

In [None]:
from tqdm import tqdm
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

In [None]:
class EarlyStopping(object):
    def __init__(self, model, save_path='../content/drive/MyDrive/eeg_classifier.pt', patience=4, tol=1e-3):
        self.model = model
        self.save_path = save_path
        self.patience = patience
        self.counter = 0
        self.tol = tol
        self.best_val_loss = float('inf')
        self.early_stop = False

    def __call__(self, batch_val_loss):
        if batch_val_loss < self.best_val_loss - self.tol:
            torch.save(self.model.state_dict(), self.save_path)
            self.best_val_loss = batch_val_loss
            self.counter = 0
            print(f'Validation Loss improved -> model saved to {self.save_path}')

        else:
            if self.counter < self.patience:
                self.counter += 1
                print(f'No improvement in Val Loss. Counter: {self.counter}/{self.patience}')

            else:
                self.early_stop = True
                print(f"Early Stopping triggered!")


In [None]:
def train_model(model, model_name, train_loader, val_loader, epochs=50, lr=1e-3, device='cpu'):
    log_dir = f'../content/drive/MyDrive/NeuroVision/runs/{model_name}_v1'
    save_path = f'../content/drive/MyDrive/NeuroVision/models/{model_name}_v2_best.pth'
    os.makedirs(os.path.dirname(log_dir), exist_ok=True)
    os.makedirs(os.path.dirname(save_path), exist_ok=True)

    # params_to_update = [p for p in model.parameters() if p.requires_grad]

    criterion = nn.CrossEntropyLoss()
    # optimizer = optim.AdamW([
    #     {'params': resnet18.conv1.parameters(), 'lr': 1e-4},
    #     {'params': resnet18.fc.parameters(), 'lr': lr},
    #     {'params': resnet18.layer1.parameters(), 'lr':5e-5},
    #     {'params': resnet18.layer2.parameters(), 'lr': 5e-5},
    #     {'params': resnet18.layer3.parameters(), 'lr': 5e-5},
    #     {'params': resnet18.layer4.parameters(), 'lr': 5e-5}
    # ], weight_decay=0.01)

    optimizer = optim.AdamW(model.parameters(), weight_decay=0.05)

    writer = SummaryWriter(log_dir=log_dir)

    # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)
    # scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=2e-4,
    #                                           total_steps=epochs * len(train_loader),
    #                                           pct_start=0.05)
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=optimizer,
                                                               T_0=5,
                                                               T_mult=1,
                                                               eta_min=5e-5)

    early_stopping = EarlyStopping(model, save_path=save_path, patience=10)
    model.to(device)

    iters = len(train_loader)

    for epoch in range(epochs):
        model.train()
        train_loss, train_correct, train_total = 0.0, 0, 0
        train_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs} [Train Pass]', leave=True)

        for i, (batch_x, batch_y) in enumerate(train_bar):
            batch_x, batch_y= batch_x.to(device), batch_y.to(device)

            optimizer.zero_grad()
            y_preds = model(batch_x)

            loss = criterion(y_preds, batch_y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            scheduler.step(epoch + i / iters)

            train_loss += loss.item() * batch_x.size(0)
            _, preds = torch.max(y_preds, 1)
            train_correct += (preds == batch_y).sum().item()
            train_total += batch_y.size(0)

            train_bar.set_postfix(loss=loss.item())

        train_acc = train_correct / train_total
        train_loss /= train_total


        model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0

        val_bar = tqdm(val_loader, desc=f"Epoch{epoch+1}/{epochs} [Val Pass]", leave=True)

        with torch.no_grad():
            for batch_x, batch_y in val_bar:
                batch_x, batch_y = batch_x.to(device), batch_y.to(device)

                y_preds = model(batch_x)
                loss = criterion(y_preds, batch_y)

                val_loss += loss.item() * batch_x.size(0)
                _, preds = torch.max(y_preds, 1)
                # print(f"Sample Predictions: {preds.cpu().numpy()}")
                val_correct += (preds == batch_y).sum().item()
                val_total += batch_y.size(0)

                val_bar.set_postfix(loss=loss.item())

        val_acc = val_correct / val_total
        val_loss /= val_total

        scheduler.step(val_loss)

        early_stopping(val_loss)
        if early_stopping.early_stop:
            break


        # logging
        writer.add_scalar('Loss/train', train_loss, epoch)
        writer.add_scalar('Loss/val', val_loss, epoch)
        writer.add_scalar('Accuracy/train', train_acc, epoch)
        writer.add_scalar('Accuracy/val', val_acc, epoch)

        print(f"Epoch {epoch+1}/{epochs}:\nTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f} %\nVal Loss: {val_loss:.3f} | Val Acc: {val_acc*100:.2f} %")

    writer.close()

In [None]:
def model_summary(model):
    print('========================================= Model Summary ==============================================\n')
    print(f"\n{'='*55}")
    print(f"{'| Parameter Name':31}|| Number of Parameters|")
    print(f"{'='*55}")

    total_params = 0

    for name, param in model.named_parameters():
        print(f'| {name:30}|{param.numel():20} |')
        print(f"{'-'*55}")
        total_params += param.numel()

    print(f"\nTotal Parameters: {total_params:,}")

## Spectrogram Approach

In [None]:
import torch.nn as nn

In [None]:
class BasicBlock(nn.Module):
  expansion = 1

  def __init__(self, in_planes, planes, stride):
    super(BasicBlock, self).__init__()
    self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
    self.bn1 = nn.BatchNorm2d(planes)
    self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
    self.bn2 = nn.BatchNorm2d(planes)

    self.cbam = CBAM(planes)

    self.shortcut = nn.Sequential()
    if stride != 1 or in_planes != self.expansion*planes:
      self.shortcut = nn.Sequential(
          nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
          nn.BatchNorm2d(self.expansion*planes)
      )

    def forward(self, x):
      out = nn.functional.relu(self.bn1(self.conv1(x)))
      out = self.bn2(self.conv2(out))

      out = self.cbam(out)

      out += self.shortcut(x)
      out = nn.functional.relu(out)

      return out

In [None]:
import torch

class ChannelAttention(nn.Module):
  def __init__(self, in_planes, ratio=16):
    super(ChannelAttention, self).__init__()
    self.avg_pool = nn.AdaptiveAvgPool2d(1)
    self.max_pool = nn.AdaptiveMaxPool2d(1)

    self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes//ratio, 1, bias=False),
                            nn.ReLU(),
                            nn.Conv2d(in_planes//ratio, in_planes, 1, bias=False)
                            )

    self.sigmoid = nn.Sigmoid()

  def forward(self, x):
    avg_out = self.fc(self.avg_pool(x))
    max_out = self.fc(self.max_pool(x))
    out = avg_out + max_out
    return self.sigmoid(out)

class SpatialAttention(nn.Module):
  def __init__(self, kernel_size=7):
    super(SpatialAttention, self).__init__()

    self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
    self.sigmoid = nn.Sigmoid()

  def forward(self, x):
    avg_out = torch.mean(x, dim=1, keepdim=True)
    max_out, _ = torch.max(x, dim=1, keepdim=True)
    x = torch.cat([avg_out, max_out], dim=1)
    x = self.conv1(x)

    return self.sigmoid(x)


class CBAM(nn.Module):
  def __init__(self, in_planes, ratio=16,  kernel_size=7):
    super(CBAM, self).__init__()
    self.ca = ChannelAttention(in_planes, ratio)
    self.sa = SpatialAttention(kernel_size)

  def forward(self, x):
    channel_attention_map = self.ca(x)
    x = x * channel_attention_map

    spatial_attention_map = self.sa(x)
    x = x * spatial_attention_map

    return x


In [None]:
from torchvision.models import resnet18, ResNet18_Weights
resnet18 = resnet18(weights=ResNet18_Weights.DEFAULT, progress=True)

for param in resnet18.parameters():
  param.requires_grad = False

In [None]:
old_conv1 = resnet18.conv1
old_fc = resnet18.fc

new_conv1 = nn.Conv2d(5, old_conv1.out_channels, old_conv1.kernel_size,
                      old_conv1.stride, old_conv1.padding, bias=False)

with torch.no_grad():
  new_conv1.weight[:, :3, :, :] = old_conv1.weight.clone()
  mean_weights = torch.mean(old_conv1.weight, dim=1, keepdim=True)
  new_conv1.weight[:, 3:5, :, :] = mean_weights.repeat(1, 2, 1, 1)

new_fc = nn.Sequential(nn.Linear(old_fc.in_features, 256),
                       nn.ReLU(inplace=True),
                       nn.BatchNorm1d(256),
                       nn.Dropout(0.5),
                       nn.Linear(in_features=256, out_features=118))
new_fc.apply(weights_init)

resnet18.fc = new_fc
resnet18.conv1 = new_conv1

resnet18.fc.requires_grad = True
resnet18.conv1.requires_grad = True


In [None]:
import types
from torchvision.models import resnet

def new_forward(self, x):
  identity = x

  out = self.conv1(x)
  out = self.bn1(out)
  out = self.relu(out)

  out = self.conv2(out)
  out = self.bn2(out)

  if hasattr(self, 'cbam'):
    out = self.cbam(out)

  if self.downsample is not None:
    identity = self.downsample(x)

  out += identity
  out = self.relu(out)

  return out

for layer_name in ['layer1', 'layer2', 'layer3', 'layer4']:
  layer = getattr(resnet18, layer_name)

  for i in range(len(layer)):
    block = layer[i]

    if isinstance(block, resnet.BasicBlock):
      num_channels = block.conv2.out_channels

      block.cbam = CBAM(num_channels)
      block.forward = types.MethodType(new_forward, block)



In [None]:
resnet18.conv1.in_channels, resnet18.fc[-1].out_features

(5, 118)

In [None]:
model_summary(resnet18)

## Model Training

In [None]:
mesonet = EEGMesoNet(5, 15, 0.6)
# mesonet.load_state_dict(torch.load('../content/drive/MyDrive/NeuroVision/models/mesonet-meta-learner_v1_best.pth'))

In [None]:
mesonet.apply(weights_init)

EEGMesoNet(
  (branch_fine): ConvBranch(
    (conv): Conv2d(5, 24, kernel_size=(4, 4), stride=(1, 1), padding=(2, 2), bias=False)
  )
  (branch_medium): ConvBranch(
    (conv): Conv2d(5, 16, kernel_size=(16, 16), stride=(4, 4), padding=(4, 4), bias=False)
  )
  (branch_coarse): ConvBranch(
    (conv): Conv2d(5, 8, kernel_size=(64, 64), stride=(16, 16), padding=(8, 8), bias=False)
  )
  (fc): Sequential(
    (0): Linear(in_features=48, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.01)
    (3): Linear(in_features=128, out_features=64, bias=True)
    (4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.01)
    (6): Dropout(p=0.6, inplace=False)
    (7): Linear(in_features=64, out_features=15, bias=True)
  )
)

In [None]:
model_summary(mesonet)



| Parameter Name               || Number of Parameters|
| branch_fine.conv.weight       |                1920 |
-------------------------------------------------------
| branch_medium.conv.weight     |               20480 |
-------------------------------------------------------
| branch_coarse.conv.weight     |              163840 |
-------------------------------------------------------
| fc.0.weight                   |                6144 |
-------------------------------------------------------
| fc.0.bias                     |                 128 |
-------------------------------------------------------
| fc.1.weight                   |                 128 |
-------------------------------------------------------
| fc.1.bias                     |                 128 |
-------------------------------------------------------
| fc.3.weight                   |                8192 |
-------------------------------------------------------
| fc.3.bias                     |             

In [None]:
# Do not forget to load the best meta learner saved on drive last night !!!!!

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
train_model(mesonet, 'mesonet-meta-learner', train_loader, val_loader, 200, 5e-3, device)
# train_model(resnet18, 'resnet-18', train_loader, val_loader, 128, 200, 1e-3, device)

Epoch 1/200 [Train Pass]: 100%|██████████| 74/74 [01:59<00:00,  1.62s/it, loss=2.7]
Epoch1/200 [Val Pass]: 100%|██████████| 32/32 [00:43<00:00,  1.37s/it, loss=3.91]


Validation Loss improved -> model saved to ../content/drive/MyDrive/NeuroVision/models/mesonet-meta-learner_v2_best.pth
Epoch 1/200:
Train Loss: 3.038 | Train Acc: 9.52 %
Val Loss: 3.615 | Val Acc: 5.05 %


Epoch 2/200 [Train Pass]: 100%|██████████| 74/74 [01:59<00:00,  1.61s/it, loss=2.52]
Epoch2/200 [Val Pass]: 100%|██████████| 32/32 [00:44<00:00,  1.39s/it, loss=2.31]


Validation Loss improved -> model saved to ../content/drive/MyDrive/NeuroVision/models/mesonet-meta-learner_v2_best.pth
Epoch 2/200:
Train Loss: 2.637 | Train Acc: 10.66 %
Val Loss: 2.340 | Val Acc: 13.54 %


Epoch 3/200 [Train Pass]: 100%|██████████| 74/74 [01:59<00:00,  1.62s/it, loss=2.58]
Epoch3/200 [Val Pass]: 100%|██████████| 32/32 [00:43<00:00,  1.37s/it, loss=2.34]


No improvement in Val Loss. Counter: 1/10
Epoch 3/200:
Train Loss: 2.516 | Train Acc: 10.79 %
Val Loss: 2.342 | Val Acc: 10.89 %


Epoch 4/200 [Train Pass]: 100%|██████████| 74/74 [01:59<00:00,  1.61s/it, loss=2.47]
Epoch4/200 [Val Pass]: 100%|██████████| 32/32 [00:44<00:00,  1.38s/it, loss=2.29]


Validation Loss improved -> model saved to ../content/drive/MyDrive/NeuroVision/models/mesonet-meta-learner_v2_best.pth
Epoch 4/200:
Train Loss: 2.464 | Train Acc: 10.80 %
Val Loss: 2.326 | Val Acc: 10.35 %


Epoch 5/200 [Train Pass]:  46%|████▌     | 34/74 [00:55<00:51,  1.28s/it, loss=2.42]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f48691ad260>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f48691ad260>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  Fi

Validation Loss improved -> model saved to ../content/drive/MyDrive/NeuroVision/models/mesonet-meta-learner_v2_best.pth
Epoch 5/200:
Train Loss: 2.455 | Train Acc: 10.91 %
Val Loss: 2.303 | Val Acc: 10.75 %


Epoch 6/200 [Train Pass]: 100%|██████████| 74/74 [01:59<00:00,  1.62s/it, loss=2.37]
Epoch6/200 [Val Pass]: 100%|██████████| 32/32 [00:44<00:00,  1.39s/it, loss=2.48]


No improvement in Val Loss. Counter: 1/10
Epoch 6/200:
Train Loss: 2.419 | Train Acc: 11.01 %
Val Loss: 2.440 | Val Acc: 4.26 %


Epoch 7/200 [Train Pass]: 100%|██████████| 74/74 [02:01<00:00,  1.64s/it, loss=2.4]
Epoch7/200 [Val Pass]: 100%|██████████| 32/32 [00:44<00:00,  1.40s/it, loss=2.32]


No improvement in Val Loss. Counter: 2/10
Epoch 7/200:
Train Loss: 2.382 | Train Acc: 11.69 %
Val Loss: 2.323 | Val Acc: 9.63 %


Epoch 8/200 [Train Pass]: 100%|██████████| 74/74 [02:00<00:00,  1.63s/it, loss=2.37]
Epoch8/200 [Val Pass]: 100%|██████████| 32/32 [00:45<00:00,  1.41s/it, loss=2.29]


Validation Loss improved -> model saved to ../content/drive/MyDrive/NeuroVision/models/mesonet-meta-learner_v2_best.pth
Epoch 8/200:
Train Loss: 2.355 | Train Acc: 12.19 %
Val Loss: 2.297 | Val Acc: 9.28 %


Epoch 9/200 [Train Pass]: 100%|██████████| 74/74 [02:00<00:00,  1.62s/it, loss=2.39]
Epoch9/200 [Val Pass]: 100%|██████████| 32/32 [00:45<00:00,  1.41s/it, loss=2.32]


No improvement in Val Loss. Counter: 1/10
Epoch 9/200:
Train Loss: 2.351 | Train Acc: 12.19 %
Val Loss: 2.331 | Val Acc: 6.76 %


Epoch 10/200 [Train Pass]: 100%|██████████| 74/74 [01:58<00:00,  1.60s/it, loss=2.41]
Epoch10/200 [Val Pass]: 100%|██████████| 32/32 [00:43<00:00,  1.35s/it, loss=2.29]


No improvement in Val Loss. Counter: 2/10
Epoch 10/200:
Train Loss: 2.342 | Train Acc: 12.63 %
Val Loss: 2.299 | Val Acc: 9.75 %


Epoch 11/200 [Train Pass]: 100%|██████████| 74/74 [01:57<00:00,  1.59s/it, loss=2.27]
Epoch11/200 [Val Pass]: 100%|██████████| 32/32 [00:43<00:00,  1.36s/it, loss=2.37]


No improvement in Val Loss. Counter: 3/10
Epoch 11/200:
Train Loss: 2.326 | Train Acc: 12.73 %
Val Loss: 2.460 | Val Acc: 5.08 %


Epoch 12/200 [Train Pass]: 100%|██████████| 74/74 [01:58<00:00,  1.60s/it, loss=2.24]
Epoch12/200 [Val Pass]: 100%|██████████| 32/32 [00:43<00:00,  1.37s/it, loss=2.94]


No improvement in Val Loss. Counter: 4/10
Epoch 12/200:
Train Loss: 2.317 | Train Acc: 13.04 %
Val Loss: 2.862 | Val Acc: 1.36 %


Epoch 13/200 [Train Pass]: 100%|██████████| 74/74 [01:59<00:00,  1.61s/it, loss=2.36]
Epoch13/200 [Val Pass]: 100%|██████████| 32/32 [00:44<00:00,  1.38s/it, loss=3.57]


No improvement in Val Loss. Counter: 5/10
Epoch 13/200:
Train Loss: 2.298 | Train Acc: 13.82 %
Val Loss: 3.626 | Val Acc: 1.26 %


Epoch 14/200 [Train Pass]: 100%|██████████| 74/74 [02:00<00:00,  1.63s/it, loss=2.22]
Epoch14/200 [Val Pass]: 100%|██████████| 32/32 [00:43<00:00,  1.36s/it, loss=2.82]


No improvement in Val Loss. Counter: 6/10
Epoch 14/200:
Train Loss: 2.281 | Train Acc: 14.90 %
Val Loss: 2.844 | Val Acc: 1.51 %


Epoch 15/200 [Train Pass]: 100%|██████████| 74/74 [01:58<00:00,  1.60s/it, loss=2.27]
Epoch15/200 [Val Pass]: 100%|██████████| 32/32 [00:43<00:00,  1.36s/it, loss=2.28]


No improvement in Val Loss. Counter: 7/10
Epoch 15/200:
Train Loss: 2.278 | Train Acc: 14.66 %
Val Loss: 2.300 | Val Acc: 9.16 %


Epoch 16/200 [Train Pass]: 100%|██████████| 74/74 [01:58<00:00,  1.60s/it, loss=2.34]
Epoch16/200 [Val Pass]: 100%|██████████| 32/32 [00:44<00:00,  1.38s/it, loss=2.25]


No improvement in Val Loss. Counter: 8/10
Epoch 16/200:
Train Loss: 2.279 | Train Acc: 14.98 %
Val Loss: 2.337 | Val Acc: 9.21 %


Epoch 17/200 [Train Pass]: 100%|██████████| 74/74 [01:59<00:00,  1.61s/it, loss=2.24]
Epoch17/200 [Val Pass]: 100%|██████████| 32/32 [00:43<00:00,  1.35s/it, loss=2.64]


No improvement in Val Loss. Counter: 9/10
Epoch 17/200:
Train Loss: 2.261 | Train Acc: 15.46 %
Val Loss: 2.682 | Val Acc: 2.15 %


Epoch 18/200 [Train Pass]: 100%|██████████| 74/74 [01:58<00:00,  1.61s/it, loss=2.31]
Epoch18/200 [Val Pass]: 100%|██████████| 32/32 [00:43<00:00,  1.37s/it, loss=2.52]


No improvement in Val Loss. Counter: 10/10
Epoch 18/200:
Train Loss: 2.243 | Train Acc: 16.78 %
Val Loss: 2.517 | Val Acc: 5.03 %


Epoch 19/200 [Train Pass]: 100%|██████████| 74/74 [01:59<00:00,  1.61s/it, loss=2.18]
Epoch19/200 [Val Pass]: 100%|██████████| 32/32 [00:44<00:00,  1.40s/it, loss=3.92]

Early Stopping triggered!





## Model Testing