In [1]:
import os

import torch

import random

import numpy as np

import pandas as pd

import mne

In [2]:
participants_file = 'dataset/participants.tsv'
participants_df = pd.read_csv(participants_file, sep='\t')

subject_ids = participants_df['participant_id'].tolist()

In [3]:
import torch.nn.functional as F

from scipy.signal import resample

from torch.utils.data import DataLoader

from sklearn.preprocessing import StandardScaler

class EEGDataLoader(torch.utils.data.Dataset):
    def __init__(self, dir, list_IDs, standardize=True):
        """
        Args:
            dir (str): Directory where the processed EEG chunks are stored.
            list_IDs (list): List of subject paths relative to the base directory.
            sampling_rate (int): Target sampling rate for resampling the data.
        """
        self.dir = dir
        self.list_IDs = list_IDs  # List of subject directories
        self.standardize = standardize
        self.label_map = {
            "A": 0,
            "C": 1,
            "F": 2
        }
        # Gather all chunk file paths from the subject directories
        self.chunk_paths = []
        for subject_id in list_IDs:
            subject_dir = os.path.join(dir, subject_id)
            self.chunk_paths.extend(
                [
                    os.path.join(subject_id, file)
                    for file in os.listdir(subject_dir)
                    if file.endswith(".pt")
                ]
            )
        self.chunk_paths.sort()
        # If we are standardizing, initialize a StandardScaler
        if self.standardize:
            self.scaler = StandardScaler()

    def __len__(self):
        """Returns the total number of chunks available."""
        return len(self.chunk_paths)

    def __getitem__(self, idx):
        """
        Args:
            idx (int): Index to retrieve a chunk and its label.

        Returns:
            X (torch.Tensor): EEG data as a tensor.
            y (int): Label corresponding to the chunk.
        """
        # Load the .pt file for the given chunk
        chunk_path = os.path.join(self.dir, self.chunk_paths[idx])
        sample = torch.load(chunk_path)

        # Extract data and labels
        eeg_data = sample["data"]
        label = sample["label"]
        # Resample if necessary
        # eegsample = resample(eeg_data, int(eeg_data.shape[-1] * self.sampling_rate / self.default_rates), axis=-1)


        # Apply standard scaling to each channel independently
        if self.standardize:
            # Standardize each channel independently (along the time points)
            eeg_data = self.scaler.fit_transform(eeg_data.T).T

        # # Normalize the data
        # eegsample = eegsample / (
        #     np.quantile(
        #         np.abs(eegsample), q=0.95, interpolation="linear", axis=-1, keepdims=True
        #     )
        #     + 1e-8
        # )
        
        # Convert to PyTorch tensor
        eegsample = torch.from_numpy(eeg_data).float()

        # Extract the target label (e.g., Group)
        X = eegsample
        y = self.label_map[label]  # Adjust this if the target label changes
        return X, y

# Example usage
# Path to dataset and labels
data_dir = "dataset/derivatives/processed_dataset"

# Create dataset and DataLoader
dataset = EEGDataLoader(data_dir, subject_ids)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# Iterate through the DataLoader
# for batch in dataloader:
#     print(f"Batch Shape: {batch[0].shape}")
#     break


def collate_fn(batch):

    samples, labels = [], []

    for i, l in batch:

        samples.append(i)

        labels.append(l)

    samples = torch.stack(samples, dim = 0)

    labels = torch.tensor(labels)

    batch = samples

    return batch, labels

In [4]:
for batch in dataloader:
    print(f"Batch Shape: {batch[1]}")
    break

  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = tor

  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)


Batch Shape: tensor([1, 1, 1, 1, 0, 2, 0, 1, 1, 1, 2, 1, 0, 2, 2, 1, 1, 1, 0, 2, 1, 0, 2, 1,
        1, 0, 1, 0, 0, 2, 1, 0, 1, 2, 2, 1, 1, 2, 0, 2, 0, 1, 2, 2, 2, 0, 0, 1,
        0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 2])


  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)
  sample = torch.load(chunk_path)


In [5]:
raw = sample = torch.load('dataset/derivatives/processed_dataset/sub-001/sub-001_task-eyesclosed_eeg_chunk_037.pt')
raw['data'].shape

  raw = sample = torch.load('dataset/derivatives/processed_dataset/sub-001/sub-001_task-eyesclosed_eeg_chunk_037.pt')


(19, 3000)

In [82]:
import math
import torch
from momentfm import MOMENTPipeline
import torch.nn as nn
from linear_attention_transformer import LinearAttentionTransformer


class ClassificationHead(nn.Sequential):
    def __init__(self, emb_size, n_classes):
        super().__init__()
        self.clshead = nn.Sequential(
            nn.ELU(),
            nn.Linear(emb_size, n_classes),
        )

    def forward(self, x):
        out = self.clshead(x)
        return out


class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
        """
        Args:
            x: `embeddings`, shape (batch, max_len, d_model)
        Returns:
            `encoder input`, shape (batch, max_len, d_model)
        """
        x = x + self.pe[:, : x.size(1)]
        return self.dropout(x)

def create_patch(xb, patch_len, stride, max_sequence_length):
    """
    xb: [bs x n_vars x seq_len ]
    """

    seq_len = max_sequence_length if max_sequence_length is not None else xb.shape[2]

    mask = torch.ones(xb.shape)

    num_patch = math.ceil((max(seq_len, patch_len) - patch_len) / stride) + 1

    tgt_len = patch_len + stride * (num_patch -1)

    pd = tgt_len - seq_len

    pad1 = (0, pd)

    xb = F.pad(xb, pad1, "constant", 0)

    mask = F.pad(mask, pad1, "constant", 0)

    xb = xb.unfold(dimension=-1, size=patch_len, step=stride)                 # xb: [bs x n_vars x num_patch x patch_len]

    mask = mask.unfold(dimension=-1, size=patch_len, step=stride)                 # xb: [bs x n_vars x num_patch x patch_len]

    return xb, mask

    

class MomentEEG(nn.Module):
    def __init__(
        self,
        emb_size=512,
        n_channels=16,
        patch_len=512,
        stride=512,
        heads=8,
        depth=1,
        max_sequence_length=2560,
        **kwargs
    ):
        super().__init__()

        self.tokenizer = MOMENTPipeline.from_pretrained(
            "AutonLab/MOMENT-1-small",
            model_kwargs={'task_name': 'embedding', 'reduction': 'mean'},
        )
        self.tokenizer.init()

        self.tokenizer.eval()

        self.transformer = LinearAttentionTransformer(
            dim=emb_size,
            heads=heads,
            depth=depth,
            max_seq_len=1024,
            attn_layer_dropout=0.2,
            attn_dropout=0.2,
        )

        self.patch_len = patch_len
        self.stride = stride
        self.max_sequence_length = max_sequence_length

    def forward(self, x, perturb=False):
        """
        x: [batch_size, channel, num_patch, ts]
        output: [batch_size, emb_size]
        """
        x, m = create_patch(x, patch_len=self.patch_len, stride=self.stride, max_sequence_length=self.max_sequence_length)
        emb_seq = []

        for i in range(x.shape[1]):
            xb = x[:, i: i + 1, :, :].squeeze(1)
            mb = m[:, i: i + 1, :, :].squeeze(1)
            bs, num_patch, patch_len = xb.shape
            xb = torch.reshape(xb, (bs * num_patch, 1, patch_len))
            mb = torch.reshape(mb, (bs * num_patch, 1, patch_len))

            with torch.no_grad():
                xb = self.tokenizer(x_enc=xb.detach(), input_mask=mb.squeeze(1).detach(), mask=mb.detach()).embeddings
            # print(xb.shape)
            xb = torch.reshape(xb, (bs, num_patch, -1))

            # Save embeddings before applying channel embedding and positional encoding
            emb_seq.append(xb)

        # Return the raw embeddings
        emb = torch.cat(emb_seq, dim=1)  # (batch_size, 16 * ts, emb)
        # emb = self.transformer(emb)  # (batch_size, 16 * ts, emb)
        # emb = emb.mean(dim=1)  # (batch_size, emb)
        # print(emb.shape)
        return emb  # Return embeddings for reuse in MomentClassifier




In [86]:
# Set random seed for reproducibility
seed = 2024
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)

# Define root directories
data_dir = "dataset/derivatives/processed_dataset"
labels_file = "dataset/participants.tsv"

all_data_loader = DataLoader(
        EEGDataLoader(data_dir, subject_ids),
        batch_size=1,
        shuffle=False,
        drop_last=False,
        num_workers=2,
        persistent_workers=True,
        collate_fn = collate_fn
    )

print(len(all_data_loader))

4521


In [98]:
output_dir = 'dataset/derivatives/embeddings'
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

print(device)
model = MomentEEG(512, n_channels=19, max_sequence_length = 3000).to(device)

i = 0
for X, y in all_data_loader:
    batch_embedd_eeg=model(X.to(device))
    sample_emb = {
                "data": batch_embedd_eeg,
                "label": y
            }
    output_file_name = all_data_loader.dataset.chunk_paths[i]
    subject_id = output_file_name.split('/')[0]
    output_file_path = os.path.join(output_dir, output_file_name)
    if not os.path.exists(os.path.join(output_dir, subject_id)):
        os.makedirs(os.path.join(output_dir, subject_id))
    print(output_file_path)
    print(batch_embedd_eeg.shape)
    torch.save(sample_emb, output_file_path)
    i += 1

cpu




torch.Size([1, 114, 512])
dataset/derivatives/embeddings/sub-001/sub-001_task-eyesclosed_eeg_chunk_000.pt
torch.Size([1, 114, 512])
torch.Size([1, 114, 512])
dataset/derivatives/embeddings/sub-001/sub-001_task-eyesclosed_eeg_chunk_001.pt
torch.Size([1, 114, 512])
torch.Size([1, 114, 512])
dataset/derivatives/embeddings/sub-001/sub-001_task-eyesclosed_eeg_chunk_002.pt
torch.Size([1, 114, 512])
torch.Size([1, 114, 512])
dataset/derivatives/embeddings/sub-001/sub-001_task-eyesclosed_eeg_chunk_003.pt
torch.Size([1, 114, 512])
torch.Size([1, 114, 512])
dataset/derivatives/embeddings/sub-001/sub-001_task-eyesclosed_eeg_chunk_004.pt
torch.Size([1, 114, 512])
torch.Size([1, 114, 512])
dataset/derivatives/embeddings/sub-001/sub-001_task-eyesclosed_eeg_chunk_005.pt
torch.Size([1, 114, 512])
torch.Size([1, 114, 512])
dataset/derivatives/embeddings/sub-001/sub-001_task-eyesclosed_eeg_chunk_006.pt
torch.Size([1, 114, 512])
torch.Size([1, 114, 512])
dataset/derivatives/embeddings/sub-001/sub-001_tas

In [97]:
raw = sample = torch.load('dataset/derivatives/embeddings/sub-001/sub-001_task-eyesclosed_eeg_chunk_000.pt')
raw['data'].shape

  raw = sample = torch.load('dataset/derivatives/embeddings/sub-001/sub-001_task-eyesclosed_eeg_chunk_000.pt')


torch.Size([1, 114, 512])

In [91]:
class MomentClassifier(nn.Module):
    def __init__(self, emb_size=512, n_channels=16, n_classes=6, **kwargs):
        super().__init__()
        self.n_channels = n_channels
        self.emb_size = emb_size

        # Channel embedding and positional encoding will now be handled here
        self.positional_encoding = PositionalEncoding(emb_size)

        # Initialize learnable channel embeddings
        self.channel_tokens = nn.Embedding(n_channels, emb_size)
        self.index = nn.Parameter(torch.LongTensor(range(n_channels)), requires_grad=False)

        # If you want to add a classification head, you can do that here
        self.classifier = nn.Linear(emb_size, n_classes)

    def forward(self, x, perturb=False, saved_embeddings=None):
        """
        x: [batch_size, channel, num_patch, ts]
        saved_embeddings: Pre-computed embeddings that bypass the `MomentEEG` forward pass.
        """

        # Apply channel embedding and positional encoding here
        batch_size, ts, _ = x.shape
        channel_emb = []

        for i in range(self.n_channels):
            # Channel token embedding (repeat across time steps)
            channel_token_emb = (
                self.channel_tokens(self.index[i])
                .unsqueeze(0)
                .unsqueeze(0)
                .repeat(batch_size, ts, 1)
            )

            # Add positional encoding to the embeddings
            emb_with_channel_pos = self.positional_encoding(x + channel_token_emb)
            channel_emb.append(emb_with_channel_pos)

        # Stack embeddings from all channels and average them
        emb = torch.cat(channel_emb, dim=1)  # (batch_size, 16 * ts, emb)


        # (batch_size, emb)
        emb = emb.mean(dim=1)
        
        # Optionally, add a classification head here (not added in this code snippet)
        emb = self.classifier(emb)
        return emb


In [92]:
model = MomentClassifier(n_channels=19, n_classes=3)

In [93]:
out = model(raw['data'])

In [94]:
out

tensor([[ 0.3287,  0.1476, -0.3343]], grad_fn=<AddmmBackward0>)

In [102]:
participants_file = 'dataset/participants.tsv'
participants_df = pd.read_csv(participants_file, sep='\t')

subject_ids = participants_df['participant_id'].tolist()

# Perform LOSO cross-validation splits
loso_splits = []
for test_subject in subject_ids:
    # Test set is the current subject
    test_set = [test_subject]

    # Randomly select 6 subjects for validation
    validation_subjects = random.sample(subject_ids, 6)
    while test_subject in validation_subjects:
        validation_subjects = random.sample(subject_ids, 6)
        
    train_set = []
    # Training set is all other subjects
    for subject in subject_ids:
        if subject != test_subject and subject not in validation_subjects:# and int(subject[-3:]) <=3:
            train_set.append(subject)
    # train_set = [subject for subject in subject_ids if subject != test_subject]

    # Append the split to the list
    loso_splits.append({'train': train_set, 
                        'val': validation_subjects,
                        'test': test_set})
# for i, split in enumerate(loso_splits):
#     print(f"Fold {i + 1}:")
#     print(f"  Train: {split['train']}")
#     print(f"  Test: {split['test']}")
#     print()

In [103]:
# Set random seed for reproducibility
seed = 2024
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)

# Define root directories
data_dir = "dataset/derivatives/embeddings"
labels_file = "dataset/participants.tsv"

train_loader = DataLoader(
        EEGDataLoader(data_dir, loso_splits[0]['train']),
        batch_size=64,
        shuffle=True,
        drop_last=False,
        num_workers=2,
        persistent_workers=True,
        collate_fn = collate_fn
    )

val_loader = DataLoader(
        EEGDataLoader(data_dir, loso_splits[0]['val']),
        batch_size=64,
        shuffle=True,
        drop_last=False,
        num_workers=2,
        persistent_workers=True,
        collate_fn = collate_fn
    )

test_loader = DataLoader(
    EEGDataLoader(data_dir, loso_splits[0]['test']),
    batch_size=64,
    shuffle=True,
    drop_last=False,
    num_workers=2,
    persistent_workers=True,
    collate_fn = collate_fn
)

print(len(train_loader))

print(len(train_loader))

66
66


# TEST

In [104]:
# define the model
model_name = "Moment"
if model_name == "Moment":
    model = MomentClassifier(512, n_classes = 3,n_channels=19)
else:
    raise NotImiplementedError
lightning_model = LitModel_finetune(model,args={
        "lr":0.01,
        "weight_decay":1e-5,
        "gamma":0.1,
        "n_steps":100,
    })

# logger and callbacks
version = "moment"
logfolder = "log"
logger = TensorBoardLogger(
    save_dir="./",
    version=version,
    name=logfolder,
)
# early_stop_callback = EarlyStopping(
#     monitor="val_f1", patience= args.patience, verbose=False, mode="max"
# )

checkpoint_callback = ModelCheckpoint(save_top_k = 0,
                                      monitor = "epoch",
                                      mode = "max",
                                      save_last = True
                                        )

tqdm_progress_bar = TQDMProgressBar(refresh_rate= 20, process_position=0)

trainer = pl.Trainer(
    # devices=1,  # Set devices to an integer instead of a list when using CPU
    devices=[0],  # Use list format only when specifying GPUs
    accelerator="auto",
    strategy='ddp_notebook',
    benchmark=True,
    enable_checkpointing=True,
    logger=logger,
    max_epochs=10,
    callbacks= [checkpoint_callback, tqdm_progress_bar], # [early_stop_callback, tqdm_progress_bar],
    log_every_n_steps = 1,
)

# train the model
trainer.fit(
    lightning_model, train_dataloaders=train_loader, val_dataloaders=val_loader
)

# test the model
pretrain_result = trainer.test(
    model=lightning_model, ckpt_path="last", dataloaders=test_loader
)[0]
print(pretrain_result)

NameError: name 'LitModel_finetune' is not defined