In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import h5py
import numpy as np


In [2]:
BASE = "/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2025.01.12"

TRAIN_H5 = f"{BASE}/data_train.hdf5"
VAL_H5   = f"{BASE}/data_val.hdf5"
TEST_H5  = f"{BASE}/data_test.hdf5"

print(TRAIN_H5)
print(VAL_H5)
print(TEST_H5)


/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2025.01.12/data_train.hdf5
/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2025.01.12/data_val.hdf5
/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2025.01.12/data_test.hdf5


In [3]:
class Brain2TextDataset(Dataset):
    def __init__(self, h5_path, max_trials=None):
        self.f = h5py.File(h5_path, "r")
        self.keys = list(self.f.keys())

        if max_trials is not None:
            self.keys = self.keys[:max_trials]

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        k = self.keys[idx]
        grp = self.f[k]

        # Neural features (T, 512)
        X = torch.tensor(grp["input_features"][:], dtype=torch.float32)

        # PHONEME TARGETS (seq_class_ids)
        if "seq_class_ids" in grp:
            ids = grp["seq_class_ids"][:]
            ids = [int(t) for t in ids if t != 0]  # remove padding
            Y = torch.tensor(ids, dtype=torch.long)
        else:
            # fallback: transcription exists
            text = grp["transcription"][()].decode("utf-8")
            ids = [ord(c) for c in text]  # not ideal but fallback
            Y = torch.tensor(ids, dtype=torch.long)

        return X, Y


In [4]:
def ctc_collate(batch):
    Xs, Ys = zip(*batch)

    x_lens = torch.tensor([x.shape[0] for x in Xs], dtype=torch.long)
    y_lens = torch.tensor([y.shape[0] for y in Ys], dtype=torch.long)

    Xs = nn.utils.rnn.pad_sequence(Xs, batch_first=True)  # (B, T, C)
    Ys = torch.cat(Ys)

    return Xs, Ys, x_lens, y_lens


In [5]:
class ConvBiGRU(nn.Module):
    def __init__(self, feat_dim=512, hidden=256, vocab_size=100):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv1d(feat_dim, feat_dim, kernel_size=5, padding=2),
            nn.GELU(),
        )

        self.gru = nn.GRU(
            input_size=feat_dim,
            hidden_size=hidden,
            num_layers=3,
            batch_first=True,
            bidirectional=True,
            dropout=0.2
        )

        self.fc = nn.Linear(hidden * 2, vocab_size)

    def forward(self, x, x_lens):
        x = x.transpose(1, 2)        # (B, C, T)
        x = self.conv(x)
        x = x.transpose(1, 2)        # (B, T, C)

        packed = nn.utils.rnn.pack_padded_sequence(
            x, x_lens.cpu(), batch_first=True, enforce_sorted=False
        )

        out, _ = self.gru(packed)

        out, out_lens = nn.utils.rnn.pad_packed_sequence(
            out, batch_first=True
        )

        logits = self.fc(out)   # (B, T, vocab)
        log_probs = F.log_softmax(logits, dim=-1)

        log_probs = log_probs.transpose(0, 1)  # CTC expects (T, B, V)
        return log_probs, out_lens


In [6]:
def greedy_decode_ctc(log_probs, out_lens):
    pred_ids = log_probs.argmax(dim=-1).cpu()  # (T, B)

    results = []
    for b in range(pred_ids.size(1)):
        seq = pred_ids[:out_lens[b], b].tolist()
        cleaned = []
        last = None
        for p in seq:
            if p != last and p != 0:   # remove repeats + blank
                cleaned.append(p)
            last = p
        results.append(cleaned)
    return results


In [7]:
def edit_distance(a, b):
    dp = np.zeros((len(a)+1, len(b)+1), dtype=int)
    for i in range(len(a)+1):
        dp[i][0] = i
    for j in range(len(b)+1):
        dp[0][j] = j
    for i in range(1, len(a)+1):
        for j in range(1, len(b)+1):
            cost = 0 if a[i-1] == b[j-1] else 1
            dp[i][j] = min(
                dp[i-1][j] + 1,
                dp[i][j-1] + 1,
                dp[i-1][j-1] + cost
            )
    return dp[-1][-1]


In [8]:
def phoneme_error_rate(pred_seq, true_seq):
    if len(true_seq) == 0:
        return 0.0
    return edit_distance(pred_seq, true_seq) / len(true_seq)


In [9]:
def validate_per(dataloader):
    model.eval()
    total_per = 0
    n = 0

    with torch.no_grad():
        for X, Y, x_lens, y_lens in dataloader:
            X = X.to(device)
            x_lens = x_lens.to(device)

            logp, out_lens = model(X, x_lens)
            preds = greedy_decode_ctc(logp.cpu(), out_lens.cpu())

            idx = 0
            targets = []
            for L in y_lens:
                targets.append(Y[idx:idx+L].tolist())
                idx += L

            for p, t in zip(preds, targets):
                per = phoneme_error_rate(p, t)
                total_per += per
                n += 1

    return total_per / n


In [10]:
def train_step(dataloader):
    model.train()
    total_loss = 0

    for X, Y, x_lens, y_lens in dataloader:
        X = X.to(device)
        Y = Y.to(device)
        x_lens = x_lens.to(device)

        logp, out_lens = model(X, x_lens)

        loss = ctc_loss(
            logp,
            Y,
            out_lens,
            y_lens.to(device)
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)


In [11]:
class EarlyStopping:
    def __init__(self, patience=10):
        self.patience = patience
        self.counter = 0
        self.best = float("inf")
        self.should_stop = False

    def step(self, metric):
        if metric < self.best:
            self.best = metric
            self.counter = 0
        else:
            self.counter += 1

        if self.counter >= self.patience:
            self.should_stop = True


In [12]:
device = "cuda" if torch.cuda.is_available() else "cpu"

train_ds = Brain2TextDataset(TRAIN_H5, max_trials=2000)
val_ds   = Brain2TextDataset(VAL_H5,   max_trials=300)

train_dl = DataLoader(train_ds, batch_size=8, shuffle=True, collate_fn=ctc_collate)
val_dl   = DataLoader(val_ds,   batch_size=8, shuffle=False, collate_fn=ctc_collate)

VOCAB_SIZE = 96  # change to your phoneme vocab size
model = ConvBiGRU(feat_dim=512, hidden=256, vocab_size=VOCAB_SIZE).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)

early_stop = EarlyStopping(patience=15)

for epoch in range(1, 200):
    train_loss = train_step(train_dl)
    val_per = validate_per(val_dl)

    print(f"Epoch {epoch:02d} | Train {train_loss:.3f} | Val PER {val_per:.3f}")

    early_stop.step(val_per)
    if early_stop.should_stop:
        print("Early stopping!")
        break



Epoch 01 | Train 32.154 | Val PER 0.965
Epoch 02 | Train 18.354 | Val PER 0.965
Epoch 03 | Train 16.170 | Val PER 1.000
Epoch 04 | Train 3.463 | Val PER 1.000
Epoch 05 | Train 3.367 | Val PER 1.000
Epoch 06 | Train 3.314 | Val PER 1.000
Epoch 07 | Train 3.289 | Val PER 1.000
Epoch 08 | Train 3.267 | Val PER 1.000
Epoch 09 | Train 3.254 | Val PER 1.000
Epoch 10 | Train 3.247 | Val PER 1.000
Epoch 11 | Train 3.232 | Val PER 1.000
Epoch 12 | Train 3.219 | Val PER 1.000
Epoch 13 | Train 3.214 | Val PER 1.000
Epoch 14 | Train 3.219 | Val PER 1.000
Epoch 15 | Train 3.210 | Val PER 1.000
Epoch 16 | Train 3.206 | Val PER 0.960
Epoch 17 | Train 3.197 | Val PER 0.960
Epoch 18 | Train 3.193 | Val PER 0.960
Epoch 19 | Train 3.198 | Val PER 0.960
Epoch 20 | Train 3.198 | Val PER 0.960
Epoch 21 | Train 3.192 | Val PER 0.960
Epoch 22 | Train 3.189 | Val PER 0.960
Epoch 23 | Train 3.179 | Val PER 0.960
Epoch 24 | Train 3.181 | Val PER 0.960
Epoch 25 | Train 3.188 | Val PER 0.960
Epoch 26 | Train 3.182