In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install gensim

Collecting gensim
  Downloading gensim-4.4.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (8.4 kB)
Downloading gensim-4.4.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (27.9 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/27.9 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.4/27.9 MB[0m [31m103.0 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━[0m [32m13.6/27.9 MB[0m [31m310.7 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━[0m [32m15.8/27.9 MB[0m [31m176.8 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━[0m [32m26.3/27.9 MB[0m [31m326.2 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m27.9/27.9 MB[0m [31m317.3 MB/s[0m eta [36m0:

In [3]:
!pip install tensorboard
import gensim.downloader as api
import os
import pandas as pd
from torch.utils.data import Dataset
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from pathlib import Path
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

Collecting tensorboard
  Downloading tensorboard-2.20.0-py3-none-any.whl.metadata (1.8 kB)
Collecting tensorboard-data-server<0.8.0,>=0.7.0 (from tensorboard)
  Downloading tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl.metadata (1.1 kB)
Collecting werkzeug>=1.0.1 (from tensorboard)
  Downloading werkzeug-3.1.4-py3-none-any.whl.metadata (4.0 kB)
Downloading tensorboard-2.20.0-py3-none-any.whl (5.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.5/5.5 MB[0m [31m106.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl (6.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.6/6.6 MB[0m [31m72.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading werkzeug-3.1.4-py3-none-any.whl (224 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m225.0/225.0 kB[0m [31m24.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: werkzeug, tensorboard-data-server

# Dataset

In [4]:
class CustomEEGDataset(Dataset):
    def __init__(self, sentence_mapping, eeg_path, pad_len=5500, dtype=torch.float32):
        self.records = pd.read_csv(sentence_mapping)
        self.records = self.records.reset_index(drop=True)
        self.eeg_path = Path(eeg_path)
        self.pad_len = int(pad_len)
        self.dtype = dtype

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx: int):
        row = self.records.iloc[idx]
        uuid = row["UniqueID"]
        sentence = row["Content"]

        eeg = self._load_and_pad_eeg(uuid)  # [105, T] -> [105, pad_len]
        # Return tensors without batch dimension; DataLoader will stack them
        return {
            "uuid": uuid,
            "sentence": sentence,        # stays as string; DataLoader will make a list of strings
            "eeg": eeg                   # torch.FloatTensor [105, pad_len]
        }

    def _load_and_pad_eeg(self, uuid: str) -> torch.Tensor:
        path = self.eeg_path / f"{uuid}.csv"
        if not path.exists():
            raise FileNotFoundError(f"EEG file not found: {path}")

        # Expect shape [channels=105, timesteps]; adjust if your CSV is transposed
        df = pd.read_csv(path)
        arr = df.values.astype(np.float32)

        # If CSV is [T, 105] instead of [105, T], transpose:
        if arr.shape[0] == 5500 and arr.shape[1] == 105:  # heuristic; change to your rule
            arr = arr.T

        # Pad or crop to pad_len along time axis (axis=1)
        c, t = arr.shape
        if c != 105:
            raise ValueError(f"Expected 105 channels, got {c} in {path}")

        if t < self.pad_len:
            pad = np.zeros((c, self.pad_len - t), dtype=np.float32)
            arr = np.concatenate([arr, pad], axis=1)
        elif t > self.pad_len:
            arr = arr[:, :self.pad_len]

        return torch.from_numpy(arr).to(self.dtype)  # [105, pad_len]


# ENCODER - EEG to CODEX

In [5]:
class ConvolutionModel(nn.Module):
    '''
    Input : single sentence EEG raw input of (105 channels, 5500 timestamps)
    Output : single sentence of 57 features of 512 dimensions embedding each
    '''
    def __init__(self):
        super().__init__()
        self.convolutional_model = nn.Sequential(
            nn.Conv1d(in_channels=105, kernel_size=10, out_channels=64, stride=3),
            nn.Conv1d(in_channels=64,  kernel_size=3,  out_channels=128, stride=2),
            nn.Conv1d(in_channels=128, kernel_size=3,  out_channels=256, stride=2),
            nn.Conv1d(in_channels=256, kernel_size=3,  out_channels=512, stride=2),
            nn.Conv1d(in_channels=512, kernel_size=2,  out_channels=512, stride=2),
            nn.Conv1d(in_channels=512, kernel_size=2,  out_channels=512, stride=2),
        )

    def forward(self, x):
        # Input shape expected: [batch_size, channels, timestamps]
        op = self.convolutional_model(x)
        # Output shape is [batch_size, d_model, num_tokens] -> transpose to [batch_size, num_tokens, d_model]
        return op.permute(0, 2, 1)

In [6]:
class AttentionBlock(nn.Module):
    """
    Single-head scaled dot-product attention using nn.Linear layers.
    """
    def __init__(self, d_model, d_k):
        super().__init__()
        self.d_k = d_k
        self.W_Q = nn.Linear(d_model, d_k, bias=False)
        self.W_K = nn.Linear(d_model, d_k, bias=False)
        self.W_V = nn.Linear(d_model, d_k, bias=False)

    def forward(self, x):
        Q = self.W_Q(x)
        K = self.W_K(x)
        V = self.W_V(x)

        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn_weights = F.softmax(scores, dim=-1)
        output = attn_weights @ V
        return output

In [7]:
class MultiHeadAttentionBlock(nn.Module):
    def __init__(self, d_model, heads=8):
        super().__init__()
        self.heads = heads
        self.d_model = d_model
        # d_k is the dimension of each head. Must be divisible by d_model
        self.d_k = d_model // heads

        # Use nn.ModuleList to register attention blocks with PyTorch
        self.attentionBlocks = nn.ModuleList([AttentionBlock(d_model, self.d_k) for _ in range(self.heads)])

        # Final linear layer to project concatenated heads back to d_model dimension
        self.output_linear = nn.Linear(d_model, d_model)

    def forward(self, x):
        head_outputs = []
        for block in self.attentionBlocks:
            # Assuming x is [batch_size, num_tokens, d_model]
            out_h = block(x) # [batch_size, num_tokens, d_k]
            head_outputs.append(out_h)

        # Concatenate along the last dimension
        total = torch.cat(head_outputs, dim=-1) # [batch_size, num_tokens, d_model]

        # Apply final linear layer
        output = self.output_linear(total)
        return output

In [8]:
class Encoder(nn.Module):
    def __init__(self, d_model=512, heads=8, beta=0.2):
        super().__init__()
        self.conv = ConvolutionModel()
        self.mha = MultiHeadAttentionBlock(d_model=d_model, heads=heads)
        self.codex = nn.Embedding(num_embeddings=2048, embedding_dim=512)
        self.words = self.codex.weight  # codebook
        self.beta = beta               # commitment weight

    def forward(self, x):
        """
        x: [batch_size, 105, 5500]

        returns:
            z_q_st: [batch_size, 57, 512]  (quantized, straight-through)
            vq_loss: scalar (codebook + commitment terms)
            indices: [batch_size, 57]  (chosen code indices)
        """
        # conv_output: [batch_size, 57, 512]
        conv_output = self.conv(x)

        # attn_output: [batch_size, 57, 512] (this is z_c in VQ-VAE terms)
        attn_output = self.mha(conv_output)

        # --- vector quantization using your codex / words ---
        B, L, D = attn_output.shape  # [B, 57, 512]

        codebook = self.words.unsqueeze(0)
        #atthention output of size [ batch_size, 57, 512 ] but its' context aware now
        distances = torch.cdist(attn_output, codebook) # distances of each of 57 EEG feature with the 2048 words in codex book

        # indices of the least-distance codex word for each EEG feature
        # indices: [B, 57]
        indices = torch.argmin(distances, dim=-1)

        z_q = self.words[indices]

        # --- VQ codebook + commitment losses ---

        # codebook loss: || sg[attn_output] - z_q ||^2  (update codex/words)
        codebook_loss = F.mse_loss(z_q, attn_output.detach())

        # commitment loss: || attn_output - sg[z_q] ||^2  (update encoder)
        commitment_loss = F.mse_loss(attn_output, z_q.detach())

        vq_loss = codebook_loss + self.beta * commitment_loss

        # straight-through: forward uses z_q, grads go to attn_output
        z_q_st = attn_output + (z_q - attn_output).detach()

        return z_q_st, vq_loss, indices

# DECODER - self reconstruction

In [9]:
class DeConvolutionModel(nn.Module):
    """
    Input  (to this block): [B, 512, 57]
    Output (from this block): [B, 105, 5500]
    """
    def __init__(self):
        super().__init__()
        self.deconv = nn.Sequential(
            nn.ConvTranspose1d(in_channels=512, out_channels=512, kernel_size=2, stride=2),  # 57 -> 114
            nn.ConvTranspose1d(in_channels=512, out_channels=512, kernel_size=2, stride=2),  # 114 -> 228
            nn.ConvTranspose1d(in_channels=512, out_channels=256, kernel_size=3, stride=2),  # 228 -> 457
            nn.ConvTranspose1d(in_channels=256, out_channels=128, kernel_size=3, stride=2),  # 457 -> 915
            nn.ConvTranspose1d(in_channels=128, out_channels=64,  kernel_size=3, stride=2),  # 915 -> 1831
            nn.ConvTranspose1d(in_channels=64,  out_channels=105, kernel_size=10, stride=3), # 1831 -> 5500
        )

    def forward(self, x):
        return self.deconv(x)  # [B, 105, 5500]

class Decoder(nn.Module):
    def __init__(self, d_model=512, heads=8):
        super().__init__()
        self.mha = MultiHeadAttentionBlock(d_model=d_model, heads=heads)  # keeps [B, 57, 512]
        self.deconv = DeConvolutionModel()

    def forward(self, x):
        # x: [B, 57, 512]
        attn_out = self.mha(x)                 # [B, 57, 512]
        attn_out = attn_out.permute(0, 2, 1)   # -> [B, 512, 57]  (channels = 512, length = 57)
        out = self.deconv(attn_out)            # -> [B, 105, 5500]
        return out


# WORD2VEC - converting sentences to embeddings

In [10]:
class Word2VecModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.w2v = api.load('word2vec-google-news-300')
        self.linear = nn.Linear(300, 512)
        self.unk_vector = torch.zeros(300)  # vector for unknown words

    def forward(self, sentence_tokens):
        """
        sentence_tokens: list of tokens (strings)
        returns: [num_words, 512] tensor
        """
        device = self.linear.weight.device

        vectors = []
        for word in sentence_tokens:
            if word in self.w2v:
                vec = torch.tensor(self.w2v[word], dtype=torch.float32, device=device)
            else:
                vec = self.unk_vector.to(device)
            vectors.append(vec)

        emb = torch.stack(vectors)              # [num_words, 300]
        enh_emb = self.linear(emb)              # [num_words, 512]
        # KEY STEP: Interpolate text to match EEG length
        z_t_interpolated = F.interpolate(
            enh_emb.transpose(0, 1).unsqueeze(0),  # [1, 512, num_words]
            size=57,                            # Match EEG length
            mode='linear',
            align_corners=True
        ).squeeze(0).transpose(0, 1)  # [57, 512]

        return z_t_interpolated


# Training Loop

In [11]:
def nt_xent_loss(out, target, temperature=0.07):
  zq = F.normalize(out,dim=1)
  zt = F.normalize(target,dim=1)

  logits = zq @ zt.T / temperature #Build similarity matrix. Smaller τ → make similarities more “peaked” → harder classification → stronger gradients.

  labels = torch.arange(logits.shape[0], device=logits.device)
  #For row 0: the correct target index is 0 → wants softmax to choose column 0 and so on for all the rows

  loss = F.cross_entropy(logits, labels)
  return loss

In [12]:
def save_checkpoint(epoch, encoder, decoder, w2v, optimizer, loss, path):
    state = {
        "epoch": epoch,
        "encoder_state": encoder.state_dict(),
        "decoder_state": decoder.state_dict(),
        "w2v_state": w2v.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "loss": loss,
    }
    torch.save(state, path)
    print(f"✔ Saved checkpoint at epoch {epoch} to {path}")

In [13]:
def load_checkpoint(path, encoder, decoder, w2v, optimizer=None):
    checkpoint = torch.load(path, map_location="cpu")

    encoder.load_state_dict(checkpoint["encoder_state"])
    decoder.load_state_dict(checkpoint["decoder_state"])
    w2v.load_state_dict(checkpoint["w2v_state"])

    if optimizer is not None:
        optimizer.load_state_dict(checkpoint["optimizer_state"])

    print(f"✔ Loaded checkpoint from {path}, epoch {checkpoint['epoch']}")
    return checkpoint["epoch"], checkpoint["loss"]


In [14]:
def train_one_epoch(epoch_index, tb_writer):
    encoder.train()
    decoder.train()

    running_loss = 0.0

    for i,data in enumerate(dataloader):
      """
      every data is {
            "uuid": uuid,
            "sentence": sentence,        # stays as string; DataLoader will make a list of strings
            "eeg": eeg                   # torch.FloatTensor [105, pad_len]
        }
      """
      sentence = data["sentence"]
      eeg = data["eeg"].to(device)

      z_q, vq_loss, indices = encoder(eeg)  # z_q: [B, 57, 512]

      predictions = decoder(z_q)            # same as before
      reconstruction_loss = F.mse_loss(predictions, eeg)

      # sentence is a list of B strings
      # Convert each sentence to tokens; here I just use .split() as a simple tokenizer
      text_embeddings_list = [w2v(s.split()) for s in sentence]  # each: [57, 512]

      # Stack into [B, 57, 512]
      text_embeddings = torch.stack(text_embeddings_list, dim=0)  # [B, 57, 512]


      # if text_embeddings is [B, L_text, 512], pool to [B, 512]
      z_q_pooled = z_q.mean(dim=1)                # [B, 512]
      text_pooled = text_embeddings.mean(dim=1)   # [B, 512]

      contrastive_loss = nt_xent_loss(z_q_pooled, text_pooled)

      L_wave = reconstruction_loss + vq_loss
      loss = L_wave + contrastive_loss            # or + alpha * contrastive_loss


      optimizer.zero_grad()
      loss.backward()
      optimizer.step()


      running_loss += loss.item()
      if writer is not None and (i % 50 == 0):
          writer.add_scalar("train/loss_total", loss.item(), epoch_index * len(dataloader) + i)
          writer.add_scalar("train/loss_recon",  reconstruction_loss.item(), epoch_index * len(dataloader) + i)
          writer.add_scalar("train/loss_contrast", contrastive_loss.item(), epoch_index * len(dataloader) + i)

    return running_loss / max(1, len(dataloader))

In [None]:
sentence_mapping = '/content/drive/MyDrive/EEG_dataset/dataset/sentence_mapping.csv'
eeg_path = '/content/drive/MyDrive/EEG_dataset/dataset'
writer = SummaryWriter('runs/eeg_training_experiment_1')

# hyperparams
BATCH_SIZE = 32
EPOCHS = 50
LR = 1e-3
ALPHA = 1.0   # weight for contrastive loss

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# models
encoder = Encoder().to(device)
decoder = Decoder().to(device)
w2v = Word2VecModel().to(device)

# dataset & dataloader
dataset = CustomEEGDataset(
    sentence_mapping=sentence_mapping,
    eeg_path=eeg_path,
)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# optimizer
optimizer = torch.optim.Adam(
    list(encoder.parameters()) +
    list(decoder.parameters()) +
    list(w2v.parameters()),
    lr=LR
)


# OPTIONAL: resume from an existing checkpoint
resume = True
resume_path = "/content/drive/MyDrive/EEG_dataset/checkpoints/epoch_26.pt"  # example

start_epoch = 0
if resume:
    start_epoch, last_loss = load_checkpoint(
        path=resume_path,
        encoder=encoder,
        decoder=decoder,
        w2v=w2v,
        optimizer=optimizer,
    )
    print(f"Resuming from epoch {start_epoch} with loss = {last_loss:.4f}")

for epoch in range(start_epoch + 1, EPOCHS):
    avg_loss = train_one_epoch(epoch_index=epoch, tb_writer=writer)

    print(f"Epoch [{epoch + 1}/{EPOCHS}] - loss: {avg_loss:.4f}")
    save_checkpoint(
        epoch=epoch,
        encoder=encoder,
        decoder=decoder,
        w2v=w2v,
        optimizer=optimizer,
        loss=avg_loss,
        path=f"/content/drive/MyDrive/EEG_dataset/checkpoints/epoch_{epoch}.pt"
    )


✔ Loaded checkpoint from /content/drive/MyDrive/EEG_dataset/checkpoints/epoch_26.pt, epoch 26
Resuming from epoch 26 with loss = 23462915.9986
