# Biosignal to Phoneme Conversion Demo

This notebook demonstrates the training and inference of two models for converting biosignals to phonemes:
1. **UniGRU + CTC**: The baseline model.
2. **Transformer**: An Encoder-Decoder attention-based model.

The code has been refactored into modular Python files for better maintainability.

In [None]:
import torch
import numpy as np
import sys
import os

# Add project root to path
sys.path.append(os.path.abspath('..'))

from src.dataset import SyntheticDataset, MAX_SEQUENCE_LENGTH
from src.models.gru import SensorToPhonemeGRU
from src.models.transformer import SensorToPhonemeTransformer
from src.utils import INDEX_TO_PH, greedy_decode, FEATURE_DIM, SOS_INDEX, NUM_PHONEMES

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 1. Train UniGRU + CTC Model
We will run the training script `src/train.py` as a module.

In [None]:
# Run from project root
import os
os.chdir('..') # Go to project root
!python -m src.train --model_type gru --epochs 5 --num_samples 1000

## 2. Train Transformer Model
Now we train the Transformer model.

In [None]:
!python -m src.train --model_type transformer --epochs 5 --num_samples 1000

## 3. Inference
Let's load the trained models and run inference on a sample sentence.

In [None]:
# Helper to generate a sample
ds = SyntheticDataset(num_samples=1)
sample_idx = 0
X, Y, L, L_target = ds[sample_idx]

# Prepare input
X_tensor = X.unsqueeze(0).to(device) # (1, T, F)
input_len = torch.tensor([L], dtype=torch.long).to(device)

print("Ground Truth:", [INDEX_TO_PH.get(int(y), "?") for y in Y])

### UniGRU Inference

In [None]:
model_gru = SensorToPhonemeGRU().to(device)
try:
    model_gru.load_state_dict(torch.load("model_gru.pth"))
    model_gru.eval()
    with torch.no_grad():
        logp = model_gru(X_tensor)
        input_len_reduced = (input_len // 4).clamp(min=1)
        decoded = greedy_decode(logp, input_len_reduced)[0]
        print("UniGRU Prediction:", decoded)
except FileNotFoundError:
    print("Model file not found. Please run training first.")

### Transformer Inference
For Transformer, we need to autoregressively generate the output.

In [None]:
model_tf = SensorToPhonemeTransformer().to(device)
try:
    model_tf.load_state_dict(torch.load("model_transformer.pth"))
    model_tf.eval()
    
    # Autoregressive decoding
    # Start with <SOS>
    decoder_input = torch.tensor([[SOS_INDEX]], dtype=torch.long).to(device)
    
    pred_seq = []
    max_len = 50
    
    with torch.no_grad():
        for _ in range(max_len):
            # Create masks if needed (not strictly necessary for inference one step at a time if we just pass full seq)
            # But nn.Transformer expects full sequence so far
            tgt_mask = model_tf.generate_square_subsequent_mask(decoder_input.size(1)).to(device)
            
            output = model_tf(X_tensor, decoder_input, tgt_mask=tgt_mask)
            
            # Get last token prediction
            last_token_logits = output[:, -1, :] # (B, C)
            next_token = last_token_logits.argmax(dim=-1).item()
            
            if next_token == SOS_INDEX: # Should not happen usually
                continue
                
            pred_seq.append(INDEX_TO_PH.get(next_token, "?"))
            
            # Append to input
            decoder_input = torch.cat([decoder_input, torch.tensor([[next_token]], device=device)], dim=1)
            
            # Stop condition (if we had EOS, but we didn't train with EOS explicitly in the loop above for simplicity)
            # But let's stop if length matches roughly
            
    print("Transformer Prediction:", pred_seq)
    
except FileNotFoundError:
    print("Model file not found. Please run training first.")