# Setup

In [None]:
from torchaudio.models.decoder import ctc_decoder
from torchaudio.utils import download_asset
from pytorch_model_summary import summary
from torch.nn import functional as F
from torch.utils.data import Dataset
import torchaudio.functional as AF
from torch import optim
import torch.nn as nn
import torchaudio
import torch

from torchaudio.utils import download_asset
import matplotlib.ticker as ticker
import matplotlib.pyplot as plt
from playsound import playsound
from natsort import natsorted
from typing import List
import numpy as np
import IPython
import time
import math
import os

In [None]:
print(torch.__version__)
print(torchaudio.__version__)

In [None]:
plt.switch_backend('agg')

In [None]:
torch.random.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(device)

In [None]:
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_10M
acoustic_model = bundle.get_model()

print("Sample Rate:", bundle.sample_rate)

print("Labels:", bundle.get_labels())

# Collect Data

In [None]:
speech_files = os.listdir("/media/nathanmon/389E28739E282BB6/Users/Natha/Datasets/MySpeechData/my_voice")
speech_files = natsorted(speech_files)

sentences = []
with open("/media/nathanmon/389E28739E282BB6/Users/Natha/Datasets/MySpeechData/sentences.txt", "r", encoding="utf-8") as f:
    for line in f.readlines():
        sentences.append(line)

In [None]:
len(sentences)

In [None]:
# i = 0
# while i < len(sentences):
#     if len(sentences[i]) < 100:
#         sentences.pop(i)
#         speech_files.pop(i)
#         i -= 1
#     i+= 1

In [None]:
len(sentences)

In [None]:
# The set of characters accepted in the transcription.
characters = [x for x in "abcdefghijklmnopqrstuvwxyz-|'"]
# Mapping characters to integers
def char_to_num(sentence):
    tokens = []
    for char in sentence:
        if char in characters:
            tokens.append(characters.index(char))
    return tokens

def num_to_char(sentence):
    chars = []
    for char in sentence:
        if char < len(characters):
            chars.append(characters[char])
    return chars

In [None]:
base = "/media/nathanmon/389E28739E282BB6/Users/Natha/Datasets/MySpeechData/my_voice"

In [None]:
sample_rate = 8000
max_length = 85000

count = 0
for i, file in enumerate(speech_files):
    wav, sr = torchaudio.load(base + "/" + file)
    wav = AF.resample(wav, sr, sample_rate)
    if sr != 16000:
        print(i)
        print(file)
    
    if len(wav[0]) > 85000:
        count += 1

count

In [None]:
# An integer scalar Tensor. The window length in samples.
n_mels = 128
# An integer scalar Tensor. The number of samples to step.
win_length = 160
# An integer scalar Tensor. The size of the FFT to apply.
# If not provided, uses the smallest power of 2 enclosing frame_length.
hop_length = 80

def load_wav(filename):
    wav, sr = torchaudio.load(base + "/" + filename)
    if sample_rate != bundle.sample_rate:
        wav = torchaudio.functional.resample(wav, sample_rate, bundle.sample_rate)
    if len(wav[0]) < max_length:
        wav = torch.concat((wav[0], torch.zeros(max_length - len(wav[0])))).unsqueeze(-2)
    else:
        cut_length = len(wav[0]) - max_length
        wav = wav[0][cut_length//2:len(wav[0])-(cut_length//2)].unsqueeze(-2)
        
    mean = wav.mean()
    std = wav.std()
    wav = (wav - mean) / std
        
    return wav, sr

def create_spect(wav, sr):
    spect = torchaudio.transforms.MelSpectrogram(
                                    sample_rate=sr, n_mels=n_mels,
                                    win_length=win_length, 
                                    hop_length=hop_length
    )(wav)
    spect = np.log(spect + 1e-14)
        
    return spect

def process_text(label):
    label = label.lower()
    label = label.replace(" ", "|")
    label = label.replace(" -- ", "|")
    label = label.replace("-", "|")
    label = label.replace(";", "|")
    label = label.replace(":", "|")
    label = char_to_num(label)
    
    return label

def encode_sample(file, label):
    wav, sr = load_wav(file)
#     spect = create_spect(wav, sr)
    label = process_text(label)
        
    return wav, label

In [None]:
index = 0
file = speech_files[index]
label = sentences[index]
print("Speech file: \n", file)
print("\nSentence: \n", label)

ex_wav, sr = load_wav(file)
ex_spect = create_spect(ex_wav, sr)[0]
print("Sample rate: \n", sr)
print("\nShape: \n", ex_spect.shape)
print()

ax = plt.subplot(1, 1, 1)
ax.imshow(ex_spect, vmax=1)
ax.axis("off")

In [None]:
class WaveformDataset(Dataset):
    def __init__(self, speech_files, labels, 
                 batch_size=16, max_length=150):
        self.filenames = speech_files
        self.labels = labels
        self.batch_size = batch_size
        self.max_length = max_length

    def __len__(self):
        return len(self.filenames)

    def getitem(self, idx):
        waves, labels = [], []
        start_idx = idx*self.batch_size
        for filename, label in zip(self.filenames[start_idx:start_idx+batch_size],
                           self.labels[start_idx:start_idx+batch_size]):
            wav, label = encode_sample(filename, label)
            waves.append(torch.tensor(wav, dtype=torch.float32))
            if len(label) < self.max_length:
                zeros = [0] * (self.max_length - len(label))
                label = label + zeros
            elif len(label) > self.max_length:
                label = label[:self.max_length]
                
            labels.append(label)

        waves = torch.cat(waves, dim=0)
        labels = torch.tensor(labels)
        return waves, labels

In [None]:
batch_size = 16
dataloader = WaveformDataset(speech_files, sentences, 
                                   batch_size=batch_size)

In [None]:
dataloader.getitem(0)[1]

# Test

In [None]:
SPEECH_FILE = download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")

speech_file = download_asset("tutorial-assets/ctc-decoding/1688-142285-0007.wav")

IPython.display.Audio(speech_file)

In [None]:
tokens = [label.lower() for label in bundle.get_labels()]

In [None]:
speech_file = "/media/nathanmon/389E28739E282BB6/Users/Natha/Datasets/MySpeechData/my_voice/0.wav"

In [None]:
IPython.display.Audio(speech_file)

In [None]:
waveform, sample_rate = torchaudio.load(speech_file)

if sample_rate != bundle.sample_rate:
    waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)

In [None]:
def char_to_num(sentence):
    nums = []
    for char in sentence:
        if char in tokens:
            nums.append(tokens.index(char))
    return nums

def num_to_char(sentence):
    chars = []
    for char in sentence:
        if char < len(tokens):
            chars.append(tokens[char])
    return chars

In [None]:
def CTCLoss(y_true, y_pred):
    # Compute the training-time loss value
    batch_len = y_true.shape[0]
    input_length = y_pred.shape[0]
    label_length = y_true.shape[1]

    input_length = input_length * torch.ones(size=(batch_len, 1), dtype=torch.int64)
    label_length = label_length * torch.ones(size=(batch_len, 1), dtype=torch.int64)

    criterion = nn.CTCLoss(blank=0, zero_infinity=True)
    
    loss = criterion(y_pred, y_true, input_length, label_length)
    
    return loss

In [None]:
def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp
print("Num params: ", get_n_params(acoustic_model))

In [None]:
# actual_transcript = "i really was very much afraid of showing him how much shocked i was at some parts of what he said"
actual_transcript = "Many non-infectious diseases have a partly or completely genetic basis and may thus be transmitted from one generation to another"
label = torch.tensor([char_to_num(actual_transcript)])
actual_transcript = actual_transcript.split()

out, hidden = acoustic_model(torch.tensor(waveform))
out = F.log_softmax(out, dim=2)
logits = out.transpose(0, 1)
loss = CTCLoss(label, logits)
loss

In [None]:
''.join(num_to_char(torch.unique_consecutive(torch.argmax(logits, 2))))

In [None]:
acoustic_model.parameters

In [None]:
for param in acoustic_model.parameters():
    param.requires_grad = False

In [None]:
acoustic_model.aux.weight.requires_grad = True
acoustic_model.aux.bias.requires_grad = True

In [None]:
for name, param in acoustic_model.named_parameters():
    print(f"{name}: {param.requires_grad}")

In [None]:
class GreedyCTCDecoder(torch.nn.Module):
    def __init__(self, labels, blank=0):
        super().__init__()
        self.labels = labels
        self.blank = blank

    def forward(self, emission: torch.Tensor) -> List[str]:
        """Given a sequence emission over labels, get the best path
        Args:
          emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.

        Returns:
          List[str]: The resulting transcript
        """
        indices = torch.argmax(emission, dim=-1)  # [num_seq,]
        indices = torch.unique_consecutive(indices, dim=-1)
        indices = [i for i in indices if i != self.blank]
        joined = "".join([self.labels[i] for i in indices])
        return joined.replace("|", " ").strip().split()


greedy_decoder = GreedyCTCDecoder(tokens)

In [None]:
greedy_result = greedy_decoder(out[0])
greedy_transcript = " ".join(greedy_result)
greedy_wer = torchaudio.functional.edit_distance(actual_transcript, greedy_result) / len(actual_transcript)

print(f"Transcript: {greedy_transcript}")
print(f"WER: {greedy_wer}")

# Train Model

In [None]:
def train_epoch(dataloader, model, optimizer, train=True):
    global batch_size
    total_loss = 0
    for batch in range(len(dataloader) // batch_size):
        waveform, labels = dataloader.getitem(batch)
        
        optimizer.zero_grad()
        
        logits, hidden = model(torch.tensor(waveform)) # (B, N, C)
        logits = F.log_softmax(logits, dim=2)
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        logits = logits.transpose(0, 1) # (N, B, C)
        loss = CTCLoss(labels, logits)
        
        if train:
            loss.backward()

            optimizer.step()
    
        total_loss += loss.item()

    return total_loss / (len(dataloader) / batch_size)

In [None]:
def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))

In [None]:
plot_train_losses = []
plot_val_losses = []

def train(train_dataloader, val_dataloader, model, n_epochs, learning_rate=0.001,
               print_every=100, plot_every=100):
    start = time.time()
    global plot_train_losses
    global plot_val_losses
    global d_model
    print_train_loss_total = 0  # Reset every print_every
    plot_train_loss_total = 0  # Reset every plot_every
    
    print_val_loss_total = 0  # Reset every print_every
    plot_val_loss_total = 0

    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                                optimizer, mode='min',
                                factor=0.50, patience=6)
    
    for epoch in range(1, n_epochs + 1):
        train_loss = train_epoch(train_dataloader, model, optimizer)
        print_train_loss_total += train_loss
        plot_train_loss_total += train_loss
        
        # Evaluate validation dataloader
        val_loss = train_epoch(val_dataloader, model, optimizer, train=False)
        print_val_loss_total += val_loss
        plot_val_loss_total += val_loss
        
        scheduler.step(val_loss)

        if epoch % print_every == 0:
            print_train_loss_avg = print_train_loss_total / print_every
            print_train_loss_total = 0
            print_val_loss_avg = print_val_loss_total / print_every
            print_val_loss_total = 0
            print('%s (%d %d%%) %.4f %.4f' % (timeSince(start, epoch / n_epochs),
                                        epoch, epoch / n_epochs * 100, print_train_loss_avg, print_val_loss_avg
                                             ))
            print()

        if epoch % plot_every == 0:
            plot_train_loss_avg = plot_train_loss_total / plot_every
            plot_train_losses.append(plot_train_loss_avg)
            plot_train_loss_total = 0
            
            plot_val_loss_avg = plot_val_loss_total / plot_every
            plot_val_losses.append(plot_val_loss_avg)
            plot_val_loss_total = 0

    showPlot(plot_train_losses)
    showPlot(plot_val_losses)

In [None]:
def CTCLoss(y_true, y_pred):
    # Compute the training-time loss value
    batch_len = y_true.shape[0]
    input_length = y_pred.shape[0]
    label_length = y_true.shape[1]
    label_length = torch.tensor([len(seq) for seq in y_true], dtype=torch.int64).unsqueeze(1)

    input_length = input_length * torch.ones(size=(batch_len, 1), dtype=torch.int64)
    label_length = label_length * torch.ones(size=(batch_len, 1), dtype=torch.int64)

    criterion = nn.CTCLoss(zero_infinity=True)
    loss = criterion(y_pred, y_true, input_length, label_length)
    
    return loss

In [None]:
def showPlot(points):
    plt.figure()
    fig, ax = plt.subplots()
    # this locator puts ticks at regular intervals
    loc = ticker.MultipleLocator(base=0.2)
    ax.yaxis.set_major_locator(loc)
    plt.plot(points)

In [None]:
# bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
# acoustic_model = bundle.get_model()

In [None]:
batch_size = 32
# 4000
train_dataloader = WaveformDataset(speech_files[:4000], sentences[:4000], 
                                   batch_size=batch_size)
val_dataloader = WaveformDataset(speech_files[4000:], sentences[4000:], 
                                   batch_size=batch_size)

train(train_dataloader, val_dataloader, acoustic_model, 5, 
      learning_rate=1e-5, print_every=1, plot_every=1)

In [None]:
for param in acoustic_model.parameters():
    param.requires_grad = True

In [None]:
torch.save(acoustic_model.state_dict(), "my_speech_recognition.pth")