<a href="https://colab.research.google.com/github/asma-walha/Machine-Learning/blob/main/eeg.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Set Up the Environment

In [1]:
!nvidia-smi

Mon Feb 17 15:29:24 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   63C    P8             11W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
!free -h

               total        used        free      shared  buff/cache   available
Mem:            12Gi       1.6Gi       9.1Gi       2.0Mi       1.9Gi        10Gi
Swap:             0B          0B          0B


In [2]:
!pip install mne

Collecting mne
  Downloading mne-1.9.0-py3-none-any.whl.metadata (20 kB)
Downloading mne-1.9.0-py3-none-any.whl (7.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m90.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: mne
Successfully installed mne-1.9.0


In [3]:
import mne
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR

Preprocess EEG Data


In [4]:
# Load a sample EEG dataset (e.g., the MNE sample dataset)
sample_data_folder = mne.datasets.sample.data_path()
sample_data_raw_file = sample_data_folder / 'MEG' / 'sample' / 'sample_audvis_raw.fif'
raw = mne.io.read_raw_fif(sample_data_raw_file, preload=True)

# Print information about the dataset
print(raw.info)

# Après le chargement des données
print(f"Shape of raw data: {raw.get_data().shape}")

Using default location ~/mne_data for sample...
Creating /root/mne_data


Downloading file 'MNE-sample-data-processed.tar.gz' from 'https://osf.io/86qa2/download?version=6' to '/root/mne_data'.
100%|█████████████████████████████████████| 1.65G/1.65G [00:00<00:00, 1.69TB/s]
Untarring contents of '/root/mne_data/MNE-sample-data-processed.tar.gz' to '/root/mne_data'


Attempting to create new mne-python configuration file:
/root/.mne/mne-python.json
Download complete in 04m37s (1576.2 MB)
Opening raw data file /root/mne_data/MNE-sample-data/MEG/sample/sample_audvis_raw.fif...
    Read a total of 3 projection items:
        PCA-v1 (1 x 102)  idle
        PCA-v2 (1 x 102)  idle
        PCA-v3 (1 x 102)  idle
    Range : 25800 ... 192599 =     42.956 ...   320.670 secs
Ready.
Reading 0 ... 166799  =      0.000 ...   277.714 secs...
<Info | 21 non-empty values
 acq_pars: ACQch001 110113 ACQch002 110112 ACQch003 110111 ACQch004 110122 ...
 bads: 2 items (MEG 2443, EEG 053)
 ch_names: MEG 0113, MEG 0112, MEG 0111, MEG 0122, MEG 0123, MEG 0121, MEG ...
 chs: 204 Gradiometers, 102 Magnetometers, 9 Stimulus, 60 EEG, 1 EOG
 custom_ref_applied: False
 description: acquisition (megacq) VectorView system at NMR-MGH
 dev_head_t: MEG device -> head transform
 dig: 146 items (3 Cardinal, 4 HPI, 61 EEG, 78 Extra)
 events: 1 item (list)
 experimenter: MEG
 file_id: 4

In [5]:
# Preprocess the EEG data
def preprocess_eeg(raw, l_freq=0.1, h_freq=38, resample_rate=256):
    # Filter the data
    raw.filter(l_freq=l_freq, h_freq=h_freq)

    raw.pick_types(eeg=True)  # Ne garde que les canaux EEG

    # Resample the data
    raw.resample(resample_rate)

    # Re-reference to average
    raw.set_eeg_reference(ref_channels='average')

    # Segment into epochs (e.g., 4 seconds)
    epochs = mne.make_fixed_length_epochs(raw, duration=4.0)
    return epochs.get_data()

epochs_data = preprocess_eeg(raw)
print(f"Shape of epochs data: {epochs_data.shape}")
epochs_tensor = torch.tensor(epochs_data, dtype=torch.float32)
print(f"Shape of epochs tensor: {epochs_tensor.shape}")

# Create a dataset and DataLoader
dataset = TensorDataset(epochs_tensor)
batch_size = 32
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.1 - 38 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.10
- Lower transition bandwidth: 0.10 Hz (-6 dB cutoff frequency: 0.05 Hz)
- Upper passband edge: 38.00 Hz
- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)
- Filter length: 19821 samples (33.001 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.2s
[Parallel(n_jobs=1)]: Done  71 tasks      | elapsed:    0.8s
[Parallel(n_jobs=1)]: Done 161 tasks      | elapsed:    2.0s
[Parallel(n_jobs=1)]: Done 287 tasks      | elapsed:    3.9s


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Not setting metadata
69 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 69 events and 1024 original time points ...
0 bad epochs dropped
Shape of epochs data: (69, 59, 1024)
Shape of epochs tensor: torch.Size([69, 59, 1024])


Implement the EEGPT Architecture

In [6]:
class EEGPT(nn.Module):
    def __init__(self, num_channels, num_timepoints, embed_dim, num_layers):
        super(EEGPT, self).__init__()
        self.num_channels = num_channels
        self.num_timepoints = num_timepoints
        self.embed_dim = embed_dim

        # Local spatio-temporal embedding
        self.embedding = nn.Linear(num_timepoints, embed_dim)

        # Transformer encoder with batch_first=True
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=8, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Momentum encoder
        self.momentum_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Predictor and reconstructor
        self.predictor = nn.Linear(embed_dim, embed_dim)
        self.reconstructor = nn.Linear(embed_dim, num_timepoints)

    def forward(self, x, mask):
        batch_size, num_channels, num_timepoints = x.shape

        # Apply masking to the input tensor
        masked_x = x * mask  # Shape: (batch_size, num_channels, num_timepoints)

        # Reshape for embedding: (batch_size * num_channels, num_timepoints)
        masked_x = masked_x.view(-1, num_timepoints)

        # Apply local spatio-temporal embedding
        embedded_x = self.embedding(masked_x)  # Shape: (batch_size * num_channels, embed_dim)

        # Reshape back: (batch_size, num_channels, embed_dim)
        embedded_x = embedded_x.view(batch_size, num_channels, self.embed_dim)

        # Encode the masked input
        enc_output = self.encoder(embedded_x)  # Shape: (batch_size, num_channels, embed_dim)

        # Predict the full features
        pred_output = self.predictor(enc_output)  # Shape: (batch_size, num_channels, embed_dim)

        # Reshape for reconstruction: (batch_size * num_channels, embed_dim)
        pred_output_reshaped = pred_output.view(-1, self.embed_dim)

        # Reconstruct the masked parts
        recon_output = self.reconstructor(pred_output_reshaped)  # Shape: (batch_size * num_channels, num_timepoints)

        # Reshape back: (batch_size, num_channels, num_timepoints)
        recon_output = recon_output.view(batch_size, num_channels, self.num_timepoints)

        # Apply embedding to the input for the momentum encoder
        full_x = x.view(-1, num_timepoints)  # Shape: (batch_size * num_channels, num_timepoints)
        embedded_full_x = self.embedding(full_x)  # Shape: (batch_size * num_channels, embed_dim)
        embedded_full_x = embedded_full_x.view(batch_size, num_channels, self.embed_dim)  # Reshape back
        momentum_output = self.momentum_encoder(embedded_full_x)  # Shape: (batch_size, num_channels, embed_dim)

        return enc_output, pred_output, recon_output, momentum_output

In [7]:
# Test avec des données factices
batch_size = 32
num_channels = 59
num_timepoints = 1024
embed_dim = 512

x = torch.randn(batch_size, num_channels, num_timepoints)
mask = torch.rand(batch_size, num_channels, num_timepoints) > 0.5
print(f"Shape of x: {x.shape}")
print(f"Shape of mask: {mask.shape}")

model = EEGPT(num_channels, num_timepoints, embed_dim, num_layers=8)
enc_output, pred_output, recon_output, momentum_output = model(x, mask)

print(f"Shape of enc_output: {enc_output.shape}")  # Doit être (batch_size, num_channels, embed_dim)
print(f"Shape of pred_output: {pred_output.shape}")  # Doit être (batch_size, num_channels, embed_dim)
print(f"Shape of recon_output: {recon_output.shape}")  # Doit être (batch_size, num_channels, num_timepoints)
print(f"Shape of momentum_output: {momentum_output.shape}")  # Doit être (batch_size, num_channels, embed_dim)

Shape of x: torch.Size([32, 59, 1024])
Shape of mask: torch.Size([32, 59, 1024])
Shape of enc_output: torch.Size([32, 59, 512])
Shape of pred_output: torch.Size([32, 59, 512])
Shape of recon_output: torch.Size([32, 59, 1024])
Shape of momentum_output: torch.Size([32, 59, 512])


Implement the Dual Self-Supervised Loss

In [11]:
import torch.nn.functional as F

def dual_self_supervised_loss(enc_output, pred_output, recon_output, momentum_output, mask):
    # Alignment loss
    alignment_loss = nn.MSELoss()(pred_output, momentum_output)

    # Resize recon_output to match enc_output's dimensions
    recon_output_resized = F.interpolate(recon_output.permute(0, 2, 1), size=512, mode='linear', align_corners=False).permute(0, 2, 1)

    # Resize mask to match recon_output_resized's dimensions
    mask_resized = F.interpolate(mask.permute(0, 2, 1), size=512, mode='nearest').permute(0, 2, 1)

    # Reconstruction loss
    reconstruction_loss = nn.MSELoss()(recon_output_resized * mask_resized, enc_output * mask_resized)

    # Total loss
    total_loss = alignment_loss + reconstruction_loss
    return total_loss

In [9]:
# Generate random masks
def generate_mask(x, mask_time_ratio=0.5, mask_channel_ratio=0.8):
    batch_size, num_channels, num_timepoints = x.shape
    time_mask = torch.rand(batch_size, num_timepoints) > mask_time_ratio  # Shape: (batch_size, num_timepoints)
    channel_mask = torch.rand(batch_size, num_channels) > mask_channel_ratio  # Shape: (batch_size, num_channels)
    mask = time_mask.unsqueeze(1) * channel_mask.unsqueeze(2)  # Shape: (batch_size, num_channels, num_timepoints)
    return mask.float()

Train the Model


In [12]:
# Initialize model, optimizer, and scheduler
model = EEGPT(num_channels=376, num_timepoints=1024, embed_dim=512, num_layers=8)
optimizer = AdamW(model.parameters(), lr=2.5e-4)
scheduler = OneCycleLR(optimizer, max_lr=5e-4, total_steps=200)

# Training loop
for epoch in range(200):
    for batch in dataloader:
        x = batch[0]  # Get the EEG data
        mask = generate_mask(x)  # Generate masks with the correct shape
        print(f"Shape of x: {x.shape}")
        print(f"Shape of mask: {mask.shape}")

        # Forward pass
        enc_output, pred_output, recon_output, momentum_output = model(x, mask)
        print(f"Shape of enc_output: {enc_output.shape}")  # Doit être (batch_size, num_channels, embed_dim)
        print(f"Shape of pred_output: {pred_output.shape}")  # Doit être (batch_size, num_channels, embed_dim)
        print(f"Shape of recon_output: {recon_output.shape}")  # Doit être (batch_size, num_channels, num_timepoints)
        print(f"Shape of momentum_output: {momentum_output.shape}")  # Doit être (batch_size, num_channels, embed_dim)

        # Compute loss
        loss = dual_self_supervised_loss(enc_output, pred_output, recon_output, momentum_output, mask)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        print(f"Epoch {epoch+1}, Loss: {loss.item()}")

Shape of x: torch.Size([32, 59, 1024])
Shape of mask: torch.Size([32, 59, 1024])
Shape of enc_output: torch.Size([32, 59, 512])
Shape of pred_output: torch.Size([32, 59, 512])
Shape of recon_output: torch.Size([32, 59, 1024])
Shape of momentum_output: torch.Size([32, 59, 512])


RuntimeError: The size of tensor a (512) must match the size of tensor b (59) at non-singleton dimension 1