In [None]:
# !pip install google.generativeai

In [1]:
import scipy
import numpy as np
import os
import pandas as pd
from sklearn.model_selection import train_test_split
# import google.generativeai as genai
import json
import tqdm
from google.api_core import exceptions
import time
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
import random
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from tqdm import tqdm

# Project 5: EEG-to-Text Summarization Using Foundation Models


# To DO

**GoalðŸŽ¯**

The goal of this project is to develop an intelligent system capable of summarizing EEG recordings into concise, human-readable text.
Students will explore how foundation models and deep learning can transform complex time-series brain signals into meaningful descriptions or clinical-style reports.

The ultimate objective is to bridge the gap between raw neurophysiological data and human interpretation, demonstrating how AI can support neurological diagnostics or cognitive research.

<!-- ![image.png](attachment:image.png) -->

**ðŸ§©Tasks for Students**

â€¢ Review recent research on EEG foundation models and EEG-to-text generation (e.g., EEGPT, NeuroLM, GLIM).

â€¢ Understand how signal summarization and multimodal modeling (EEG + text) are implemented.

â€¢ Identify and preprocess a suitable EEG dataset (e.g., TUH EEG Corpus, Sleep-EDF, or ZuCo).

â€¢ Apply filtering, normalization, and segmentation techniques using MNE-Python or similar tools.

â€¢ Implement or fine-tune an EEG encoder (using existing pretrained models or a small Transformer/CNN).

â€¢ Combine it with a text decoder or classifier (e.g., T5-small, LLaMA-3, or GPT-style model).

â€¢ Train the system to generate or summarize textual descriptions of EEG signals.

â€¢ Quantitatively assess summaries using NLP metrics (BLEU, ROUGE).

â€¢ Qualitatively evaluate interpretability (are the generated summaries meaningful?).

â€¢ Optionally visualize which EEG regions/time windows influence the summary.

â€¢ Build a simple web-based demo (Streamlit or Gradio) showing EEG input and generated text output.

â€¢ Write a final report summarizing methods, experiments, and findings.


**References:**

- EEG-To-text: https://arxiv.org/abs/2505.17099

- Enhancing EEG-to-Text Decoding through Transferable Representations from Pre-trained Contrastive EEG-Text Masked Autoencoder https://aclanthology.org/2024.acl-long.393

- EEGtoText: Learning to Write Medical Reports from EEG Recordings https://proceedings.mlr.press/v106/biswal19a.html

- NeuroLM: A Universal Multi-task Foundation Model for Bridging the Gap between Language and EEG Signals https://arxiv.org/abs/2409.00101

# Problem statement

[Learning Interpretable Representations Leads to Semantically Faithful EEG-to-Text Generation](https://arxiv.org/abs/2505.17099)

# Architecture - overview (for previous version with paraphrases):

<img src="arch_plan.png" alt="Arch" width="600">

# Architecture - Actual

<img src="arch_actual.png" alt="Arch" width="2000">

# Dataset: ZuCo (Zurich Cognitive Language Processing Corpus)

[The ZuCo benchmark on cross-subject reading task classification with EEG and eye-tracking data](https://www.frontiersin.org/journals/psychology/articles/10.3389/fpsyg.2022.1028824/full)

ZuCo dataset is a benchmark for sumultaneous eye-tracking and EEG recordings during natual reading.
The dataset contains 128-channel EEG recordings sampled at 500Hz during English sentence reading.

Key features:

- modality: high-density EEG signals

- reading paradigms:

    - Normal Reading (NR): passive, natural reading of sentences

    - Task Specific Reading (TSR): active reading where subjects must determine specific semantic relations or answers comprehension questions

- corpora: sentence are drawn from:
    - Wikipedia (relation extraction)
    - Stanford Sentiment Treebank (sentiment analysis)

[ZuCo 1.0](https://osf.io/q3zws/files/osfstorage) (for now I'm using only this one and only task1)

[ZuCo 2.0](https://osf.io/2urht/files/osfstorage)

## EEG Preprocessing

Each sample in `sentenceData.rawData` consists of a raw EEG matrix with dimensions: `n_channels` $\times$ `n_timepoints`.
The original signal is sampled at 500Hz.

We apply a standardized preprocessing pipeline to normalize these dimensions for the model:

- Downsampling `500Hz -> 128Hz`

    The sampling frequency is reduced from 500Hz to 128Hz

- Padding

    - Time: Sequences are padded or truncated to a fixed length of **1280 points** (~10s)

    - Channels: The spatial dimension is zero-padded **from 104 to 128 channels** to match th encoder architecture


In [None]:
CONFIG = {
    'target_frequency': 128,
    'sequence_length': 1280,
    'target_channels': 128,
    'excluded_channel': 105,
    'dataset_path': './dataset_zuco',
    'processed_path': './processed_zuco'
}

class ZuCoProcessor:
    def __init__(self, config):
        self.config = config

    def load_mat_file(self, file_path):
        '''
        Load a .mat file and return its contents.
        '''
        data = []
        # dataset_name = 'ZuCo1.0' if 'ZuCo1.0' in file_path else 'ZuCo2.0'
        if 'SR' in file_path:
            task = 'SR'
        elif 'NR' in file_path:
            task = 'NR'
        else:
            task = 'TSR'

        print(file_path)
        mat = scipy.io.loadmat(file_path,
                               squeeze_me=True,
                               struct_as_record=False,
                               variable_names=['sentenceData'])
        sentences = mat['sentenceData']

        for sentence in sentences:
            content = sentence.content
            eeg_data = sentence.rawData

            if not np.isnan(eeg_data).all() and eeg_data.shape[0] > 0:
                # final_eeg, normalized_eeg = self.process_eeg_signal(eeg_data)
                data.append({
                    'text': content,
                    'eeg_data': eeg_data,
                    'task': task,
                    'subject': os.path.basename(file_path).split('_')[0].replace('results', '')
                    # 'final_eeg': final_eeg,
                    # 'normalized_eeg': normalized_eeg
                })
        return data


    def process_eeg_signal(self, eeg_signal, original_freq=500):
        ''''
        Process the EEG signal to match target frequency and channels.
        '''
        # print(f'Original EEG shape: {eeg_signal.shape}')
        ###### 1. Nan Handling ######
        # Remove Channels with NaN Values
        eeg_signal = eeg_signal[~np.isnan(eeg_signal).any(axis=1)]
        # Input Nans with 0 (if any left)
        eeg_signal = np.nan_to_num(eeg_signal)

        ###### 2. Downsampling ###### 500Hz -> 128Hz
        downsample_factor = original_freq // self.config['target_frequency']
        eeg_resampled = scipy.signal.decimate(eeg_signal, downsample_factor, axis=1, zero_phase=True)
        # print(f'Downsampled EEG shape: {eeg_resampled.shape}')

        ###### 3. Time Padding ###### to fixed length 1280 samples
        c, t = eeg_resampled.shape
        padded_eeg = np.zeros((c, self.config['sequence_length']))
        if t >= self.config['sequence_length']:
            padded_eeg = eeg_resampled[:, :self.config['sequence_length']]
        else:
            padded_eeg[:, :t] = eeg_resampled
        # print(f'(Time) Padded EEG shape: {padded_eeg.shape}')

        ###### 4. Channel Padding ######
        c, t = padded_eeg.shape
        final_eeg = np.zeros((self.config['target_channels'], t))

        if c >= self.config['target_channels']:
            final_eeg = padded_eeg[:self.config['target_channels'], :]
        else:
            final_eeg[:c, :] = padded_eeg
        # print(f'(Channel) Padded EEG shape: {final_eeg.shape}')

        ##### 5. Normalization Z-score ######
        mean = np.mean(final_eeg, axis=1, keepdims=True)
        std = np.std(final_eeg, axis=1, keepdims=True)
        normalized_eeg = (final_eeg - mean) / (std + 1e-6)

        return final_eeg, normalized_eeg

    def create_dataset(self):
        eeg_dir = os.path.join(self.config["processed_path"], "eeg_files")

        all_samples = []
        idx = 0

        for root, dirs, files in os.walk(self.config['dataset_path']):
            for file in files:
                if file.endswith('.mat') and file.startswith('results'):
                    full_path = os.path.join(root, file)
                    print(f'Processing file: {full_path}')

                    samples = self.load_mat_file(full_path)
                    for sample in samples:
                        processed_eeg, normalized_eeg = self.process_eeg_signal(sample['eeg_data'])
                        file_name = f"{sample['task']}_{sample['subject']}_{idx}.npy"
                        file_path = os.path.join(eeg_dir, file_name)

                        # save files with normalized eeg data
                        np.save(file_path, normalized_eeg)

                        all_samples.append({
                            'text': sample['text'],
                            # 'eeg_data': sample['eeg_data'],
                            # 'eeg_processed': processed_eeg,
                            'eeg_path': os.path.join('eeg_files', file_name),
                            'task': sample['task'],
                            'subject': sample['subject']
                        })
                        idx+=1
        print(f"Preprocessed and normalized EEG data saved in {eeg_dir}")

        self.split_and_save(all_samples)

    def split_and_save(self, all_samples):
        '''
        Split the dataset into train, val, test and save as CSV files.
        '''

        # unique sentences
        sentences = list(set([sample['text'] for sample in all_samples]))

        # sentences split
        train_texts, test_texts = train_test_split(sentences, test_size=0.2, random_state=42)
        val_texts, test_texts = train_test_split(test_texts, test_size=0.5, random_state=42)

        splits = {
            'train': [],
            'val': [],
            'test': []
        }

        for sample in all_samples:
            if sample['text'] in set(train_texts):
                splits['train'].append(sample)
            elif sample['text'] in set(val_texts):
                splits['val'].append(sample)
            else:
                splits['test'].append(sample)

        print(f"Train samples: {len(splits['train'])}"
              f", Val samples: {len(splits['val'])}"
              f", Test samples: {len(splits['test'])}")

        # Save splits to CSV
        for split_name, split_data in splits.items():
            df_split = pd.DataFrame(split_data)
            save_path = os.path.join(self.config['processed_path'], f'{split_name}_dataset.csv')
            df_split.to_csv(save_path, index=False)
            print(f'Saved {split_name} dataset to {save_path}')

In [None]:
# processor = ZuCoProcessor(CONFIG)
# processor.create_dataset()

Processing file: ./dataset_zuco\task1-SR\resultsZAB_SR.mat
./dataset_zuco\task1-SR\resultsZAB_SR.mat
Processing file: ./dataset_zuco\task1-SR\resultsZDM_SR.mat
./dataset_zuco\task1-SR\resultsZDM_SR.mat
Processing file: ./dataset_zuco\task1-SR\resultsZDN_SR.mat
./dataset_zuco\task1-SR\resultsZDN_SR.mat
Processing file: ./dataset_zuco\task1-SR\resultsZGW_SR.mat
./dataset_zuco\task1-SR\resultsZGW_SR.mat
Processing file: ./dataset_zuco\task1-SR\resultsZJM_SR.mat
./dataset_zuco\task1-SR\resultsZJM_SR.mat
Processing file: ./dataset_zuco\task1-SR\resultsZJN_SR.mat
./dataset_zuco\task1-SR\resultsZJN_SR.mat
Processing file: ./dataset_zuco\task1-SR\resultsZJS_SR.mat
./dataset_zuco\task1-SR\resultsZJS_SR.mat
Processing file: ./dataset_zuco\task1-SR\resultsZKB_SR.mat
./dataset_zuco\task1-SR\resultsZKB_SR.mat
Processing file: ./dataset_zuco\task1-SR\resultsZKH_SR.mat
./dataset_zuco\task1-SR\resultsZKH_SR.mat
Processing file: ./dataset_zuco\task1-SR\resultsZKW_SR.mat
./dataset_zuco\task1-SR\resultsZ

## Data Augmentation Strategy - Multiple Text Variants (MTV) - exclude for now

MTV is the crucial component of the GLIM approach.

For every sentence, multiple paraphrased versions are generated using LLM.

These variants **preserve the core meaning** but **vary in syntax and lexical** choice.

During training, the model learns to map a single EEG signal to these diverse text variants, **forcing it to learn abstract semantics** rather than memorizing specific word sequences.

In [None]:
# with open('API_key.txt', 'r') as file:
#     API_KEY = file.read().strip()

# genai.configure(api_key=API_KEY)
# model = genai.GenerativeModel('gemini-2.5-flash',
#                               generation_config={
#                                   "response_mime_type": "application/json"
#                               })

In [None]:
# def get_paraphrase(sentence): # TODO: update prom -> processing in batches
#     """
#     Generate 5 distinct paraphrases for the given sentence using a language model.
#     """
#     prompt = f"""
#     Target sentence: "{sentence}"

#     Task: Generate 5 distinct paraphrases of the target sentence.
#     Requirements:
#     - Each paraphrase must convey the same meaning as the target sentence.
#     - Use varied vocabulary and sentence structures.
#     - Output MUST be in JSON format as a list of strings.

#     Example Output:
#     ["Paraphrase 1", "Paraphrase 2", "Paraphrase 3", "Paraphrase 4", "Paraphrase 5"]
#     """
#     while True:
#         try:
#             response = model.generate_content(prompt)
#             return json.loads(response.text)
#         except exceptions.ResourceExhausted as e:
#             print("Rate limit exceeded. Waiting 60 seconds...")
#             time.sleep(60)
#         except Exception as e:
#             print(f"Error generating paraphrases: {e}")
#             return []

# def generate_mtv(df):
#     """
#     Generate a dictionary mapping each unique sentence in the dataframe
#     to its list of 5 paraphrases.
#     """
#     # df = pd.read_csv(df_path)
#     mtv_dict = {}

#     unique_sentences = df['text'].unique()

#     if os.path.exists('mtv_dict.json'):
#         with open('mtv_dict.json', 'r') as f:
#             mtv_dict = json.load(f)

#     sentences_to_process = [s for s in unique_sentences if s not in mtv_dict]
#     for sentence in sentences_to_process:
#         paraphrases = get_paraphrase(sentence)
#         if paraphrases:
#             mtv_dict[sentence] = paraphrases
#         else:
#             print(f"Failed to generate paraphrases for: {sentence}")

#     with open('mtv_dict.json', 'w') as f:
#         json.dump(mtv_dict, f, indent=4)

In [None]:
# df_train = pd.read_csv('./processed_zuco/train_dataset.csv')
# df_val = pd.read_csv('./processed_zuco/val_dataset.csv')
# df_test = pd.read_csv('./processed_zuco/test_dataset.csv')

# for df in [df_train, df_val, df_test]:
#     generate_mtv(df)

# The GLIM (Generative Language Inspection Model) Architecture

To enable effective joint training, GLIM uses promptbased domain adaptation across three factors:

- reading paradigm (task),
- dataset version (dataset),
- subject identity (subject).

Among them, the task prompt is particularly important, motivated by the distinct cognitive processes in NR versus TSR, reflected in the consistent differences in reading time.

(However for now I have downloaded and processed only data for ZuCo1.0 for task1: SR and dont implement prompt injection)

In [2]:
class ZuCoDataset(Dataset):
    def __init__(self, csv_path, root_dir, mtv_path=None, tokenizer_name="google/flan-t5-small", phase="train", max_len=64):
        self.data = pd.read_csv(csv_path)
        self.root_dir = root_dir
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        self.phase = phase
        self.max_len = max_len

        if mtv_path and os.path.exists(mtv_path):
            with open(mtv_path, 'r') as f:
                self.mtv_dict = json.load(f)
        else:
            self.mtv_dict = None

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):

        # file for given index
        row = self.data.iloc[index]

        normalized_eeg_path = row['eeg_path'].replace('\\', '/')
        full_path = os.path.join(self.root_dir, normalized_eeg_path)
        eeg_array = np.load(full_path)
        eeg_tensor = torch.from_numpy(eeg_array).float()

        original_text = row["text"]

        # Data Augmentation - MTV - for now we're using only original text
        if self.phase == 'train' and self.mtv_dict and original_text in self.mtv_dict:
            # randomizing text variations prevents the model from overfitting to identical inputs across epochs. This forces the model to capture semantic meaning rather than memorizing syntax.
            variants = self.mtv_dict[original_text]
            target_text = random.choice(variants)
        else:
            target_text = original_text # test: deterministic

        # sentence tokenization for language model
        encoding = self.tokenizer(
            target_text,
            padding='max_length',
            truncation=True,
            max_length=self.max_len,
            return_tensors='pt'     # PyTorch
        )

        return {
            'eeg': eeg_tensor,                                          # EEG signal tensor
            'input_ids': encoding['input_ids'].squeeze(),               # tokenized text converted to IDs (target for the LLM Decoder), squeeze->to remove batch dimension added by tokenizer (1, 64)->(64)
            'attention_mask': encoding['attention_mask'].squeeze(),     # binary mask indicating real tokens (1) vs padding tokens (0) - essential for the model to ignore padding during attention calculation
            'subject': row['subject'],
            'task': row['task']
        }

In [None]:
# # Dataset intances
# train_dataset = ZuCoDataset(
#     csv_path='./processed_zuco/train_dataset.csv',
#     root_dir = './processed_zuco/',
#     mtv_path='mtv_dict.json',
#     phase='train'
# )

# val_dataset = ZuCoDataset(
#     csv_path='./processed_zuco/val_dataset.csv',
#     root_dir = './processed_zuco/',
#     mtv_path='mtv_dict.json',
#     phase='test'
# )

# # DataLoaders
# train_loader = DataLoader(
#     train_dataset,
#     batch_size=8,
#     shuffle=True
# )

# val_loader = DataLoader(
#     val_dataset,
#     batch_size=8,
#     shuffle=False
# )

# # TEST
# batch = next(iter(train_loader))
# print(f"EEG Shape: {batch['eeg'].shape}")
# print(f"Tokens Shape: {batch['input_ids'].shape}")

## EEG Enoder

**Role:**

The EEGEncoder serves as the promary feature extraction module within the architecture.

Objective: transform raw, high-dimensional EEG signals into a robust, dense latent representation.

This module acts as the interface betweeen the brain signals and the neural network's embedding space.

**Mechanism:**

The encoder employs a hybrid architecture combinig CNNs and Transformers:

- _Temporal Convolution_:

    The initial stage utilizes a 1D Conv layer to process the raw input along the temporal dimension.
    This layer performs 2 critical functions:
    - projects the discrete physical channels (128 electrodes) into a higher-dimensional feature space (d_model - for t5-small: 512)
    - captures local temporal patterns, reducing signal noise

- _Self-Attention Mechanism_:

    Following convolution, a standard Transformer Encoder stack processes the sequence. Utilizing the Multi-Head Self-Attention mechanism, this component captures long-range dependencies across the entire time series (1280 points), allowing the model to understand the global context of the brain activity beyond immediate local fluctuations.


In [3]:
class EEGEncoder(nn.Module):
    """
    Processes raw EEG signals into feature vectors compatible with the LLM
    Input Shape: (Batch, 128 channels, 1280 time points)
    """
    def __init__(self, d_model=512, nhead=8, num_layers=6):
        super().__init__()

        # Initial Layer Conv1D
        # In: 128 EEG channels
        # Out: d_model dimensions (512 for t5-small)
        self.conv = nn.Conv1d(in_channels=128, out_channels=d_model, kernel_size=3, padding=1)

        # Transformer Encoder
        # Captures long-range dependencies in the EEG signal
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(self, x):
        # Input: [Batch, 128, 1280]
        x = self.conv(x)
        # Output: [Batch, d_model, 1280]

        # Swap dimensions for Transformer (Batch, Time, Features)
        x = x.permute(0, 2, 1)
        # Output: [Batch, 1280, d_model]

        output = self.transformer_encoder(x)
        return output

## Query Aligner

**Role:**

The QAligner functions as a semantic bridge between the EEG Encoder and the LLM by translating continous EEG features into a sequence of token-like embeddings that are compatible with LLM's input space.

**Mechanism:**
This module relies on a Cross-Attention mechanism driven by learnable parameters:

- _Learnable Queries:_

    - the model initialized a fixed set of latent vectors
    - these vectors act as "questions" that the network learns to ask the EEG signal to extract relevant information

- _Cross-Attention:_

    Multi-Head Cross-Attention layer:

    - Query [Q] - Learnable Queries

    - Key [K], Value [V] - output of the EEG Encoder

    This operation compress the long-temporal sequence of the EEG (1280 points) into a compact sequence of semantic vectors, effectively filtering out noise and retaining only the information necessary for text generation.

- _Dimension Alignment:_

    A linear projection layer nsures that the feature dimension matches the embeddingg size of the target LLM

In [4]:
class QAligner(nn.Module):
    """
    Cross-Attention mechanism to bridge the gap between the EEG Encoder and the LLM:
    Compress the long temporal sequence of EEG features into a short, fixed numbers of summary tokens.
    """
    def __init__(self, eeg_dim, llm_dim, num_queries=64):
        super().__init__()
        self.num_queries = num_queries

        # Input Projection
        # Projects EEG features to match the LLM dimension before attention
        self.input_proj = nn.Linear(eeg_dim, llm_dim)

        # LEarnable Queries
        # Vectors that network learns to optimzie - buckets that gather specific information from the EEG signal
        self.query_tokens = nn.Parameter(torch.randn(1, num_queries, llm_dim))

        # Cross-Attention Mechanism
        # query = Learnable Queries (what we want to know)
        # key/value = EEG Features (the source data)
        self.cross_attn = nn.MultiheadAttention(embed_dim=llm_dim, num_heads=8, batch_first=True)

        # Normalization
        self.norm = nn.LayerNorm(llm_dim)

    def forward(self, eeg_feats):
        """
        eeg_feats: Tensor of shape [Batch, 1280, eeg_dim]
        """
        batch_size = eeg_feats.shape[0]

        # Project EEG features to LLM dimension
        # Shape: [Batch, 1280, llm_dim]
        keys_values = self.input_proj(eeg_feats)

        # Expand Learnable Queries to match the batch size
        # Shape: [Batch, num_queries, llm_dim]
        queries = self.query_tokens.expand(batch_size, -1, -1)

        # Apply Cross-Attention
        # The queries attend to the EEG sequence to extract relevant info.
        # attn_output shape: [Batch, num_queries, llm_dim]
        attn_output, _ = self.cross_attn(
            query=queries,
            key=keys_values,
            value=keys_values
        )

        return self.norm(attn_output) # Tensor of shape [Batch, num_queries, llm_dim]

## GLIM Model

**Role:**

The GLIMModel represents the complete, end-to-end architecture for the brain-to-text decoding task. It orchestrates the flow of information from biological signals to natural language generation. The design follows a transfer learning paradigm, leveraging the pre-existing linguistic knowledge of LLM while training a domain-specific adapter for brain signals.

**Mechanism:**

The architecture integrates the previously described modules with a frozen Large Language Model backbone (Flan-T5):

- _Frozen Backbone Strategy:_ The parameters of the pre-trained Flan-T5 model are frozen - made non-trainable to preserve its generalized language modeling capabilities and reduce computational costs.

- _Soft Prompt Injection:_ Instead of discrete text tokens, the aligned EEG embeddings produced by the QAligner are injected directly into the LLM's embedding space via the **inputs_embeds** interface. This guides the LLM to generate the corresponding textual description.

- _Encoder-Decoder Framework_: The entire system functions as a composite Encoder-Decoder model, where the EEGEncoder and QAligner replace the standard text encoder, enabling the direct translation of neural activity into coherent English text

In [5]:
class GLIMModel(nn.Module):
    """
    The main architecture connecting the brain signal processing (EEG)
    with the language generation capabilities (LLM).

    Pipeline:
    Raw EEG -> EEGEncoder -> QAligner (Cross-Attn) -> Frozen Flan-T5 -> Text
    """
    def __init__(self, pretrained_model_name="google/flan-t5-small", num_queries=64):
        super().__init__()

        # Load the pre-trained Language Model (Decoder)
        print(f"Loading LLM: {pretrained_model_name}")
        self.t5 = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name)

        # Detect the embedding size of the loaded LLM
        llm_dim = self.t5.config.d_model
        print(f"LLM embedding dimension: {llm_dim}")

        # FREEZE the LLM parameters
        # Train the adapter layers: Encoder + QAligner
        # The LLM stays frozen to preserve its pre-trained knowledge
        for param in self.t5.parameters():
            param.requires_grad = False

        # Initialize EEG Encoder
        self.eeg_dim = 1024
        self.eeg_encoder = EEGEncoder(d_model=self.eeg_dim)

        # Initialize Q-Aligner
        self.q_aligner = QAligner(eeg_dim=self.eeg_dim, llm_dim=llm_dim, num_queries=num_queries)

    def forward(self, eeg_batch, input_ids, attention_mask=None):
        """
        eeg_batch: Raw EEG [Batch, 128, 1280]
        input_ids: Tokenized target text [Batch, Seq_Len]
        """

        # Encode raw EEG signals into high-level features
        # Input: [Batch, 128, 1280] -> Output: [Batch, 1280, 1024]
        raw_eeg_feats = self.eeg_encoder(eeg_batch)

        # Align and Compress features using Cross-Attention
        # Input: [Batch, 1280, 1024] -> Output: [Batch, 64, llm_dim]
        aligned_embeds = self.q_aligner(raw_eeg_feats)

        # Generate Text
        # Inject the brain embeddings directly into the LLM.
        outputs = self.t5(
            inputs_embeds=aligned_embeds,   # The "Prompt" derived from Brain Signals
            labels=input_ids,               # The Target Answer (Text)
        )

        return outputs # Seq2SeqLMOutput containing loss and logits

# Training

In [None]:
def train_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0

    progress_bar = tqdm(dataloader, desc="Training")

    for batch in progress_bar:

        eeg = batch['eeg'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)

        # Reset gradients
        optimizer.zero_grad()

        # Forward Pass
        outputs = model(
            eeg_batch=eeg,
            input_ids=input_ids,
            attention_mask=attention_mask
        )

        loss = outputs.loss

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        progress_bar.set_postfix({'loss': loss.item()})

    return total_loss / len(dataloader)

def validate(model, dataloader, device):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validation"):
            eeg = batch['eeg'].to(device)
            input_ids = batch['input_ids'].to(device)

            outputs = model(eeg_batch=eeg, input_ids=input_ids)
            loss = outputs.loss
            total_loss += loss.item()

    return total_loss / len(dataloader)

In [6]:
# for colab and processed_zuco.zip on GoogleDrive
from google.colab import drive
drive.mount('/content/drive')

!cp "/content/drive/MyDrive/processed_zuco.zip" .

!unzip -q processed_zuco.zip

import os
if os.path.exists('./processed_zuco'):
    print("Dane gotowe")
else:
    print("CoÅ› poszÅ‚o nie tak")

Mounted at /content/drive
Dane gotowe


In [None]:
BATCH_SIZE = 4
LEARNING_RATE = 1e-4
EPOCHS = 10
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SAVE_DIR = "./results"

# 1. Create directory for checkpoints
if not os.path.exists(SAVE_DIR):
    os.makedirs(SAVE_DIR)

print(f"Using device: {DEVICE}")

# 2. Data Preparation
print("Loading data...")
train_dataset = ZuCoDataset(
    csv_path='./processed_zuco/train_dataset.csv',
    root_dir='./processed_zuco',
    mtv_path='./processed_zuco/mtv_dict.json',
    phase='train'
)
val_dataset = ZuCoDataset(
    csv_path='./processed_zuco/val_dataset.csv',
    root_dir='./processed_zuco',
    mtv_path='./processed_zuco/mtv_dict.json',
    phase='test'
)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

# 3. Model Initialization
model = GLIMModel(pretrained_model_name="google/flan-t5-small")
model.to(DEVICE)

# 4. Optimizer
trainable_params = [p for p in model.parameters() if p.requires_grad]
optimizer = AdamW(trainable_params, lr=LEARNING_RATE)

print(f"Number of trainable parameters: {sum(p.numel() for p in trainable_params)}")

best_val_loss = float('inf')

for epoch in range(EPOCHS):
    print(f"\n Epoch {epoch+1}/{EPOCHS}")

    # Training Phase
    train_loss = train_epoch(model, train_loader, optimizer, DEVICE)
    print(f"Average Train Loss: {train_loss:.4f}")

    # Validation Phase
    val_loss = validate(model, val_loader, DEVICE)
    print(f"Average Val Loss: {val_loss:.4f}")

    # Save the best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), os.path.join(SAVE_DIR, "best_model.pth"))
        print("Saved best_model.pth")

    torch.save(model.state_dict(), os.path.join(SAVE_DIR, "last_checkpoint.pth"))

Using device: cuda
Loading data...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

Loading LLM: google/flan-t5-small


config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/308M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

LLM embedding dimension: 512
Number of trainable parameters: 52402688

 Epoch 1/10


Training: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 908/908 [12:54<00:00,  1.17it/s, loss=14]


Average Train Loss: 21.8043


Validation: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 114/114 [00:26<00:00,  4.28it/s]


Average Val Loss: 24.4876
Saved best_model.pth

 Epoch 2/10


Training: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 908/908 [13:05<00:00,  1.16it/s, loss=7.87]


Average Train Loss: 13.4059


Validation: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 114/114 [00:26<00:00,  4.31it/s]


Average Val Loss: 8.0976
Saved best_model.pth

 Epoch 3/10


Training: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 908/908 [13:05<00:00,  1.16it/s, loss=6.58]


Average Train Loss: 6.8513


Validation: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 114/114 [00:27<00:00,  4.22it/s]


Average Val Loss: 7.6728
Saved best_model.pth

 Epoch 4/10


Training: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 908/908 [13:18<00:00,  1.14it/s, loss=5.99]


Average Train Loss: 6.3920


Validation: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 114/114 [00:27<00:00,  4.20it/s]


Average Val Loss: 6.8974
Saved best_model.pth

 Epoch 5/10


Training: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 908/908 [13:10<00:00,  1.15it/s, loss=5.78]


Average Train Loss: 6.2329


Validation: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 114/114 [00:26<00:00,  4.27it/s]


Average Val Loss: 6.9373

 Epoch 6/10


Training: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 908/908 [13:03<00:00,  1.16it/s, loss=6.33]


Average Train Loss: 6.3317


Validation: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 114/114 [00:26<00:00,  4.23it/s]


Average Val Loss: 9.4587

 Epoch 7/10


Training: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 908/908 [12:58<00:00,  1.17it/s, loss=7.15]


Average Train Loss: 6.3133


Validation: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 114/114 [00:24<00:00,  4.61it/s]


Average Val Loss: 5.9023
Saved best_model.pth

 Epoch 8/10


Training: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 908/908 [13:00<00:00,  1.16it/s, loss=6.26]


Average Train Loss: 6.1161


Validation: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 114/114 [00:25<00:00,  4.46it/s]


Average Val Loss: 8.2332

 Epoch 9/10


Training: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 908/908 [12:57<00:00,  1.17it/s, loss=6.34]


Average Train Loss: 6.0484


Validation: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 114/114 [00:24<00:00,  4.62it/s]


Average Val Loss: 5.8593
Saved best_model.pth

 Epoch 10/10


Training: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 908/908 [12:58<00:00,  1.17it/s, loss=6.15]


Average Train Loss: 6.0852


Validation: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 114/114 [00:25<00:00,  4.48it/s]


Average Val Loss: 7.0485


# Results demonstration

In [26]:
drive.mount('/content/drive', force_remount=True)
!cp "/content/drive/MyDrive/eeg/best_model.pth" .

Mounted at /content/drive


In [27]:
BATCH_SIZE = 1
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT_PATH = "./results/best_model.pth"
MODEL_NAME = "google/flan-t5-small"

def predict(model, dataloader, tokenizer, device):
    model.eval()

    with torch.no_grad():
        for i, batch in enumerate(dataloader):
            if i >= 5:
                break

            eeg = batch['eeg'].to(device)
            target_text_ids = batch['input_ids']

            # GENERATION
            # Get EEG Features from Encoder
            # Shape: [1, 1280, 1024]
            raw_eeg_feats = model.eeg_encoder(eeg)

            # Align Features using Q-Aligner
            # Shape: [1, 64, 768]
            aligned_embeds = model.q_aligner(raw_eeg_feats)

            # Generate Text using T5
            output_ids = model.t5.generate(
                inputs_embeds=aligned_embeds,
                max_length=64,
                num_beams=4,
                early_stopping=True
            )

            # DECODING
            # Convert token IDs back to human-readable string
            predicted_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
            actual_text = tokenizer.decode(target_text_ids[0], skip_special_tokens=True)

            print(f"Example {i+1}:")
            print(f"PREDICTED: {predicted_text}")
            print(f"ACTUAL:    {actual_text}")


tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

test_dataset = ZuCoDataset(
    csv_path='./processed_zuco/test_dataset.csv',
    root_dir='./processed_zuco',
    phase='test'
)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

model = GLIMModel(pretrained_model_name=MODEL_NAME)
model.to(DEVICE)

state_dict = torch.load("best_model.pth", map_location=DEVICE, weights_only=False)
model.load_state_dict(state_dict)

predict(model, test_loader, tokenizer, DEVICE)

Loading LLM: google/flan-t5-small
LLM embedding dimension: 512
Example 1:
PREDICTED: secol confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession confession
ACTUAL:    Ultimately feels emp11111ty and unsatisfying, like swallowing a Communion wafer without the wine.
Example 2:
PREDICTED: secol confession confession confession confession confession confession confession confession confession co