In [4]:
%load_ext autoreload
%autoreload 2
import pandas as pd
import torch
from torch.utils.data import Subset, ConcatDataset
from pathlib import Path
import numpy as np
from lib import *
from data_utils import combine_fixed_length, decollate_tensor
from ctcdecode import CTCBeamDecoder
from read_eeg import EEGDataset
from data_utils import TextTransform, TextTransformOrig
import tqdm
import jiwer
import gc
torch.cuda.empty_cache()
gc.collect()
device = "cuda" if torch.cuda.is_available() else "cpu"
base_dir = Path("/ocean/projects/cis240129p/shared/data/eeg_alice")
subjects_used = ["S04"]  # exclude 'S05' - less channels # , "S13", "S19"
transform = TextTransform()
# ds = BrennanDataset(
#     root_dir=base_dir,
#     phoneme_dir=base_dir / "phonemes",
#     idx="S04",
#     text_transform=transform,
#     phoneme_dict_path=base_dir / "phoneme_dict.txt",
#     debug=False,
# )

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
subjects = [
    "S01",
    "S03",
    "S04",
    # "S05", missing one channel
    "S08",
    "S11",
    "S12",
    "S13",
    "S16",
    "S17",
    "S18",
    "S19",
    "S22",
    "S26",
    "S36",
    "S37",
    # "S38", missing one channel
    "S40",
    "S41",
    "S42",
    "S44",
    "S48",
]
# trainset, devset, testset = EEGDataset.from_subjects(
#     subjects=subjects,
#     # generated_subjects=generated_subjects,
#     base_dir=base_dir,
#     train_ratio=0.8,
#     dev_ratio=0.1,
#     test_ratio=0.1,
# )

In [5]:
from collections import defaultdict

all_data = []  # (subject_id, idx, sample)
sentence_to_indices = defaultdict(list)  # sentence -> list of indices in all_data
text_transform = TextTransformOrig()
preload_dataset = {}

for subject_id, subject in enumerate(subjects):

    dataset = BrennanDataset(
        text_transform=text_transform,
        root_dir=base_dir,
        phoneme_dir=base_dir / "phonemes",
        idx=subject,
        phoneme_dict_path=base_dir / "phoneme_dict.txt",
    )
    preload_dataset[subject_id] = dataset
    for idx, sample in enumerate(dataset):
        sentence = sample["label"]
        all_data.append((subject_id, idx, sample))
        sentence_to_indices[sentence].append(len(all_data) - 1)

Extracting parameters from /ocean/projects/cis240129p/shared/data/eeg_alice/S01.vhdr...
Setting channel info structure...
Reading 0 ... 366524  =      0.000 ...   733.048 secs...
Extracting parameters from /ocean/projects/cis240129p/shared/data/eeg_alice/S03.vhdr...
Setting channel info structure...
Reading 0 ... 367299  =      0.000 ...   734.598 secs...
Extracting parameters from /ocean/projects/cis240129p/shared/data/eeg_alice/S04.vhdr...
Setting channel info structure...
Reading 0 ... 368449  =      0.000 ...   736.898 secs...
Extracting parameters from /ocean/projects/cis240129p/shared/data/eeg_alice/S08.vhdr...
Setting channel info structure...
Reading 0 ... 369574  =      0.000 ...   739.148 secs...
Extracting parameters from /ocean/projects/cis240129p/shared/data/eeg_alice/S11.vhdr...
Setting channel info structure...
Reading 0 ... 369574  =      0.000 ...   739.148 secs...
Extracting parameters from /ocean/projects/cis240129p/shared/data/eeg_alice/S12.vhdr...
Setting channel i

In [30]:
len(train_indices[1])

84

In [28]:
from sklearn.model_selection import train_test_split

train_indices, dev_indices, test_indices = (
    defaultdict(list),
    defaultdict(list),
    defaultdict(list),
)
for sentence, indices in sentence_to_indices.items():
    # stratify by sentences
    if len(indices) < 10:
        continue
    train_idx, test_idx = train_test_split(indices, test_size=0.2, random_state=1)
    dev_idx, test_idx = train_test_split(test_idx, test_size=0.5, random_state=1)
    for idx in train_idx:
        subject_id, _, _ = all_data[idx]
        train_indices[subject_id].append(idx)
    for idx in dev_idx:
        subject_id, _, _ = all_data[idx]
        dev_indices[subject_id].append(idx)
    for idx in test_idx:
        subject_id, _, _ = all_data[idx]
        test_indices[subject_id].append(idx)
trainsets, devsets, testsets = [], [], []
for subject_id, subject in enumerate(subjects):
    trainsets.append(Subset(preload_dataset[subject_id], train_indices[subject_id]))
    devsets.append(Subset(preload_dataset[subject_id], dev_indices[subject_id]))
    testsets.append(Subset(preload_dataset[subject_id], test_indices[subject_id]))

In [None]:
sentence_freq = {}
for i in range(len(trainset)):
    sentence = trainset[i]["label"]
    if sentence not in sentence_freq:
        sentence_freq[sentence] = {"train": 0, "dev": 0}
    sentence_freq[sentence]["train"] += 1

for i in range(len(devset)):
    sentence = devset[i]["label"]
    if sentence not in sentence_freq:
        sentence_freq[sentence] = {"train": 0, "dev": 0}
    sentence_freq[sentence]["dev"] += 1

for i in range(len(testset)):
    sentence = testset[i]["label"]
    if sentence not in sentence_freq:
        sentence_freq[sentence] = {"train": 0, "dev": 0}
    sentence_freq[sentence]["test"] = 1

In [9]:
# convert to dataframe
df = pd.DataFrame(sentence_freq).T
df.head()

Unnamed: 0,train,dev
THE POOR LITTLE THING SAT DOWN AND CRIED COME THERE S NO USE CRYING LIKE THAT SAID ALICE TO HERSELF RATHER SHARPLY I ADVISE YOU TO LEAVE OFF THIS MINUTE SHE GENERALLY GAVE HERSELF VERY GOOD ADVICE THOUGH SHE VERY SELDOM FOLLOWED IT AND SOMETIMES SHE SCOLDED HERSELF SO SEVERELY AS TO BRING TEARS INTO HER EYES,20,0
AND THE WHITE RABBIT WAS STILL IN SIGHT HURRYING DOWN,20,0
BUT ALAS EITHER THE LOCKS WERE TOO LARGE OR THE KEY WAS TOO SMALL BUT AT ANY RATE IT WOULD NOT OPEN ANY OF THEM HOWEVER ON THE SECOND TIME ROUND SHE CAME UPON A LOW CURTAIN SHE HAD NOT NOTICED BEFORE,18,2
THE RABBIT HOLE WENT STRAIGHT ON LIKE A TUNNEL FOR SOME WAY AND THEN DIPPED SUDDENLY DOWN SO SUDDENLY THAT ALICE HAD NOT A MOMENT TO THINK ABOUT STOPPING HERSELF BEFORE SHE FOUND HERSELF FALLING DOWN A VERY DEEP WELL,20,0
PLEASE MAAM IS THIS NEW ZEALAND OR AUSTRALIA AND SHE TRIED TO CURTSEY AS SHE SPOKE FANCY CURTSEYING AS YOU RE FALLING THROUGH THE AIR DO YOU THINK YOU COULD MANAGE IT AND WHAT AN IGNORANT LITTLE GIRL SHE LL THINK ME FOR ASKING NO,18,0


In [14]:
df.shape

(83, 2)

In [11]:
print(f"Total number of unique sentences: {len(df)}")
print(f"# only in trainset: {len(df[df['dev'] == 0])}")
print(f"# only in devset: {len(df[df['train'] == 0])}")
print(f"# in both train and devset: {len(df[(df['train'] > 0) & (df['dev'] > 0)])}")

Total number of unique sentences: 83
# only in trainset: 70
# only in devset: 7
# in both train and devset: 6


In [None]:
# Count how many times sentences only appear in trainset times the frequency
# of the sentence in the trainset

df[df["test"] == 0]["train"].sum()

KeyError: 'test'

In [24]:
df[df["train"] == 0]["dev"].sum(), df["dev"].sum()

(116, 160)

In [None]:
len(trainset), len(devset), len(testset)


KeyboardInterrupt



In [None]:
df[df["train"] == 0]

Unnamed: 0,train,dev
SO ALICE SOON BEGAN TALKING AGAIN DINAH LL MISS ME VERY MUCH TONIGHT I SHOULD THINK DINAH WAS THE CAT,0,18
EITHER THE WELL WAS VERY DEEP OR SHE FELL VERY SLOWLY FOR SHE HAD PLENTY OF TIME AS SHE WENT DOWN TO LOOK ABOUT HER AND TO WONDER WHAT WAS GOING TO HAPPEN NEXT,0,18
FOR FEAR OF KILLING SOMEBODY SO SHE MANAGED TO PUT IT INTO ONE OF THE CUPBOARDS AS SHE FELL PAST IT,0,20
DINAH MY DEAR I WISH YOU WERE DOWN HERE WITH ME THERE ARE NO MICE IN THE AIR I M AFRAID BUT YOU MIGHT CATCH A BAT AND THAT S VERY LIKE A MOUSE,0,18
MUST BE GETTING SOMEWHERE NEAR THE CENTER OF THE EARTH,0,20
IN ANOTHER MOMENT DOWN WENT ALICE AFTER IT NEVER ONCE CONSIDERING HOW IN THE WORLD SHE WAS TO GET OUT AGAIN,0,20
AND SHE TRIED TO FANCY WHAT THE FLAME OF A CANDLE IS LIKE AFTER THE CANDLE IS BLOWN OUT,0,2


In [None]:
df[(df["train"] > 0) & (df["dev"] > 0)]

Unnamed: 0,train,dev
BUT ALAS EITHER THE LOCKS WERE TOO LARGE OR THE KEY WAS TOO SMALL BUT AT ANY RATE IT WOULD NOT OPEN ANY OF THEM HOWEVER ON THE SECOND TIME ROUND SHE CAME UPON A LOW CURTAIN SHE HAD NOT NOTICED BEFORE,18,2
SHE FOUND HERSELF IN A LONG LOW HALL WHICH WAS LIT UP BY A ROW OF LAMPS HANGING FROM THE ROOF,18,2
IT LL NEVER DO TO ASK PERHAPS I SHALL SEE IT WRITTEN UP SOMEWHERE,18,2
FIRST HOWEVER SHE WAITED FOR A FEW MINUTES TO SEE IF SHE WAS GOING TO SHRINK ANY FURTHER,18,2
AND SAYING TO HER VERY EARNESTLY NOW DINAH TELL ME THE TRUTH DID YOU EVER EAT A BAT WHEN SUDDENLY THUMP THUMP DOWN SHE CAME UPON A HEAP OF STICKS AND DRY LEAVES,2,18
HOWEVER THIS BOTTLE WAS NOT MARKED POISON SO ALICE VENTURED TO TASTE IT AND FINDING IT VERY NICE IT HAD IN FACT A SORT OF MIXED FLAVOR OF CHERRY TART CUSTARD PINEAPPLE ROAST TURKEY TOFFEE AND HOT BUTTERED TOAST,2,18


In [7]:
trainset, devset, testset = EEGDataset.from_subjects(
    subjects=["S04"],
    # generated_subjects=generated_subjects,
    base_dir=base_dir,
)
train_max_seq_len = trainset.verify_dataset()
dev_max_seq_len = devset.verify_dataset()
test_max_seq_len = testset.verify_dataset()

max_seq_len = max(train_max_seq_len, dev_max_seq_len, test_max_seq_len)

Extracting parameters from /ocean/projects/cis240129p/shared/data/eeg_alice/S04.vhdr...
Setting channel info structure...
Reading 0 ... 368449  =      0.000 ...   736.898 secs...
Subject S04 splits:
  Train: 58 (69.0%)
  Val: 13 (15.5%)
  Test: 13 (15.5%)
Verifying dataset...
Dataset verification complete. 58 samples checked.
EEG feature dimensions: 60
Longest sequence length: 9539
Verifying dataset...
Dataset verification complete. 13 samples checked.
EEG feature dimensions: 60
Longest sequence length: 9472
Verifying dataset...
Dataset verification complete. 13 samples checked.
EEG feature dimensions: 60
Longest sequence length: 6668


In [8]:
from torch import nn
from eeg_architecture import ResBlock
import torch.nn.functional as F

from transformer import TransformerEncoderLayer


class EEGModel(nn.Module):
    def __init__(self, num_features, num_outs):
        super().__init__()

        self.conv_blocks = nn.Sequential(
            ResBlock(num_features, 768, 2),
            ResBlock(768, 768, 2),
            # ResBlock(768, 768, 2),
        )
        self.w_raw_in = nn.Linear(768, 768)

        encoder_layer = TransformerEncoderLayer(
            d_model=768,
            nhead=8,
            relative_positional=True,
            relative_positional_distance=100,
            dim_feedforward=3072,
            dropout=0.2,
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, 6)
        self.w_out = nn.Linear(768, num_outs)

    def forward(self, x_raw):
        # x shape is (batch, time, electrode)
        x_raw = x_raw.transpose(1, 2)  # put channel before time for conv
        x_raw = self.conv_blocks(x_raw)
        x_raw = x_raw.transpose(1, 2)  # transpose back
        x_raw = self.w_raw_in(x_raw)

        x = x_raw

        # put time first because transformers expect input int the shape (sequence length, batch size, feature dim)
        x = x.transpose(0, 1)
        x = self.transformer(x)
        x = x.transpose(0, 1)

        return self.w_out(x)

In [9]:
n_chars = len(devset.text_transform.chars)
model = EEGModel(devset.num_features, n_chars + 1)

In [10]:
state_dict = torch.load("models/recognition_model/model.pt")
model.load_state_dict(state_dict)
model = model.to(device)

RuntimeError: Error(s) in loading state_dict for EEGModel:
	size mismatch for w_out.weight: copying a param with shape torch.Size([39, 768]) from checkpoint, the shape in current model is torch.Size([40, 768]).
	size mismatch for w_out.bias: copying a param with shape torch.Size([39]) from checkpoint, the shape in current model is torch.Size([40]).

## Testing

In [None]:
dataloader = torch.utils.data.DataLoader(
    dataset=trainset,
    pin_memory=(device == "cuda"),
    num_workers=0,
    collate_fn=EEGDataset.collate_raw,
    batch_size=1,
)
n_chars = len(devset.text_transform.chars)
blank_id = len(testset.text_transform.chars)
chars = "".join(testset.text_transform.chars) + "_"
decoder = CTCBeamDecoder(
    chars,
    blank_id=blank_id,
    log_probs_input=True,
    model_path="lm.binary",
    alpha=1.5,
    beta=1.85,
    beam_width=20,
)
optim = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
# consine annealing scheduler
lr_sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=100)

# eval mode
model.eval()
torch.cuda.empty_cache()
gc.collect()
# loop through epochs here
with torch.no_grad():
    for e in range(10):
        losses = []
        references = []
        predictions = []
        wers = []
        for batch_i, example in tqdm.tqdm(
            enumerate(dataloader), "Train step", disable=None
        ):
            X = combine_fixed_length(example["eeg_raw"], 1000).float().to(device)
            pred = model(X)
            pred = F.log_softmax(pred, 2)
            pred_lengths = [l // 4 for l in example["lengths"]]
            pred = nn.utils.rnn.pad_sequence(
                decollate_tensor(pred, pred_lengths),
                batch_first=False,
                padding_value=trainset.text_transform.pad_token_id,
            )
            y = nn.utils.rnn.pad_sequence(
                example["text_int"],
                batch_first=True,
                padding_value=trainset.text_transform.pad_token_id,
            ).to(device)
            loss = F.ctc_loss(
                pred, y, pred_lengths, example["text_int_lengths"], blank=blank_id
            )
            losses.append(loss.item())
            # loss.backward()
            pred = pred.permute(1, 0, 2)
            beam_results, beam_scores, timesteps, out_lens = decoder.decode(
                pred  # TODO: , seq_lens=example["text_int_lengths"]
            )
            for i in range(len(y)):
                target_text = trainset.text_transform.int_to_text(y[i].cpu().numpy())
                # target_text = target_text.strip()
                target_text = target_text.replace(trainset.text_transform.pad_token, "")
                references.append(target_text)
                if i < len(beam_results):
                    pred_int = beam_results[i, 0, : out_lens[i, 0]].tolist()
                    try:
                        pred_text = trainset.text_transform.int_to_text(pred_int)
                        # pred_text = pred_text.strip()
                        pred_text = pred_text.replace(
                            trainset.text_transform.pad_token, ""
                        )
                    except:
                        print(f"!!!ERROR!!! batch idx: {batch_i}, i: {i}")
                        break
                    predictions.append(pred_text)
            torch.cuda.empty_cache()
        train_loss = np.mean(losses)
        train_wer = jiwer.wer(references, predictions)
        print(f"Epoch {e} train loss: {train_loss}, train wer: {train_wer}")

## Training loop

In [None]:
dataloader = torch.utils.data.DataLoader(
    dataset=trainset,
    pin_memory=(device == "cuda"),
    num_workers=0,
    collate_fn=EEGDataset.collate_raw,
    batch_size=1,
)
n_chars = len(devset.text_transform.chars)
blank_id = len(testset.text_transform.chars)
chars = "".join(testset.text_transform.chars) + "_"
decoder = CTCBeamDecoder(
    chars,
    blank_id=blank_id,
    log_probs_input=True,
    model_path="lm.binary",
    alpha=1.5,
    beta=1.85,
    beam_width=20,
)
optim = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
# consine annealing scheduler
lr_sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=100)

# eval mode
# model.eval()
torch.cuda.empty_cache()
gc.collect()
# loop through epochs here
# with torch.no_grad():
model.train()
for e in range(10):
    losses = []
    references = []
    predictions = []
    wers = []
    for batch_i, example in tqdm.tqdm(
        enumerate(dataloader), "Train step", disable=None
    ):
        X = combine_fixed_length(example["eeg_raw"], 1000).float().to(device)
        pred = model(X)
        pred = F.log_softmax(pred, 2)
        pred_lengths = [l // 4 for l in example["lengths"]]
        pred = nn.utils.rnn.pad_sequence(
            decollate_tensor(pred, pred_lengths),
            batch_first=False,
            # padding_value = n_chars-1,
            padding_value=trainset.text_transform.pad_token_id,
        )
        y = nn.utils.rnn.pad_sequence(
            example["text_int"],
            batch_first=True,
            # padding_value = n_chars-1,
            padding_value=trainset.text_transform.pad_token_id,
        ).to(device)
        loss = F.ctc_loss(
            pred, y, pred_lengths, example["text_int_lengths"], blank=blank_id
        )
        losses.append(loss.item())
        loss.backward()
        pred = pred.permute(1, 0, 2)
        beam_results, beam_scores, timesteps, out_lens = decoder.decode(
            pred  # TODO: , seq_lens=example["text_int_lengths"]
        )
        for i in range(len(y)):
            target_text = trainset.text_transform.int_to_text(y[i].cpu().numpy())
            # target_text = target_text.strip()
            target_text = target_text.replace(trainset.text_transform.pad_token, "")
            references.append(target_text)
            if i < len(beam_results):
                pred_int = beam_results[i, 0, : out_lens[i, 0]].tolist()
                try:
                    pred_text = trainset.text_transform.int_to_text(pred_int)
                    # pred_text = pred_text.strip()
                    pred_text = pred_text.replace(trainset.text_transform.pad_token, "")
                except:
                    print(f"!!!ERROR!!! batch idx: {batch_i}, i: {i}")
                    break
                predictions.append(pred_text)
        if (batch_i + 1) % 2 == 0:
            optim.step()
            optim.zero_grad()
        torch.cuda.empty_cache()
    train_loss = np.mean(losses)
    train_wer = jiwer.wer(references, predictions)
    print(f"Epoch {e} train loss: {train_loss}, train wer: {train_wer}")

: 

: 

: 

In [23]:
predictions

['her there was nothing so very remarkable in that nor did alice think it so very much out of the way to hear the rabbit say to itself oh dear oh dear i shall be late when she thought it over afterwards it occurred to her',
 'she found herself in a long low hall which was lit up by a row of lamps hanging from the roof nabakatokia bagabornabou cacahuamilpa abecedarian cabalistically cadaverousness lafayette chaboisseau bababalouk sabachthani abacadabra eatanswill babalatchi bababalouk hadadrimmon',
 'it ll never do to ask perhaps i shall see it written up somewhere ekateringofsky abdalmalek dabulamanzi babalatchi alcacarquivir arbalestriers balachulish bababalouk babebibobubybaeboe babebibobubybaeboe abandonment habareskul academicianship babalatchi adachigahara academicianship jagadananda cabalistically',
 'i hope they ll remember her saucer of milk at tea time babebibobubybaeboe babebibobubybaeboe sbarovitch babalatchi academicianship anachronistically gabardines babalatchi catachrest

In [24]:
references

['her there was nothing so very remarkable in that nor did alice think it so very much out of the way to hear the rabbit say to itself oh dear oh dear i shall be late when she thought it over afterwards it occurred to her',
 'she found herself in a long low hall which was lit up by a row of lamps hanging from the roofaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa',
 'it ll never do to ask perhaps i shall see it written up somewhereaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa',
 'i hope they ll remember her saucer of milk at tea timeaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa',
 'oh my ears and whiskers how late it s getting she was close behind it when she turned the corner but t

In [25]:
dataloader = torch.utils.data.DataLoader(
    dataset=devset,
    pin_memory=(device == "cuda"),
    num_workers=0,
    collate_fn=EEGDataset.collate_raw,
    batch_size=8,
)
n_chars = len(devset.text_transform.chars)
blank_id = len(testset.text_transform.chars)
decoder = CTCBeamDecoder(
    devset.text_transform.chars + "_",
    blank_id=blank_id,
    log_probs_input=True,
    model_path="lm.binary",
    alpha=1.5,
    beta=1.85,
)
optim = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
lr_sched = torch.optim.lr_scheduler.MultiStepLR(
    optim, milestones=[125, 150, 175], gamma=0.5
)
# eval mode
model.eval()
# loop through epochs here
losses = []
references = []
predictions = []
with torch.no_grad():
    for batch_i, example in tqdm.tqdm(
        enumerate(dataloader), "Train step", disable=None
    ):
        X = combine_fixed_length(example["eeg_raw"], 5000).float().to(device)
        pred = model(X)
        pred = F.log_softmax(pred, 2)

        pred_lengths = [l // 4 for l in example["lengths"]]
        pred_pad = nn.utils.rnn.pad_sequence(
            decollate_tensor(pred, pred_lengths), batch_first=False
        )

        y = nn.utils.rnn.pad_sequence(example["text_int"], batch_first=True).to(
            device
        )  # TODO: padding_value
        loss = F.ctc_loss(
            pred_pad, y, pred_lengths, example["text_int_lengths"], blank=n_chars
        )
        losses.append(loss.item())
        pred_pad = pred_pad.permute(1, 0, 2)
        beam_results, beam_scores, timesteps, out_lens = decoder.decode(pred_pad)
        for i in range(len(y)):
            target_text = trainset.text_transform.int_to_text(y[i].cpu().numpy())
            references.append(target_text)
            if i < len(beam_results):
                pred_int = beam_results[i, 0, : out_lens[i, 0]].tolist()
                try:
                    pred_text = trainset.text_transform.int_to_text(pred_int)
                except:
                    print(f"batch idx: {batch_i}, i: {i}")
                    break
                predictions.append(pred_text)
        break

In [29]:
pred.shape, y.shape, pred_pad.shape

(torch.Size([6, 1250, 38]), torch.Size([8, 309]), torch.Size([8, 2368, 38]))

In [26]:
predictions

['esemplastic babebibobubybaeboe babebibobubybaeboe icaromenippus balancing garamapingwe academicianship cabalistically achaemenidae babalatchi academicians macadamization cacaracamouchen haberdashery hadadrimmon cabalistically araucanians achaemenidae cacaracamouchen academicians achaemenidae arabesques babalatchi cacahuamilpa cabalistically academicians abhandlungen mablethorpe cabalistically babalatchi bababalouk sagaciously abhandlungen',
 'deinde astonished eyatonkawee capabilities babalatchi bababalouk gabardines cabalistically chakamankabudibaba cadaverousness babalatchi sagaciously bagabornabou alcacarquivir tablecloths eachdaireachd gcalekaland babebibobubybaeboe cabalistically dabulamanzi bagabornabou academicianship',
 'ebenezer gablehurst falcinellus cabalistically anabaptists jadakweniyosaon capabilities damanarkist cacahuamilpa babebibobubybaeboe lagadigadeou jadakweniyosaon bibativeness babalatchi anabaptists afanassievna eachdaireachd nayakoghstonde adachigahara adachig

In [27]:
references

['must be getting somewhere near the center of the earthaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa',
 'either the well was very deep or she fell very slowly for she had plenty of time as she went down to look about her and to wonder what was going to happen nextaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa',
 'please maam is this new zealand or australia and she tried to curtsey as she spoke fancy curtseying as you re falling through the air do you think you could manage it and what an ignorant little girl she ll think me for asking noaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa',
 'the poor little thing sat down and cried come there s no

In [27]:
text_transform = TextTransform()
text_transform.chars += "_"
text_transform.int_to_text(p1.cpu().numpy())

'heeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeer theeeeereeeeeeeeeeeee was nothing so veeeeeeeeeeeeeeeeeeeeeeeeeery reeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeemarrkableeeeeeeeeeeeeeeeeeeeeeeeeeeee in that nor did aliceeeeeeeeeeeeeeeeeeeeeee think it  so veery much  out of theeeeeeeeeeeeee way  to hear the rab_bit say to itself oh  deeeeeeeeeeeear oh deeeeeeeeeeeeeeeeeeeeeeeeeeeear i sshal_l beeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee       latteeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee wheeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee

In [None]:
beam_results, beam_scores, timesteps, out_lens = decoder.decode(pred)
pred_int = beam_results[0, 0, : out_lens[0, 0]].tolist()
pred_text = testset.text_transform.int_to_text(pred_int)
pred_text

'her there was nothing so very remarkable in that nor did alice think it so very much out of the way to hear the rabbit say to itself oh dear oh dear i shall be late when she thought it ove'

In [36]:
beam_results.shape

torch.Size([5, 100, 1250])

In [29]:
target_txt = testset.text_transform.int_to_text(y[0].cpu().numpy())
jiwer.wer(target_txt, pred_text)

0.13043478260869565

In [None]:
model = model.to(device)
batch_size = 2
dataloader = torch.utils.data.DataLoader(
    dataset=trainset,
    pin_memory=(device == "cuda"),
    num_workers=0,
    collate_fn=EEGDataset.collate_raw,
    batch_size=batch_size,
)
blank_id = len(testset.text_transform.chars)
decoder = CTCBeamDecoder(
    devset.text_transform.chars + "_",
    blank_id=blank_id,
    log_probs_input=True,
    model_path="lm.binary",
    alpha=1.5,
    beta=1.85,
)
model.eval()
references = []
predictions = []
losses = []
with torch.no_grad():
    for example in tqdm.tqdm(dataloader, "Evaluate", disable=None):
        # X = example["eeg_raw"][0].float().to(device)
        X = combine_fixed_length(example["eeg_raw"], 5000).float().to(device)
        pred = F.log_softmax(model(X), -1)
        pred_lengths = [l // 4 for l in example["lengths"]]
        pred_pad = nn.utils.rnn.pad_sequence(
            decollate_tensor(pred, pred_lengths), batch_first=False
        )

        # y = nn.utils.rnn.pad_sequence(example["text_int"], batch_first=True).to(device)
        y = torch.cat(example["text_int"]).to(device)
        loss = F.ctc_loss(
            pred_pad, y, pred_lengths, example["text_int_lengths"], blank=n_chars
        )
        losses.append(loss.item())

        beam_results, beam_scores, timesteps, out_lens = decoder.decode(pred)
        pred_int = beam_results[0, 0, : out_lens[0, 0]].tolist()

        pred_text = testset.text_transform.int_to_text(pred_int)
        target_text = testset.text_transform.clean_text(example["labels"][0])

        references.append(target_text)
        predictions.append(pred_text)
wer = jiwer.wer(references, predictions)

In [116]:
print(f"pred_pad: {pred_pad.shape}")
print(f"pred.shape: {pred.shape}")
print(f"y.shape: {y.shape}")
print(f"pred_lengths: {pred_lengths}")
print(f"n_chars: {n_chars}")
loss = F.ctc_loss(pred_pad, y, pred_lengths, example["text_int_lengths"], blank=n_chars)
print(f"loss: {loss}")

pred_pad: torch.Size([761, 2, 38])
pred.shape: torch.Size([1, 1250, 38])
y.shape: torch.Size([155])
pred_lengths: [248, 761]
n_chars: 37
loss: 2.801051139831543


In [113]:
predictions

['tetootne',
 'tetootne',
 'tetootne',
 'teetota',
 'tetootne',
 'tetootne',
 'tetootne',
 'tetaheite',
 'tetootne',
 'teetotalis',
 'tetootne',
 'tetootne',
 'tetootne',
 'teetota',
 'teetota',
 'tetootne',
 'tetootne',
 'tetootne',
 'tetootne',
 'teetotale',
 'tetootne',
 'tetootne',
 'teetotalis',
 'etteniot',
 'tetootne',
 'teetotale',
 'etonensis',
 'tetootne',
 'tottontai']

In [104]:
pred_exp = torch.exp(pred)  # Convert log_probs back to probabilities
pred_max = pred_exp.argmax(dim=-1)  # Get the most likely token at each time step
print(pred_max.squeeze().tolist())  # Inspect the token indices over time

[37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37,

In [65]:
beam_results, beam_scores, timesteps, out_lens = decoder.decode(pred_pad)
beam_results_o, beam_scores_o, timesteps_o, out_lens_o = decoder.decode(pred)
# pred_int = beam_results[0, 0, : out_lens[0, 0]].tolist()

In [67]:
beam_results.shape, beam_results_o.shape

(torch.Size([761, 100, 1]), torch.Size([1, 100, 1250]))

In [92]:
pred.shape

torch.Size([1, 1250, 38])

In [None]:
pred_int = beam_results_o[0, 0, : out_lens_o[:, 0]].tolist()  # out_lens_o[:, 0]

In [98]:
len(testset.text_transform.int_to_text(pred_int))

10

In [95]:
pred_int

[]

In [62]:
beam_results.shape, len(example["labels"][0])

(torch.Size([761, 100, 1]), 120)

In [None]:
loss = F.ctc_loss(pred_prd, y, pred_lengths, example["text_int_lengths"], blank=n_chars)
loss

tensor(2.8002, device='cuda:0')

In [56]:
for i, data in enumerate(dataloader):
    break

In [57]:
data.keys()

dict_keys(['eeg_raw', 'labels', 'lengths', 'text_int', 'text_int_lengths'])

In [19]:
predictions[0]

'tetootne'

In [108]:
trainset.text_transform.chars

'abcdefghijklmnopqrstuvwxyz0123456789 '

In [88]:
len(testset.text_transform.chars)

37

In [None]:
from eeg_architecture import EEGModel
from torch import nn
import torch.nn.functional as F
import tqdm
from data_utils import combine_fixed_length, decollate_tensor
from IPython.core.debugger import Pdb


device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 32
learning_rate = 1e-4
l2 = 1e-5
n_epochs = 2
learning_rate_warmup = 100

dataloader = torch.utils.data.DataLoader(devset, batch_size=1)
for example in tqdm.tqdm(dataloader, "Evaluate", disable=None):
    target = example["label"]
    break

In [16]:
target[0]

'MUST BE GETTING SOMEWHERE NEAR THE CENTER OF THE EARTH'

In [None]:
ds_brennan = BrennanDataset(
    root_dir=base_dir,
    phoneme_dir=base_dir / "phonemes",
    idx="S04",
    phoneme_dict_path=base_dir / "phoneme_dict.txt",
    debug=True,
)

Extracting parameters from /ocean/projects/cis240129p/shared/data/eeg_alice/S04.vhdr...
Setting channel info structure...
Reading 0 ... 368449  =      0.000 ...   736.898 secs...


In [11]:
item2 = ds_brennan[0]
print(item2.keys())
print(item2["audio_feats"].shape, item2["eeg_feats"].shape, len(item2["phonemes"]))

0 Alice
dict_keys(['label', 'audio_feats', 'audio_raw', 'eeg_raw', 'eeg_feats', 'phonemes'])
(104, 128) (159, 300) 104


In [16]:
train_dataset, test_dataset = create_datasets(subjects_used, base_dir)

print(
    f"Train dataset length: {len(train_dataset)}, Test dataset length: {len(test_dataset)}"
)

Extracting parameters from /ocean/projects/cis240129p/shared/data/eeg_alice/S04.vhdr...
Setting channel info structure...
Reading 0 ... 368449  =      0.000 ...   736.898 secs...


Train dataset length: 1703, Test dataset length: 426


In [17]:
# def collate_fn(batch):
#     """
#     A custom collate function that handles different types of data in a batch.
#     It dynamically creates batches by converting arrays or lists to tensors and
#     applies padding to variable-length sequences.
#     """
#     batch_dict = {}
#     for key in batch[0].keys():
#         batch_items = [item[key] for item in batch]
#         if isinstance(batch_items[0], np.ndarray) or isinstance(
#             batch_items[0], torch.Tensor
#         ):
#             if isinstance(batch_items[0], np.ndarray):
#                 batch_items = [torch.tensor(b) for b in batch_items]
#             if len(batch_items[0].shape) > 0:
#                 batch_dict[key] = torch.nn.utils.rnn.pad_sequence(
#                     batch_items, batch_first=True  # pad with zeros
#                 )
#             else:
#                 batch_dict[key] = torch.stack(batch_items)
#         else:
#             batch_dict[key] = batch_items

#     return batch_dict


train_dataloder = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=2,
    num_workers=1,
    shuffle=True,
    collate_fn=collate_fn,
)

test_dataloder = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=2,
    num_workers=1,
    shuffle=False,
    collate_fn=collate_fn,
)

In [18]:
item = train_dataset[0]
for k, v in item.items():
    try:
        print(k, v.shape, type(v))
    except:
        print(k, type(v))

label <class 'str'>
audio_feats (104, 128) <class 'numpy.ndarray'>
audio_raw (16735,) <class 'numpy.ndarray'>
eeg_raw (520, 62) <class 'numpy.ndarray'>
eeg_feats (159, 310) <class 'numpy.ndarray'>
phonemes (104,) <class 'numpy.ndarray'>


In [23]:
item["eeg_raw"].dtype

dtype('float64')

In [19]:
# test dataloader
i = 0
for batch in train_dataloder:
    print(i)
    for k, v in batch.items():
        try:
            print(k, v.shape, type(v))
        except:
            print(k, type(v))
    i += 1
    if i > 4:
        break

0
label <class 'list'>
audio_feats torch.Size([2, 130, 128]) <class 'torch.Tensor'>
audio_raw torch.Size([2, 20800]) <class 'torch.Tensor'>
eeg_raw torch.Size([2, 520, 62]) <class 'torch.Tensor'>
eeg_feats torch.Size([2, 159, 310]) <class 'torch.Tensor'>
phonemes torch.Size([2, 130]) <class 'torch.Tensor'>
1
label <class 'list'>
audio_feats torch.Size([2, 130, 128]) <class 'torch.Tensor'>
audio_raw torch.Size([2, 20800]) <class 'torch.Tensor'>
eeg_raw torch.Size([2, 520, 62]) <class 'torch.Tensor'>
eeg_feats torch.Size([2, 159, 310]) <class 'torch.Tensor'>
phonemes torch.Size([2, 130]) <class 'torch.Tensor'>
2
label <class 'list'>
audio_feats torch.Size([2, 130, 128]) <class 'torch.Tensor'>
audio_raw torch.Size([2, 20800]) <class 'torch.Tensor'>
eeg_raw torch.Size([2, 520, 62]) <class 'torch.Tensor'>
eeg_feats torch.Size([2, 159, 310]) <class 'torch.Tensor'>
phonemes torch.Size([2, 130]) <class 'torch.Tensor'>
3
label <class 'list'>
audio_feats torch.Size([2, 130, 128]) <class 'torch.T