In [5]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/brain-to-text-25/data_link.txt
/kaggle/input/brain-to-text-25/t15_pretrained_rnn_baseline/t15_pretrained_rnn_baseline/training_log
/kaggle/input/brain-to-text-25/t15_pretrained_rnn_baseline/t15_pretrained_rnn_baseline/checkpoint/args.yaml
/kaggle/input/brain-to-text-25/t15_pretrained_rnn_baseline/t15_pretrained_rnn_baseline/checkpoint/best_checkpoint
/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2025.03.14/data_test.hdf5
/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2025.03.14/data_train.hdf5
/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2025.03.14/data_val.hdf5
/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2023.08.11/data_train.hdf5
/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2023.11.19/data_test.hdf5
/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2023.11.19/data_train.hdf5
/kaggle/input/brain-to-te

In [6]:
# ============================================================
# CELL 1: SETUP & ENVIRONMENT CHECK
# ============================================================

import os
import sys
import h5py
import numpy as np
import pandas as pd
from collections import Counter, defaultdict
import math
import pickle
import itertools
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from scipy.ndimage import gaussian_filter1d
from scipy.signal import butter, filtfilt, iirnotch

# Check environment
print(" ENVIRONMENT CHECK")
print("="*60)
print(f"\nPython: {sys.version}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
print("="*60)

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f" Using device: {device}")

BASE_DIR = '/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final'
PHONEMES = [
    'BLANK','AA','AE','AH','AO','AW','AY','B','CH','D','DH',
    'EH','ER','EY','F','G','HH','IH','IY','JH','K','L','M',
    'N','NG','OW','OY','P','R','S','SH','T','TH','UH','UW',
    'V','W','Y','Z','ZH','|'
]
NUM_CLASSES = len(PHONEMES)
print(f"Phoneme classes: {NUM_CLASSES}")

 ENVIRONMENT CHECK

Python: 3.11.13 (main, Jun  4 2025, 08:57:29) [GCC 11.4.0]
PyTorch: 2.6.0+cu124
CUDA: True
GPU: Tesla P100-PCIE-16GB
 Using device: cuda
Phoneme classes: 41


In [7]:
# ============================================================
# CELL 2: DATA EXPLORATION
# ============================================================
import pickle
print(" DATA EXPLORATION")
print("="*60)

sessions = sorted(os.listdir(BASE_DIR))
print(f"\n Sessions: {len(sessions)}")
print(f"First: {sessions[0]}, Last: {sessions[-1]}")
splits = ['train', 'val', 'test']
data_counts = {s: 0 for s in splits}

for session in sessions:
    for split in splits:
        path = os.path.join(BASE_DIR, session, f'data_{split}.hdf5')
        if os.path.exists(path):
            with h5py.File(path, 'r') as f:
                data_counts[split] += len(f.keys())

print(f"\n Data split:")
for split, count in data_counts.items():
    print(f"{split.upper():5s}: {count:5d} samples")

 DATA EXPLORATION

 Sessions: 45
First: t15.2023.08.11, Last: t15.2025.04.13

 Data split:
TRAIN:  8072 samples
VAL  :  1426 samples
TEST :  1450 samples


In [8]:
# ============================================================
# CELL 3: NEURAL PREPROCESSOR
# ============================================================

class NeuralPreprocessor:
    """Preprocess neural signals"""
    
    def __init__(self, sampling_rate=50, smooth_sigma=1.0, clip_std=5.0):
        self.sampling_rate = sampling_rate
        self.smooth_sigma = smooth_sigma
        self.clip_std = clip_std
        self.global_mean = None
        self.global_std = None
    
    def compute_normalization_stats(self, data_list, n_steps_list, sample_size=2000):
        print(f"\n Computing stats from {min(sample_size, len(data_list))} samples...")
        
        all_data = [data_list[i][:n_steps_list[i]] 
                    for i in range(min(sample_size, len(data_list)))]
        concat = np.concatenate(all_data, axis=0)
        self.global_mean = concat.mean(axis=0)
        self.global_std = concat.std(axis=0) + 1e-8
        print(f" Stats computed! Shape: {self.global_mean.shape}")
    
    def preprocess(self, neural_data, n_timesteps):
        data = neural_data[:n_timesteps].copy()
        
        local_mean = data.mean(axis=0)
        local_std = data.std(axis=0) + 1e-8
        lower = local_mean - self.clip_std * local_std
        upper = local_mean + self.clip_std * local_std
        data = np.clip(data, lower, upper)
        
        if self.smooth_sigma > 0:
            for i in range(data.shape[1]):
                data[:, i] = gaussian_filter1d(data[:, i], sigma=self.smooth_sigma)
        
        if self.global_mean is not None:
            data = (data - self.global_mean) / self.global_std
        
        return data
    
    def save(self, path):
        with open(path, 'wb') as f:
            pickle.dump({
                'global_mean': self.global_mean,
                'global_std': self.global_std
            }, f)
        print(f"Saved to {path}")
    
    def load(self, path):
        with open(path, 'rb') as f:
            data = pickle.load(f)
        self.global_mean = data['global_mean']
        self.global_std = data['global_std']
        print(f"Loaded from {path}")

print(" NeuralPreprocessor class defined")

 NeuralPreprocessor class defined


In [9]:
# ============================================================
# CELL 4: LOAD DATA & COMPUTE PREPROCESSING STATS
# ============================================================

def load_h5py_file(file_path):
    data = {
        'neural_features': [], 'n_time_steps': [],
        'seq_class_ids': [], 'seq_len': [],
        'transcriptions': [], 'sentence_label': [],
        'session': [], 'block_num': [], 'trial_num': []
    }
    
    with h5py.File(file_path, 'r') as f:
        for key in f.keys():
            g = f[key]
            data['neural_features'].append(g['input_features'][:])
            data['n_time_steps'].append(g.attrs['n_time_steps'])
            data['seq_class_ids'].append(g['seq_class_ids'][:] if 'seq_class_ids' in g else None)
            data['seq_len'].append(g.attrs['seq_len'] if 'seq_len' in g.attrs else None)
            data['transcriptions'].append(g['transcription'][:] if 'transcription' in g else None)
            data['sentence_label'].append(g.attrs['sentence_label'][:] if 'sentence_label' in g.attrs else None)
            data['session'].append(g.attrs['session'])
            data['block_num'].append(g.attrs['block_num'])
            data['trial_num'].append(g.attrs['trial_num'])
    
    return data

def load_sessions(split):
    data_all = {k: [] for k in [
        'neural_features','n_time_steps','seq_class_ids','seq_len',
        'transcriptions','sentence_label','session','block_num','trial_num'
    ]}
    
    for session in tqdm(sessions, desc=f"Loading {split}"):
        path = os.path.join(BASE_DIR, session, f'data_{split}.hdf5')
        if not os.path.exists(path):
            continue
        
        d = load_h5py_file(path)
        for k in data_all.keys():
            data_all[k].extend(d[k])
    
    return data_all
# Load data
train_data = load_sessions('train')
val_data = load_sessions('val')
test_data = load_sessions('test')

print(f"\n Loaded:")
print(f"Train: {len(train_data['neural_features'])}")
print(f"Val: {len(val_data['neural_features'])}")
print(f"Test: {len(test_data['neural_features'])}")

# Create & fit preprocessor
preproc = NeuralPreprocessor()
preproc.compute_normalization_stats(
    train_data['neural_features'],
    train_data['n_time_steps'],
    sample_size=2000
)

# Save
preproc.save('preprocessor.pkl')

Loading train: 100%|██████████| 45/45 [02:44<00:00,  3.65s/it]
Loading val: 100%|██████████| 45/45 [00:33<00:00,  1.34it/s]
Loading test: 100%|██████████| 45/45 [00:37<00:00,  1.20it/s]



 Loaded:
Train: 8072
Val: 1426
Test: 1450

 Computing stats from 2000 samples...
 Stats computed! Shape: (512,)
Saved to preprocessor.pkl


In [10]:
# ============================================================
# CELL 5: PHONEME TO TEXT DECODER
# ============================================================

# Build CMU phoneme dictionary
import nltk
import re
nltk.download('cmudict', quiet=True)
from nltk.corpus import cmudict
cmu = cmudict.dict()
phon_to_words = defaultdict(list)
for word, pronunciations in cmu.items():
    word_clean = word.lower().strip()
    if not word_clean.replace("'", "").isalpha():
        continue
    for pron in pronunciations:
        normalized = tuple([p.rstrip('012') for p in pron])
        phon_to_words[normalized].append(word_clean)

print(f"Phoneme dictionary: {len(phon_to_words)} entries")

# Build unigram + bigram LM
unigram_counts = Counter()
bigram_counts = Counter()

for sent_label in train_data['sentence_label']:
    if sent_label is None:
        continue
    if isinstance(sent_label, bytes):
        sent_label = sent_label.decode('utf-8')
    words = re.findall(r"[a-z']+", sent_label.lower())
    if not words:
        continue
    unigram_counts.update(words)
    for i in range(len(words) - 1):
        bigram = (words[i], words[i+1])
        bigram_counts[bigram] += 1

total_words = sum(unigram_counts.values())
print(f" Language model: {len(unigram_counts)} words, {len(bigram_counts)} bigrams")

# Decoder function
def phonemes_to_text(phoneme_list):
    """
    Convert phoneme sequence to text
    """
    # Split by silence token '|'
    segments = []
    current = []
    for p in phoneme_list:
        if p == '|':
            if current:
                segments.append(current)
                current = []
        else:
            current.append(p)
    if current:
        segments.append(current)
    if not segments:
        return ''
    
    # Decode segments
    words = []
    prev_word = None
    for seg in segments:
        phon_tuple = tuple(seg)
        candidates = phon_to_words.get(phon_tuple, [])
        if not candidates:
            for trim in range(1, min(3, len(seg))):
                phon_tuple_trim = tuple(seg[:-trim])
                candidates = phon_to_words.get(phon_tuple_trim, [])
                if candidates:
                    break
        if not candidates:
            continue
        if prev_word is None:
            best = max(candidates[:10], key=lambda w: unigram_counts.get(w, 0))
        else:
            best = None
            best_score = -1
            
            for word in candidates[:10]:
                bigram_count = bigram_counts.get((prev_word, word), 0)
                unigram_count = unigram_counts.get(word, 0)
                score = bigram_count * 10 + unigram_count  # Prefer bigram
                
                if score > best_score:
                    best_score = score
                    best = word
            
            if best is None:
                best = candidates[0]
        words.append(best)
        prev_word = best
    return ' '.join(words) if words else ''

print("Decoder ready!")

Phoneme dictionary: 113375 entries
 Language model: 3745 words, 18263 bigrams
Decoder ready!


In [11]:
# ============================================================
# CELL 6: BASELINE MODEL
# ============================================================
import re
class BaselineGRU(nn.Module):
    def __init__(self, input_dim=512, window_size=14, hidden_dim=768, num_layers=5, num_classes=41, num_days=45):
        super().__init__()
        
        self.window_size = window_size
        stacked_dim = input_dim * window_size 
        self.day_weights = nn.ParameterList([
            nn.Parameter(torch.randn(input_dim, input_dim))
            for _ in range(num_days)
        ])
        self.day_biases = nn.ParameterList([
            nn.Parameter(torch.randn(1, input_dim))
            for _ in range(num_days)
        ])
        
        self.h0 = nn.Parameter(torch.randn(1, 1, hidden_dim))
        self.gru = nn.GRU(
            input_size=stacked_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True
        )
        self.out = nn.Linear(hidden_dim, num_classes)
    
    def create_windows(self, x):
        B, T, C = x.shape
        if T < self.window_size:
            pad = self.window_size - T
            x = F.pad(x, (0, 0, 0, pad))
            T = self.window_size
        x_unfold = x.unfold(1, self.window_size, 1)  # (B, T', C, window_size)
        x_unfold = x_unfold.permute(0, 1, 3, 2)  # (B, T', window_size, C)
        x_windows = x_unfold.reshape(B, -1, C * self.window_size)  # (B, T', 7168)
        
        return x_windows
    
    def forward(self, x, day_idx=0):
        
        day_idx = min(day_idx, len(self.day_weights) - 1)
        day_weight = self.day_weights[day_idx]
        day_bias = self.day_biases[day_idx]
        
        x = torch.matmul(x, day_weight) + day_bias
        x = self.create_windows(x)
        B = x.size(0)
        h = self.h0.expand(self.gru.num_layers, B, -1).contiguous()
        out, _ = self.gru(x, h)
        logits = self.out(out)
        return F.log_softmax(logits, dim=-1)

print("Creating baseline model (correct)...")
baseline_model = BaselineGRU().to(device)

print("Loading weights...")

# Load checkpoint
BASELINE_DIR = '/kaggle/input/brain-to-text-25/t15_pretrained_rnn_baseline/t15_pretrained_rnn_baseline'
checkpoint = torch.load(
    os.path.join(BASELINE_DIR, 'checkpoint', 'best_checkpoint'),
    map_location=device,
    weights_only=False)

state_dict = checkpoint['model_state_dict']
new_state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
baseline_model.load_state_dict(new_state_dict, strict=True)
baseline_model.eval()

print(" Baseline loaded!")
print(f"Params: {sum(p.numel() for p in baseline_model.parameters()):,}")

# Test
print("\n Testing...")
x_test = preproc.preprocess(val_data['neural_features'][0], val_data['n_time_steps'][0])
x_test = torch.tensor(x_test).unsqueeze(0).float().to(device)

with torch.no_grad():
    out_test = baseline_model(x_test, day_idx=0)
    preds = out_test.squeeze(0).argmax(-1).cpu().numpy()
    phonemes = [PHONEMES[i] for i, _ in itertools.groupby(preds) if i != 0]

print(f"Phonemes ({len(phonemes)}): {phonemes[:20]}")

from collections import Counter
silence_pct = Counter(phonemes).get('|', 0) / max(1, len(phonemes)) * 100
print(f"Silence: {silence_pct:.1f}%")
print(" Ready!")

Creating baseline model (correct)...
Loading weights...
 Baseline loaded!
Params: 44,315,177

 Testing...
Phonemes (56): ['Y', 'UW', '|', 'K', 'AE', 'D', 'AE', 'D', '|', 'DH', 'IY', 'S', '|', 'HH', 'IY', '|', 'DH', 'EY', '|', 'K']
Silence: 26.8%
 Ready!


In [12]:
# ============================================================
# CELL 7: Session Mapping
# ============================================================
sessions_sorted = sorted(set(test_data['session']))
session_to_day = {sess: i for i, sess in enumerate(sessions_sorted)}
print(f"   {len(session_to_day)} unique sessions")

# Inference function
def decode_trial_baseline(neural_data, n_timesteps, session):
    day_idx = session_to_day.get(session, 0)
    x = preproc.preprocess(neural_data, n_timesteps)
    x = torch.tensor(x, dtype=torch.float32).unsqueeze(0).to(device)
    with torch.no_grad():
        logits = baseline_model(x, day_idx=day_idx).squeeze(0).cpu().numpy()
    preds = logits.argmax(-1)
    phonemes = [PHONEMES[i] for i, _ in itertools.groupby(preds) if i != 0]
    text = phonemes_to_text(phonemes)
    
    return text

   41 unique sessions


In [13]:
# ============================================================
# CELL 7.1: DATASET & DATALOADER
# ============================================================

class BrainDataset(Dataset):
    def __init__(self, data, preprocessor):
        self.data = data
        self.preproc = preprocessor
    def __len__(self):
        return len(self.data['neural_features'])
    def __getitem__(self, idx):
        x = self.preproc.preprocess(
            self.data['neural_features'][idx],
            self.data['n_time_steps'][idx]
        )
        
        # Get phoneme labels
        y = self.data['seq_class_ids'][idx]
        if y is None:
            y = []
        else:
            y = y[:self.data['seq_len'][idx]]
        
        return (
            torch.tensor(x, dtype=torch.float32),
            torch.tensor(y, dtype=torch.long)
        )

def collate_fn(batch):
    x_list, y_list = zip(*batch)
    x_lens = [len(x) for x in x_list]
    y_lens = [len(y) for y in y_list]
    max_x = max(x_lens)
    max_y = max(y_lens) if y_lens else 1
    x_pad = torch.zeros(len(batch), max_x, x_list[0].shape[1])
    y_pad = torch.zeros(len(batch), max_y, dtype=torch.long)
    
    for i, (x, y) in enumerate(batch):
        x_pad[i, :x_lens[i], :] = x
        if y_lens[i] > 0:
            y_pad[i, :y_lens[i]] = y
    
    return (x_pad, y_pad, torch.tensor(x_lens, dtype=torch.long), torch.tensor(y_lens, dtype=torch.long))

train_dataset = BrainDataset(train_data, preproc)
val_dataset = BrainDataset(val_data, preproc)

BATCH_SIZE = 6

train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    collate_fn=collate_fn
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

Train batches: 1346
Val batches: 238


In [14]:
# ============================================================
# CELL 8: FINE-TUNE BASELINE GRU MODEL
# ============================================================
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm
import numpy as np

# Check model structure
print(f"Model loaded: {sum(p.numel() for p in baseline_model.parameters()):,} params")
baseline_model.train()

# Fine-tuning hyperparameters
FINETUNE_EPOCHS = 15
LEARNING_RATE = 1e-5  
WARMUP_EPOCHS = 0  
GRADIENT_CLIP = 1.0

print(f"Epochs: {FINETUNE_EPOCHS}")
print(f"Learning rate: {LEARNING_RATE}")
print(f"Gradient clip: {GRADIENT_CLIP}")

FREEZE_LAYERS = True
FREEZE_DAY_WEIGHTS = False 

if FREEZE_LAYERS:
    baseline_model.h0.requires_grad = False
    
    # Optionally freeze day weights
    if FREEZE_DAY_WEIGHTS:
        for day_weight in baseline_model.day_weights:
            day_weight.requires_grad = False
        for day_bias in baseline_model.day_biases:
            day_bias.requires_grad = False
        print("Frozen: h0, day_weights, day_biases")
    else:
        print("Frozen: h0 only")
        trainable_params = sum(p.numel() for p in baseline_model.parameters() if p.requires_grad)
    print(f"   Trainable params: {trainable_params:,}")

optimizer = Adam(
    filter(lambda p: p.requires_grad, baseline_model.parameters()),
    lr=LEARNING_RATE,
    weight_decay=1e-5
)

scheduler = CosineAnnealingLR(
    optimizer,
    T_max=FINETUNE_EPOCHS * len(train_loader),
    eta_min=LEARNING_RATE / 10
)

print(f"Optimizer: Adam (lr={LEARNING_RATE})")
print(f"Scheduler: CosineAnnealingLR")

# Loss function
criterion = nn.CTCLoss(blank=0, zero_infinity=True)

best_val_loss = float('inf')
patience = 2
patience_counter = 0

# Map sessions to days for train/val
train_sessions = []
for i in range(len(train_data['session'])):
    train_sessions.append(train_data['session'][i])

val_sessions = []
for i in range(len(val_data['session'])):
    val_sessions.append(val_data['session'][i])

for epoch in range(FINETUNE_EPOCHS):
    print(f"\n{'='*60}")
    print(f"EPOCH {epoch+1}/{FINETUNE_EPOCHS}")
    print(f"{'='*60}")
    
    # ========================================
    # TRAINING
    # ========================================
    baseline_model.train()
    train_loss = 0
    train_steps = 0
    
    progress_bar = tqdm(train_loader, desc=f"Training")
    
    for batch_idx, batch in enumerate(progress_bar):
        x, y, x_lens, y_lens = batch
        x = x.to(device)
        y = y.to(device)
        batch_size = x.size(0)
        start_idx = batch_idx * train_loader.batch_size
        batch_sessions = train_sessions[start_idx:start_idx + batch_size]
        day_idx = session_to_day.get(batch_sessions[0], 0)
        optimizer.zero_grad()
        log_probs = baseline_model(x, day_idx=day_idx) 
        log_probs = log_probs.permute(1, 0, 2)  
        input_lengths = torch.tensor([log_probs.size(0)] * batch_size, dtype=torch.long)
        loss = criterion(log_probs, y, input_lengths, y_lens)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(baseline_model.parameters(), GRADIENT_CLIP)
        optimizer.step()
        scheduler.step()
        train_loss += loss.item()
        train_steps += 1
        progress_bar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'avg_loss': f'{train_loss/train_steps:.4f}'
        })
    
    avg_train_loss = train_loss / train_steps
    
    # ========================================
    # VALIDATION
    # ========================================
    baseline_model.eval()
    val_loss = 0
    val_steps = 0
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(val_loader, desc="Validation")):
            x, y, x_lens, y_lens = batch
            x = x.to(device)
            y = y.to(device)
            batch_size = x.size(0)
            start_idx = batch_idx * val_loader.batch_size
            batch_sessions = val_sessions[start_idx:start_idx + batch_size]
            day_idx = session_to_day.get(batch_sessions[0], 0)
            log_probs = baseline_model(x, day_idx=day_idx)
            log_probs = log_probs.permute(1, 0, 2)
            input_lengths = torch.tensor([log_probs.size(0)] * batch_size, dtype=torch.long)
            
            loss = criterion(log_probs, y, input_lengths, y_lens)
            val_loss += loss.item()
            val_steps += 1
    
    avg_val_loss = val_loss / val_steps
    
    # ========================================
    # LOGGING
    # ========================================
    print(f"\n Epoch {epoch+1} Results:")
    print(f"Train Loss: {avg_train_loss:.4f}")
    print(f"Val Loss:   {avg_val_loss:.4f}")
    
    # ========================================
    # EARLY STOPPING & CHECKPOINTING
    # ========================================
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0
        torch.save({
            'epoch': epoch,
            'model_state_dict': baseline_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': avg_val_loss,
        }, 'finetuned_baseline.pt')
        
        print(f" New best model saved! (val_loss: {avg_val_loss:.4f})")
    else:
        patience_counter += 1
        print(f" No improvement ({patience_counter}/{patience})")
        
        if patience_counter >= patience:
            print(f"\nEarly stopping triggered!")
            break

    torch.cuda.empty_cache()
print("FINE-TUNING COMPLETE!")
print(f"Best validation loss: {best_val_loss:.4f}")
print(f"Model saved: finetuned_baseline.pt")

checkpoint = torch.load('finetuned_baseline.pt')
baseline_model.load_state_dict(checkpoint['model_state_dict'])
baseline_model.eval()

print("\n Testing on 10 samples...")

for i in range(10):
    session = test_data['session'][i]
    day_idx = session_to_day.get(session, 0)
    
    x = preproc.preprocess(
        test_data['neural_features'][i],
        test_data['n_time_steps'][i]
    )
    x = torch.tensor(x).unsqueeze(0).float().to(device)
    
    with torch.no_grad():
        logits = baseline_model(x, day_idx=day_idx).squeeze(0).cpu().numpy()
    
    preds = logits.argmax(-1)
    phonemes = [PHONEMES[idx] for idx, _ in itertools.groupby(preds) if idx != 0]
    text = phonemes_to_text(phonemes)
    
    print(f"\n[{i}] {text}")

Model loaded: 44,315,177 params
Epochs: 15
Learning rate: 1e-05
Gradient clip: 1.0
Frozen: h0 only
   Trainable params: 44,314,409
Optimizer: Adam (lr=1e-05)
Scheduler: CosineAnnealingLR

EPOCH 1/15


Training: 100%|██████████| 1346/1346 [18:35<00:00,  1.21it/s, loss=4.7822, avg_loss=3.6458]
Validation: 100%|██████████| 238/238 [01:30<00:00,  2.62it/s]



 Epoch 1 Results:
Train Loss: 3.6458
Val Loss:   1.7863
 New best model saved! (val_loss: 1.7863)

EPOCH 2/15


Training: 100%|██████████| 1346/1346 [18:28<00:00,  1.21it/s, loss=2.9506, avg_loss=2.0568]
Validation: 100%|██████████| 238/238 [01:30<00:00,  2.63it/s]



 Epoch 2 Results:
Train Loss: 2.0568
Val Loss:   1.3688
 New best model saved! (val_loss: 1.3688)

EPOCH 3/15


Training: 100%|██████████| 1346/1346 [18:34<00:00,  1.21it/s, loss=2.9000, avg_loss=1.6420]
Validation: 100%|██████████| 238/238 [01:31<00:00,  2.61it/s]



 Epoch 3 Results:
Train Loss: 1.6420
Val Loss:   1.1834
 New best model saved! (val_loss: 1.1834)

EPOCH 4/15


Training: 100%|██████████| 1346/1346 [18:33<00:00,  1.21it/s, loss=0.3336, avg_loss=1.4463]
Validation: 100%|██████████| 238/238 [01:30<00:00,  2.62it/s]



 Epoch 4 Results:
Train Loss: 1.4463
Val Loss:   1.0928
 New best model saved! (val_loss: 1.0928)

EPOCH 5/15


Training: 100%|██████████| 1346/1346 [18:32<00:00,  1.21it/s, loss=1.8453, avg_loss=1.3473]
Validation: 100%|██████████| 238/238 [01:30<00:00,  2.62it/s]



 Epoch 5 Results:
Train Loss: 1.3473
Val Loss:   1.0582
 New best model saved! (val_loss: 1.0582)

EPOCH 6/15


Training: 100%|██████████| 1346/1346 [18:38<00:00,  1.20it/s, loss=1.2732, avg_loss=1.2717]
Validation: 100%|██████████| 238/238 [01:30<00:00,  2.62it/s]



 Epoch 6 Results:
Train Loss: 1.2717
Val Loss:   1.0058
 New best model saved! (val_loss: 1.0058)

EPOCH 7/15


Training: 100%|██████████| 1346/1346 [18:36<00:00,  1.21it/s, loss=1.0608, avg_loss=1.2167]
Validation: 100%|██████████| 238/238 [01:31<00:00,  2.61it/s]



 Epoch 7 Results:
Train Loss: 1.2167
Val Loss:   0.9813
 New best model saved! (val_loss: 0.9813)

EPOCH 8/15


Training: 100%|██████████| 1346/1346 [18:35<00:00,  1.21it/s, loss=1.0524, avg_loss=1.1751]
Validation: 100%|██████████| 238/238 [01:29<00:00,  2.65it/s]



 Epoch 8 Results:
Train Loss: 1.1751
Val Loss:   0.9565
 New best model saved! (val_loss: 0.9565)

EPOCH 9/15


Training: 100%|██████████| 1346/1346 [18:32<00:00,  1.21it/s, loss=1.3438, avg_loss=1.1394]
Validation: 100%|██████████| 238/238 [01:30<00:00,  2.62it/s]



 Epoch 9 Results:
Train Loss: 1.1394
Val Loss:   0.9498
 New best model saved! (val_loss: 0.9498)

EPOCH 10/15


Training: 100%|██████████| 1346/1346 [18:32<00:00,  1.21it/s, loss=2.1256, avg_loss=1.1208]
Validation: 100%|██████████| 238/238 [01:30<00:00,  2.63it/s]



 Epoch 10 Results:
Train Loss: 1.1208
Val Loss:   0.9330
 New best model saved! (val_loss: 0.9330)

EPOCH 11/15


Training: 100%|██████████| 1346/1346 [18:28<00:00,  1.21it/s, loss=2.4278, avg_loss=1.1141]
Validation: 100%|██████████| 238/238 [01:30<00:00,  2.64it/s]



 Epoch 11 Results:
Train Loss: 1.1141
Val Loss:   0.9291
 New best model saved! (val_loss: 0.9291)

EPOCH 12/15


Training: 100%|██████████| 1346/1346 [18:33<00:00,  1.21it/s, loss=1.8556, avg_loss=1.0950]
Validation: 100%|██████████| 238/238 [01:30<00:00,  2.63it/s]



 Epoch 12 Results:
Train Loss: 1.0950
Val Loss:   0.9120
 New best model saved! (val_loss: 0.9120)

EPOCH 13/15


Training: 100%|██████████| 1346/1346 [18:30<00:00,  1.21it/s, loss=1.7643, avg_loss=1.0696]
Validation: 100%|██████████| 238/238 [01:30<00:00,  2.63it/s]



 Epoch 13 Results:
Train Loss: 1.0696
Val Loss:   0.9108
 New best model saved! (val_loss: 0.9108)

EPOCH 14/15


Training: 100%|██████████| 1346/1346 [18:31<00:00,  1.21it/s, loss=2.8017, avg_loss=1.0728]
Validation: 100%|██████████| 238/238 [01:30<00:00,  2.62it/s]



 Epoch 14 Results:
Train Loss: 1.0728
Val Loss:   0.9114
 No improvement (1/2)

EPOCH 15/15


Training: 100%|██████████| 1346/1346 [18:33<00:00,  1.21it/s, loss=1.7926, avg_loss=1.0813]
Validation:  63%|██████▎   | 151/238 [00:56<00:32,  2.69it/s]


KeyboardInterrupt: 

In [31]:
# ===============================================================
# LAST CELL: INFERENCE WITH FINE-TUNED BASELINE + CREATE SUBMISSION
# ===============================================================

import torch
import itertools
import pandas as pd
from tqdm import tqdm

# 讀取最佳 fine-tuned model
checkpoint = torch.load('finetuned_baseline.pt', map_location=device)
baseline_model.load_state_dict(checkpoint['model_state_dict'])
baseline_model.eval()

print(f"Loaded fine-tuned model (val_loss: {checkpoint.get('val_loss', 'N/A')})")
print(f"Running inference on {len(test_data['neural_features'])} test examples...\n")

predictions_finetuned = []

# 推論整個 test set
for i in tqdm(range(len(test_data['neural_features']))):
    session = test_data['session'][i]
    day_idx = session_to_day.get(session, 0)

    x = preproc.preprocess(
        test_data['neural_features'][i],
        test_data['n_time_steps'][i]
    )
    x = torch.tensor(x).unsqueeze(0).float().to(device)

    with torch.no_grad():
        logits = baseline_model(x, day_idx=day_idx).squeeze(0)
        preds = logits.argmax(-1)

    phonemes = [PHONEMES[idx] for idx, _ in itertools.groupby(preds)]
    text = phonemes_to_text(phonemes)

    # 若空字串就放一個空白，避免出錯
    predictions_finetuned.append(text if text else ' ')

print("\nInference finished.")
print("Number of predictions:", len(predictions_finetuned))
print("First 3 predictions:", predictions_finetuned[:3])

# 建立 submission DataFrame
submission = pd.DataFrame({
    'id': range(len(predictions_finetuned)),
    'text': predictions_finetuned
})

print("\nSubmission shape:", submission.shape)  # 預期應該是 (1450, 2)
print(submission.head())

# 輸出成 CSV
submission.to_csv('submission.csv', index=False)
print("\nSaved: submission.csv")

Loaded fine-tuned model (val_loss: 0.9108351614294934)
Running inference on 1450 test examples...



100%|██████████| 1450/1450 [04:27<00:00,  5.43it/s]


Inference finished.
Number of predictions: 1450
First 3 predictions: [' ', ' ', ' ']

Submission shape: (1450, 2)
   id text
0   0     
1   1     
2   2     
3   3     
4   4     

Saved: submission.csv





In [32]:
from IPython.display import FileLink
FileLink('submission.csv')

In [29]:
import pandas as pd

print("len(predictions_finetuned) =", len(predictions_finetuned))

submission = pd.DataFrame({
    'id': range(len(predictions_finetuned)),   # 若是 1450 就會是 0~1449
    'text': predictions_finetuned
})

print(submission.shape)   # 應該看到 (1450, 2)
print(submission.head())

submission.to_csv('submission.csv', index=False)
print("Saved: submission.csv")

len(predictions_finetuned) = 3
(3, 2)
   id                  text
0   0  example prediction 1
1   1  example prediction 2
2   2              Ellipsis
Saved: submission.csv


In [None]:
# 另外存一個版本，不會動到原本的 submission.csv
submission_v2 = pd.DataFrame({
    'id': range(len(predictions_finetuned)),
    'text': predictions_finetuned
})

submission_v2.to_csv('submission_v2.csv', index=False)
print("Saved: submission_v2.csv")

In [27]:
# 這一格最下面，加這幾行
print("\nInference finished.")
print("Number of predictions:", len(predictions_finetuned))
print("First 3 predictions:", predictions_finetuned[:3])


Inference finished.
Number of predictions: 3
First 3 predictions: ['example prediction 1', 'example prediction 2', Ellipsis]


In [18]:
submission_v2 = pd.DataFrame({
    'id': range(len(predictions_finetuned)),
    'text': predictions_finetuned
})
submission_v2.to_csv('submission_v2.csv', index=False)
print("Saved: submission_v2.csv")


Saved: submission_v2.csv


In [19]:
print(type(predictions_finetuned))
print(len(predictions_finetuned))  # Should equal number of test samples

<class 'list'>
3


In [20]:
import pandas as pd
submission = pd.DataFrame({
    'id': range(len(predictions_finetuned)),
    'text': predictions_finetuned
})
submission.to_csv('submission.csv', index=False)
print("Saved: submission.csv")

submission_v2 = pd.DataFrame({
    'id': range(len(predictions_finetuned)),
    'text': predictions_finetuned
})
submission_v2.to_csv('submission_v2.csv', index=False)
print("Saved: submission_v2.csv")

Saved: submission.csv
Saved: submission_v2.csv


In [21]:
import pandas as pd
print(pd.read_csv('submission.csv').head())
print(pd.read_csv('submission.csv').shape)


   id                  text
0   0  example prediction 1
1   1  example prediction 2
2   2              Ellipsis
(3, 2)


In [22]:
from IPython.display import FileLink

FileLink('submission.csv')

In [25]:
import pandas as pd

# 這裡直接做出 1450 筆資料，id 從 0 到 1449
n_rows = 1450

submission = pd.DataFrame({
    'id': range(n_rows),
    'text': ['dummy prediction'] * n_rows   # 先全部放一樣的文字
})

print(submission.shape)   # 應該會印出 (1450, 2)
print(submission.head())

submission.to_csv('submission.csv', index=False)
print("Saved: submission.csv")

(1450, 2)
   id              text
0   0  dummy prediction
1   1  dummy prediction
2   2  dummy prediction
3   3  dummy prediction
4   4  dummy prediction
Saved: submission.csv


In [30]:
from IPython.display import FileLink
FileLink('submission.csv')

In [33]:
import pandas as pd

submission = pd.read_csv('submission.csv')

print(submission.head())
print(submission.tail())

# 看每一列文字長度的統計
print(submission['text'].str.len().describe())

   id text
0   0     
1   1     
2   2     
3   3     
4   4     
        id text
1445  1445     
1446  1446     
1447  1447     
1448  1448     
1449  1449     
count    1450.0
mean        1.0
std         0.0
min         1.0
25%         1.0
50%         1.0
75%         1.0
max         1.0
Name: text, dtype: float64


In [34]:
# ===============================================================
# NEW LAST CELL: INFERENCE WITH FINE-TUNED BASELINE + SUBMISSION
# ===============================================================

import torch
import itertools
import pandas as pd
from tqdm import tqdm

# 讀取最佳 fine-tuned model
checkpoint = torch.load('finetuned_baseline.pt', map_location=device)
baseline_model.load_state_dict(checkpoint['model_state_dict'])
baseline_model.eval()

print(f"Loaded fine-tuned model (val_loss: {checkpoint.get('val_loss', 'N/A')})")
print(f"Running inference on {len(test_data['neural_features'])} test examples...\n")

predictions_finetuned = []

# ---------- 這裡就是「for 迴圈、跑 1450 筆推論」 ----------
for i in tqdm(range(len(test_data['neural_features']))):
    session = test_data['session'][i]
    day_idx = session_to_day.get(session, 0)

    x = preproc.preprocess(
        test_data['neural_features'][i],
        test_data['n_time_steps'][i]
    )
    x = torch.tensor(x).unsqueeze(0).float().to(device)

    with torch.no_grad():
        logits = baseline_model(x, day_idx=day_idx).squeeze(0)
        preds = logits.argmax(-1)

    # ✅ 關鍵：過濾掉 CTC 的 BLANK（index = 0）
    phonemes = [PHONEMES[idx] for idx, _ in itertools.groupby(preds) if idx != 0]

    text = phonemes_to_text(phonemes)

    # 若空字串就放一個空白，避免出錯
    predictions_finetuned.append(text if text else ' ')

print("\nInference finished.")
print("Number of predictions:", len(predictions_finetuned))
print("First 3 predictions:", predictions_finetuned[:3])

# 建立 submission DataFrame
submission = pd.DataFrame({
    'id': range(len(predictions_finetuned)),
    'text': predictions_finetuned
})

print("\nSubmission shape:", submission.shape)  # 預期 (1450, 2)
print(submission.head())

# 輸出成 CSV
submission.to_csv('submission.csv', index=False)
print("\nSaved: submission.csv")

Loaded fine-tuned model (val_loss: 0.9108351614294934)
Running inference on 1450 test examples...



100%|██████████| 1450/1450 [04:27<00:00,  5.42it/s]


Inference finished.
Number of predictions: 1450
First 3 predictions: ['i get tired with the song and days commit', 'here', 'you ought a mcgirr surprised']

Submission shape: (1450, 2)
   id                                       text
0   0  i get tired with the song and days commit
1   1                                       here
2   2               you ought a mcgirr surprised
3   3               i think mamie you like it it
4   4               hsiao that they do have prom

Saved: submission.csv





In [35]:
import pandas as pd

submission = pd.read_csv('submission.csv')

print(submission.head())
print(submission.tail())
print(submission['text'].str.len().describe())

   id                                       text
0   0  i get tired with the song and days commit
1   1                                       here
2   2               you ought a mcgirr surprised
3   3               i think mamie you like it it
4   4               hsiao that they do have prom
        id                              text
1445  1445  gees they daane have the ill fee
1446  1446           a lot of new sus it the
1447  1447                   an aue a have a
1448  1448            she wass at the aw one
1449  1449           shire beta that is fant
count    1450.000000
mean       24.041379
std         9.929479
min         1.000000
25%        17.000000
50%        24.000000
75%        31.000000
max        64.000000
Name: text, dtype: float64


In [38]:
from IPython.display import FileLink
FileLink('submission.csv')

In [37]:
# ============================================================
# ENSEMBLE INFERENCE:
#   original baseline + fine-tuned baseline (CTC)
#   產生 submission_ensemble.csv
# ============================================================

import os
import torch
import itertools
import pandas as pd
from tqdm import tqdm

# 1. 建兩個 model：原始 baseline + fine-tuned baseline

print("Loading ORIGINAL baseline model ...")
baseline_model_base = BaselineGRU().to(device)

BASELINE_DIR = "/kaggle/input/brain-to-text-25/t15_pretrained_rnn_baseline/t15_pretrained_rnn_baseline"
ckpt_base = torch.load(
    os.path.join(BASELINE_DIR, "checkpoint", "best_checkpoint"),
    map_location=device,
    weights_only=False,
)
state_dict_base = ckpt_base["model_state_dict"]
state_dict_base = {k.replace("_orig_mod.", ""): v for k, v in state_dict_base.items()}
baseline_model_base.load_state_dict(state_dict_base, strict=True)
baseline_model_base.eval()

print("Loading FINE-TUNED baseline model ...")
baseline_model_ft = BaselineGRU().to(device)
ckpt_ft = torch.load("finetuned_baseline.pt", map_location=device)
baseline_model_ft.load_state_dict(ckpt_ft["model_state_dict"])
baseline_model_ft.eval()

print("Both models ready.")
print(f"Running ENSEMBLE inference on {len(test_data['neural_features'])} test examples...\n")

# 2. 逐筆做 ensemble 推論
predictions_ens = []

for i in tqdm(range(len(test_data["neural_features"]))):
    session = test_data["session"][i]
    day_idx = session_to_day.get(session, 0)

    x = preproc.preprocess(
        test_data["neural_features"][i],
        test_data["n_time_steps"][i],
    )
    x = torch.tensor(x).unsqueeze(0).float().to(device)

    with torch.no_grad():
        # 兩個模型都算一次 logits（其實是 log-probs）
        logits_base = baseline_model_base(x, day_idx=day_idx)
        logits_ft = baseline_model_ft(x, day_idx=day_idx)

        # 做簡單平均 ensemble
        logits_ens = 0.5 * (logits_base + logits_ft)

        preds = logits_ens.squeeze(0).argmax(-1).cpu().numpy()

    # CTC: 合併重複 & 去掉 BLANK (index = 0)
    phonemes = [PHONEMES[idx] for idx, _ in itertools.groupby(preds) if idx != 0]

    text = phonemes_to_text(phonemes)
    predictions_ens.append(text if text else " ")

print("\nInference finished.")
print("Number of predictions:", len(predictions_ens))

# 3. 建立 submission DataFrame
submission_ens = pd.DataFrame({
    "id": range(len(predictions_ens)),
    "text": predictions_ens,
})

print("\nSubmission shape:", submission_ens.shape)
print(submission_ens.head())

# 4. 存檔
submission_ens.to_csv("submission_ensemble.csv", index=False)
print("\nSaved: submission_ensemble.csv")

Loading ORIGINAL baseline model ...
Loading FINE-TUNED baseline model ...
Both models ready.
Running ENSEMBLE inference on 1450 test examples...



100%|██████████| 1450/1450 [07:27<00:00,  3.24it/s]


Inference finished.
Number of predictions: 1450

Submission shape: (1450, 2)
   id   text
0   0    i'd
1   1       
2   2   earp
3   3  e you
4   4    luz

Saved: submission_ensemble.csv





In [41]:
import torch
import torch.nn as nn

# 1. 假設你的音素表像這樣
PHONEMES = ['BLANK', 'AA', 'AE', ..., 'OY', ...]   # 全部音素排序
num_phonemes = len(PHONEMES)

# 2. 指定稀有音素index（照你的表順序！）
rare_idx = [PHONEMES.index(p) for p in ['AW', 'UH', 'CH', 'JH', 'OY']]
weights = torch.ones(num_phonemes)
for idx in rare_idx:
    weights[idx] = 10.0  # 測試3~10之間，看score最優

criterion = nn.CTCLoss(blank=0, zero_infinity=True, weight=weights.to(device))
# 在訓練時: loss = criterion(logits, targets, input_lens, target_lens)


ValueError: 'AW' is not in list

In [42]:
augmented_data = []
for i, y in enumerate(train_labels):
    if any(target in rare_idx for target in y):  # y是一個trial對應的音素列表
        for _ in range(10):  # oversample十次
            augmented_data.append(train_data[i])
# 組合新訓練資料
train_data = train_data + augmented_data
train_labels = train_labels + [train_labels[i] for i, y in enumerate(train_labels)
                              for target in y if any(t in rare_idx for t in y) for _ in range(10)]


NameError: name 'train_labels' is not defined