In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Step 2: Set Working Directory
import os
BASE_DIR = '/content/drive/MyDrive/speech_understanding_project'
MODEL_PATH = os.path.join(BASE_DIR, 'trained_decoder.pt')
DATA_DIR = os.path.join(BASE_DIR, 'data')


In [None]:
!pip install transformers sentencepiece torch

In [None]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer

In [None]:
# 📍 Step 4: Define Tokenizers
text_tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-50")

class UnitTokenizer:
    def __init__(self, vocab_size):
        self.vocab_size = vocab_size
    def encode(self, seq): return list(map(int, seq.split()))
    def decode(self, ids): return " ".join(map(str, ids))

unit_tokenizer = UnitTokenizer(vocab_size=100)


In [None]:
# Step 5: Define Model Architecture
class SeqDecoder(nn.Module):
    def __init__(self, unit_vocab_size, text_vocab_size, d_model=256, nhead=4, num_layers=4):
        super().__init__()
        self.unit_emb = nn.Embedding(unit_vocab_size, d_model)
        self.text_emb = nn.Embedding(text_vocab_size, d_model)
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead),
            num_layers=num_layers
        )
        self.decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead),
            num_layers=num_layers
        )
        self.out_proj = nn.Linear(d_model, text_vocab_size)

    def forward(self, src, tgt):
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
        src = self.unit_emb(src)
        tgt = self.text_emb(tgt)
        memory = self.encoder(src, src_key_padding_mask=(src == 0))
        output = self.decoder(tgt, memory, tgt_mask=tgt_mask)
        return self.out_proj(output)


In [None]:
# 📍 Step 6: Load Model
model = SeqDecoder(unit_vocab_size=100, text_vocab_size=text_tokenizer.vocab_size).cuda()
model.load_state_dict(torch.load(MODEL_PATH))
model.eval()
print("Model Loaded")


In [None]:
# Step 7: Inference Function
def translate_units(unit_sequence, max_length=50):
    input_ids = torch.tensor([unit_tokenizer.encode(unit_sequence)], dtype=torch.long).cuda()
    src = input_ids

    # Start with decoder input as BOS token
    tgt_ids = torch.tensor([[text_tokenizer.bos_token_id]], dtype=torch.long).cuda()

    for _ in range(max_length):
        with torch.no_grad():
            logits = model(src, tgt_ids)
            next_token_logits = logits[:, -1, :]
            next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(1)
        if next_token.item() == text_tokenizer.eos_token_id:
            break
        tgt_ids = torch.cat((tgt_ids, next_token), dim=1)

    return text_tokenizer.decode(tgt_ids.squeeze().tolist(), skip_special_tokens=True)


In [None]:
# Example pseudo-phoneme sequence
example_units = "23 45 12 7 9 4 3 6 23 45 12 3 9 2"
translated_text = translate_units(example_units)
print("Translated HRL Text:\n", translated_text)
