In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Step 2: Set Working Directory
import os
BASE_DIR = '/content/drive/MyDrive/speech_understanding_project'
DATA_DIR = os.path.join(BASE_DIR, 'data')
os.makedirs(DATA_DIR, exist_ok=True)


In [None]:
!pip install torch torchvision torchaudio transformers sentencepiece


In [None]:

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
from tqdm import tqdm


In [None]:
# Step 4: Load Data
# Format: <utt_id>\t<unit_seq>
def load_units(file_path):
    with open(file_path, 'r') as f:
        return [line.strip().split('\t')[1] for line in f.readlines()]

def load_text(file_path):
    with open(file_path, 'r') as f:
        return [line.strip() for line in f.readlines()]

unit_file = os.path.join(DATA_DIR, 'units/lrl_units.txt')
text_file = os.path.join(DATA_DIR, 'text/hrl_text.txt')

unit_seqs = load_units(unit_file)
text_seqs = load_text(text_file)

assert len(unit_seqs) == len(text_seqs), "Mismatch between unit and text pairs"


In [None]:
# Step 5: Tokenization (Units & Text)
from transformers import AutoTokenizer

# Tokenizer for HRL Text (e.g., English, using pretrained BPE tokenizer)
text_tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-50")

# Unit tokenizer (integer units, vocab size = n_clusters)
class UnitTokenizer:
    def __init__(self, vocab_size):
        self.vocab_size = vocab_size
    def encode(self, seq): return list(map(int, seq.split()))
    def decode(self, ids): return " ".join(map(str, ids))

unit_tokenizer = UnitTokenizer(vocab_size=100)


In [None]:
# Step 6: Create PyTorch Dataset
MAX_LEN_UNITS = 128
MAX_LEN_TEXT = 64

class UnitToTextDataset(Dataset):
    def __init__(self, unit_seqs, text_seqs):
        self.unit_seqs = unit_seqs
        self.text_seqs = text_seqs
    def __len__(self):
        return len(self.unit_seqs)
    def __getitem__(self, idx):
        u = unit_tokenizer.encode(self.unit_seqs[idx])[:MAX_LEN_UNITS]
        t = text_tokenizer.encode(self.text_seqs[idx], truncation=True, max_length=MAX_LEN_TEXT)
        return {
            "input_ids": torch.tensor(u, dtype=torch.long),
            "labels": torch.tensor(t, dtype=torch.long)
        }

dataset = UnitToTextDataset(unit_seqs, text_seqs)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=lambda x: {
    "input_ids": nn.utils.rnn.pad_sequence([d["input_ids"] for d in x], batch_first=True, padding_value=0),
    "labels": nn.utils.rnn.pad_sequence([d["labels"] for d in x], batch_first=True, padding_value=-100),
})


In [None]:
# Step 7: Define Transformer Decoder Model
class SeqDecoder(nn.Module):
    def __init__(self, unit_vocab_size, text_vocab_size, d_model=256, nhead=4, num_layers=4):
        super().__init__()
        self.unit_emb = nn.Embedding(unit_vocab_size, d_model)
        self.text_emb = nn.Embedding(text_vocab_size, d_model)
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead),
            num_layers=num_layers
        )
        self.decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead),
            num_layers=num_layers
        )
        self.out_proj = nn.Linear(d_model, text_vocab_size)

    def forward(self, src, tgt):
        src_mask = None
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
        src = self.unit_emb(src)
        tgt = self.text_emb(tgt)
        memory = self.encoder(src, src_key_padding_mask=(src == 0))
        output = self.decoder(tgt, memory, tgt_mask=tgt_mask)
        return self.out_proj(output)

model = SeqDecoder(unit_vocab_size=100, text_vocab_size=text_tokenizer.vocab_size).cuda()


In [None]:
# Step 8: Train the Model
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4)
criterion = nn.CrossEntropyLoss(ignore_index=-100)

EPOCHS = 5

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for batch in tqdm(dataloader):
        input_ids = batch["input_ids"].cuda()
        labels = batch["labels"].cuda()
        decoder_input = labels[:, :-1]
        target = labels[:, 1:]

        outputs = model(input_ids, decoder_input)
        loss = criterion(outputs.view(-1, outputs.shape[-1]), target.reshape(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {total_loss/len(dataloader):.4f}")


In [None]:
# Step 9: Save Model
model_path = os.path.join(BASE_DIR, 'trained_decoder.pt')
torch.save(model.state_dict(), model_path)
print("Model saved to:", model_path)
