In [None]:
import sys
sys.path.insert(0, 'src')

In [None]:
import os
import pickle
import time
import torch
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

In [None]:
from neural_decoder.dataset import SpeechDataset
from neural_decoder.advanced_trainer import train_advanced_model

In [None]:
DATASET_PATH = os.path.expanduser("~/competitionData/ptDecoder_ctc")
OUTPUT_DIR = os.path.expanduser("~/results/advanced_simple")

In [None]:
args = {
    'variant': 'advanced',
    'batchSize': 16,
    'nBatch': 30000,
    'seed': 0,
    'lr': 0.005,
    'weightDecay': 0.0001,
    'nClasses': 40,
    'nInputFeatures': 256,
    'strideLen': 4,
    'kernelLen': 8,
    'gaussianSmoothWidth': 2.0,
    'modelDim': 512,
    'modelLayers': 6,
    'modelHeads': 8,
    'dropout': 0.2,
    'intermediateLayer': 3,
    'timeMaskRatio': 0.6,
    'channelDropProb': 0.3,
    'featureMaskProb': 0.1,
    'minTimeMask': 16,
    'consistencyWeight': 0.2,
    'intermediateLossWeight': 0.3,
    'testTimeLR': 0.0001,
    'enableTestTimeAdaptation': True,
    'enableOnlineAdaptation': True,
    'onlineAdaptationLR': 0.00001,
    'diphoneContext': 40,
    'transformerTimeMaskProb': 0.1,
    'relPosMaxDist': None,
    'relBiasByHead': True,
    'ffMult': 4,
    'outputDir': OUTPUT_DIR,
    'datasetPath': DATASET_PATH,
}

In [None]:
print("=" * 70)
print("ADVANCED TRANSFORMER MODEL TRAINING")
print("WARNING: This model typically gets ~72% CER (worse than baseline)")
print("=" * 70)
print("\nStarting training...")

In [None]:
train_advanced_model(args)

In [None]:
print("\n" + "=" * 70)
print("TRAINING COMPLETE!")
print("=" * 70)

# Language Model Evaluation


In [None]:
# Install KenLM if not already installed
try:
    import kenlm
    print("✓ KenLM already installed")
except ImportError:
    print("Installing KenLM...")
    ! pip install https://github.com/kpu/kenlm/archive/master.zip
    import kenlm
    print("✓ KenLM installed successfully")

from neural_decoder.phoneme_lm import PhonemeLM, beam_search_decode, create_phoneme_map
from neural_decoder.advanced_models import StreamingTransformerDecoder
from edit_distance import SequenceMatcher
import numpy as np

LM_PATH = "phoneme_lm.arpa"
phoneme_map = create_phoneme_map()
lm = PhonemeLM(LM_PATH, phoneme_map=phoneme_map)
print(f"✓ Language Model loaded from: {LM_PATH}")

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

with open(DATASET_PATH, "rb") as f:
    data = pickle.load(f)

test_ds = SpeechDataset(data["test"])

def collate_fn(batch):
    X, y, X_lens, y_lens, days = zip(*batch)
    X_padded = pad_sequence(X, batch_first=True, padding_value=0.0)
    y_padded = pad_sequence(y, batch_first=True, padding_value=0)
    return (
        X_padded,
        y_padded,
        torch.stack(X_lens),
        torch.stack(y_lens),
        torch.stack(days),
    )

test_loader = DataLoader(
    test_ds,
    batch_size=args['batchSize'],
    shuffle=False,
    num_workers=0,
    pin_memory=False,
    collate_fn=collate_fn,
)

model = StreamingTransformerDecoder(
    neural_dim=args['nInputFeatures'],
    n_phonemes=args['nClasses'],
    d_model=args['modelDim'],
    nhead=args['modelHeads'],
    num_layers=args['modelLayers'],
    nDays=len(data["train"]),
    dropout=args['dropout'],
    device=device,
    strideLen=args['strideLen'],
    kernelLen=args['kernelLen'],
    gaussianSmoothWidth=args['gaussianSmoothWidth'],
    ff_mult=args['ffMult'],
    rel_pos_max_dist=args['relPosMaxDist'],
    diphone_context=args['diphoneContext'],
).to(device)

weights_path = os.path.join(OUTPUT_DIR, "modelWeights.pt")
model.load_state_dict(torch.load(weights_path, map_location=device))
model.eval()

print(f"✓ Model loaded from: {weights_path}")
print(f"Using device: {device}")

In [None]:
all_predictions_baseline = []
all_predictions_lm = []
all_targets = []

LM_WEIGHT = 0.8  
BEAM_WIDTH = 10 

print(f"LM Weight: {LM_WEIGHT}")
print(f"Beam Width: {BEAM_WIDTH}\n")

with torch.no_grad():
    for batch_idx, (X, y, X_len, y_len, day_idx) in enumerate(test_loader):
        X = X.to(device)
        X_len = X_len.to(device)
        day_idx = day_idx.to(device)
        

        output = model(X, X_len, day_idx)
        logits = output['log_probs'].exp()  
        lengths = output['eff_lengths']
        log_probs = torch.log_softmax(logits, dim=-1)
        
        for i in range(len(y)):
            seq_len = int(lengths[i])
            lp = log_probs[i, :seq_len, :]
            
            greedy = torch.argmax(lp, dim=-1).cpu().numpy()
            decoded_baseline = []
            prev = None
            for tok in greedy:
                if tok == prev or tok == 0:
                    prev = tok
                    continue
                decoded_baseline.append(tok)
                prev = tok
            all_predictions_baseline.append(decoded_baseline)
            
            decoded_lm = beam_search_decode(
                lp,
                lm=lm,
                lm_weight=LM_WEIGHT,
                beam_width=BEAM_WIDTH,
                blank_id=0,
                topk_acoustic=5
            )
            all_predictions_lm.append(decoded_lm)
            
            target = y[i, :y_len[i]].cpu().numpy().tolist()
            all_targets.append(target)
        
        if (batch_idx + 1) % 10 == 0:
            print(f"  Processed {batch_idx + 1}/{len(test_loader)} batches")

print("✓ Evaluation complete!")

In [None]:
def compute_error_rate(predictions, targets):
    total_edit = 0
    total_len = 0
    for pred, target in zip(predictions, targets):
        matcher = SequenceMatcher(a=target, b=pred)
        total_edit += matcher.distance()
        total_len += len(target)
    return total_edit / total_len if total_len > 0 else 0.0

cer_baseline = compute_error_rate(all_predictions_baseline, all_targets)
cer_lm = compute_error_rate(all_predictions_lm, all_targets)
improvement = (cer_baseline - cer_lm) / cer_baseline * 100

print("\n" + "=" * 70)
print("FINAL RESULTS")
print("=" * 70)
print(f"\nBaseline (Greedy):")
print(f"  CER: {cer_baseline:.4f} ({cer_baseline*100:.2f}%)")
print(f"\nWith Language Model:")
print(f"  CER: {cer_lm:.4f} ({cer_lm*100:.2f}%)")
print(f"  Improvement: {improvement:.2f}% relative")
print(f"  Absolute gain: {(cer_baseline - cer_lm)*100:.2f} percentage points")
print(f"Language Model improved accuracy by {improvement:.1f}%")
