In [None]:
import os
import pickle
from functools import partial
from os import path
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from IPython.display import Markdown, clear_output, display
from joblib import Parallel, delayed
from matplotlib import pyplot as plt
from peft import AdaptionPromptConfig, PeftModel, get_peft_model
from torch import nn
from torch.amp import GradScaler, autocast
from torch.utils.data import DataLoader
from toto.inference.embedding import embed
from tqdm.auto import tqdm, trange
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    DataCollatorWithPadding,
)

torch.backends.cuda.enable_flash_sdp(True)
# torch.set_float32_matmul_precision('high')

# Configuration

In [None]:
SESSION_ROOT = "../adalog-sessions/phil"
TAGS = ["eeg", "writing", "intuitive"]
# SESSION_ROOT = "../adalog-sessions/Antoine"
# TAGS = ["AutomaticWriting"]

SFREQ = 256  # Hz
PRE_ONSET_DURATION = 0.25  # seconds
POST_ONSET_DURATION = 0.75  # seconds

MODEL_ID = "meta-llama/Llama-3.2-3B"
DEVICE = "cuda"
BATCH_SIZE = 7
ADAPTER_LEN = 16
ADAPTER_LAYERS = 8
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 0.01
NUM_TRAIN_CONTEXT_WORDS = 20

uname = Path(SESSION_ROOT).name

# Data loading
### Load EEG + text data

In [None]:
def get_paths(root: str, modalities: List[str], tags: List[str] = [], file_glob: str = "*") -> Dict[str, List[str]]:
    if not isinstance(modalities, list):
        modalities = [modalities]
    if not isinstance(tags, list):
        tags = [tags]

    root = Path(SESSION_ROOT)
    paths = {}
    for tag_path in root.glob(f"**/tags.csv"):
        # check if all tags are present
        ts = [t for t in pd.read_csv(tag_path)["tags"].tolist() if isinstance(t, str)]
        ts = sum(map(partial(str.split, sep=", "), ts), [])
        if len(set(tags).intersection(set(ts))) < len(tags):
            # make sure all tags are present
            continue

        # check if all modalities are present
        ms = [p.name for p in tag_path.parent.glob("*") if p.is_dir()]
        if len(set(modalities).intersection(set(ms))) < len(modalities):
            # make sure all modalities are present
            continue

        # get paths per modality
        for modality in modalities:
            if modality not in paths:
                paths[modality] = []
            paths[modality].extend(list(tag_path.parent.glob(f"{modality}/{file_glob}")))
    return {k: sorted([str(p) for p in v]) for k, v in paths.items()}


paths = get_paths(SESSION_ROOT, ["Text", "Eeg"], TAGS, file_glob="*.csv")
txt_paths = paths["Text"]
eeg_paths = paths["Eeg"]

In [None]:
def timestamp2index(ts: pd.Timestamp, df: pd.DataFrame, max_diff: float = 0.2) -> int:
    """Find the index of the closest timestamp in a DataFrame, within max_diff seconds."""
    diffs = (df["timestamp"] - ts).abs()
    min_diff = diffs.min().total_seconds()
    if min_diff > max_diff:
        raise ValueError(f"No timestamp within {max_diff} seconds (closest: {min_diff:.3f}s)")
    return diffs.argsort()[0]


def process_row(tdf_row: pd.Series, edf: pd.DataFrame) -> Optional[Tuple[str, int, pd.DataFrame]]:
    try:
        """Process a single row of text data and extract the corresponding EEG segment."""
        # get a word and its timestamp
        ts, word = tdf_row[["timestamp", "content"]]

        # find the corresponding sample indices in the EEG data
        try:
            onset_idx = timestamp2index(ts, edf)
        except ValueError as e:
            print(f"Warning: {e} for word '{word}' at {ts}. Skipping this segment.")
            return None
        pre_idx = onset_idx - int(PRE_ONSET_DURATION * SFREQ)
        post_idx = onset_idx + int(POST_ONSET_DURATION * SFREQ)

        if pre_idx < 0 or post_idx >= len(edf):
            print(
                f"Warning: Out of bounds for word '{word}' at {ts}. "
                f"Pre-index: {pre_idx}, Post-index: {post_idx}, Length: {len(edf)}. Skipping this segment."
            )
            return None

        # extract the EEG data around the word onset
        eeg_segment = edf.iloc[pre_idx:post_idx].copy()
        eeg_segment.drop(columns=["timestamp"], inplace=True)
        eeg_segment = eeg_segment.values.astype(np.float32)

        expected_length = int((PRE_ONSET_DURATION + POST_ONSET_DURATION) * SFREQ)
        assert (
            len(eeg_segment) == expected_length
        ), f"Expected EEG segment length {expected_length}, but got {len(eeg_segment)} for word '{word}' at {ts}."
        return word, onset_idx - pre_idx, eeg_segment
    except Exception as e:
        print(f"Error processing word '{tdf_row['content']}' at {tdf_row['timestamp']}: {e}")
        return None


data = []
print(f"Processing {len(txt_paths)} text + EEG files...")
for tp, ep in tqdm(zip(txt_paths, eeg_paths), total=len(txt_paths)):
    # load text and EEG data
    tdf = pd.read_csv(tp, parse_dates=["timestamp"])
    edf = pd.read_csv(ep, parse_dates=["timestamp"])

    chunks = Parallel(n_jobs=-1)(
        delayed(process_row)(row, edf) for _, row in tqdm(tdf.iterrows(), total=len(tdf), leave=False)
    )
    n = len(chunks)
    chunks = list(filter(None, chunks))
    print(f"Processed {n} words, discarded {n - len(chunks)} due to out-of-bounds or errors.")
    data.extend(chunks)

print(f"\nTotal number of valid word+EEG pairs: {len(data)}")

### Embed EEG data

In [None]:
eeg_segments_path = f"eeg_segments-{uname}-n={len(data)}.pkl"
if path.exists(eeg_segments_path):
    print(f"Loading existing EEG segments from '{eeg_segments_path}'...")
    with open(eeg_segments_path, "rb") as f:
        df = pickle.load(f)
else:
    df = pd.DataFrame(data, columns=["word", "onset_idx", "eeg"])
    df["embedding"] = None
    for i in trange(len(df), desc="Embedding EEG segments"):
        word, onset_idx, eeg_segment, _ = df.iloc[i]
        embedded = embed(eeg_segment.T)
        df.at[i, "embedding"] = embedded.mean(dim=(0, 1, 2)).numpy()

    print(f"Saving EEG segments to '{eeg_segments_path}'...")
    with open(eeg_segments_path, "wb") as f:
        pickle.dump(df, f)

# Training

In [None]:
class EEGAdapter(nn.Module):
    """Wraps a causal-LM with an AdaptionPrompt PEFT adapter and a learnable
    linear projection that turns a 768-d EEG embedding into *adapter_len* extra
    prefix tokens.

    The class handles forward logic **and** self-contained save / load routines
    so that the adapter + EEG projection can be checkpointed and restored with
    a single folder.
    """

    def __init__(self, base_model, adapter_cfg: AdaptionPromptConfig, eeg_embed_dim: int = 768):
        super().__init__()
        self.base_model = get_peft_model(base_model, adapter_cfg)
        self.hidden_size = base_model.config.hidden_size
        self.adapter_len = adapter_cfg.adapter_len

        # learnable projection from EEG embedding to adapter prefix tokens
        self.eeg_proj = nn.Linear(eeg_embed_dim, self.hidden_size * self.adapter_len)

    def forward(self, input_ids=None, attention_mask=None, eeg_embed=None, past_key_values=None, use_cache=False):
        """If *past_key_values* is *None* (first time step) we inject the EEG
        prefix. On subsequent calls we just pass the single token to the base
        model together with *past_key_values*.
        """
        if past_key_values is None:  # TODO: we probably always want to pass the eeg_embed, not only on the first step
            # project EEG embedding to prefix dimension
            assert eeg_embed is not None, "Missing EEG embedding for the first step"
            bs = input_ids.size(0)
            prefix = self.eeg_proj(eeg_embed).view(bs, self.adapter_len, self.hidden_size)

            # update base embeddings with EEG prefix
            tok_emb = self.base_model.base_model.get_input_embeddings()(input_ids)
            inputs_embeds = torch.cat([prefix, tok_emb], dim=1)

            # update attention mask to account for EEG prefix
            if attention_mask is None:
                attention_mask = torch.ones_like(input_ids)
            pref_mask = torch.ones(bs, self.adapter_len, dtype=attention_mask.dtype, device=attention_mask.device)
            attention_mask = torch.cat([pref_mask, attention_mask], dim=1)

            # base model forward pass with EEG prefix
            return self.base_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, use_cache=use_cache)
        else:
            # TODO: we probably always want to pass the eeg_embed, not only on the first step
            return self.base_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                past_key_values=past_key_values,
                use_cache=use_cache,
            )

    def save_pretrained(self, save_directory: str):
        os.makedirs(save_directory, exist_ok=True)
        # adapter weights + config
        self.base_model.save_pretrained(save_directory)
        # EEG projection
        torch.save(self.eeg_proj.state_dict(), os.path.join(save_directory, "eeg_proj.bin"))

    @classmethod
    def from_pretrained(cls, base_model, load_directory: str, eeg_embed_dim: int = 768):
        # attach PEFT adapter from load_directory to the base model
        peft_model = PeftModel.from_pretrained(base_model, load_directory)
        peft_cfg = peft_model.peft_config["default"]

        # instantiate and configure EEGAdapter
        instance = cls.__new__(cls)
        nn.Module.__init__(instance)
        instance.base_model = peft_model
        instance.hidden_size = base_model.config.hidden_size
        instance.adapter_len = peft_cfg.adapter_len
        instance.eeg_proj = nn.Linear(eeg_embed_dim, instance.hidden_size * instance.adapter_len)

        # apply state dict to the EEG projection layer
        eeg_state = torch.load(os.path.join(load_directory, "eeg_proj.bin"), map_location="cpu")
        instance.eeg_proj.load_state_dict(eeg_state)
        return instance


bnb_cfg = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

In [None]:
class EEGLMDataset(torch.utils.data.Dataset):
    """
    Returns only input_ids + per-token EEG; labels will be built in the collator.
    `df` must expose columns 'word' and 'embedding' (768-d list).
    """

    def __init__(self, df, model_id, ctx_words=30):
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.padding_side = "left"

        self.words = df["word"].tolist()
        self.eegs = df["embedding"].tolist()
        self.ctx = ctx_words
        self.items = []

        for st in trange(0, len(self.words) - ctx_words, desc="Building dataset"):
            w_chunk = [" ".join(self.words[st : st + ctx_words])]
            e_chunk = self.eegs[st : st + ctx_words]

            enc = self.tokenizer(w_chunk, is_split_into_words=True, add_special_tokens=False, return_attention_mask=False)
            w_ids = enc.word_ids()
            eeg_seq = np.array([e_chunk[w] for w in w_ids])  # repeat per sub-token

            self.items.append(
                {
                    "input_ids": torch.tensor(enc["input_ids"], dtype=torch.long),
                    "eeg_embed": torch.from_numpy(eeg_seq.astype(np.float32)),
                }
            )

    def __len__(self):
        return len(self.items)

    def __getitem__(self, i):
        return self.items[i]


class EEGCollator:
    def __init__(self, tokenizer):
        self.pad_text = DataCollatorWithPadding(tokenizer, return_tensors="pt")

    def __call__(self, batch):
        txt = self.pad_text([{"input_ids": b["input_ids"]} for b in batch])
        L = txt["input_ids"].shape[1]
        D = batch[0]["eeg_embed"].shape[-1]

        txt["labels"] = txt["input_ids"].clone()  # ← build labels

        eeg = torch.zeros(len(batch), L, D)
        for i, b_ in enumerate(batch):
            n = b_["eeg_embed"].shape[0]
            eeg[i, -n:, :] = b_["eeg_embed"]  # left-pad
        txt["eeg_embed"] = eeg
        return txt

In [None]:
# dataset
ds = EEGLMDataset(df, model_id=MODEL_ID, ctx_words=NUM_TRAIN_CONTEXT_WORDS)
dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=EEGCollator(ds.tokenizer))

# base LLM
base_model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map=DEVICE, quantization_config=bnb_cfg)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_toke = tokenizer.eos_token
tokenizer.padding_side = "left"

# EEG adapter
adapter_cfg = AdaptionPromptConfig(task_type="CAUSAL_LM", adapter_len=ADAPTER_LEN, adapter_layers=ADAPTER_LAYERS)
eeg_adapter = EEGAdapter(base_model, adapter_cfg).to(DEVICE)
# eeg_adapter.compile() # TODO: currently we hit the recompile limit due to the growing past_key_values, maybe increasing the recompile limit is enough

# optimizer
optimizer = torch.optim.AdamW(eeg_adapter.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scaler = GradScaler()
losses = []

In [None]:
eeg_adapter.train()

try:
    for epoch in range(1):
        epoch_pbar = tqdm(dl, desc=f"epoch {epoch}")
        for i, batch in enumerate(epoch_pbar):
            eeg_seq = batch.pop("eeg_embed")[:, :-1].to(DEVICE)
            input_ids = batch["input_ids"].to(DEVICE)
            labels = batch["labels"][:, 1:].to(DEVICE)

            past, loss = None, 0.0
            eeg_adapter.zero_grad(set_to_none=True)

            for t in trange(input_ids.size(1) - 1, leave=False, desc="steps"):
                with autocast(device_type=DEVICE):
                    out = eeg_adapter(
                        input_ids=input_ids[:, t : t + 1],
                        eeg_embed=(
                            eeg_seq[:, t, :] if past is None else None
                        ),  # TODO: verify whether we only want to pass the eeg_vec on the first step
                        past_key_values=past,
                        use_cache=True,
                    )
                    past = out.past_key_values

                    logits_step = out.logits[:, -1, :]
                    loss_step = nn.functional.cross_entropy(logits_step, labels[:, t])
                    loss += loss_step / (input_ids.size(1) - 1)

            epoch_pbar.set_postfix({"loss": loss.item()})
            losses.append(loss.item())

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
except KeyboardInterrupt:
    print("Training interrupted.")


plt.plot(losses)
plt.xlabel("Training Steps")
plt.ylabel("Loss")
plt.show()

### Save model checkpoint

In [None]:
ckpt_dir = f"eeg_llama_adapter-{uname}"
eeg_adapter.save_pretrained(ckpt_dir)
tokenizer.save_pretrained(ckpt_dir)

# Inference

### Load model checkpoint

In [None]:
ckpt_dir = f"eeg_llama_adapter-{uname}"
base_model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map=DEVICE, quantization_config=bnb_cfg)
eeg_adapter = EEGAdapter.from_pretrained(base_model, ckpt_dir).eval().to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained(ckpt_dir, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token

In [None]:
eos_id = tokenizer.eos_token_id
space_id = tokenizer(" ", add_special_tokens=False)["input_ids"][0]
eeg_adapter.eval()


def nucleus_sampling(logits, top_p=0.95, temperature=1.0):
    """Apply nucleus (top-p) sampling to logits."""
    probs = torch.softmax(logits / temperature, dim=-1)
    # Sort probabilities in descending order
    sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
    # Calculate cumulative probabilities
    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
    # Create mask for tokens to keep (cumulative prob <= top_p)
    sorted_indices_to_remove = cumulative_probs > top_p
    # Always keep at least the top token
    sorted_indices_to_remove[..., 0] = False
    # Zero out probabilities for tokens to remove
    sorted_probs[sorted_indices_to_remove] = 0.0
    # Renormalize
    sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)
    # Sample from the filtered distribution
    sampled_sorted_indices = torch.multinomial(sorted_probs, 1)
    # Map back to original indices
    next_token = sorted_indices.gather(-1, sampled_sorted_indices)
    return next_token.squeeze(-1)


@torch.no_grad()
def generate_word_from_eeg(
    eeg_vec: torch.Tensor,
    max_tokens_per_word: int = 20,
    temperature: float = 1.0,
    top_p: float = 0.95,
    context_ids: list = None,
):
    """
    Generate one word conditioned on a 768-d EEG latent vector.

    Args:
        eeg_vec: EEG embedding tensor of shape (768,) or (1, 768)
        max_tokens_per_word: Maximum tokens to generate before stopping
        temperature: Sampling temperature
        top_p: Nucleus sampling threshold
        context_size: Maximum context length in tokens
        context_ids: Previous token IDs to continue from (for multi-word generation)

    Returns:
        tuple: (word_string, updated_context_ids)
    """
    # ensure proper tensor shape and device
    if eeg_vec.ndim == 1:
        eeg_vec = eeg_vec.unsqueeze(0)  # (1, 768)
    eeg_vec = eeg_vec.to(DEVICE)

    # initialize context
    if context_ids is None:
        context_ids = [eos_id]
    past_key_values = None

    for _ in range(max_tokens_per_word):
        # forward pass
        with autocast(device_type=DEVICE):
            outputs = eeg_adapter(
                input_ids=torch.tensor([[context_ids[-1]]], device=DEVICE).long(),
                eeg_embed=(
                    eeg_vec if past_key_values is None else None
                ),  # TODO: verify whether we only want to pass the eeg_vec on the first step
                use_cache=True,
                past_key_values=past_key_values,  # TODO: potentially limit the context size
            )
        past_key_values = outputs.past_key_values

        # sample next token
        logits = outputs.logits[:, -1, :]  # (batch_size, vocab_size)
        next_token_id = nucleus_sampling(logits, top_p=top_p, temperature=temperature).item()
        context_ids.append(next_token_id)

        # stop if we hit a space (end of word) or EOS
        if next_token_id == space_id or next_token_id == eos_id:
            break
    return context_ids


def generate_text_from_eeg_sequence(embeddings, preprompt=None, max_words=None, **generation_kwargs):
    """
    Generate text from a sequence of EEG embeddings.

    Args:
        embeddings: List or array of EEG embedding vectors
        max_words: Maximum number of words to generate (None for all embeddings)
        **generation_kwargs: Arguments passed to generate_word_from_eeg

    Returns:
        str: Generated text
    """
    if max_words is not None:
        embeddings = embeddings[:max_words]

    if preprompt is None:
        context_ids = None
    else:
        context_ids = tokenizer(preprompt, add_special_tokens=False)["input_ids"]

    for z in embeddings:
        if not isinstance(z, torch.Tensor):
            z = torch.from_numpy(z)
        context_ids = generate_word_from_eeg(z, context_ids=context_ids, **generation_kwargs)

        # update output stream
        clear_output(wait=True)
        display(Markdown(tokenizer.decode(context_ids, skip_special_tokens=True)))
    return tokenizer.decode(context_ids, skip_special_tokens=True)


# Generate from first 10 embeddings with progress display
generated_text = generate_text_from_eeg_sequence(
    df["embedding"].tolist()[:10],
    preprompt=None,
)