In [4]:
!ls -R /kaggle/input/brain-to-text-25


/kaggle/input/brain-to-text-25:
data_link.txt  t15_copyTask_neuralData	t15_pretrained_rnn_baseline

/kaggle/input/brain-to-text-25/t15_copyTask_neuralData:
hdf5_data_final

/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final:
t15.2023.08.11	t15.2023.09.29	t15.2023.11.04	t15.2024.02.25	t15.2024.07.19
t15.2023.08.13	t15.2023.10.01	t15.2023.11.17	t15.2024.03.03	t15.2024.07.21
t15.2023.08.18	t15.2023.10.06	t15.2023.11.19	t15.2024.03.08	t15.2024.07.28
t15.2023.08.20	t15.2023.10.08	t15.2023.11.26	t15.2024.03.15	t15.2025.01.10
t15.2023.08.25	t15.2023.10.13	t15.2023.12.03	t15.2024.03.17	t15.2025.01.12
t15.2023.08.27	t15.2023.10.15	t15.2023.12.08	t15.2024.04.25	t15.2025.03.14
t15.2023.09.01	t15.2023.10.20	t15.2023.12.10	t15.2024.04.28	t15.2025.03.16
t15.2023.09.03	t15.2023.10.22	t15.2023.12.17	t15.2024.05.10	t15.2025.03.30
t15.2023.09.24	t15.2023.11.03	t15.2023.12.29	t15.2024.06.14	t15.2025.04.13

/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2023

In [5]:
import h5py

sample = "/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2025.01.12/data_train.hdf5"
with h5py.File(sample, "r") as f:
    first_trial = list(f.keys())[0]
    print("Example trial:", first_trial)
    print("Contents:")
    for key in f[first_trial].keys():
        print("  ", key)


Example trial: trial_0000
Contents:
   input_features
   seq_class_ids
   transcription


In [6]:
train_h5 = "/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2025.01.12/data_train.hdf5"


In [7]:
# Extract all label IDs from train set to infer vocabulary
all_ids = set()

with h5py.File(train_h5, "r") as f:
    for k in list(f.keys())[:200]:   # first 200 trials are enough
        ids = f[f"{k}/seq_class_ids"][:]
        all_ids.update(ids.tolist())

print("Unique token IDs:", sorted(all_ids))
print("Count:", len(all_ids))


Unique token IDs: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40]
Count: 41


In [8]:
import h5py, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import math, random
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)


Device: cuda


In [9]:
# ASCII printable characters + mandatory blank at index 0
chars = ['<blank>'] + [chr(i) for i in range(32,127)]
stoi = {c:i for i,c in enumerate(chars)}
itos = {i:c for c,i in stoi.items()}
VOCAB_SIZE = len(stoi)

def ascii_to_tokens(arr):
    # Converts raw ascii ints → model token ids
    out=[]
    for x in arr:
        if x == 0: 
            continue
        c = chr(int(x))
        if c in stoi:
            out.append(stoi[c])
    return torch.tensor(out, dtype=torch.long)


In [10]:
class Brain2TextDataset(Dataset):
    def __init__(self, h5_path, max_trials=None):
        self.f = h5py.File(h5_path, "r")
        keys = list(self.f.keys())

        # Keep only trials that contain *at least one* label source
        self.valid_keys = []
        for k in keys:
            grp = self.f[k]
            if "transcription" in grp or "seq_class_ids" in grp:
                self.valid_keys.append(k)

        if max_trials is not None:
            self.valid_keys = self.valid_keys[:max_trials]

    def __len__(self):
        return len(self.valid_keys)

    def __getitem__(self, idx):
        k = self.valid_keys[idx]
        grp = self.f[k]

        # Neural data
        x = grp["input_features"][()]  # (T, 512)

        # Labels: use transcription first, else fallback
        if "transcription" in grp:
            y_raw = grp["transcription"][()]  # ASCII ints
            y = ascii_to_tokens(y_raw)

        else:
            # seq_class_ids are already integers (but may include 0 padding)
            y_raw = grp["seq_class_ids"][()]
            y = torch.tensor([t for t in y_raw if t != 0], dtype=torch.long)

        return torch.tensor(x, dtype=torch.float32), y


In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F


In [12]:
def ctc_collate(batch):
    xs, ys = zip(*batch)  # lists of tensors
    x_lens = torch.tensor([len(x) for x in xs], dtype=torch.long)
    y_lens = torch.tensor([len(y) for y in ys], dtype=torch.long)

    X = nn.utils.rnn.pad_sequence(xs, batch_first=True)
    Y = torch.cat(ys)

    return X, Y, x_lens, y_lens


In [13]:
# train_h5 = "/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2025.01.12/data_train.hdf5"
# val_h5   = "/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2025.01.12/data_val.hdf5"
# test_h5  = "/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2025.01.12/data_test.hdf5"

# train_ds = Brain2TextDataset(train_h5, max_trials=800)   # adjust subset size
# val_ds   = Brain2TextDataset(val_h5, max_trials=200)
# test_ds  = Brain2TextDataset(test_h5, max_trials=200)

# train_dl = DataLoader(train_ds, batch_size=8, shuffle=True, collate_fn=ctc_collate)
# val_dl   = DataLoader(val_ds, batch_size=8, shuffle=False, collate_fn=ctc_collate)
# test_dl  = DataLoader(test_ds, batch_size=8, shuffle=False, collate_fn=ctc_collate)


In [None]:
BASE = "/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2025.01.12"

train_h5 = f"{BASE}/data_train.hdf5"
val_h5   = f"{BASE}/data_val.hdf5"
test_h5  = f"{BASE}/data_test.hdf5"

print(train_h5)
print(val_h5)
print(test_h5)


In [None]:
train_ds = Brain2TextDataset(train_h5, max_trials=3000)
val_ds   = Brain2TextDataset(val_h5,   max_trials=500)
test_ds  = Brain2TextDataset(test_h5,  max_trials=500)


In [14]:
class TransformerBlock(nn.Module):
    def __init__(self, dim, heads=4, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(dim, heads, dropout=dropout, batch_first=True)
        self.ln1 = nn.LayerNorm(dim)

        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim * mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(dim * mlp_ratio), dim),
        )
        self.ln2 = nn.LayerNorm(dim)

    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.ln1(x + attn_out)
        x = self.ln2(x + self.mlp(x))
        return x


In [15]:
class CNNTransformerCTC(nn.Module):
    def __init__(self, feat_dim=512, model_dim=256, num_layers=4, vocab_size=VOCAB_SIZE):
        super().__init__()

        # 1D CNN for smoothing + local features
        self.conv = nn.Sequential(
            nn.Conv1d(feat_dim, model_dim, kernel_size=5, padding=2),
            nn.GELU(),
            nn.Conv1d(model_dim, model_dim, kernel_size=5, padding=2),
            nn.GELU(),
        )

        # transformer layers
        self.layers = nn.ModuleList([
            TransformerBlock(model_dim, heads=4, mlp_ratio=4.0, dropout=0.1)

            for _ in range(num_layers)
        ])

        # output classifier
        self.fc = nn.Linear(model_dim, vocab_size)

    def forward(self, x, x_lens):
        # (B, T, C) → (B, C, T)
        x = x.transpose(1, 2)
        x = self.conv(x)
        x = x.transpose(1, 2)

        for layer in self.layers:
            x = layer(x)

        logits = self.fc(x)
        return F.log_softmax(logits, dim=-1), x_lens


In [16]:
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.0):
        self.patience = patience
        self.min_delta = min_delta
        self.best_wer = float("inf")
        self.counter = 0
        self.should_stop = False

    def step(self, current_wer):
        if current_wer < self.best_wer - self.min_delta:
            self.best_wer = current_wer
            self.counter = 0
        else:
            self.counter += 1

        if self.counter >= self.patience:
            self.should_stop = True


In [17]:
model = CNNTransformerCTC().to(DEVICE)
ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)


In [18]:
def train_step(dataloader):
    model.train()
    total_loss = 0

    for X, Y, x_lens, y_lens in dataloader:
        X = X.to(DEVICE)
        Y = Y.to(DEVICE)
        x_lens = x_lens.to(DEVICE)
        y_lens = y_lens.to(DEVICE)

        optimizer.zero_grad()
        logp, out_lens = model(X, x_lens)

        loss = ctc_loss(
            logp.transpose(0,1),
            Y,
            out_lens,
            y_lens
        )
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)


In [19]:
def edit_distance(a, b):
    dp = [[0]*(len(b)+1) for _ in range(len(a)+1)]
    for i in range(len(a)+1): dp[i][0] = i
    for j in range(len(b)+1): dp[0][j] = j

    for i in range(1, len(a)+1):
        for j in range(1, len(b)+1):
            cost = 0 if a[i-1] == b[j-1] else 1
            dp[i][j] = min(
                dp[i-1][j] + 1,
                dp[i][j-1] + 1,
                dp[i-1][j-1] + cost
            )
    return dp[-1][-1]


def greedy_decode(logp, lens):
    pred = logp.argmax(dim=-1)
    results = []
    for i in range(pred.size(0)):
        seq = pred[i][:lens[i]].tolist()
        cleaned = []
        last = None
        for x in seq:
            if x != 0 and x != last:
                cleaned.append(x)
            last = x
        results.append(cleaned)
    return results


def validate(dataloader):
    model.eval()
    total = 0
    count = 0

    with torch.no_grad():
        for X, Y, x_lens, y_lens in dataloader:
            X = X.to(DEVICE)
            x_lens = x_lens.to(DEVICE)

            logp, out_lens = model(X, x_lens)

            preds = greedy_decode(logp.cpu(), out_lens.cpu())

            # split Y according to y_lens
            idx = 0
            targets = []
            for L in y_lens:
                targets.append(Y[idx:idx+L].tolist())
                idx += L

            for p, t in zip(preds, targets):
                if len(t) == 0:
                    continue
                total += edit_distance(p, t) / len(t)
                count += 1

    return total / count


In [20]:
early_stop = EarlyStopping(patience=8)
min_epochs = 10

for epoch in range(1, 60):
    tr = train_step(train_dl)
    val = validate(val_dl)

    print(f"Epoch {epoch:02d} | Train {tr:.3f} | Val WER {val:.3f}")

    if epoch > min_epochs:
        early_stop.step(val)
        if early_stop.should_stop:
            print(f"Stopped at epoch {epoch}, best WER={early_stop.best_wer:.3f}")
            break


Epoch 01 | Train 12.875 | Val WER 1.000
Epoch 02 | Train 3.243 | Val WER 1.000
Epoch 03 | Train 3.154 | Val WER 1.000
Epoch 04 | Train 3.101 | Val WER 1.000
Epoch 05 | Train 3.015 | Val WER 1.000
Epoch 06 | Train 2.828 | Val WER 1.000
Epoch 07 | Train 2.543 | Val WER 1.000
Epoch 08 | Train 2.280 | Val WER 1.000
Epoch 09 | Train 2.040 | Val WER 1.000
Epoch 10 | Train 1.794 | Val WER 0.999
Epoch 11 | Train 1.537 | Val WER 0.995
Epoch 12 | Train 1.293 | Val WER 0.958
Epoch 13 | Train 1.051 | Val WER 0.916
Epoch 14 | Train 0.804 | Val WER 0.805
Epoch 15 | Train 0.600 | Val WER 0.808
Epoch 16 | Train 0.464 | Val WER 0.847
Epoch 17 | Train 0.361 | Val WER 0.790
Epoch 18 | Train 0.284 | Val WER 0.770
Epoch 19 | Train 0.253 | Val WER 0.772
Epoch 20 | Train 0.235 | Val WER 0.744
Epoch 21 | Train 0.215 | Val WER 0.773
Epoch 22 | Train 0.206 | Val WER 0.717
Epoch 23 | Train 0.166 | Val WER 0.722
Epoch 24 | Train 0.139 | Val WER 0.701
Epoch 25 | Train 0.128 | Val WER 0.688
Epoch 26 | Train 0.092 |