In [None]:
# ============================================================================
# COMPLETE WER EVALUATION
import os
import h5py
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import editdistance
from collections import defaultdict, Counter
import warnings
warnings.filterwarnings("ignore")

BASE_PATH = "/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final"
MODEL_PATH = "/kaggle/input/transformer/other/default/1/best_model_full_per.pt"
BATCH_SIZE_VAL = 2
NUM_WORKERS = 2
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Model architecture (must match training)
VOCAB_SIZE = 41
D_MODEL = 256
NHEAD = 8
NUM_LAYERS = 4
DROPOUT = 0.1

# Phoneme vocabulary (same as training)
LOGIT_TO_PHONEME = [
    "BLANK",  # index 0
    "AA", "AE", "AH", "AO", "AW", "AY", "B", "CH", "D", "DH",
    "EH", "ER", "EY", "F", "G", "HH", "IH", "IY", "JH", "K",
    "L", "M", "N", "NG", "OW", "OY", "P", "R", "S", "SH",
    "T", "TH", "UH", "UW", "V", "W", "Y", "Z", "ZH", " | "
]
idx_to_phoneme = {i: p for i, p in enumerate(LOGIT_TO_PHONEME)}

# ============================================================================
# DATA LOADING 
# ============================================================================

def get_all_hdf5_files(base_path, split_name):
    """Recursively find all HDF5 files for a given split"""
    all_files = []
    for date_folder in sorted(os.listdir(base_path)):
        if date_folder.startswith("t15."):
            fp = os.path.join(base_path, date_folder, f"data_{split_name}.hdf5")
            if os.path.exists(fp):
                all_files.append(fp)
    return all_files

def load_full_dataset_info(h5_paths):
    """Load metadata from all HDF5 files"""
    all_trials = []
    total_trials = 0
    
    for h5_path in h5_paths:
        if not os.path.exists(h5_path):
            continue
            
        with h5py.File(h5_path, "r") as f:
            file_trials = 0
            for k in f.keys():
                grp = f[k]
                if "seq_class_ids" in grp:
                    all_trials.append((h5_path, k))
                    file_trials += 1
            
            print(f"  {os.path.basename(h5_path)}: {file_trials} trials")
            total_trials += file_trials
    
    print(f"\nTotal trials found: {total_trials}")
    return all_trials

class Brain2TextDatasetFull(Dataset):
    """Dataset loader (same as training)"""
    def __init__(self, trial_list, max_trials=None):
        self.trials = trial_list if max_trials is None else trial_list[:max_trials]
        self.file_cache = {}
        print(f"Dataset initialized with {len(self.trials)} trials")
    
    def __len__(self):
        return len(self.trials)
    
    def __getitem__(self, idx):
        file_path, key = self.trials[idx]
        
        if file_path not in self.file_cache:
            self.file_cache[file_path] = h5py.File(file_path, "r")
        
        f = self.file_cache[file_path]
        grp = f[key]
        
        # Load neural features
        x = grp["input_features"][()]
        x = torch.tensor(x, dtype=torch.float32)
        
        # Load phoneme target
        if "seq_class_ids" in grp:
            seq_class_ids = grp["seq_class_ids"][()]
            seq_len = grp.attrs.get("seq_len", len(seq_class_ids))
            phoneme_seq = seq_class_ids[:seq_len]
            y = torch.tensor(phoneme_seq, dtype=torch.long)
        else:
            y = torch.tensor([], dtype=torch.long)
        
        return x, y
    
    def __del__(self):
        for f in self.file_cache.values():
            try:
                f.close()
            except:
                pass

def ctc_collate(batch):
    """Collate function for CTC loss"""
    xs, ys = zip(*batch)
    x_lens = torch.tensor([len(x) for x in xs], dtype=torch.long)
    y_lens = torch.tensor([len(y) for y in ys], dtype=torch.long)
    X = nn.utils.rnn.pad_sequence(xs, batch_first=True)
    Y = torch.cat(ys)
    return X, Y, x_lens, y_lens

# ============================================================================
# MODEL ARCHITECTURE (Same as training)
# ============================================================================

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=10000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0))
    
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class BrainTransformer(nn.Module):
    def __init__(self, vocab_size=41, d_model=256, nhead=8, num_layers=4, dropout=0.1):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Linear(512, d_model),
            nn.LayerNorm(d_model),
            nn.GELU(),
            nn.Dropout(dropout)
        )
        self.pos = PositionalEncoding(d_model)
        layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=d_model*4,
            dropout=dropout, activation='gelu', batch_first=True, norm_first=True
        )
        self.transformer = nn.TransformerEncoder(layer, num_layers)
        self.out = nn.Linear(d_model, vocab_size)
    
    def forward(self, x, mask=None):
        x = self.proj(x)
        x = self.pos(x)
        x = self.transformer(x, src_key_padding_mask=mask)
        return self.out(x)

# ============================================================================
# CTC DECODING (Same as training)
# ============================================================================

def ctc_greedy_decode(logits):
    """CTC greedy decoding: argmax -> collapse repeats -> remove blanks"""
    pred_ids = logits.argmax(-1).cpu().numpy()
    decoded_sequences = []
    
    for seq in pred_ids:
        collapsed = []
        prev = -1
        for token_id in seq:
            if token_id != prev:
                collapsed.append(int(token_id))
            prev = token_id
        no_blanks = [tok for tok in collapsed if tok != 0]
        decoded_sequences.append(no_blanks)
    
    return decoded_sequences

def calculate_per(predictions, ground_truths):
    """Calculate Phoneme Error Rate"""
    total_edit_distance = 0
    total_length = 0
    
    for pred_seq, true_seq in zip(predictions, ground_truths):
        edit_dist = editdistance.eval(true_seq, pred_seq)
        total_edit_distance += edit_dist
        total_length += len(true_seq)
    
    per = (total_edit_distance / total_length * 100) if total_length > 0 else 100.0
    return per, total_edit_distance, total_length

# ============================================================================
# PHONEME-TO-WORD CONVERSION
# ============================================================================

class PhonemeDictionary:
    """Phoneme-to-word dictionary for decoding"""
    
    def __init__(self):
        # Common words dictionary (phonemes as tuples -> words)
        self.phoneme_to_words = {
            # Pronouns
            ('Y', 'UW'): ['you'],
            ('HH', 'IY'): ['he'],
            ('SH', 'IY'): ['she'],
            ('W', 'IY'): ['we'],
            ('DH', 'EY'): ['they'],
            ('IH', 'T'): ['it'],
            
            # Common verbs
            ('K', 'AE', 'N'): ['can'],
            ('W', 'IH', 'L'): ['will'],
            ('HH', 'AE', 'V'): ['have'],
            ('W', 'AH', 'Z'): ['was'],
            ('W', 'ER'): ['were'],
            ('IH', 'Z'): ['is'],
            ('S', 'IY'): ['see'],
            ('G', 'OW'): ['go'],
            ('K', 'AH', 'M'): ['come'],
            ('G', 'EH', 'T'): ['get'],
            ('M', 'EY', 'K'): ['make'],
            ('K', 'IY', 'P'): ['keep'],
            
            # Articles & Prepositions
            ('DH', 'AH'): ['the'],
            ('AH'): ['a', 'uh'],
            ('AE', 'N'): ['an'],
            ('T', 'UW'): ['to', 'too'],
            ('AH', 'V'): ['of'],
            ('AE', 'T'): ['at'],
            ('IH', 'N'): ['in'],
            ('AO', 'N'): ['on'],
            ('F', 'AO', 'R'): ['for'],
            ('AE', 'Z'): ['as'],
            
            # Demonstratives
            ('DH', 'IH', 'S'): ['this'],
            ('DH', 'AE', 'T'): ['that'],
            
            # Conjunctions
            ('AE', 'N', 'D'): ['and'],
            ('B', 'AH', 'T'): ['but'],
            ('AO', 'R'): ['or'],
            
            # Question words
            ('HH', 'AW'): ['how'],
            ('W', 'AH', 'T'): ['what'],
            ('W', 'EH', 'N'): ['when'],
            ('W', 'EH', 'R'): ['where'],
            ('HH', 'UW'): ['who'],
            ('W', 'AY'): ['why'],
            
            # Common nouns
            ('T', 'AY', 'M'): ['time'],
            ('P', 'IY', 'P', 'AH', 'L'): ['people'],
            ('K', 'OW', 'D'): ['code'],
            ('P', 'OY', 'N', 'T'): ['point'],
            ('K', 'AA', 'S', 'T'): ['cost'],
            ('D', 'AW', 'N'): ['down'],
            
            # Adjectives
            ('N', 'UW'): ['new'],
            ('G', 'UH', 'D'): ['good'],
            ('N', 'AA', 'T'): ['not'],
            
            # Common expressions
            ('Y', 'EH', 'S'): ['yes'],
            ('N', 'OW'): ['no'],
            ('W', 'EH', 'L'): ['well'],
        }
        
        # Word frequencies
        self.word_freq = Counter({
            'the': 1000000, 'to': 500000, 'and': 450000,
            'of': 400000, 'a': 350000, 'in': 300000,
            'is': 250000, 'that': 200000, 'for': 180000,
            'it': 170000, 'you': 160000, 'can': 150000,
            'will': 140000, 'have': 130000, 'this': 120000,
            'at': 110000, 'see': 100000, 'what': 95000,
            'how': 90000, 'not': 85000, 'code': 50000,
        })
    
    def lookup(self, phoneme_tuple):
        """Exact dictionary lookup"""
        return self.phoneme_to_words.get(phoneme_tuple, None)
    
    def fuzzy_match(self, phoneme_tuple, max_edit_dist=2):
        """Find closest matches using edit distance"""
        if not phoneme_tuple:
            return []
        
        candidates = []
        for dict_phonemes, words in self.phoneme_to_words.items():
            dist = editdistance.eval(phoneme_tuple, dict_phonemes)
            
            if dist <= max_edit_dist:
                for word in words:
                    freq = self.word_freq.get(word, 1)
                    score = -dist + np.log(freq) / 10
                    candidates.append((word, score, dist))
        
        candidates.sort(key=lambda x: x[1], reverse=True)
        return candidates[:5]

class PhonemeToWordConverter:
    """Converts phoneme sequences to word sequences"""
    
    def __init__(self, phoneme_dict):
        self.phoneme_dict = phoneme_dict
        self.silence_token = ' | '
    
    def convert(self, phoneme_list, use_fuzzy=True):
        """Convert list of phoneme strings to words"""
        words = []
        current_phonemes = []
        
        for phoneme in phoneme_list:
            if phoneme == self.silence_token:
                if current_phonemes:
                    word = self._phonemes_to_word(tuple(current_phonemes), use_fuzzy)
                    if word:
                        words.append(word)
                    current_phonemes = []
            else:
                current_phonemes.append(phoneme)
        
        # Handle last word
        if current_phonemes:
            word = self._phonemes_to_word(tuple(current_phonemes), use_fuzzy)
            if word:
                words.append(word)
        
        return words
    
    def _phonemes_to_word(self, phoneme_tuple, use_fuzzy=True):
        """Convert phoneme tuple to single word"""
        # Exact match
        matches = self.phoneme_dict.lookup(phoneme_tuple)
        if matches:
            return max(matches, key=lambda w: self.phoneme_dict.word_freq.get(w, 0))
        
        # Fuzzy match
        if use_fuzzy:
            candidates = self.phoneme_dict.fuzzy_match(phoneme_tuple)
            if candidates:
                return candidates[0][0]
        
        # Fallback
        return '<' + '_'.join(phoneme_tuple) + '>'

# ============================================================================
# WER CALCULATION
# ============================================================================

def calculate_wer(predictions, references):
    """Calculate Word Error Rate"""
    total_errors = 0
    total_words = 0
    
    for pred_words, ref_words in zip(predictions, references):
        errors = editdistance.eval(pred_words, ref_words)
        total_errors += errors
        total_words += len(ref_words)
    
    wer = (total_errors / total_words * 100) if total_words > 0 else 100.0
    return wer, total_errors, total_words

# ============================================================================
# FULL EVALUATION FUNCTION
# ============================================================================

def evaluate_wer(model, dataloader, idx_to_phoneme, device='cuda'):
    """Evaluate trained model with WER metric"""
    model.eval()
    
    phoneme_dict = PhonemeDictionary()
    converter = PhonemeToWordConverter(phoneme_dict)
    
    all_pred_phoneme_seqs = []
    all_true_phoneme_seqs = []
    all_pred_word_seqs = []
    all_true_word_seqs = []
    
    print("Running evaluation...")
    
    with torch.no_grad():
        for batch_idx, (X, Y, x_len, y_len) in enumerate(dataloader):
            X = X.to(device)
            mask = torch.arange(X.size(1), device=device)[None, :] >= x_len.to(device)[:, None]
            
            logits = model(X, mask)
            pred_seqs = ctc_greedy_decode(logits)
            
            # Convert predictions to phonemes and words
            for pred_indices in pred_seqs:
                pred_phonemes = [idx_to_phoneme[idx] for idx in pred_indices if idx < len(idx_to_phoneme)]
                pred_words = converter.convert(pred_phonemes, use_fuzzy=True)
                
                all_pred_phoneme_seqs.append(pred_indices)
                all_pred_word_seqs.append(pred_words)
            
            # Convert ground truth
            start = 0
            for length in y_len:
                true_indices = Y[start:start+length].cpu().numpy().tolist()
                true_phonemes = [idx_to_phoneme[idx] for idx in true_indices if idx < len(idx_to_phoneme)]
                true_words = converter.convert(true_phonemes, use_fuzzy=True)
                
                all_true_phoneme_seqs.append(true_indices)
                all_true_word_seqs.append(true_words)
                start += length
            
            if (batch_idx + 1) % 100 == 0:
                print(f"  Processed {batch_idx + 1}/{len(dataloader)} batches")
            
            del logits, X, mask
            torch.cuda.empty_cache()
    
    # Calculate metrics
    per, _, _ = calculate_per(all_pred_phoneme_seqs, all_true_phoneme_seqs)
    wer, wer_errors, total_words = calculate_wer(all_pred_word_seqs, all_true_word_seqs)
    
    print(f"\n{'='*60}")
    print(f"EVALUATION RESULTS")
    print(f"{'='*60}")
    print(f"Phoneme Error Rate (PER): {per:.2f}%")
    print(f"Word Error Rate (WER): {wer:.2f}%")
    print(f"Total word errors: {wer_errors}")
    print(f"Total words: {total_words}")
    print(f"{'='*60}")
    
    # Show samples
    print("\nSample Predictions (first 5):")
    for i in range(min(5, len(all_pred_phoneme_seqs))):
        print(f"\n--- Sample {i+1} ---")
        
        pred_phon_str = ' '.join([idx_to_phoneme[idx] for idx in all_pred_phoneme_seqs[i][:30]])
        true_phon_str = ' '.join([idx_to_phoneme[idx] for idx in all_true_phoneme_seqs[i][:30]])
        
        print(f"Pred Phonemes: {pred_phon_str}")
        print(f"True Phonemes: {true_phon_str}")
        print(f"Pred Words: {' '.join(all_pred_word_seqs[i])}")
        print(f"True Words: {' '.join(all_true_word_seqs[i])}")
    
    return {
        'per': per,
        'wer': wer,
        'wer_errors': wer_errors,
        'total_words': total_words,
    }

# ============================================================================
# MAIN EXECUTION
# ============================================================================

print("="*60)
print("BRAIN-TO-TEXT WER EVALUATION")
print("="*60)

# 1. Load validation data
print("\nLoading validation data...")
val_files = get_all_hdf5_files(BASE_PATH, "val")
val_trials = load_full_dataset_info(val_files)
val_ds = Brain2TextDatasetFull(val_trials)
val_dl = DataLoader(
    val_ds, batch_size=BATCH_SIZE_VAL, shuffle=False,
    collate_fn=ctc_collate, num_workers=NUM_WORKERS, pin_memory=True
)

# 2. Load trained model
print("\nLoading trained model...")
model = BrainTransformer(
    vocab_size=VOCAB_SIZE,
    d_model=D_MODEL,
    nhead=NHEAD,
    num_layers=NUM_LAYERS,
    dropout=DROPOUT
)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.to(DEVICE)
model.eval()
print(f"✓ Model loaded from {MODEL_PATH}")
print(f"✓ Model has {sum(p.numel() for p in model.parameters()):,} parameters")

# 3. Run evaluation
print("\n" + "="*60)
results = evaluate_wer(
    model=model,
    dataloader=val_dl,
    idx_to_phoneme=idx_to_phoneme,
    device=DEVICE
)

print("\n" + "="*60)
print("FINAL RESULTS")
print("="*60)
print(f"WER: {results['wer']:.2f}%")
print("="*60)

BRAIN-TO-TEXT WER EVALUATION

Loading validation data...
  data_val.hdf5: 35 trials
  data_val.hdf5: 49 trials
  data_val.hdf5: 48 trials
  data_val.hdf5: 25 trials
  data_val.hdf5: 25 trials
  data_val.hdf5: 49 trials
  data_val.hdf5: 34 trials
  data_val.hdf5: 35 trials
  data_val.hdf5: 48 trials
  data_val.hdf5: 44 trials
  data_val.hdf5: 36 trials
  data_val.hdf5: 17 trials
  data_val.hdf5: 44 trials
  data_val.hdf5: 44 trials
  data_val.hdf5: 9 trials
  data_val.hdf5: 33 trials
  data_val.hdf5: 50 trials
  data_val.hdf5: 15 trials
  data_val.hdf5: 25 trials
  data_val.hdf5: 20 trials
  data_val.hdf5: 44 trials
  data_val.hdf5: 34 trials
  data_val.hdf5: 50 trials
  data_val.hdf5: 25 trials
  data_val.hdf5: 30 trials
  data_val.hdf5: 50 trials
  data_val.hdf5: 23 trials
  data_val.hdf5: 24 trials
  data_val.hdf5: 48 trials
  data_val.hdf5: 48 trials
  data_val.hdf5: 25 trials
  data_val.hdf5: 25 trials
  data_val.hdf5: 48 trials
  data_val.hdf5: 46 trials
  data_val.hdf5: 48 trials