In [1]:
!ls -R /kaggle/input/brain-to-text-25


/kaggle/input/brain-to-text-25:
data_link.txt  t15_copyTask_neuralData	t15_pretrained_rnn_baseline

/kaggle/input/brain-to-text-25/t15_copyTask_neuralData:
hdf5_data_final

/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final:
t15.2023.08.11	t15.2023.09.29	t15.2023.11.04	t15.2024.02.25	t15.2024.07.19
t15.2023.08.13	t15.2023.10.01	t15.2023.11.17	t15.2024.03.03	t15.2024.07.21
t15.2023.08.18	t15.2023.10.06	t15.2023.11.19	t15.2024.03.08	t15.2024.07.28
t15.2023.08.20	t15.2023.10.08	t15.2023.11.26	t15.2024.03.15	t15.2025.01.10
t15.2023.08.25	t15.2023.10.13	t15.2023.12.03	t15.2024.03.17	t15.2025.01.12
t15.2023.08.27	t15.2023.10.15	t15.2023.12.08	t15.2024.04.25	t15.2025.03.14
t15.2023.09.01	t15.2023.10.20	t15.2023.12.10	t15.2024.04.28	t15.2025.03.16
t15.2023.09.03	t15.2023.10.22	t15.2023.12.17	t15.2024.05.10	t15.2025.03.30
t15.2023.09.24	t15.2023.11.03	t15.2023.12.29	t15.2024.06.14	t15.2025.04.13

/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2023

In [2]:
import h5py

sample = "/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2025.01.12/data_train.hdf5"
with h5py.File(sample, "r") as f:
    first_trial = list(f.keys())[0]
    print("Example trial:", first_trial)
    print("Contents:")
    for key in f[first_trial].keys():
        print("  ", key)


Example trial: trial_0000
Contents:
   input_features
   seq_class_ids
   transcription


In [3]:
train_h5 = "/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2025.01.12/data_train.hdf5"


In [4]:
# Extract all label IDs from train set to infer vocabulary
all_ids = set()

with h5py.File(train_h5, "r") as f:
    for k in list(f.keys())[:200]:   # first 200 trials are enough
        ids = f[f"{k}/seq_class_ids"][:]
        all_ids.update(ids.tolist())

print("Unique token IDs:", sorted(all_ids))
print("Count:", len(all_ids))


Unique token IDs: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40]
Count: 41


In [5]:
import h5py, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import math, random
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)


Device: cuda


In [6]:
# ASCII printable characters + mandatory blank at index 0
chars = ['<blank>'] + [chr(i) for i in range(32,127)]
stoi = {c:i for i,c in enumerate(chars)}
itos = {i:c for c,i in stoi.items()}
VOCAB_SIZE = len(stoi)

def ascii_to_tokens(arr):
    # Converts raw ascii ints → model token ids
    out=[]
    for x in arr:
        if x == 0: 
            continue
        c = chr(int(x))
        if c in stoi:
            out.append(stoi[c])
    return torch.tensor(out, dtype=torch.long)


In [7]:
class Brain2TextDataset(Dataset):
    def __init__(self, h5_path, max_trials=None):
        self.f = h5py.File(h5_path, "r")
        keys = list(self.f.keys())

        # Keep only trials that contain *at least one* label source
        self.valid_keys = []
        for k in keys:
            grp = self.f[k]
            if "transcription" in grp or "seq_class_ids" in grp:
                self.valid_keys.append(k)

        if max_trials is not None:
            self.valid_keys = self.valid_keys[:max_trials]

    def __len__(self):
        return len(self.valid_keys)

    def __getitem__(self, idx):
        k = self.valid_keys[idx]
        grp = self.f[k]

        # Neural data
        x = grp["input_features"][()]  # (T, 512)

        # Labels: use transcription first, else fallback
        if "transcription" in grp:
            y_raw = grp["transcription"][()]  # ASCII ints
            y = ascii_to_tokens(y_raw)

        else:
            # seq_class_ids are already integers (but may include 0 padding)
            y_raw = grp["seq_class_ids"][()]
            y = torch.tensor([t for t in y_raw if t != 0], dtype=torch.long)

        return torch.tensor(x, dtype=torch.float32), y


In [8]:
def ctc_collate(batch):
    xs, ys = zip(*batch)  # lists of tensors
    x_lens = torch.tensor([len(x) for x in xs], dtype=torch.long)
    y_lens = torch.tensor([len(y) for y in ys], dtype=torch.long)

    X = nn.utils.rnn.pad_sequence(xs, batch_first=True)
    Y = torch.cat(ys)

    return X, Y, x_lens, y_lens


In [9]:
train_h5 = "/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2025.01.12/data_train.hdf5"
val_h5   = "/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2025.01.12/data_val.hdf5"
test_h5  = "/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2025.01.12/data_test.hdf5"

train_ds = Brain2TextDataset(train_h5, max_trials=800)   # adjust subset size
val_ds   = Brain2TextDataset(val_h5, max_trials=200)
test_ds  = Brain2TextDataset(test_h5, max_trials=200)

train_dl = DataLoader(train_ds, batch_size=8, shuffle=True, collate_fn=ctc_collate)
val_dl   = DataLoader(val_ds, batch_size=8, shuffle=False, collate_fn=ctc_collate)
test_dl  = DataLoader(test_ds, batch_size=8, shuffle=False, collate_fn=ctc_collate)


In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResidualGRU(nn.Module):
    def __init__(self, feat_dim, hidden=256, layers=4, dropout=0.15):
        super().__init__()
        self.layers = nn.ModuleList()
        self.projections = nn.ModuleList()

        for _ in range(layers):
            gru = nn.GRU(
                input_size=feat_dim,
                hidden_size=hidden,
                num_layers=1,
                batch_first=True,
                bidirectional=True
            )
            self.layers.append(gru)

            # project from (hidden*2) → feat_dim so residual matches
            self.projections.append(nn.Linear(hidden * 2, feat_dim))

        self.dropout = dropout

        self.conv = nn.Conv1d(feat_dim, feat_dim, kernel_size=5, padding=2)


    def forward(self, x, xl):
        for gru, proj in zip(self.layers, self.projections):
            packed = nn.utils.rnn.pack_padded_sequence(x, xl.cpu(), batch_first=True, enforce_sorted=False)
            out, _ = gru(packed)
            out, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first=True)
            out = proj(out)
            x = x + F.dropout(out, p=self.dropout, training=self.training)
        return x, xl


class BrainRNN(nn.Module):
    def __init__(self, feat_dim, vocab):
        super().__init__()
        self.rnn = ResidualGRU(feat_dim, hidden=256, layers=4, dropout=0.1)
        self.fc = nn.Linear(feat_dim, vocab)

    def forward(self, x, xl):
        out, xl = self.rnn(x, xl)
        logits = self.fc(out)
        logp = F.log_softmax(logits, dim=-1)
        return logp, xl


In [11]:
sample_X, _, _, _ = next(iter(train_dl))
FEAT_DIM = sample_X.shape[-1]
print("Feature dimension:", FEAT_DIM)

model = BrainRNN(FEAT_DIM, VOCAB_SIZE).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=4e-4, weight_decay=1e-3)
ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)


Feature dimension: 512


In [12]:
def train_step(dataloader):
    model.train()
    total_loss = 0

    for X, Y, x_lens, y_lens in dataloader:
        X = X.to(DEVICE)
        Y = Y.to(DEVICE)
        x_lens = x_lens.to(DEVICE)
        y_lens = y_lens.to(DEVICE)

        optimizer.zero_grad()

        # Forward pass
        logp, out_lens = model(X, x_lens)

        # CTC expects (T, B, C)
        logp = logp.transpose(0, 1)

        loss = ctc_loss(logp, Y, out_lens, y_lens)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)


In [13]:
def greedy_decode(logp, lens):
    # logp: (B, T, vocab)
    pred = logp.argmax(dim=-1)  # (B, T)
    output = []

    for i in range(pred.size(0)):
        seq = pred[i][:lens[i]].tolist()
        # collapse repeats + remove blanks (0)
        cleaned = []
        last = None
        for x in seq:
            if x != last and x != 0:
                cleaned.append(x)
            last = x
        output.append(cleaned)
    return output


def validate(dataloader):
    model.eval()
    total_wer = 0
    count = 0

    with torch.no_grad():
        for X, Y, x_lens, y_lens in dataloader:
            X = X.to(DEVICE)
            Y = Y.to(DEVICE)
            x_lens = x_lens.to(DEVICE)
            y_lens = y_lens.to(DEVICE)

            logp, out_lens = model(X, x_lens)

            preds = greedy_decode(logp.cpu(), out_lens.cpu())

            # split targets by lengths
            idx = 0
            targets = []
            for L in y_lens:
                targets.append(Y[idx:idx+L].tolist())
                idx += L

            # compute a VERY rough WER (edit distance / target length)
            for p, t in zip(preds, targets):
                if len(t) == 0:
                    continue
                dist = edit_distance(p, t)
                total_wer += dist / len(t)
                count += 1

    return total_wer / max(1, count)


In [14]:
def edit_distance(a, b):
    dp = [[0]*(len(b)+1) for _ in range(len(a)+1)]

    for i in range(len(a)+1):
        dp[i][0] = i
    for j in range(len(b)+1):
        dp[0][j] = j

    for i in range(1, len(a)+1):
        for j in range(1, len(b)+1):
            cost = 0 if a[i-1] == b[j-1] else 1
            dp[i][j] = min(
                dp[i-1][j] + 1,
                dp[i][j-1] + 1,
                dp[i-1][j-1] + cost
            )
    return dp[-1][-1]


In [15]:
for epoch in range(1,21):
    tr = train_step(train_dl)
    val = validate(val_dl)
    print(f"Epoch {epoch:02d} | Train {tr:.3f} | Val WER {val:.3f}")


Epoch 01 | Train 36.023 | Val WER 0.813
Epoch 02 | Train 4.606 | Val WER 0.857
Epoch 03 | Train 4.631 | Val WER 0.869
Epoch 04 | Train 3.603 | Val WER 0.872
Epoch 05 | Train 3.479 | Val WER 0.854
Epoch 06 | Train 3.300 | Val WER 0.824
Epoch 07 | Train 3.204 | Val WER 0.886
Epoch 08 | Train 3.067 | Val WER 0.886
Epoch 09 | Train 2.946 | Val WER 0.863
Epoch 10 | Train 2.821 | Val WER 0.858
Epoch 11 | Train 2.746 | Val WER 0.878
Epoch 12 | Train 2.618 | Val WER 0.832
Epoch 13 | Train 2.529 | Val WER 0.822
Epoch 14 | Train 2.394 | Val WER 0.757
Epoch 15 | Train 2.241 | Val WER 0.770
Epoch 16 | Train 2.072 | Val WER 0.759
Epoch 17 | Train 1.859 | Val WER 0.756
Epoch 18 | Train 1.612 | Val WER 0.762
Epoch 19 | Train 1.365 | Val WER 0.772
Epoch 20 | Train 1.118 | Val WER 0.776
