# Final2 Model Training (Mamba-based B2T)

This notebook trains the Final2 (NeuralDecoder) model with Mamba architecture.

**Architecture:**
- FeatureExtractor: Conv layers with Highway networks
- Encoder: 5-layer bidirectional Mamba
- Decoder: 5-layer bidirectional Mamba
- Output: Phoneme predictions (40 classes)

## 1. Setup and Imports

In [None]:
# Install required packages (uncomment if needed)
# !pip install edit-distance

import os
import sys
import pickle
import time
import random
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from edit_distance import SequenceMatcher
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## 2. Configuration

In [None]:
# Configuration
config = {
    'variant': 'final2',
    'batchSize': 32,  # Reduced for debugging
    'nBatch': 5000,   # Reduced for debugging
    'seed': 0,
    
    # Architecture
    'nInputFeatures': 256,
    'nClasses': 40,
    'conv_size': 1024,
    'conv_kernel1': 7,
    'conv_kernel2': 3,
    'conv_g1': 256,
    'conv_g2': 1,
    'hidden_size': 512,
    'encoder_n_layer': 5,
    'decoder_n_layer': 5,
    'decoders': ['ph'],  # Only phoneme decoder for simplicity
    'update_probs': 0.7,
    
    # Optimizer
    'lrStart': 0.0001,  # Very conservative learning rate
    'lrEnd': 0.00001,
    'l2_decay': 0.00001,
    
    # Augmentation
    'whiteNoiseSD': 0.0,  # Disabled for debugging
    'constantOffsetSD': 0.0,
}

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {DEVICE}")

## 3. Dataset Class

In [None]:
class SpeechDataset(Dataset):
    def __init__(self, data):
        self.data = data
        self.n_days = len(data)
        self.n_trials = sum([len(d["sentenceDat"]) for d in data])

        self.neural_feats = []
        self.phone_seqs = []
        self.neural_time_bins = []
        self.phone_seq_lens = []
        self.days = []
        
        for day in range(self.n_days):
            for trial in range(len(data[day]["sentenceDat"])):
                self.neural_feats.append(data[day]["sentenceDat"][trial])
                self.phone_seqs.append(data[day]["phonemes"][trial])
                self.neural_time_bins.append(data[day]["sentenceDat"][trial].shape[0])
                self.phone_seq_lens.append(data[day]["phoneLens"][trial])
                self.days.append(day)

    def __len__(self):
        return self.n_trials

    def __getitem__(self, idx):
        return (
            torch.tensor(self.neural_feats[idx], dtype=torch.float32),
            torch.tensor(self.phone_seqs[idx], dtype=torch.int32),
            torch.tensor(self.neural_time_bins[idx], dtype=torch.int32),
            torch.tensor(self.phone_seq_lens[idx], dtype=torch.int32),
            torch.tensor(self.days[idx], dtype=torch.int64),
        )

def collate_fn(batch):
    X, y, X_lens, y_lens, days = zip(*batch)
    X_padded = pad_sequence(X, batch_first=True, padding_value=0)
    y_padded = pad_sequence(y, batch_first=True, padding_value=0)
    return (
        X_padded,
        y_padded,
        torch.stack(X_lens),
        torch.stack(y_lens),
        torch.stack(days),
    )

## 4. Module Definitions

In [None]:
# Phoneme vocabulary (40 classes)
phoneme_vocab = [
    '|', 'AA', 'AE', 'AH', 'AO', 'AW', 'AY', 'B', 'CH', 'D', 'DH',
    'EH', 'ER', 'EY', 'F', 'G', 'HH', 'IH', 'IY', 'JH', 'K', 'L',
    'M', 'N', 'NG', 'OW', 'OY', 'P', 'R', 'S', 'SH', 'T', 'TH',
    'UH', 'UW', 'V', 'W', 'Y', 'Z', 'ZH'
]

print(f"Vocab size: {len(phoneme_vocab)} phonemes")

In [None]:
class MambaLayer(nn.Module):
    """Simplified Mamba layer"""
    def __init__(self, d_model, expand_factor=2):
        super(MambaLayer, self).__init__()
        self.d_model = d_model
        self.d_inner = d_model * expand_factor

        self.in_proj = nn.Linear(d_model, self.d_inner * 2)
        self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, kernel_size=3, padding=1, groups=self.d_inner)
        self.out_proj = nn.Linear(self.d_inner, d_model)
        self.activation = nn.SiLU()
        self.norm = nn.LayerNorm(d_model)

        # Initialize with small values
        nn.init.xavier_uniform_(self.in_proj.weight, gain=0.01)
        nn.init.xavier_uniform_(self.out_proj.weight, gain=0.01)
        nn.init.zeros_(self.in_proj.bias)
        nn.init.zeros_(self.out_proj.bias)

    def forward(self, x):
        residual = x
        x = self.norm(x)
        x_proj = self.in_proj(x)
        x, gate = x_proj.chunk(2, dim=-1)
        x = x.transpose(1, 2)
        x = self.conv1d(x)
        x = x.transpose(1, 2)
        x = self.activation(x) * torch.sigmoid(gate)
        x = self.out_proj(x)
        return x + residual

class mamba_block(nn.Module):
    def __init__(self, d_model, n_layer=1, bidirectional=False, update_probs=0.7):
        super(mamba_block, self).__init__()
        self.d_model = d_model
        self.n_layer = n_layer
        self.bidirectional = bidirectional

        self.forward_layers = nn.ModuleList([MambaLayer(d_model) for _ in range(n_layer)])
        if bidirectional:
            self.backward_layers = nn.ModuleList([MambaLayer(d_model) for _ in range(n_layer)])
        self.dropout = nn.Dropout(1.0 - update_probs) if update_probs < 1.0 else None

    def forward(self, x, lens):
        for layer in self.forward_layers:
            x = layer(x)
            if self.dropout is not None:
                x = self.dropout(x)
        if self.bidirectional:
            x_backward = torch.flip(x, dims=[1])
            for layer in self.backward_layers:
                x_backward = layer(x_backward)
                if self.dropout is not None:
                    x_backward = self.dropout(x_backward)
            x_backward = torch.flip(x_backward, dims=[1])
            x = (x + x_backward) / 2
        return x, lens

print("✓ Mamba modules defined")

In [None]:
class Highway(nn.Module):
    def __init__(self, input_dim, num_layers=1):
        super(Highway, self).__init__()
        self.input_dim = input_dim
        self.num_layers = num_layers
        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            layer = nn.ModuleDict({
                'transform': nn.Linear(input_dim, input_dim),
                'gate': nn.Linear(input_dim, input_dim),
            })
            self.layers.append(layer)

    def forward(self, x, lens):
        for layer in self.layers:
            transform_gate = torch.sigmoid(layer['gate'](x))
            transform = torch.relu(layer['transform'](x))
            x = transform_gate * transform + (1 - transform_gate) * x
        return x, lens

class conv_block(nn.Module):
    def __init__(self, input_dims, output_dims, kernel_size, stride, groups=1):
        super(conv_block, self).__init__()
        self.stride = stride
        self.kernel_size = kernel_size
        padding = (kernel_size - 1) // 2
        self.conv = nn.Conv1d(input_dims, output_dims, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups)
        self.bn = nn.BatchNorm1d(output_dims)
        self.relu = nn.ReLU()

    def forward(self, x, lens):
        x = x.transpose(1, 2)
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = x.transpose(1, 2)
        new_lens = (lens + 2 * ((self.kernel_size - 1) // 2) - self.kernel_size) // self.stride + 1
        new_lens = new_lens.clamp(min=1)
        return x, new_lens

class UnPack(nn.Module):
    def forward(self, x, lens):
        return x, lens

print("✓ Support modules defined")

## 5. Model Definition

In [None]:
class ModuleStack(nn.Module):
    def __init__(self, layers):
        super(ModuleStack, self).__init__()
        modules_list = []
        self.output_dims = 0

        for l in layers:
            if l[0] == "mamba":
                _, in_channels, n_layer, bidirectional, update_probs = l
                modules_list.append(mamba_block(d_model=in_channels, n_layer=n_layer, bidirectional=bidirectional, update_probs=update_probs))
                self.output_dims = in_channels
            elif l[0] == "highway":
                _, in_channels, n_layer = l
                modules_list.append(Highway(input_dim=in_channels, num_layers=n_layer))
                self.output_dims = in_channels
            elif l[0] == "conv":
                _, input_dims, output_dims, kernel_size, stride, groups = l
                modules_list.append(conv_block(input_dims, output_dims, kernel_size, stride, groups))
                self.output_dims = output_dims
            elif l[0] == "unpack":
                modules_list.append(UnPack())

        self.layers = nn.ModuleList(modules_list)
        assert self.output_dims != 0

    def forward(self, hidden_states, lens):
        for layer in self.layers:
            hidden_states, lens = layer(hidden_states, lens)
        return hidden_states, lens

print("✓ ModuleStack defined")

In [None]:
class SimpleNeuralDecoder(nn.Module):
    """Simplified NeuralDecoder for Colab"""
    def __init__(self, config):
        super(SimpleNeuralDecoder, self).__init__()
        
        # Feature extractor
        self.feature_extractor = ModuleStack([
            ["unpack"],
            ["conv", 256, config['conv_size'], config['conv_kernel1'], 2, config['conv_g1']],
            ["highway", config['conv_size'], 2],
            ["conv", config['conv_size'], config['hidden_size'], config['conv_kernel2'], 2, config['conv_g2']],
            ["highway", config['hidden_size'], 2],
        ])
        
        # Encoder
        self.encoder = ModuleStack([
            ["mamba", config['hidden_size'], config['encoder_n_layer'], True, config['update_probs']]
        ])
        
        # Decoder
        self.decoder = ModuleStack([
            ["mamba", config['hidden_size'], config['decoder_n_layer'], True, config['update_probs']]
        ])
        
        # Output layer (40 phonemes + 1 blank for CTC)
        self.linear = nn.Linear(config['hidden_size'], config['nClasses'] + 1)
        
    def forward(self, x, day_idx=None):
        batch_size = x.shape[0]
        input_lens = torch.full((batch_size,), x.shape[1], device=x.device, dtype=torch.int32)
        
        # Feature extraction
        x, lens = self.feature_extractor(x, input_lens)
        
        # Encoder
        x, lens = self.encoder(x, lens)
        
        # Decoder
        x, lens = self.decoder(x, lens)
        
        # Output
        logits = self.linear(x)
        
        # Store lens for CTC loss
        self._last_output_lens = lens
        
        return logits

print("✓ SimpleNeuralDecoder defined")

## 6. Load Data

In [None]:
# IMPORTANT: Update this path to your data location
# For Colab, upload the data first or mount Google Drive
DATASET_PATH = "/home/ivansit1214/competitionData/ptDecoder_ctc"  # Update this!

# For Google Drive (uncomment if using Drive):
# from google.colab import drive
# drive.mount('/content/drive')
# DATASET_PATH = '/content/drive/MyDrive/competitionData/ptDecoder_ctc'

print(f"Loading data from: {DATASET_PATH}")
with open(DATASET_PATH, "rb") as f:
    data = pickle.load(f)

train_ds = SpeechDataset(data["train"])
test_ds = SpeechDataset(data["test"])

train_loader = DataLoader(train_ds, batch_size=config['batchSize'], shuffle=True, num_workers=0, pin_memory=True, collate_fn=collate_fn)
test_loader = DataLoader(test_ds, batch_size=config['batchSize'], shuffle=False, num_workers=0, pin_memory=True, collate_fn=collate_fn)

print(f"✓ Train samples: {len(train_ds)}")
print(f"✓ Test samples: {len(test_ds)}")
print(f"✓ Number of days: {len(data['train'])}")

## 7. Create Model

In [None]:
# Set seed
torch.manual_seed(config['seed'])
np.random.seed(config['seed'])
random.seed(config['seed'])

# Create model
print("Creating model...")
model = SimpleNeuralDecoder(config).to(DEVICE)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"✓ Total parameters: {total_params:,}")
print(f"✓ Trainable parameters: {trainable_params:,}")

# Loss and optimizer
loss_ctc = nn.CTCLoss(blank=0, reduction="mean", zero_infinity=True)
optimizer = torch.optim.Adam(model.parameters(), lr=config['lrStart'], weight_decay=config['l2_decay'])
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=config['lrEnd']/config['lrStart'], total_iters=config['nBatch'])

print(f"✓ Optimizer: Adam, lr={config['lrStart']}")
print(f"✓ Model ready to train!")

## 8. Test Forward Pass (DEBUG)

In [None]:
# Test forward pass with one batch
print("Testing forward pass...")
model.eval()
with torch.no_grad():
    X, y, X_len, y_len, day_idx = next(iter(train_loader))
    X = X.to(DEVICE)
    
    print(f"Input shape: {X.shape}")
    pred = model(X)
    print(f"Output shape: {pred.shape}")
    print(f"Output lens: {model._last_output_lens}")
    print(f"Output min: {pred.min().item():.4f}, max: {pred.max().item():.4f}")
    print(f"✓ Forward pass works!")

## 9. Training Loop

In [None]:
print(f"\nStarting training for {config['nBatch']} batches...")
print("=" * 70)

test_loss_list = []
test_per_list = []
best_per = None
start_time = time.time()

for batch_idx in range(config['nBatch']):
    model.train()
    
    # Get batch
    X, y, X_len, y_len, day_idx = next(iter(train_loader))
    X, y, X_len, y_len = X.to(DEVICE), y.to(DEVICE), X_len.to(DEVICE), y_len.to(DEVICE)
    
    # Forward
    pred = model(X)
    out_lens = model._last_output_lens
    
    # CTC loss
    loss = loss_ctc(
        torch.permute(pred.log_softmax(2), [1, 0, 2]),
        y,
        out_lens,
        y_len,
    )
    loss = torch.sum(loss)
    
    # Check for NaN
    if torch.isnan(loss):
        print(f"\n⚠️  NaN loss at batch {batch_idx}!")
        print(f"   pred stats: min={pred.min():.4f}, max={pred.max():.4f}, mean={pred.mean():.4f}")
        print(f"   Stopping training...")
        break
    
    # Backward
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()
    scheduler.step()
    
    # Evaluation every 100 batches
    if batch_idx % 100 == 0:
        with torch.no_grad():
            model.eval()
            all_loss = []
            total_edit = 0
            total_len = 0
            
            for eval_idx, (X, y, X_len, y_len, test_day_idx) in enumerate(test_loader):
                if eval_idx >= 10:  # Only 10 batches for speed
                    break
                    
                X, y, X_len, y_len = X.to(DEVICE), y.to(DEVICE), X_len.to(DEVICE), y_len.to(DEVICE)
                pred = model(X)
                out_lens = model._last_output_lens
                
                loss = loss_ctc(torch.permute(pred.log_softmax(2), [1, 0, 2]), y, out_lens, y_len)
                loss = torch.sum(loss)
                all_loss.append(loss.cpu().detach().numpy())
                
                # Decode and compute PER
                for i in range(pred.shape[0]):
                    logits = pred[i, :out_lens[i], :]
                    decoded = torch.argmax(logits, dim=-1)
                    decoded = torch.unique_consecutive(decoded)
                    decoded = decoded.cpu().detach().numpy()
                    decoded = decoded[decoded != 0]
                    target = y[i, :y_len[i]].cpu().detach().numpy()
                    matcher = SequenceMatcher(a=target.tolist(), b=decoded.tolist())
                    total_edit += matcher.distance()
                    total_len += len(target)
            
            avg_loss = np.mean(all_loss)
            per = total_edit / total_len
            elapsed = (time.time() - start_time) / 100 if batch_idx > 0 else 0.0
            current_lr = optimizer.param_groups[0]['lr']
            
            print(f"batch {batch_idx:5d}, loss: {avg_loss:.4f}, PER: {per:.4f}, lr: {current_lr:.6f}, time/batch: {elapsed:.3f}s")
            start_time = time.time()
            
            test_loss_list.append(avg_loss)
            test_per_list.append(per)
            
            if best_per is None or per < best_per:
                best_per = per
                print(f"  → New best PER: {per:.4f}")

print("\n" + "=" * 70)
print("TRAINING COMPLETE!")
print("=" * 70)
if best_per is not None:
    print(f"Best PER: {best_per:.4f} ({best_per*100:.2f}%)")

## 10. Plot Results

In [None]:
import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(test_loss_list)
ax1.set_xlabel('Evaluation Step (x100 batches)')
ax1.set_ylabel('Test Loss')
ax1.set_title('Test Loss Over Training')
ax1.grid(True)

ax2.plot(test_per_list)
ax2.set_xlabel('Evaluation Step (x100 batches)')
ax2.set_ylabel('Phoneme Error Rate')
ax2.set_title('PER Over Training')
ax2.grid(True)

plt.tight_layout()
plt.show()

print(f"Final PER: {test_per_list[-1]:.4f}" if test_per_list else "No results yet")