In [None]:
"""
Training script for the final2 model using NeuralDecoder architecture
"""
import sys
sys.path.insert(0, 'src')

In [None]:
import os
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger
from omegaconf import OmegaConf
import torch

In [None]:
from neural_decoder.neural_decoder_trainer import DataModule
from neural_decoder.final2_model import NeuralDecoder

In [None]:
CONFIG_PATH = "src/neural_decoder/conf/decoder/final2.yaml"
DATASET_NAME = "competition_data"
OUTPUT_DIR = "results/final2_training"

In [None]:
print("=" * 70)
print("FINAL2 MODEL TRAINING")
print("=" * 70)

In [None]:
config = OmegaConf.load(CONFIG_PATH)
print(f"\nLoaded configuration from: {CONFIG_PATH}")
print(f"Model variant: {config.variant}")

In [None]:
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [None]:
torch.manual_seed(config.get('seed', 0))

In [None]:
datamodule = DataModule(
    dataset_name=DATASET_NAME,
    batch_size=config.batchSize,
    num_workers=4,
)

In [None]:
model = NeuralDecoder(
    conv_size=config.get('conv_size', 1024),
    conv_kernel1=config.get('conv_kernel1', 7),
    conv_kernel2=config.get('conv_kernel2', 3),
    conv_g1=config.get('conv_g1', 256),
    conv_g2=config.get('conv_g2', 1),
    hidden_size=config.get('hidden_size', 512),
    encoder_n_layer=config.get('encoder_n_layer', 5),
    decoder_n_layer=config.get('decoder_n_layer', 5),
    decoders=config.get('decoders', ['al', 'ph']),
    update_probs=config.get('update_probs', 0.7),
    al_loss_weight=config.get('al_loss_weight', 0.5),
    peak_lr=config.get('peak_lr', 1e-4),
    last_lr=config.get('last_lr', 1e-6),
    beta_1=config.get('beta_1', 0.9),
    beta_2=config.get('beta_2', 0.95),
    weight_decay=config.get('weight_decay', 0.1),
    eps=config.get('eps', 1e-08),
    lr_warmup_perc=config.get('lr_warmup_perc', 0.1),
)

In [None]:
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nModel parameters:")
print(f"  Total: {total_params:,}")
print(f"  Trainable: {trainable_params:,}")

In [None]:
checkpoint_callback = ModelCheckpoint(
    dirpath=OUTPUT_DIR,
    filename="final2-{epoch:02d}-{wer:.4f}",
    monitor="wer",
    mode="min",
    save_top_k=3,
    save_last=True,
)

In [None]:
lr_monitor = LearningRateMonitor(logging_interval='step')

In [None]:
logger = TensorBoardLogger(
    save_dir=OUTPUT_DIR,
    name="final2_logs",
)

In [None]:
trainer = L.Trainer(
    max_epochs=config.get('max_epochs', 100),
    accelerator="auto",
    devices=1,
    logger=logger,
    callbacks=[checkpoint_callback, lr_monitor],
    gradient_clip_val=config.get('gradient_clip_val', 1.0),
    accumulate_grad_batches=config.get('accumulate_grad_batches', 1),
    precision=config.get('precision', '32'),
    log_every_n_steps=10,
)

In [None]:
print(f"\nStarting training for {config.get('max_epochs', 100)} epochs...")
print("=" * 70)

In [None]:
trainer.fit(model, datamodule=datamodule)

In [None]:
print("\n" + "=" * 70)
print("TRAINING COMPLETE!")
print("=" * 70)
print(f"Best model saved to: {checkpoint_callback.best_model_path}")
print("=" * 70)

# Language Model Evaluation

Now let's evaluate the model with a phoneme language model for improved accuracy!

In [None]:
# Install KenLM if not already installed
try:
    import kenlm
    print("✓ KenLM already installed")
except ImportError:
    print("Installing KenLM...")
    !pip install https://github.com/kpu/kenlm/archive/master.zip
    import kenlm
    print("✓ KenLM installed successfully")

# Import LM module
from neural_decoder.phoneme_lm import PhonemeLM, beam_search_decode, create_phoneme_map
from edit_distance import SequenceMatcher
import numpy as np

# Load the language model
LM_PATH = "phoneme_lm.arpa"
phoneme_map = create_phoneme_map()
lm = PhonemeLM(LM_PATH, phoneme_map=phoneme_map)
print(f"✓ Language Model loaded from: {LM_PATH}")

In [None]:
# Load best model checkpoint
best_model = NeuralDecoder.load_from_checkpoint(
    checkpoint_callback.best_model_path
)
best_model.eval()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_model = best_model.to(device)

print(f"✓ Loaded best model from: {checkpoint_callback.best_model_path}")
print(f"Using device: {device}")

In [None]:
# Evaluate with Language Model
print("=" * 70)
print("EVALUATING WITH LANGUAGE MODEL")
print("=" * 70)

all_predictions_baseline = []
all_predictions_lm = []
all_targets = []

LM_WEIGHT = 0.8  # Can adjust: 0.4-1.2
BEAM_WIDTH = 10  # Can adjust: 5-50

print(f"LM Weight: {LM_WEIGHT}")
print(f"Beam Width: {BEAM_WIDTH}\n")

# Get test dataloader
test_dataloader = datamodule.test_dataloader()

with torch.no_grad():
    for batch_idx, batch in enumerate(test_dataloader):
        # Move batch to device
        x = batch['x'].to(device)
        day_idx = batch['dayIndex'].to(device)
        y = batch['y']
        y_len = batch['yLen']
        
        # Forward pass - get phoneme predictions
        outputs = best_model(x, day_idx)
        logits = outputs['ph']  # Phoneme logits
        log_probs = torch.log_softmax(logits, dim=-1)
        
        # Decode each sample
        for i in range(len(y)):
            lp = log_probs[i]  # [T, V]
            
            # Baseline: Greedy decoding
            greedy = torch.argmax(lp, dim=-1).cpu().numpy()
            decoded_baseline = []
            prev = None
            for tok in greedy:
                if tok == prev or tok == 0:
                    prev = tok
                    continue
                decoded_baseline.append(tok)
                prev = tok
            all_predictions_baseline.append(decoded_baseline)
            
            # With LM: Beam search
            decoded_lm = beam_search_decode(
                lp,
                lm=lm,
                lm_weight=LM_WEIGHT,
                beam_width=BEAM_WIDTH,
                blank_id=0,
                topk_acoustic=5
            )
            all_predictions_lm.append(decoded_lm)
            
            # Get target
            target = y[i, :y_len[i]].cpu().numpy().tolist()
            all_targets.append(target)
        
        if (batch_idx + 1) % 10 == 0:
            print(f"  Processed {batch_idx + 1}/{len(test_dataloader)} batches")

print("✓ Evaluation complete!")

In [None]:
# Compute PER
def compute_error_rate(predictions, targets):
    total_edit = 0
    total_len = 0
    for pred, target in zip(predictions, targets):
        matcher = SequenceMatcher(a=target, b=pred)
        total_edit += matcher.distance()
        total_len += len(target)
    return total_edit / total_len if total_len > 0 else 0.0

per_baseline = compute_error_rate(all_predictions_baseline, all_targets)
per_lm = compute_error_rate(all_predictions_lm, all_targets)
improvement = (per_baseline - per_lm) / per_baseline * 100

print("\n" + "=" * 70)
print("FINAL RESULTS")
print("=" * 70)
print(f"\nBaseline (Greedy):")
print(f"  PER: {per_baseline:.4f} ({per_baseline*100:.2f}%)")
print(f"\nWith Language Model:")
print(f"  PER: {per_lm:.4f} ({per_lm*100:.2f}%)")
print(f"  Improvement: {improvement:.2f}% relative")
print(f"  Absolute gain: {(per_baseline - per_lm)*100:.2f} percentage points")
print("\n" + "=" * 70)
print(f"Language Model improved accuracy by {improvement:.1f}%!")
print("=" * 70)