<a href="https://colab.research.google.com/github/asma-walha/Machine-Learning/blob/main/MINI_eegpt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Data Loading and Preprocessing

In [17]:
!pip install mne



In [2]:
import mne
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR


In [3]:
# Load a sample EEG dataset (e.g., the MNE sample dataset)
sample_data_folder = mne.datasets.sample.data_path()
sample_data_raw_file = sample_data_folder / 'MEG' / 'sample' / 'sample_audvis_raw.fif'
raw = mne.io.read_raw_fif(sample_data_raw_file, preload=True)

Opening raw data file /root/mne_data/MNE-sample-data/MEG/sample/sample_audvis_raw.fif...
    Read a total of 3 projection items:
        PCA-v1 (1 x 102)  idle
        PCA-v2 (1 x 102)  idle
        PCA-v3 (1 x 102)  idle
    Range : 25800 ... 192599 =     42.956 ...   320.670 secs
Ready.
Reading 0 ... 166799  =      0.000 ...   277.714 secs...


In [4]:
print(raw.info)
print(f"Shape of raw data: {raw.get_data().shape}")

<Info | 21 non-empty values
 acq_pars: ACQch001 110113 ACQch002 110112 ACQch003 110111 ACQch004 110122 ...
 bads: 2 items (MEG 2443, EEG 053)
 ch_names: MEG 0113, MEG 0112, MEG 0111, MEG 0122, MEG 0123, MEG 0121, MEG ...
 chs: 204 Gradiometers, 102 Magnetometers, 9 Stimulus, 60 EEG, 1 EOG
 custom_ref_applied: False
 description: acquisition (megacq) VectorView system at NMR-MGH
 dev_head_t: MEG device -> head transform
 dig: 146 items (3 Cardinal, 4 HPI, 61 EEG, 78 Extra)
 events: 1 item (list)
 experimenter: MEG
 file_id: 4 items (dict)
 highpass: 0.1 Hz
 hpi_meas: 1 item (list)
 hpi_results: 1 item (list)
 lowpass: 172.2 Hz
 meas_date: 2002-12-03 19:01:10 UTC
 meas_id: 4 items (dict)
 nchan: 376
 proj_id: 1 item (ndarray)
 proj_name: test
 projs: PCA-v1: off, PCA-v2: off, PCA-v3: off
 sfreq: 600.6 Hz
>
Shape of raw data: (376, 166800)


In [5]:

# Preprocess the EEG data
def preprocess_eeg(raw, l_freq=0.1, h_freq=38, resample_rate=256):
    # Filter the data
    raw.filter(l_freq=l_freq, h_freq=h_freq)

    raw.pick_types(eeg=True)  # Ne garde que les canaux EEG

    # Resample the data
    raw.resample(resample_rate)

    # Re-reference to average
    raw.set_eeg_reference(ref_channels='average')

    # Segment into epochs (e.g., 4 seconds)
    epochs = mne.make_fixed_length_epochs(raw, duration=4.0)
    return epochs.get_data()

In [18]:
epochs_data = preprocess_eeg(raw)
print(f"Shape of epochs data: {epochs_data.shape}")
epochs_tensor = torch.tensor(epochs_data, dtype=torch.float32)
print(f"Shape of epochs tensor: {epochs_tensor.shape}")

dataset = TensorDataset(epochs_tensor)

from torch.utils.data import Subset
# Assume `dataset` is your full dataset
dataset_size = len(dataset)
subset_size = int(0.1 * dataset_size)  # Use 10% of the dataset
indices = torch.randperm(dataset_size)[:subset_size]  # Randomly sample indices
subset = Subset(dataset, indices)

# Create a DataLoader for the subset
batch_size = 32
dataloader = DataLoader(subset, batch_size=batch_size, shuffle=True)

from torch.utils.data import random_split

train_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

batch_size = 32
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.1 - 38 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.10
- Lower transition bandwidth: 0.10 Hz (-6 dB cutoff frequency: 0.05 Hz)
- Upper passband edge: 38.00 Hz
- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)
- Filter length: 8449 samples (33.004 s)

NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Sampling frequency of the instance is already 256.0, returning unmodified.
EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.1s


Not setting metadata
69 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 69 events and 1024 original time points ...
0 bad epochs dropped
Shape of epochs data: (69, 59, 1024)
Shape of epochs tensor: torch.Size([69, 59, 1024])


EEGPT Model

In [7]:
class EEGPT(nn.Module):
    def __init__(self, num_channels, num_timepoints, embed_dim, num_layers):
        super(EEGPT, self).__init__()
        self.num_channels = num_channels
        self.num_timepoints = num_timepoints
        self.embed_dim = embed_dim

        # Local spatio-temporal embedding
        self.embedding = nn.Linear(num_timepoints, embed_dim)

        # Transformer encoder with batch_first=True
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=8, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Momentum encoder
        self.momentum_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Predictor and reconstructor
        self.predictor = nn.Linear(embed_dim, embed_dim)
        self.reconstructor = nn.Linear(embed_dim, num_timepoints)

    def forward(self, x, mask):
        batch_size, num_channels, num_timepoints = x.shape

        # Apply masking to the input tensor
        masked_x = x * mask  # Shape: (batch_size, num_channels, num_timepoints)

        # Reshape for embedding: (batch_size * num_channels, num_timepoints)
        masked_x = masked_x.view(-1, num_timepoints)

        # Apply local spatio-temporal embedding
        embedded_x = self.embedding(masked_x)  # Shape: (batch_size * num_channels, embed_dim)

        # Reshape back: (batch_size, num_channels, embed_dim)
        embedded_x = embedded_x.view(batch_size, num_channels, self.embed_dim)

        # Encode the masked input
        enc_output = self.encoder(embedded_x)  # Shape: (batch_size, num_channels, embed_dim)

        # Predict the full features
        pred_output = self.predictor(enc_output)  # Shape: (batch_size, num_channels, embed_dim)

        # Reshape for reconstruction: (batch_size * num_channels, embed_dim)
        pred_output_reshaped = pred_output.view(-1, self.embed_dim)

        # Reconstruct the masked parts
        recon_output = self.reconstructor(pred_output_reshaped)  # Shape: (batch_size * num_channels, num_timepoints)

        # Reshape back: (batch_size, num_channels, num_timepoints)
        recon_output = recon_output.view(batch_size, num_channels, self.num_timepoints)

        # Apply embedding to the input for the momentum encoder
        full_x = x.view(-1, num_timepoints)  # Shape: (batch_size * num_channels, num_timepoints)
        embedded_full_x = self.embedding(full_x)  # Shape: (batch_size * num_channels, embed_dim)
        embedded_full_x = embedded_full_x.view(batch_size, num_channels, self.embed_dim)  # Reshape back
        momentum_output = self.momentum_encoder(embedded_full_x)  # Shape: (batch_size, num_channels, embed_dim)

        return enc_output, pred_output, recon_output, momentum_output

Dual Self-Supervised Loss

In [15]:
import torch.nn.functional as F

def dual_self_supervised_loss(enc_output, pred_output, recon_output, momentum_output, mask):
    # Alignment loss
    alignment_loss = nn.MSELoss()(pred_output, momentum_output)

    # Resize recon_output along the time dimension (from 1024 to 512)
    # recon_output shape: [batch_size, num_channels, num_timepoints] -> [32, 59, 1024]
    # We want to resize the time dimension (1024) to 512
    recon_output_resized = F.interpolate(
        recon_output,  # Shape: [32, 59, 1024]
        size=(512,),  # Resize the last dimension (time) to 512
        mode='linear',  # Use linear interpolation for time series
        align_corners=False
    )
    #print(f"Shape of resized recon_output: {recon_output_resized.shape}")  # Should be [32, 59, 512]

    # Resize mask along the time dimension (from 1024 to 512)
    # mask shape: [batch_size, num_channels, num_timepoints] -> [32, 59, 1024]
    # We want to resize the time dimension (1024) to 512
    mask_resized = F.interpolate(
        mask,  # Shape: [32, 59, 1024]
        size=(512,),  # Resize the last dimension (time) to 512
        mode='nearest'  # Use nearest-neighbor interpolation for binary masks
    )
    #print(f"Shape of resized mask: {mask_resized.shape}")  # Should be [32, 59, 512]

    # Reconstruction loss
    reconstruction_loss = nn.MSELoss()(recon_output_resized * mask_resized, enc_output * mask_resized)

    # Total loss
    total_loss = alignment_loss + reconstruction_loss
    return total_loss

Mask Generation

In [9]:
# Generate random masks
def generate_mask(x, mask_time_ratio=0.5, mask_channel_ratio=0.8):
    batch_size, num_channels, num_timepoints = x.shape
    time_mask = torch.rand(batch_size, num_timepoints) > mask_time_ratio  # Shape: (batch_size, num_timepoints)
    channel_mask = torch.rand(batch_size, num_channels) > mask_channel_ratio  # Shape: (batch_size, num_channels)
    mask = time_mask.unsqueeze(1) * channel_mask.unsqueeze(2)  # Shape: (batch_size, num_channels, num_timepoints)
    return mask.float()

Training

In [19]:
# Initialize model, optimizer, and scheduler
model = EEGPT(num_channels=376, num_timepoints=1024, embed_dim=512, num_layers=8)
optimizer = AdamW(model.parameters(), lr=2.5e-4)
scheduler = OneCycleLR(optimizer, max_lr=5e-4, total_steps=200)

# Training loop
for epoch in range(200):
    print(epoch)
    for batch in dataloader:
        x = batch[0]
        mask = generate_mask(x)

        enc_output, pred_output, recon_output, momentum_output = model(x, mask)

        loss = dual_self_supervised_loss(enc_output, pred_output, recon_output, momentum_output, mask)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        print(f"Epoch {epoch+1}, Loss: {loss.item()}")

0
Epoch 1, Loss: 1.4348863363265991
1
Epoch 2, Loss: 1.2060030698776245
2
Epoch 3, Loss: 0.994652509689331
3
Epoch 4, Loss: 0.8511300086975098
4
Epoch 5, Loss: 0.760947585105896
5
Epoch 6, Loss: 0.6519553661346436
6
Epoch 7, Loss: 0.6000574827194214
7
Epoch 8, Loss: 0.5378994345664978
8
Epoch 9, Loss: 0.49014610052108765
9
Epoch 10, Loss: 0.435608834028244
10
Epoch 11, Loss: 0.4086250960826874
11
Epoch 12, Loss: 0.3465929925441742
12
Epoch 13, Loss: 0.3181648254394531
13
Epoch 14, Loss: 0.28415828943252563
14
Epoch 15, Loss: 0.2756674587726593
15
Epoch 16, Loss: 0.24407249689102173
16
Epoch 17, Loss: 0.22507092356681824
17
Epoch 18, Loss: 0.2121245414018631
18
Epoch 19, Loss: 0.18708649277687073
19
Epoch 20, Loss: 0.15436141192913055
20
Epoch 21, Loss: 0.14754872024059296
21
Epoch 22, Loss: 0.1400475949048996
22
Epoch 23, Loss: 0.1400957703590393
23
Epoch 24, Loss: 0.11368207633495331
24
Epoch 25, Loss: 0.10511444509029388
25
Epoch 26, Loss: 0.09851361811161041
26
Epoch 27, Loss: 0.092

Evaluation

In [None]:
# Evaluate the model on the test set
model.eval()  # Set the model to evaluation mode
all_preds = []
all_labels = []

with torch.no_grad():  # Disable gradient computation
    for batch in test_dataloader:
        x = batch[0]  # Get the EEG data
        x = x.to(device)  # Move data to the GPU (if available)

        # Forward pass
        enc_output, pred_output, recon_output, momentum_output = model(x, mask=None)  # No masking for evaluation

        # Get predictions (e.g., for classification)
        logits = pred_output  # Use the predictor's output
        probs = torch.softmax(logits, dim=-1)
        preds = torch.argmax(probs, dim=-1)

        # Store predictions and labels
        all_preds.append(preds.cpu())
        all_labels.append(batch[1].cpu())  # Assuming labels are in batch[1]

# Concatenate all predictions and labels
all_preds = torch.cat(all_preds)
all_labels = torch.cat(all_labels)

# Compute evaluation metrics
from sklearn.metrics import confusion_matrix, f1_score, balanced_accuracy_score
import seaborn as sns
import matplotlib.pyplot as plt

# Confusion Matrix
cm = confusion_matrix(all_labels, all_preds)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()

# Weighted F1 Score
f1 = f1_score(all_labels, all_preds, average='weighted')
print(f"Weighted F1 Score: {f1}")

# Balanced Accuracy (BAC)
bac = balanced_accuracy_score(all_labels, all_preds)
print(f"Balanced Accuracy (BAC): {bac}")