# Brain-to-Text Model Evaluation Notebook

This notebook is designed for Kaggle competitions and allows you to:
1. Load pretrained RNN models
2. Evaluate on test/validation data
3. Generate predictions for submission

Based on the NEJM Brain-to-Text paper (Card et al., 2024)

## 1. Install Dependencies

In [None]:
# Install required packages
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install -q omegaconf pandas numpy h5py tqdm editdistance scipy g2p_en

## 2. Setup Project Structure

In [None]:
import os
import sys

# Create necessary directories
os.makedirs('model_training', exist_ok=True)
os.makedirs('nejm_b2txt_utils', exist_ok=True)
os.makedirs('data', exist_ok=True)

# Add to path
sys.path.append('/kaggle/working')
sys.path.append('/kaggle/working/model_training')
sys.path.append('/kaggle/working/nejm_b2txt_utils')

## 3. Copy Project Files

**Note:** In Kaggle, you'll need to add the project files as datasets or copy them manually.
This cell assumes the files are available in `/kaggle/input/` or you can upload them.


In [None]:
# Copy files from input dataset (adjust path as needed)
# Uncomment and modify if you have the files in a Kaggle dataset

# import shutil
# 
# # Example: Copy from input dataset
# # shutil.copytree('/kaggle/input/your-dataset-name/model_training', '/kaggle/working/model_training', dirs_exist_ok=True)
# # shutil.copytree('/kaggle/input/your-dataset-name/nejm_b2txt_utils', '/kaggle/working/nejm_b2txt_utils', dirs_exist_ok=True)

print("Files should be copied manually or via Kaggle dataset")


## 4. Import or Define Core Functions

Try to import from project files first. If not available, essential functions are defined inline.


In [None]:
# Try to import from files, otherwise define inline
try:
    from model_training.rnn_model import GRUDecoder
    from model_training.evaluate_model_helpers import (
        load_h5py_file, runSingleDecodingStep, 
        LOGIT_TO_PHONEME, rearrange_speech_logits_pt, remove_punctuation
    )
    from model_training.data_augmentations import gauss_smooth
    print("✓ Successfully imported from project files")
except ImportError as e:
    print(f"Files not found ({e}), defining functions inline...")
    
    # Import required libraries
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import numpy as np
    import h5py
    import re
    from scipy.ndimage import gaussian_filter1d
    
    # Define phoneme mapping
    LOGIT_TO_PHONEME = [
        'BLANK',
        'AA', 'AE', 'AH', 'AO', 'AW',
        'AY', 'B',  'CH', 'D', 'DH',
        'EH', 'ER', 'EY', 'F', 'G',
        'HH', 'IH', 'IY', 'JH', 'K',
        'L', 'M', 'N', 'NG', 'OW',
        'OY', 'P', 'R', 'S', 'SH',
        'T', 'TH', 'UH', 'UW', 'V',
        'W', 'Y', 'Z', 'ZH',
        ' | ',
    ]
    
    # Define GRUDecoder (simplified version - see rnn_model.py for full implementation)
    class GRUDecoder(nn.Module):
        def __init__(self, neural_dim, n_units, n_days, n_classes, 
                     rnn_dropout=0.0, input_dropout=0.0, n_layers=5, 
                     patch_size=0, patch_stride=0):
            super(GRUDecoder, self).__init__()
            self.neural_dim = neural_dim
            self.n_units = n_units
            self.n_classes = n_classes
            self.n_layers = n_layers
            self.n_days = n_days
            self.rnn_dropout = rnn_dropout
            self.input_dropout = input_dropout
            self.patch_size = patch_size
            self.patch_stride = patch_stride
            
            self.day_layer_activation = nn.Softsign()
            self.day_weights = nn.ParameterList(
                [nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)]
            )
            self.day_biases = nn.ParameterList(
                [nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)]
            )
            self.day_layer_dropout = nn.Dropout(input_dropout)
            self.input_size = self.neural_dim
            
            if self.patch_size > 0:
                self.input_size *= self.patch_size
            
            self.gru = nn.GRU(
                input_size=self.input_size,
                hidden_size=self.n_units,
                num_layers=self.n_layers,
                dropout=self.rnn_dropout,
                batch_first=True,
                bidirectional=False,
            )
            
            for name, param in self.gru.named_parameters():
                if "weight_hh" in name:
                    nn.init.orthogonal_(param)
                if "weight_ih" in name:
                    nn.init.xavier_uniform_(param)
            
            self.out = nn.Linear(self.n_units, self.n_classes)
            nn.init.xavier_uniform_(self.out.weight)
            self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units)))
        
        def forward(self, x, day_idx, states=None, return_state=False):
            day_weights = torch.stack([self.day_weights[i] for i in day_idx], dim=0)
            day_biases = torch.cat([self.day_biases[i] for i in day_idx], dim=0).unsqueeze(1)
            x = torch.einsum("btd,bdk->btk", x, day_weights) + day_biases
            x = self.day_layer_activation(x)
            if self.input_dropout > 0:
                x = self.day_layer_dropout(x)
            if self.patch_size > 0:
                x = x.unsqueeze(1)
                x = x.permute(0, 3, 1, 2)
                x_unfold = x.unfold(3, self.patch_size, self.patch_stride)
                x_unfold = x_unfold.squeeze(2)
                x_unfold = x_unfold.permute(0, 2, 3, 1)
                x = x_unfold.reshape(x.size(0), x_unfold.size(1), -1)
            if states is None:
                states = self.h0.expand(self.n_layers, x.shape[0], self.n_units).contiguous()
            output, hidden_states = self.gru(x, states)
            logits = self.out(output)
            if return_state:
                return logits, hidden_states
            return logits
    
    # Helper functions
    def gauss_smooth(inputs, device, smooth_kernel_std=2, smooth_kernel_size=100, padding='same'):
        inp = np.zeros(smooth_kernel_size, dtype=np.float32)
        inp[smooth_kernel_size // 2] = 1
        gaussKernel = gaussian_filter1d(inp, smooth_kernel_std)
        validIdx = np.argwhere(gaussKernel > 0.01)
        gaussKernel = gaussKernel[validIdx]
        gaussKernel = np.squeeze(gaussKernel / np.sum(gaussKernel))
        gaussKernel = torch.tensor(gaussKernel, dtype=torch.float32, device=device)
        gaussKernel = gaussKernel.view(1, 1, -1)
        B, T, C = inputs.shape
        inputs = inputs.permute(0, 2, 1)
        gaussKernel = gaussKernel.repeat(C, 1, 1)
        smoothed = F.conv1d(inputs, gaussKernel, padding=padding, groups=C)
        return smoothed.permute(0, 2, 1)
    
    def load_h5py_file(file_path, b2txt_csv_df):
        data = {
            'neural_features': [], 'n_time_steps': [], 'seq_class_ids': [],
            'seq_len': [], 'transcriptions': [], 'sentence_label': [],
            'session': [], 'block_num': [], 'trial_num': [], 'corpus': [],
        }
        with h5py.File(file_path, 'r') as f:
            for key in list(f.keys()):
                g = f[key]
                year, month, day = g.attrs['session'].split('.')[1:]
                date = f'{year}-{month}-{day}'
                row = b2txt_csv_df[(b2txt_csv_df['Date'] == date) & 
                                   (b2txt_csv_df['Block number'] == g.attrs['block_num'])]
                corpus_name = row['Corpus'].values[0] if len(row) > 0 else 'unknown'
                data['neural_features'].append(g['input_features'][:])
                data['n_time_steps'].append(g.attrs['n_time_steps'])
                data['seq_class_ids'].append(g['seq_class_ids'][:] if 'seq_class_ids' in g else None)
                data['seq_len'].append(g.attrs['seq_len'] if 'seq_len' in g.attrs else None)
                data['transcriptions'].append(g['transcription'][:] if 'transcription' in g else None)
                data['sentence_label'].append(g.attrs['sentence_label'][:] if 'sentence_label' in g.attrs else None)
                data['session'].append(g.attrs['session'])
                data['block_num'].append(g.attrs['block_num'])
                data['trial_num'].append(g.attrs['trial_num'])
                data['corpus'].append(corpus_name)
        return data
    
    def rearrange_speech_logits_pt(logits):
        return np.concatenate((logits[:, :, 0:1], logits[:, :, -1:], logits[:, :, 1:-1]), axis=-1)
    
    def remove_punctuation(sentence):
        sentence = re.sub(r'[^a-zA-Z\- \']', '', sentence)
        sentence = sentence.replace('- ', ' ').replace('--', '').replace(" '", "'").lower()
        return ' '.join([word for word in sentence.strip().split() if word != ''])
    
    def runSingleDecodingStep(x, input_layer, model, model_args, device):
        with torch.autocast(device_type="cuda", enabled=model_args.get('use_amp', False), dtype=torch.bfloat16):
            x = gauss_smooth(x, device,
                smooth_kernel_std=model_args['dataset']['data_transforms']['smooth_kernel_std'],
                smooth_kernel_size=model_args['dataset']['data_transforms']['smooth_kernel_size'],
                padding='valid')
            with torch.no_grad():
                logits, _ = model(x=x, day_idx=torch.tensor([input_layer], device=device),
                                states=None, return_state=True)
        return logits.float().cpu().numpy()
    
    print("✓ Functions defined inline")


## 5. Configuration


In [None]:
import torch
from omegaconf import OmegaConf
import pandas as pd

# Configuration - UPDATE THESE PATHS FOR YOUR KAGGLE SETUP
CONFIG = {
    'model_path': '/kaggle/input/your-model-path/t15_pretrained_rnn_baseline',  # Update this
    'data_dir': '/kaggle/input/your-data-path/hdf5_data_final',  # Update this
    'csv_path': '/kaggle/input/your-data-path/t15_copyTaskData_description.csv',  # Update this
    'eval_type': 'test',  # 'val' or 'test'
    'gpu_number': 0,  # GPU to use, or -1 for CPU
    'output_file': 'predictions.csv'
}

# Set device
if torch.cuda.is_available() and CONFIG['gpu_number'] >= 0:
    device = torch.device(f'cuda:{CONFIG["gpu_number"]}')
    print(f'✓ Using {device} for model inference.')
else:
    device = torch.device('cpu')
    print('⚠ Using CPU for model inference.')


## 6. Load Model


In [None]:
# Load model configuration
model_args = OmegaConf.load(os.path.join(CONFIG['model_path'], 'checkpoint/args.yaml'))

# Create model
model = GRUDecoder(
    neural_dim=model_args['model']['n_input_features'],
    n_units=model_args['model']['n_units'],
    n_days=len(model_args['dataset']['sessions']),
    n_classes=model_args['dataset']['n_classes'],
    rnn_dropout=model_args['model']['rnn_dropout'],
    input_dropout=model_args['model']['input_network']['input_layer_dropout'],
    n_layers=model_args['model']['n_layers'],
    patch_size=model_args['model']['patch_size'],
    patch_stride=model_args['model']['patch_stride'],
)

# Load model weights
checkpoint = torch.load(
    os.path.join(CONFIG['model_path'], 'checkpoint/best_checkpoint'),
    weights_only=False,
    map_location=device
)

# Handle DataParallel keys
state_dict = checkpoint['model_state_dict']
new_state_dict = {}
for key, value in state_dict.items():
    new_key = key.replace('module.', '').replace('_orig_mod.', '')
    new_state_dict[new_key] = value

model.load_state_dict(new_state_dict)
model.to(device)
model.eval()

print(f"✓ Model loaded successfully. Parameters: {sum(p.numel() for p in model.parameters()):,}")


## 7. Load Data


In [None]:
# Load CSV metadata
b2txt_csv_df = pd.read_csv(CONFIG['csv_path'])

# Load data for each session
test_data = {}
total_test_trials = 0

for session in model_args['dataset']['sessions']:
    session_dir = os.path.join(CONFIG['data_dir'], session)
    if os.path.exists(session_dir):
        files = [f for f in os.listdir(session_dir) if f.endswith('.hdf5')]
        eval_file = os.path.join(session_dir, f'data_{CONFIG["eval_type"]}.hdf5')
        
        if os.path.exists(eval_file):
            data = load_h5py_file(eval_file, b2txt_csv_df)
            test_data[session] = data
            total_test_trials += len(test_data[session]["neural_features"])
            print(f'✓ Loaded {len(test_data[session]["neural_features"])} {CONFIG["eval_type"]} trials for session {session}.')

print(f'✓ Total number of {CONFIG["eval_type"]} trials: {total_test_trials}')


## 8. Run Inference (Phoneme Predictions)


In [None]:
from tqdm import tqdm
import numpy as np

# Run inference to get phoneme logits
with tqdm(total=total_test_trials, desc='Predicting phoneme sequences', unit='trial') as pbar:
    for session, data in test_data.items():
        data['logits'] = []
        input_layer = model_args['dataset']['sessions'].index(session)
        
        for trial in range(len(data['neural_features'])):
            # Get neural input
            neural_input = data['neural_features'][trial]
            neural_input = np.expand_dims(neural_input, axis=0)
            neural_input = torch.tensor(neural_input, device=device, dtype=torch.bfloat16)
            
            # Run decoding
            logits = runSingleDecodingStep(neural_input, input_layer, model, model_args, device)
            data['logits'].append(logits)
            
            pbar.update(1)

print("✓ Inference completed!")


## 9. Convert Logits to Phonemes


In [None]:
# Convert logits to phoneme sequences
for session, data in test_data.items():
    data['pred_seq'] = []
    for trial in range(len(data['logits'])):
        logits = data['logits'][trial][0]
        pred_seq = np.argmax(logits, axis=-1)
        # Remove blanks (0)
        pred_seq = [int(p) for p in pred_seq if p != 0]
        # Remove consecutive duplicates
        pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]]
        # Convert to phonemes
        pred_seq = [LOGIT_TO_PHONEME[p] for p in pred_seq]
        data['pred_seq'].append(pred_seq)

print("✓ Phoneme conversion completed!")


## 10. Generate Predictions

**Note:** The full language model pipeline requires Redis and additional setup. 
For Kaggle, you may need to use a simplified approach or implement phoneme-to-text conversion.


In [None]:
# Generate predictions (simple version without language model)
# This is a placeholder - you may need to implement phoneme-to-text conversion

predictions = []
ids = []

for session in test_data.keys():
    for trial in range(len(test_data[session]['pred_seq'])):
        # Simple placeholder: join phonemes with spaces
        # In practice, you'd want proper phoneme-to-text conversion
        phoneme_seq = ' '.join(test_data[session]['pred_seq'][trial])
        
        # For now, use phoneme sequence as placeholder
        # You should replace this with actual text prediction
        predictions.append(phoneme_seq)
        ids.append(len(ids))

print(f"✓ Generated {len(predictions)} predictions")
print(f"\nSample predictions:")
for i in range(min(5, len(predictions))):
    print(f"  {i}: {predictions[i][:100]}...")


## 11. Save Predictions


In [None]:
import pandas as pd
import time

# Create submission DataFrame
df_out = pd.DataFrame({
    'id': ids,
    'text': predictions
})

# Save to CSV
output_file = CONFIG['output_file']
df_out.to_csv(output_file, index=False)

print(f"✓ Predictions saved to {output_file}")
print(f"\nFirst few predictions:")
print(df_out.head(10))


## 12. Evaluation (Validation Set Only)

If evaluating on validation set, calculate Word Error Rate (WER).


In [None]:
if CONFIG['eval_type'] == 'val':
    import editdistance
    
    total_true_length = 0
    total_edit_distance = 0
    
    trial_idx = 0
    for session in test_data.keys():
        for trial in range(len(test_data[session]['sentence_label'])):
            if trial_idx < len(predictions):
                true_sentence = remove_punctuation(test_data[session]['sentence_label'][trial]).strip()
                pred_sentence = remove_punctuation(predictions[trial_idx]).strip()
                
                ed = editdistance.eval(true_sentence.split(), pred_sentence.split())
                total_true_length += len(true_sentence.split())
                total_edit_distance += ed
            trial_idx += 1
    
    if total_true_length > 0:
        wer = 100 * total_edit_distance / total_true_length
        print(f'Total true sentence length: {total_true_length}')
        print(f'Total edit distance: {total_edit_distance}')
        print(f'Aggregate Word Error Rate (WER): {wer:.2f}%')
    else:
        print("⚠ No validation data available for WER calculation")
else:
    print("ℹ Test set evaluation - ground truth not available")


## Notes

1. **Data Paths**: Update the paths in the CONFIG section (Cell 5) to match your Kaggle dataset structure.
2. **Language Model**: The full language model pipeline requires Redis and additional setup. For Kaggle, you may need to use a simplified approach or implement phoneme-to-text conversion.
3. **Phoneme-to-Text**: The current implementation uses phoneme sequences as placeholders. You'll need to implement proper phoneme-to-text conversion or use the language model.
4. **File Structure**: Make sure your Kaggle dataset includes:
   - Model checkpoint files (`checkpoint/args.yaml`, `checkpoint/best_checkpoint`)
   - HDF5 data files (`hdf5_data_final/`)
   - CSV metadata file (`t15_copyTaskData_description.csv`)
   - Project source files (optional, if not defining inline)

5. **Kaggle Dataset Setup**: 
   - Upload your model and data as Kaggle datasets
   - Add them to your notebook via "Add Data" button
   - Update the paths in CONFIG to point to `/kaggle/input/your-dataset-name/...`
