# Tarea 2. Transformer con instrucciones

# Imports

In [1]:
import os
import torch
import torch, torchaudio, glob
import numpy as np
import random
import csv

from Trabajo_Utils import NoiseAug, RIRAug, identity, wer
from Trabajo_Model import *

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
SEQ_LEN = 12
NB_EPOCHS = 5
BATCH_SIZE = 32

def seed_everything(seed):      
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
seed_everything(42)

### Data Downloads

In [3]:
# dataset = load_dataset("FluidInference/musan", split="train")

# if not os.path.exists("musan_small"):
#     os.makedirs("musan_small", exist_ok=True)
#     for i, example in enumerate(dataset):
#         audio = example["audio"]["array"]
#         sr = example["audio"]["sampling_rate"]
#         sf.write(f"musan_small/file_{i}.wav", audio, sr)
        
# url = 'https://openslr.elda.org/resources/28/rirs_noises.zip'

# if not os.path.exists('RIRS_NOISES'):
#     if not os.path.exists('rirs_noises.zip'):
#         os.system('wget ' + url)
#     os.system('unzip -q rirs_noises.zip')
#     os.system('rm rirs_noises.zip')

## 2.1
A continuación prepare un tokenizador que codifique mediante tokens especiales antes del mensaje de texto la acción que queremos ejecutar con el transformer de las 4 posibles: <br>
- transcribe_es para transcribir directamente un audio en español a español <br>
- transcribe_en para transcribir directamente un audio en inglés a inglés <br>
- translate_en_es para traducir el mensaje con audio en inglés a español <br>
- translate_es_en para traducir el mensaje con audio en español a inglés <br>

In [4]:
class WordTokenizer:
    def __init__(self, csv_file):
        # Token base
        self.word2index = {
            '<pad>': 0,
            '<sos>': 1,
            '<eos>': 2,
            '<unk>': 3,

            # Task tokens (Tarea 2)
            '<transcribe_es>': 4,
            '<transcribe_en>': 5,
            '<translate_en_es>': 6,
            '<translate_es_en>': 7,
        }

        self.index2word = {v: k for k, v in self.word2index.items()}

        self.build_vocab(csv_file)

    def build_vocab(self, csv_file):
        with open(csv_file, encoding='utf-8') as f:
            reader = csv.DictReader(f)
            for row in reader:
                text = row['txt'].lower()
                for word in text.split():
                    if word not in self.word2index:
                        idx = len(self.word2index)
                        self.word2index[word] = idx
                        self.index2word[idx] = word

    def encode(self, action, text, seq_len):
        """
        action: string, e.g. 'translate_es_en'
        text: transcription/translation target
        """

        task_token = f'<{action}>'

        tokens = (
            ['<sos>', task_token]
            + text.lower().split()
            + ['<eos>']
        )

        ids = [
            self.word2index.get(t, self.word2index['<unk>'])
            for t in tokens
        ]

        if len(ids) < seq_len:
            ids += [self.word2index['<pad>']] * (seq_len - len(ids))
        else:
            ids = ids[:seq_len]

        return torch.tensor(ids)
    
    def decode(self, ids):
        if isinstance(ids, torch.Tensor):
            ids = ids.tolist()

        words = []
        for i in ids:
            w = self.index2word.get(i, '<unk>')
            if w.startswith('<') and w.endswith('>'):
                continue
            words.append(w)

        text = ' '.join(words)
        return text.strip()


# 2.2
Prepare un dataset/dataloader nuevo que utilizando los ficheros y transcripciones de la tarea anterior prepare una combinación de las cuatro acciones posibles con ejemplos suficientes para entrenar el transformer, es decir, todas las combinaciones de audio en español/inglés y texto español/inglés con la instrucción correspondiente antes del texto.

In [5]:
train_es = "fechas2/fechas2_train.es.csv"
train_en = "fechas2/fechas2_train.en.csv"

out_csv = "fechas2/fechas2_train_instruct.csv"

rows = []

with open(train_es, encoding="utf-8") as f_es, \
     open(train_en, encoding="utf-8") as f_en:

    reader_es = csv.DictReader(f_es)
    reader_en = csv.DictReader(f_en)

    for r_es, r_en in zip(reader_es, reader_en):

        wav_es = r_es["wav"]
        wav_en = r_en["wav"]

        txt_es = r_es["txt"]
        txt_en = r_en["txt"]

        # Audio ES
        rows.append([wav_es, "transcribe_es", txt_es])
        rows.append([wav_es, "translate_es_en", txt_en])

        # Audio EN
        rows.append([wav_en, "transcribe_en", txt_en])
        rows.append([wav_en, "translate_en_es", txt_es])

with open(out_csv, "w", encoding="utf-8", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(["wav", "action", "txt"])
    writer.writerows(rows)

print("TRAIN instruct samples:", len(rows))
print("Saved to:", out_csv)


TRAIN instruct samples: 40000
Saved to: fechas2/fechas2_train_instruct.csv


In [6]:
class TrainDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        csv_file,
        tokenizer,
        audio_len=4*16000,
        seq_len=SEQ_LEN,
        transform=[identity]
    ):
        self.audio_len = audio_len
        self.seq_len = seq_len
        self.transform = transform
        self.tokenizer = tokenizer

        with open(csv_file, encoding='utf-8') as f:
            reader = csv.DictReader(f)
            self.data = [
                (row['wav'], row['action'], row['txt'])
                for row in reader
            ]

        print("Train samples:", len(self.data))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        wav_path, action, text = self.data[idx]

        # --- Audio ---
        x, fs = torchaudio.load(wav_path)

        if x.shape[1] < self.audio_len:
            x = torch.nn.functional.pad(
                x, (0, self.audio_len - x.shape[1]), value=0
            )
        else:
            x = x[:, :self.audio_len]

        x = x[0].numpy()
        for t in self.transform:
            x = t(x)

        # --- Target (ACTION + TEXT) ---
        y = self.tokenizer.encode(
            action=action,
            text=text,
            seq_len=self.seq_len
        )

        return x, y


class TestDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        csv_file,
        tokenizer,
        audio_len=4*16000,
        seq_len=SEQ_LEN
    ):
        self.audio_len = audio_len
        self.seq_len = seq_len
        self.tokenizer = tokenizer

        with open(csv_file, encoding='utf-8') as f:
            reader = csv.DictReader(f)
            self.data = [
                (row['wav'], row['action'], row['txt'])
                for row in reader
            ]

        print("Test samples:", len(self.data))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        wav_path, action, text = self.data[idx]

        x, fs = torchaudio.load(wav_path)
        if x.shape[1] < self.audio_len:
            x = torch.nn.functional.pad(
                x, (0, self.audio_len - x.shape[1]), value=0
            )
        else:
            x = x[:, :self.audio_len]

        x = x[0]

        y = self.tokenizer.encode(
            action=action,
            text=text,
            seq_len=self.seq_len
        )

        return x, y


In [7]:
task_tokenizer = WordTokenizer(
    csv_file='fechas2/fechas2_train_instruct.csv'
)

trainset = TrainDataset(
    csv_file='fechas2/fechas2_train_instruct.csv',
    tokenizer=task_tokenizer,
    transform=[
        NoiseAug(prob=0.5),
        RIRAug(prob=0.5)
    ]
)

testset = TestDataset(
    csv_file='fechas2/fechas2_test_instruct.csv',
    tokenizer=task_tokenizer
)

# Dataloaders
trainloader = torch.utils.data.DataLoader(
    trainset,
    batch_size=BATCH_SIZE,
    shuffle=True
)

testloader = torch.utils.data.DataLoader(
    testset,
    batch_size=1,
    shuffle=False
)

vocab_size = len(task_tokenizer.word2index)

Train samples: 40000
Test samples: 1000


In [8]:
# model = AudioTransformer(
#     vocab_size=vocab_size,
#     d_model=256,
#     nb_layers=8,
#     d_ff=512,
#     n_heads=8,
#     d_head=32,
#     dropout=0.1,
#     seq_len=SEQ_LEN
# )

# model.to(device)
# optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
# pad_idx = task_tokenizer.word2index['<pad>']

# import math

# BATCH_FACTOR = 5
# MAX_BATCH = None
# if device == 'cpu':
#     MAX_BATCH = math.ceil(len(trainloader) / BATCH_FACTOR)

# model.train()
# for epoch in range(NB_EPOCHS):
#     total_loss = 0
#     for i, (x, y) in enumerate(trainloader):
#         if MAX_BATCH and i >= MAX_BATCH:
#             break
#         x = x.to(device)
#         y = y.to(device)

#         optimizer.zero_grad()
#         loss = model.loss(x, y, pad_idx=pad_idx)
#         loss.backward()
#         optimizer.step()

#         total_loss += loss.item()

#     print(f"Epoch {epoch+1}/{NB_EPOCHS} - Loss: {total_loss/len(trainloader):.3f}")

# torch.save(model.state_dict(), "model_instruct.pt")

In [9]:
def generate_with_instruction(model, x, action, tokenizer):
    model.eval()
    device = next(model.parameters()).device

    sos = tokenizer.word2index['<sos>']
    action_tok = tokenizer.word2index[f'<{action}>']
    eos = tokenizer.word2index['<eos>']

    y = [sos, action_tok]

    with torch.no_grad():
        enc = model.encoder(x.unsqueeze(0).to(device))

        while len(y) < SEQ_LEN and y[-1] != eos:
            y_tensor = torch.tensor(y).unsqueeze(0).to(device)
            logits = model.decoder(y_tensor, enc)
            next_token = logits.argmax(-1)[0, -1].item()
            y.append(next_token)

    return y

In [10]:
model = AudioTransformer(
    vocab_size=vocab_size,
    d_model=256,
    nb_layers=8,
    d_ff=512,
    n_heads=8,
    d_head=32,
    dropout=0.1,
    seq_len=SEQ_LEN
)

state = torch.load("model_instruct.pt", map_location=device)
model.load_state_dict(state)
model.to(device)
model.eval()

AudioTransformer(
  (fe): AudioFeatures(
    (fe): MelSpectrogram(
      (spectrogram): Spectrogram()
      (mel_scale): MelScale()
    )
    (spec_aug): SpecAug()
    (linear): Linear(in_features=80, out_features=256, bias=True)
  )
  (enc): Encoder(
    (att): ModuleList(
      (0-7): 8 x SelfAttention(
        (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (q_linear): Linear(in_features=256, out_features=256, bias=True)
        (v_linear): Linear(in_features=256, out_features=256, bias=True)
        (k_linear): Linear(in_features=256, out_features=256, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (out): Linear(in_features=256, out_features=256, bias=True)
      )
    )
    (ff): ModuleList(
      (0-7): 8 x FeedForward(
        (ff): Sequential(
          (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (1): Linear(in_features=256, out_features=512, bias=True)
          (2): ReLU()
          (3): Dropout(p=0.1, inplac

In [11]:
for i in range(5):
    x, y = testset[i]
    wav, action, _ = testset.data[i]

    ref = task_tokenizer.decode(y)
    hyp = task_tokenizer.decode(
        model.generate(x.unsqueeze(0), task_tokenizer, action)
    )

    print(f"ACTION: {action}")
    print("REF:", ref)
    print("HYP:", hyp)
    print("-"*40)


ACTION: translate_es_en
REF: the day after tomorrow thanks
HYP: the day after tomorrow thanks
----------------------------------------
ACTION: transcribe_es
REF: por favor en un par de días
HYP: por favor en un par de días
----------------------------------------
ACTION: translate_en_es
REF: por favor en un par de días
HYP: por favor en un par de días
----------------------------------------
ACTION: translate_en_es
REF: este jueves gracias
HYP: este jueves gracias
----------------------------------------
ACTION: translate_en_es
REF: por favor el miércoles siguiente
HYP: por favor el siguiente miércoles
----------------------------------------


In [12]:
def evaluate_wer_instruct(model, dataset, tokenizer, gen_mode = 'greedy'):
    total_wer = 0

    for i in range(len(dataset)):
        x, y = dataset[i]
        action = dataset.data[i][1]

        ref = tokenizer.decode(y)

        if gen_mode == 'greedy':
            hyp_ids = model.generate(x.unsqueeze(0), tokenizer, action=action)
        elif gen_mode == 'sampling':
            hyp_ids = model.generate_sampling(x.unsqueeze(0), tokenizer, action=action)
        else:
            hyp_ids = model.generate_topk(x.unsqueeze(0), tokenizer, action=action)
            
        hyp = tokenizer.decode(hyp_ids)
        total_wer += wer(ref, hyp)

    return total_wer / len(dataset)

In [13]:
model.load_state_dict(torch.load("model_instruct.pt", map_location=device))
model.eval()

wer_instr = evaluate_wer_instruct(model, testset, task_tokenizer, gen_mode='greedy')
print(f"WER Tarea 2: {wer_instr:.3f}")

wer_instr = evaluate_wer_instruct(model, testset, task_tokenizer, gen_mode='sampling')
print(f"WER Tarea 2 - sampling: {wer_instr:.3f}")

wer_instr = evaluate_wer_instruct(model, testset, task_tokenizer, gen_mode='topk')
print(f"WER Tarea 2 - topk: {wer_instr:.3f}")

WER Tarea 2: 0.198
WER Tarea 2 - sampling: 0.255
WER Tarea 2 - topk: 0.267
