In [1]:

# Brain-to-Text CTC Training - Full Dataset + PER

import os
import h5py
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import editdistance
import warnings
warnings.filterwarnings("ignore")

# --------------------------
# CONFIG
# --------------------------
BASE_PATH = "/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final"
BATCH_SIZE = 4  # Reduced for full dataset to avoid OOM
BATCH_SIZE_VAL = 2
NUM_WORKERS = 2
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EPOCHS = 30
LEARNING_RATE = 5e-4
WEIGHT_DECAY = 0.01
SAVE_PATH = "best_model_full_per.pt"

# Model architecture
D_MODEL = 256
NHEAD = 8
NUM_LAYERS = 4
DROPOUT = 0.1

# NO downsampling, NO max sequence length limit
DOWNSAMPLE_FACTOR = 1  # No downsampling
MAX_SEQ_LEN = None  # No length cap

# Use ALL data - no subset limits
MAX_TRAIN_TRIALS = None  # Use all training trials
MAX_VAL_TRIALS = None  # Use all validation trials

# --------------------------
# Phoneme Vocab 
# --------------------------
LOGIT_TO_PHONEME = [
    "BLANK",  # index 0
    "AA",
    "AE",
    "AH",
    "AO",
    "AW",
    "AY",
    "B",
    "CH",
    "D",
    "DH",
    "EH",
    "ER",
    "EY",
    "F",
    "G",
    "HH",
    "IH",
    "IY",
    "JH",
    "K",
    "L",
    "M",
    "N",
    "NG",
    "OW",
    "OY",
    "P",
    "R",
    "S",
    "SH",
    "T",
    "TH",
    "UH",
    "UW",
    "V",
    "W",
    "Y",
    "Z",
    "ZH",
    " | ",  # index 40 - silence/word boundary
]

VOCAB_SIZE = len(LOGIT_TO_PHONEME)  # 41 classes (including BLANK)
phoneme_to_idx = {p: i for i, p in enumerate(LOGIT_TO_PHONEME)}
idx_to_phoneme = {i: p for i, p in enumerate(LOGIT_TO_PHONEME)}

print(f"VOCAB_SIZE = {VOCAB_SIZE} (BLANK + 39 phonemes + silence)")
print(f"Phoneme set: {LOGIT_TO_PHONEME[1:]}")  # Skip BLANK for display

# --------------------------
# Full Dataset Loading Functions
# --------------------------
def get_all_hdf5_files(base_path, split_name):
    """Recursively find all HDF5 files for a given split"""
    all_files = []
    for date_folder in sorted(os.listdir(base_path)):
        if date_folder.startswith("t15."):
            fp = os.path.join(base_path, date_folder, f"data_{split_name}.hdf5")
            if os.path.exists(fp):
                all_files.append(fp)
    return all_files

def load_full_dataset_info(h5_paths):
    """
    Load metadata from all HDF5 files to determine total dataset size
    Returns: list of (file_path, key) tuples for all valid trials
    """
    all_trials = []
    total_trials = 0
    
    for h5_path in h5_paths:
        if not os.path.exists(h5_path):
            continue
            
        with h5py.File(h5_path, "r") as f:
            file_trials = 0
            for k in f.keys():
                grp = f[k]
                # Check if trial has seq_class_ids (phoneme labels)
                if "seq_class_ids" in grp:
                    all_trials.append((h5_path, k))
                    file_trials += 1
            
            print(f"  {os.path.basename(h5_path)}: {file_trials} trials")
            total_trials += file_trials
    
    print(f"\nTotal trials found: {total_trials}")
    return all_trials

# --------------------------
# Dataset (Full Data, No Downsampling)
# --------------------------
class Brain2TextDatasetFull(Dataset):
    def __init__(self, trial_list, max_trials=None):
        """
        Args:
            trial_list: list of (file_path, key) tuples
            max_trials: optional limit on number of trials (None = use all)
        """
        self.trials = trial_list if max_trials is None else trial_list[:max_trials]
        self.file_cache = {}  # Cache open file handles
        print(f"Dataset initialized with {len(self.trials)} trials")
    
    def __len__(self):
        return len(self.trials)
    
    def __getitem__(self, idx):
        file_path, key = self.trials[idx]
        
        # Open file (or use cached handle)
        if file_path not in self.file_cache:
            self.file_cache[file_path] = h5py.File(file_path, "r")
        
        f = self.file_cache[file_path]
        grp = f[key]
        
        # Load neural features - NO downsampling, NO length cap
        x = grp["input_features"][()]  # (T, 512)
        x = torch.tensor(x, dtype=torch.float32)
        
        # Load phoneme target (seq_class_ids)
        # These are ALREADY phoneme indices from the dataset
        if "seq_class_ids" in grp:
            seq_class_ids = grp["seq_class_ids"][()]
            seq_len = grp.attrs.get("seq_len", len(seq_class_ids))
            
            # Extract valid phoneme sequence (up to seq_len, remove padding)
            phoneme_seq = seq_class_ids[:seq_len]
            
            # Convert to tensor - these are already phoneme indices
            y = torch.tensor(phoneme_seq, dtype=torch.long)
        else:
            # Fallback: empty sequence
            y = torch.tensor([], dtype=torch.long)
        
        return x, y
    
    def __del__(self):
        # Close all cached file handles
        for f in self.file_cache.values():
            try:
                f.close()
            except:
                pass

def ctc_collate(batch):
    """Collate function for CTC loss"""
    xs, ys = zip(*batch)
    x_lens = torch.tensor([len(x) for x in xs], dtype=torch.long)
    y_lens = torch.tensor([len(y) for y in ys], dtype=torch.long)
    X = nn.utils.rnn.pad_sequence(xs, batch_first=True)
    Y = torch.cat(ys)
    return X, Y, x_lens, y_lens

# --------------------------
# Model
# --------------------------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=10000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0))
    
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class BrainTransformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Linear(512, D_MODEL),
            nn.LayerNorm(D_MODEL),
            nn.GELU(),
            nn.Dropout(DROPOUT)
        )
        self.pos = PositionalEncoding(D_MODEL)
        layer = nn.TransformerEncoderLayer(
            d_model=D_MODEL, nhead=NHEAD, dim_feedforward=D_MODEL*4,
            dropout=DROPOUT, activation='gelu', batch_first=True, norm_first=True
        )
        self.transformer = nn.TransformerEncoder(layer, NUM_LAYERS)
        self.out = nn.Linear(D_MODEL, VOCAB_SIZE)
    
    def forward(self, x, mask=None):
        x = self.proj(x)
        x = self.pos(x)
        x = self.transformer(x, src_key_padding_mask=mask)
        return self.out(x)

# --------------------------
# Phoneme Error Rate (PER) Calculation
# --------------------------
def ctc_greedy_decode(logits):
    """
    CTC greedy decoding: argmax -> collapse repeats -> remove blanks
    Matches the reference implementation
    
    Args:
        logits: (batch, time, vocab_size) tensor
    Returns:
        List of phoneme index sequences (one per batch item)
    """
    pred_ids = logits.argmax(-1).cpu().numpy()  # (batch, time)
    
    decoded_sequences = []
    for seq in pred_ids:
        # Collapse consecutive duplicates
        collapsed = []
        prev = -1
        for token_id in seq:
            if token_id != prev:
                collapsed.append(int(token_id))
            prev = token_id
        
        # Remove blanks (index 0)
        no_blanks = [tok for tok in collapsed if tok != 0]
        decoded_sequences.append(no_blanks)
    
    return decoded_sequences

def calculate_per(predictions, ground_truths):
    """
    Calculate Phoneme Error Rate using edit distance
    Matches reference implementation
    
    Args:
        predictions: list of predicted phoneme index sequences
        ground_truths: list of ground truth phoneme index sequences
    Returns:
        per: Phoneme Error Rate as percentage
        total_edit_distance: sum of edit distances
        total_length: sum of ground truth lengths
    """
    total_edit_distance = 0
    total_length = 0
    
    for pred_seq, true_seq in zip(predictions, ground_truths):
        # Calculate edit distance between sequences
        edit_dist = editdistance.eval(true_seq, pred_seq)
        total_edit_distance += edit_dist
        total_length += len(true_seq)
    
    per = (total_edit_distance / total_length * 100) if total_length > 0 else 100.0
    
    return per, total_edit_distance, total_length

# --------------------------
# Validation (Full Dataset)
# --------------------------
def validate(model, dl):
    """Validate on full validation set"""
    model.eval()
    all_pred_seqs = []
    all_true_seqs = []
    
    print("Starting validation...")
    with torch.no_grad():
        for batch_idx, (X, Y, x_len, y_len) in enumerate(dl):
            X = X.to(DEVICE)
            mask = torch.arange(X.size(1), device=DEVICE)[None, :] >= x_len.to(DEVICE)[:, None]
            
            logits = model(X, mask)  # (batch, time, vocab_size)
            
            # CTC greedy decode
            pred_seqs = ctc_greedy_decode(logits)
            
            # Extract ground truth sequences
            start = 0
            true_seqs = []
            for L in y_len:
                seq = Y[start:start+L].cpu().numpy().tolist()
                true_seqs.append(seq)
                start += L
            
            all_pred_seqs.extend(pred_seqs)
            all_true_seqs.extend(true_seqs)
            
            if (batch_idx + 1) % 50 == 0:
                print(f"  Validated {batch_idx + 1}/{len(dl)} batches")
            
            del logits, X, mask
            torch.cuda.empty_cache()
    
    # Calculate PER
    per, total_edit_dist, total_len = calculate_per(all_pred_seqs, all_true_seqs)
    
    print(f"\nValidation Results:")
    print(f"  Total edit distance: {total_edit_dist}")
    print(f"  Total sequence length: {total_len}")
    print(f"  PER: {per:.2f}%")
    
    # Get sample predictions for display (first 5)
    sample_preds = []
    sample_truths = []
    for i in range(min(5, len(all_pred_seqs))):
        pred_phonemes = [idx_to_phoneme[idx] for idx in all_pred_seqs[i] if idx < VOCAB_SIZE]
        true_phonemes = [idx_to_phoneme[idx] for idx in all_true_seqs[i] if idx < VOCAB_SIZE]
        sample_preds.append(' '.join(pred_phonemes))
        sample_truths.append(' '.join(true_phonemes))
    
    return per, sample_preds, sample_truths

# --------------------------
# Main Training Loop
# --------------------------
print("="*60)
print("LOADING FULL DATASET...")
print("="*60)

# Get all HDF5 files
train_files = get_all_hdf5_files(BASE_PATH, "train")
val_files = get_all_hdf5_files(BASE_PATH, "val")

print(f"\nFound {len(train_files)} training files")
print(f"Found {len(val_files)} validation files")

# Load trial information
print("\nLoading training trials...")
train_trials = load_full_dataset_info(train_files)

print("\nLoading validation trials...")
val_trials = load_full_dataset_info(val_files)

# Create datasets (full data)
train_ds = Brain2TextDatasetFull(train_trials, max_trials=MAX_TRAIN_TRIALS)
val_ds = Brain2TextDatasetFull(val_trials, max_trials=MAX_VAL_TRIALS)

# Create dataloaders
train_dl = DataLoader(
    train_ds, batch_size=BATCH_SIZE, shuffle=True,
    collate_fn=ctc_collate, num_workers=NUM_WORKERS, pin_memory=True
)
val_dl = DataLoader(
    val_ds, batch_size=BATCH_SIZE_VAL, shuffle=False,
    collate_fn=ctc_collate, num_workers=NUM_WORKERS, pin_memory=True
)

print(f"\nDataLoader created:")
print(f"  Train batches: {len(train_dl)}")
print(f"  Val batches: {len(val_dl)}")

# --------------------------
# Model + Optimizer
# --------------------------
model = BrainTransformer().to(DEVICE)
print(f"\nModel parameters: {sum(p.numel() for p in model.parameters()):,}")

optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=LEARNING_RATE, epochs=EPOCHS,
    steps_per_epoch=len(train_dl), pct_start=0.1, anneal_strategy='cos'
)
ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)

best_per = 100.0

print("\n" + "="*60)
print("STARTING TRAINING ON FULL DATASET")
print("="*60)

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    
    for batch_idx, (X, Y, x_len, y_len) in enumerate(train_dl):
        X = X.to(DEVICE)
        Y = Y.to(DEVICE)
        x_len = x_len.to(DEVICE)
        y_len = y_len.to(DEVICE)
        
        mask = torch.arange(X.size(1), device=DEVICE)[None,:] >= x_len[:,None]
        
        logits = model(X, mask)
        log_probs = logits.log_softmax(-1).transpose(0, 1)  # (T, B, C)
        
        loss = ctc_loss(log_probs, Y, x_len, y_len)
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        
        total_loss += loss.item()
        
        if (batch_idx + 1) % 100 == 0:
            print(f"  Batch {batch_idx + 1}/{len(train_dl)} | Loss: {loss.item():.4f}")
    
    avg_loss = total_loss / len(train_dl)
    print(f"\nEPOCH {epoch+1}/{EPOCHS} | Avg Loss: {avg_loss:.4f}")
    
    # Validate every 3 epochs or in last 5 epochs
    if (epoch+1) % 3 == 0 or epoch >= EPOCHS - 5:
        print("\nRunning full validation...")
        val_per, sample_preds, sample_truths = validate(model, val_dl)
        
        if val_per < best_per:
            best_per = val_per
            torch.save(model.state_dict(), SAVE_PATH)
            print("✓ NEW BEST MODEL SAVED!")
        
        print("\nSample predictions (phonemes):")
        for i, (truth, pred) in enumerate(zip(sample_truths[:3], sample_preds[:3])):
            print(f"Sample {i+1}:")
            print(f"  Truth: {truth}")
            print(f"  Pred:  {pred}")
        print("-"*60)

print(f"\n{'='*60}")
print(f"Training complete!")
print(f"Best PER: {best_per:.2f}%")
print(f"Model saved: {SAVE_PATH}")
print(f"{'='*60}")

VOCAB_SIZE = 41 (BLANK + 39 phonemes + silence)
Phoneme set: ['AA', 'AE', 'AH', 'AO', 'AW', 'AY', 'B', 'CH', 'D', 'DH', 'EH', 'ER', 'EY', 'F', 'G', 'HH', 'IH', 'IY', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OY', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UW', 'V', 'W', 'Y', 'Z', 'ZH', ' | ']
LOADING FULL DATASET...

Found 45 training files
Found 41 validation files

Loading training trials...
  data_train.hdf5: 288 trials
  data_train.hdf5: 348 trials
  data_train.hdf5: 197 trials
  data_train.hdf5: 278 trials
  data_train.hdf5: 88 trials
  data_train.hdf5: 150 trials
  data_train.hdf5: 297 trials
  data_train.hdf5: 322 trials
  data_train.hdf5: 245 trials
  data_train.hdf5: 153 trials
  data_train.hdf5: 218 trials
  data_train.hdf5: 174 trials
  data_train.hdf5: 284 trials
  data_train.hdf5: 155 trials
  data_train.hdf5: 239 trials
  data_train.hdf5: 98 trials
  data_train.hdf5: 134 trials
  data_train.hdf5: 149 trials
  data_train.hdf5: 80 trials
  data_train.hdf5: 100 trials
  data_train.