In [1]:
import os, pathlib, time, copy, librosa, gc

import torch
import torchaudio
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.optim import lr_scheduler

from rich_logger import RichTablePrinter
from tqdm import trange
from rich import pretty
from rich.console import Console
from IPython.display import Audio

console = Console()
pretty.install()
plt.ion()

use_gpu = torch.cuda.is_available()
if use_gpu:
    device = 'cuda'
    print("Using CUDA")
else:
    device = 'cpu'
    print('Using CPU')

Using CUDA


In [2]:
class GreedyCTCDecoder(torch.nn.Module):
    def __init__(self, labels, blank=0):
        super().__init__()
        self.labels = labels
        self.blank = blank

    def forward(self, emission: torch.Tensor) -> str:
        """Given a sequence emission over labels, get the best path string
        Args:
          emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.

        Returns:
          str: The resulting transcript
        """
        indices = torch.argmax(emission, dim=-1)  # [num_seq,]
        indices = torch.unique_consecutive(indices, dim=-1)
        indices = [i for i in indices if i != self.blank]
        return "".join([self.labels[i] for i in indices])

In [3]:
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
model = bundle.get_model().to(device)

In [4]:
test_file = './datasets/speech-handsign_commands/speech/val/down/bf5d409d_nohash_0.wav'
wav, sr = torchaudio.load(test_file)
Audio(wav.numpy()[0], rate=sr)
wav = wav.to(device)

In [59]:
down_1 = 'datasets/speech-handsign_commands/speech/val/down/c0445658_nohash_3.wav'
down_2 = 'datasets/speech-handsign_commands/speech/val/down/c4e1f6e0_nohash_1.wav'
right_1 = 'datasets/speech-handsign_commands/speech/val/right/c44d2a58_nohash_2.wav'
right_2 = 'datasets/speech-handsign_commands/speech/val/right/c6a23ff5_nohash_2.wav'
stop_1 = 'datasets/speech-handsign_commands/speech/val/stop/bf8d5617_nohash_0.wav'
stop_2 = 'datasets/speech-handsign_commands/speech/val/stop/cc6bae0d_nohash_0.wav'


def pad_seq(seq):
    num_frames = seq.shape[1]
    if num_frames > 16000:
        seq = seq[:, :16000]
    else:
        pad_width = 16000 - num_frames
        seq = np.pad(seq, [(0, 0), (0, pad_width)], mode="constant")
    return seq


wav_down_1 = pad_seq(torchaudio.load(down_1)[0])
wav_down_2 = pad_seq(torchaudio.load(down_2)[0])
wav_right_1 = pad_seq(torchaudio.load(right_1)[0])
wav_right_2 = pad_seq(torchaudio.load(right_2)[0])
wav_stop_1 = pad_seq(torchaudio.load(stop_1)[0])
wav_stop_2 = pad_seq(torchaudio.load(stop_2)[0])

wav_down_1 = torch.from_numpy(wav_down_1).to(device)
wav_down_2 = torch.from_numpy(wav_down_2).to(device)
wav_right_1 = torch.from_numpy(wav_right_1).to(device)
wav_right_2 = torch.from_numpy(wav_right_2).to(device)
wav_stop_1 = torch.from_numpy(wav_stop_1).to(device)
wav_stop_2 = torch.from_numpy(wav_stop_2).to(device)

with torch.inference_mode():
    down_features_1, _ = model(wav_down_1)
    down_features_2, _ = model(wav_down_2)
    right_features_1, _ = model(wav_right_1)
    right_features_2, _ = model(wav_right_2)
    stop_features_1, _ = model(wav_stop_1)
    stop_features_2, _ = model(wav_stop_2)
down_features_1 = down_features_1[0].squeeze(0).flatten()#.mean(dim=0)
down_features_2 = down_features_2[0].squeeze(0).flatten()#.mean(dim=0)
right_features_1 = right_features_1[0].squeeze(0).flatten()#.mean(dim=0)
right_features_2 = right_features_1[0].squeeze(0).flatten()#.mean(dim=0)
stop_features_1 = stop_features_1[0].squeeze(0).flatten()#.mean(dim=0)
stop_features_2 = stop_features_2[0].squeeze(0).flatten()#.mean(dim=0)

In [60]:
down_features_1.shape

torch.Size([1421])

In [61]:
cosim = torch.nn.CosineSimilarity(dim=0)
sim_1 = cosim(down_features_1, down_features_2)
sim_2 = cosim(down_features_1, right_features_1)
sim_3 = cosim(right_features_1, right_features_2)
sim_4 = cosim(down_features_2, right_features_2)

sim_5 = cosim(stop_features_1, down_features_1)
sim_6 = cosim(stop_features_1, right_features_1)
sim_7 = cosim(stop_features_1, stop_features_2)
console.log(
    "Similarity\ndown_1 X down_2 : {}\ndown_1 X right_1 : {}\nright_1 X right_2 : {}\ndown_2 X right_2 : {}\nstop_1 X down_1 : {}\nstop_1 X right_1 : {}\nstop_1 X stop_2 : {}"
    .format(sim_1, sim_2, sim_3, sim_4, sim_5, sim_6, sim_7)
)

In [58]:
cosim(stop_features_1, stop_features_2)

tensor(0.9542, device='cuda:0')

In [5]:
with torch.inference_mode():
    features, _ = model.extract_features(wav)
console.log(features[0].shape)

In [6]:
with torch.inference_mode():
    emission, _ = model(wav)
console.log(emission.shape, emission)

In [7]:
decoder = GreedyCTCDecoder(labels=bundle.get_labels())
decoder = decoder.to(device)
transcript = decoder(emission[0])
console.log(transcript)

In [4]:
console.log(model)

In [4]:
for param in model.feature_extractor.parameters():
    param.require_grad = False
for param in model.encoder.parameters():
    param.require_grad = False
for param in model.aux.parameters():
    param.require_grad = False
# model.aux = torch.nn.Sequential(
#     torch.nn.Flatten(),
#     torch.nn.Linear(in_features=37632, out_features=4096, device=device),
#     torch.nn.Linear(in_features=4096, out_features=8, bias=True, device=device)
#     
# )
# model.aux = torch.nn.Linear(in_features=768, out_features=8, bias=True, device=device)
aux = [model.aux]
# aux.extend([torch.nn.Linear(in_features=29, out_features=8, bias=True, device=device)])
aux.extend([torch.nn.Conv1d(in_channels=49, out_channels=8, kernel_size=29, stride=1, device=device)])
model.aux = torch.nn.Sequential(*aux)

In [5]:
console.log(model)

In [21]:
model2 = torchaudio.models.wav2vec2_base(aux_num_out=8)
model2.aux = torch.nn.Sequential(
    torch.nn.Flatten(),
    torch.nn.Linear(in_features=37632, out_features=4096, device=device),
    torch.nn.Linear(in_features=4096, out_features=8, bias=True, device=device)
    
)

In [22]:
model2.load_state_dict(model.state_dict())

<All keys matched successfully>

In [42]:
console.log(model2)

In [6]:
from utils.audio_dataloader import MiniSpeechCommands
# from typing import Tuple, List, Dict
# from torch.utils.data import Dataset

data_dir = './datasets/speech-handsign_commands/speech'
TRAIN = 'train'
VAL = 'val'
TEST = 'test'

speech_datasets = {
    x: MiniSpeechCommands(os.path.join(data_dir, x))
    for x in [TRAIN, VAL, TEST]
}

dataloaders = {
    x: torch.utils.data.DataLoader(speech_datasets[x], batch_size=128, shuffle=True, num_workers=14)  #os.cpu_count() = 24
    for x in [TRAIN, VAL, TEST]
}

dataset_sizes = { x: len(speech_datasets[x]) for x in [TRAIN, VAL, TEST] }
class_names = speech_datasets[TRAIN].classes

for x in [TRAIN, VAL, TEST]:
    console.log("Loaded {} audios under {}".format(dataset_sizes[x], x))
console.log("Classes: ", class_names)

In [12]:
inputs, _ = next(iter(dataloaders[TRAIN]))

In [13]:
inputs.shape

torch.Size([128, 16000])

In [14]:
outputs = model(inputs.to(device))

In [15]:
outputs[0].shape

torch.Size([128, 8])

In [16]:
outputs[0]

tensor([[ 0.0415,  0.0131,  0.0660,  ..., -0.0428, -0.0345,  0.1240],
        [ 0.1763,  0.0545, -0.0361,  ..., -0.0921,  0.1089,  0.1063],
        [-0.0022, -0.0502,  0.0143,  ...,  0.0030, -0.0025,  0.2192],
        ...,
        [-0.0282,  0.0153, -0.0244,  ..., -0.0377, -0.0734,  0.0102],
        [-0.0063,  0.0209, -0.0725,  ..., -0.0924,  0.1185,  0.1341],
        [ 0.0175,  0.0481,  0.0086,  ..., -0.0490,  0.0337,  0.0410]],
       device='cuda:0', grad_fn=<AddmmBackward0>)

In [17]:
torch.softmax(outputs[0], dim=1).argmax(1).shape

torch.Size([128])

In [18]:
torch.softmax(outputs[0], dim=1).argmax(1)

tensor([7, 0, 7, 7, 7, 0, 7, 7, 7, 1, 7, 0, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 7, 7,
        7, 7, 3, 7, 2, 7, 7, 7, 0, 7, 0, 1, 2, 0, 0, 7, 0, 7, 7, 7, 7, 7, 6, 7,
        7, 7, 7, 6, 7, 0, 0, 7, 2, 7, 6, 7, 1, 7, 6, 7, 6, 7, 7, 7, 7, 7, 7, 7,
        0, 0, 7, 6, 7, 0, 7, 7, 6, 0, 6, 3, 7, 1, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
        7, 7, 3, 7, 7, 7, 7, 7, 7, 0, 7, 7, 7, 7, 3, 7, 7, 7, 7, 7, 7, 0, 6, 7,
        7, 7, 7, 7, 7, 1, 7, 1], device='cuda:0')

In [7]:
def accuracy_fn(y_true: torch.Tensor, y_pred: torch.Tensor):
    correct = torch.eq(y_true, y_pred).sum().item()
    acc = correct/len(y_pred)
    return acc

def train(model, criterion, optimizer, scheduler, dataloader):
    model.train()
    losses, accs = 0, 0
    for i, batch in enumerate(dataloader):
        gc.collect()
        optimizer.zero_grad()
        wav, label = batch
        wav, label = wav.to(device), label.to(device).float()
        logits, _ = model(wav)
        pred = torch.softmax(logits.squeeze(2), dim=1).argmax(dim=1).float()
        # pred = torch.mean(logits, dim=1).softmax(dim=1).argmax(dim=1).float()
        # pred = torch.softmax(logits, dim=1).argmax(dim=1).float()
        loss = criterion(pred, label)
        loss.requires_grad = True
        acc = accuracy_fn(label, pred)
        losses += loss.item()
        accs += acc
        loss.backward()
        optimizer.step()
        scheduler.step()
    return {
        "train/loss": losses/dataset_sizes[TRAIN],
        "train/acc": accs/dataset_sizes[TRAIN]
    }

def validate(model, criterion, dataloader):
    model.eval()
    losses, accs = 0, 0
    for i, batch in enumerate(dataloader):
        gc.collect()
        wav, label = batch
        wav, label = wav.to(device), label.to(device).float()
        with torch.inference_mode():
            logits, _ = model(wav)
            pred = torch.softmax(logits.squeeze(2), dim=1).argmax(dim=1).float()
            # pred = torch.mean(logits, dim=1).softmax(dim=1).argmax(dim=1).float()
            # pred = torch.softmax(logits, dim=1).argmax(dim=1).float()
            loss = criterion(pred, label)
            acc = accuracy_fn(label, pred)
            losses += loss.item()
            accs += acc
    return {
        "val/loss": losses/dataset_sizes[VAL],
        "val/acc": accs/dataset_sizes[VAL]
    }

def evaluate(model, criterion, dataloader):
    model.eval()
    losses, accs = 0, 0
    for i, batch in enumerate(dataloader):
        gc.collect()
        wav, label = batch
        wav, label = wav.to(device), label.to(device).float()
        with torch.inference_mode():
            logits, _ = model(wav)
            pred = torch.softmax(logits.squeeze(2), dim=1).argmax(dim=1).float()
            # pred = torch.mean(logits, dim=1).softmax(dim=1).argmax(dim=1).float()
            # pred = torch.softmax(logits, dim=1).argmax(dim=1).float()
            loss = criterion(pred, label)
            acc = accuracy_fn(label, pred)
            losses += loss.item()
            accs += acc
    return {
        "eval/loss": losses/dataset_sizes[TEST],
        "eval/acc": accs/dataset_sizes[TEST]
    }

In [8]:
def run_optimization_loop():
    logger_fields = {
        "epoch": {},
        "tr/loss": {
            "goal": "lower_is_better",
            "format": "{:.6f}",
            "name": r"train/loss",
        },
        "va/loss": {
            "goal": "lower_is_better",
            "format": "{:.6f}",
            "name": r"val/loss",
        },
        "tr/acc": {
            "goal": "higher_is_better",
            "format": "{:.6f}",
            "name": r"train/acc",
        },
        "va/acc": {
            "goal": "higher_is_better",
            "format": "{:.6f}",
            "name": r"val/acc",
        },
        "duration": {"format": "{:.1f}", "name": "dur(s)"},
        ".*": True,  # Any other field must be logged at the end
    }
    printer = RichTablePrinter(key="epoch", fields=logger_fields)
    printer.hijack_tqdm()
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr = 0.01, betas=(.9, .98), eps= 1e-08)
    scheduler = lr_scheduler.StepLR(optimizer, step_size= 10, gamma= .1)
    
    gc.collect()
    start_t = time.time()
    for epoch in trange(100):
        if use_gpu:
            torch.cuda.empty_cache()
        t = time.time()
        train_metrics = train(model, criterion, optimizer, scheduler, dataloaders[TRAIN])
        printer.log(
            {
                "epoch": epoch,
                "tr/loss": train_metrics['train/loss'],
                "tr/acc": train_metrics['train/acc'],
            }
        )
        val_metrics = validate(model, criterion, dataloaders[VAL])
        printer.log(
            {
                "epoch": epoch,
                "va/loss": val_metrics['val/loss'],
                "va/acc": val_metrics['val/acc'],
                "duration": time.time() - t,
            }
        )
    printer.finalize()
    elapsed_time = time.time() - start_t
    console.log("Total Time Used : {:.0f}m {:.0f}s".format(elapsed_time // 60, elapsed_time % 60))
    start_e = time.time()
    eval_metrics = evaluate(model, criterion, dataloaders[TEST])
    elapsed_time = time.time() - start_e
    console.log("Average loss on Test set : {:.6f}".format(eval_metrics['eval/loss']))
    console.log("Average accuracy on Test set : {:.6f}".format(eval_metrics['eval/acc']))
    console.log("Evaluation completed in {:.0f}m {:.0f}s".format(elapsed_time // 60, elapsed_time % 60))
    return model, scheduler, optimizer

In [None]:
model, scheduler, optimizer = run_optimization_loop()

In [15]:
out, _ = model(wav)
pred = torch.softmax(out, dim=1).argmax(dim=1).float()
console.log(pred)