# Assignment 7: Transformers

To get you warmed up and familiar with some of the libararies, we start out easy with a BERT tutorial from J. Alammar. 
The tutorial builds a simple sentiment analysis model based on pretrained BERT models with the [HuggingFace](https://huggingface.co/) library.
It will get you familiarized with the libary and make the next exercise a bit easier. 
The [Visual Guide](https://jalammar.github.io/a-visual-guide-to-using-bert-for-the-first-time/) has nice graphics and visualizations and will increase your general understanding of transformers and especially the BERT model even more. 

---

## Task 1) Wav2vec 2.0 for keyword recognition

After the warm-up with BERT, this exercise is a bit more advanced and you will be mostly on your own.
The task in this exercise is to build a keyword recognition system based on wav2vec 2.0. 
There are a couple of options you will have to think about and decide which implementation path you want to follow.

You can use the Huggingface [Audio Classification Tutorial](https://github.com/huggingface/notebooks/blob/main/examples/audio_classification.ipynb) as starting point.
There are a couple of options, that will lead to differnt performance on this problem. They vary in complexity as well as performance.
You should be able to reason the design and implementation choices you made.
Choose one of the options that suits you best or the one that you think might yield the best performance.
1. What model will you use? ```BASE vs. LARGE``` and what pretrained weights ```ASR vs BASE```, ```XLSR53 vs ENGLISH```?
1. HuggingFace or ```torchaudio.pipelines```?
1. Use a simple neural classification head?
3. Extract features and use them with some downstream classifier (e.g. SVM, Naive Bayes etc.)
    1. What pooling strategy will you use (mean, statistical, etc)?
    2. Compare downstream classifiers (e.g., SVM vs MLP cs CNN).
    3. Should you use a dimeninsionality reduction method?
1. Or use CTC loss and a greedy decoder? (closed vocab!)

## Dataset

For this exercise please use the [speech-commands-dataset](https://ai.googleblog.com/2017/08/launching-speech-commands-dataset.html) from google to train and evaluate your keyword recognition systems.
The data can also be obtained using the 
[HuggingFace api](https://huggingface.co/datasets/speech_commands) or you can use [torchaudio](https://pytorch.org/audio/stable/_modules/torchaudio/datasets/speechcommands.html).

*In this Jupyter Notebook, we will provide the steps to solve this task and give hints via functions & comments. However, code modifications (e.g., function naming, arguments) and implementation of additional helper functions & classes are allowed. The code aims to help you get started.*

---

### Prepare the Data

In [None]:
### YOUR CODE HERE

import tarfile
from typing import Iterable, Optional
import requests
import os
from sklearn.model_selection import train_test_split
from torch import Tensor
import torchaudio
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence

DATASET_URL = "http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz"
DATA_DIR_PATH = "./data"
DATASET_TARBALL_PATH = f"{DATA_DIR_PATH}/speech_commands_v0.01.tar.gz"
DATASET_PATH = f"{DATA_DIR_PATH}/speech_commands_v0.01"

RANDOM_SEED = 42

TEST_RATIO = 0.2
VAL_RATIO = 0.15

if not os.path.exists(DATA_DIR_PATH):
    os.mkdir(DATA_DIR_PATH)
if not os.path.exists(DATASET_TARBALL_PATH):
    with open(DATASET_TARBALL_PATH, "wb") as fp:
        fp.write(requests.get(DATASET_URL).content)
if not os.path.exists(DATASET_PATH):
    with tarfile.open(DATASET_TARBALL_PATH) as tar:
        tar.extractall(DATASET_PATH)

audios = []
labels = []
label_set = set()
for p in os.listdir(DATASET_PATH):
    dir_p = f"{DATASET_PATH}/{p}"
    if p.startswith("_") or not os.path.isdir(dir_p):
        continue
    label_set.add(p)
    for wp in os.listdir(dir_p):
        audio, _ = torchaudio.load(f"{dir_p}/{wp}")
        audios.append(audio.flatten())
        labels.append(p)
distinct_labels = sorted(label_set)
print(f"distinct labels: {distinct_labels}")

train_idcs, temp = train_test_split(np.arange(len(audios)), test_size=TEST_RATIO + VAL_RATIO, random_state=RANDOM_SEED)
test_idcs, val_idcs = train_test_split(temp, test_size=VAL_RATIO / (VAL_RATIO + TEST_RATIO), random_state=RANDOM_SEED)

class CommandsDataset(Dataset):
    def __init__(self, audios: Iterable[Tensor], labels: Iterable[str], distinct_labels: Optional[list[str]] = None):
        self.__audios = list(audios)
        self.__labels = list(labels)
        assert len(self.__audios) == len(self.__labels)
        self.__distinct_labels = sorted(set(labels)) if distinct_labels is None else list(distinct_labels)
        self.__label_index_lookup = {l: i for i, l in enumerate(self.__distinct_labels)}

    def __len__(self) -> int:
        return len(self.__audios)

    def __getitem__(self, i: int) -> tuple[Tensor, int]:
        return self.__audios[i], self.__label_index_lookup[self.__labels[i]]
    
    def get_label(self, i: int) -> str:
        return self.__distinct_labels[i]
    
    def loader(self, batch_size: int) -> DataLoader:
        return DataLoader(self, batch_size=batch_size, collate_fn=CommandsDataset.__collate)
    
    @staticmethod
    def __collate(tups: Iterable[tuple[Tensor, int]]) -> tuple[Tensor, Tensor]:
        xs = []
        ys = []
        for x, y in tups:
            xs.append(x)
            ys.append(y)
        return pad_sequence(xs, True), torch.tensor(ys)

train_dataset = CommandsDataset(map(audios.__getitem__, train_idcs), map(labels.__getitem__, train_idcs), distinct_labels)
val_dataset = CommandsDataset(map(audios.__getitem__, val_idcs), map(labels.__getitem__, val_idcs), distinct_labels)
test_dataset = CommandsDataset(map(audios.__getitem__, test_idcs), map(labels.__getitem__, test_idcs), distinct_labels)

### END YOUR CODE

### Train the wav2vec model

In [None]:
### YOUR CODE HERE

from typing import Callable
from sklearn.metrics import confusion_matrix
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
from torch.optim import Adam
from torch.nn import Linear, CrossEntropyLoss
from torch.optim import Optimizer

MODEL_NAME = "superb/wav2vec2-large-superb-ks"
BEST_MODEL_PATH = f"{DATA_DIR_PATH}/best_model.pt"
CACHE_DIR = f"{DATA_DIR_PATH}/model_cache"

SAMPLING_RATE = 16000

DEVICE = "cuda"
BATCH_SIZE = 4
LR = 0.001
EPOCHS = 100
PATIENCE = 10
VALIDATION_INTERVAL = 5

def build_model(model_name: str, class_count: int) -> tuple[Wav2Vec2FeatureExtractor, Wav2Vec2ForSequenceClassification]:
    if not os.path.exists(CACHE_DIR):
        os.mkdir(CACHE_DIR)
    ftx: Wav2Vec2FeatureExtractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name, cache_dir=CACHE_DIR)
    model: Wav2Vec2ForSequenceClassification = Wav2Vec2ForSequenceClassification.from_pretrained(model_name, cache_dir=CACHE_DIR) # type: ignore
    model.classifier = Linear(model.projector.out_features, class_count)
    return ftx, model

def train_epoch(data_loader: DataLoader, ftx: Wav2Vec2FeatureExtractor, model: Wav2Vec2ForSequenceClassification, crit: Callable[[Tensor, Tensor], Tensor], optim: Optimizer):
    batch_count = len(data_loader)
    running_loss = 0.0
    model.train()
    for i, (X, y) in enumerate(data_loader):
        print(f"\r  training batch {i + 1}/{batch_count}", end="")
        X = X
        y = y.to(model.device)
        optim.zero_grad()
        features = ftx(X, return_tensors="pt", sampling_rate=SAMPLING_RATE).input_values[0].to(model.device)
        logits = model(features).logits
        loss = crit(logits, y)
        running_loss += loss.item()
        loss.backward()
        optim.step()
    print(f"\n    average loss: {running_loss / batch_count}")

@torch.no_grad
def validate(data_loader: DataLoader, ftx: Wav2Vec2FeatureExtractor, model: Wav2Vec2ForSequenceClassification, crit: Callable[[Tensor, Tensor], Tensor], compute_metrics: bool) -> float:
    ground_truth = []
    predictions = []
    running_loss = 0.0
    batch_count = len(data_loader)
    model.eval()
    for i, (X, y) in enumerate(data_loader):
        print(f"\r  validation batch {i + 1}/{batch_count}", end="")
        X = X
        y = y.to(model.device)
        features = ftx(X, return_tensors="pt", sampling_rate=SAMPLING_RATE).input_values[0].to(model.device)
        logits = model(features).logits
        loss = crit(logits, y)
        running_loss += loss.item()
        if compute_metrics:
            ground_truth += y.tolist()
            predictions += torch.argmax(logits, dim=1).tolist()
    average_loss = running_loss / batch_count
    print(f"\n    average loss: {average_loss}")
    if compute_metrics:
        C = confusion_matrix(ground_truth, predictions)
        acc = C.diagonal().sum() / C.sum()
        prec = C.diagonal() / C.sum(0)
        rec = C.diagonal() / C.sum(1)
        f1 = 2 / (1 / (prec + 1e-24) + 1 / (rec + 1e-24) + 1e-24)
        print(f"    accuracy:  {acc.item()}")
        print(f"    precision: {prec.tolist()}")
        print(f"    recall:    {rec.tolist()}")
        print(f"    f1:        {f1.tolist()}")
    return average_loss

ftx, model = build_model(MODEL_NAME, len(distinct_labels))
model.to(DEVICE) # type: ignore
optim = Adam(model.parameters(), LR)
crit = CrossEntropyLoss()
best_model_loss = torch.inf
best_model_epoch = 0

dl_train = train_dataset.loader(BATCH_SIZE)
dl_val = val_dataset.loader(BATCH_SIZE)

for e in range(EPOCHS):
    print(f"training epoch {e + 1}...")
    train_epoch(dl_train, ftx, model, crit, optim)
    if (e + 1) % VALIDATION_INTERVAL == 0:
        print("validating model...")
        new_average_loss = validate(dl_val, ftx, model, crit, False)
        if new_average_loss < best_model_loss:
            print("  new best model found")
            torch.save(model, BEST_MODEL_PATH)
            best_model_loss = new_average_loss
            best_model_epoch = e
        else:
            print("  performance decreased")
            if e - best_model_epoch >=PATIENCE:
                print("  aborting training")

### END YOUR CODE

### Evaluate your model

In [None]:
### YOUR CODE HERE

dl_test = test_dataset.loader(BATCH_SIZE)
model = torch.load(BEST_MODEL_PATH)
print("testing best model...")
validate(dl_test, ftx, model, crit, True)

### END YOUR CODE