This notebook implements a version of Facebook's Wav2Letter Model, described in https://arxiv.org/pdf/1609.03193.pdf. 

# Step 1: Import necessary libraries and download + visualize the dataset

## Step 1.1: Install libraries and import things


In [None]:
!pip install jiwer

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch
import os
import matplotlib.pyplot as plt

import torchaudio
from torchaudio.datasets.librispeech import LIBRISPEECH
import torchaudio.transforms as transforms
from torchaudio.models.decoder import ctc_decoder 
from torchaudio.models import wav2letter

import numpy as np
from numpy import random

from IPython.display import Audio, display

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.nn.modules.loss import CTCLoss

from torchsummary import summary
from jiwer import wer # https://pypi.org/project/jiwer/
from Levenshtein import distance # https://maxbachmann.github.io/Levenshtein/levenshtein.html#distance

from datetime import datetime
from enum import unique

In [None]:
print("PyTorch Version: ",torch.__version__)
# Detect if we have a GPU available
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    print("Using the GPU!")
else:
    print("WARNING: Could not find GPU! Using CPU only. If you want to enable GPU, please to go Edit > Notebook Settings > Hardware Accelerator and select GPU.")

np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

## Step 1.2: Global Variables

In [None]:
PATH="/content/drive/MyDrive/EECS 442 Final Project: Wav2Letter/Code/"
LEXICON_PATH = PATH + "/lexicon.txt"
TOKENS_PATH = PATH + "/tokens.txt"
COMPLETE_DATASET_PATH = PATH + "/librispeech"
TRAIN_SET_PATH = PATH + "/train_set"
TEST_SET_PATH = PATH + "/test_set"

RERUN_DATALOADS = True; 
"""
Other Global Variables defined elsewhere:
  DEVICE: "cuda:0" or "cpu", defined in 1.1
  COMPLETE_DATASET: original Librispeech dataset
  TRAIN_SET: training set ready to pass into model
  TEST_SET: test set ready to pass into model
""";

## Step 1.3: Download Librispeech Dataset 

In [None]:
# Download the Librispeech dataset (dev-clean portion)
if not os.path.exists(COMPLETE_DATASET_PATH):
  os.mkdir(COMPLETE_DATASET_PATH)

COMPLETE_DATASET = torchaudio.datasets.LIBRISPEECH(COMPLETE_DATASET_PATH, url='dev-clean', download=True)

## Step 1.4: Define utility functions for visualizing data

In [None]:
# These functions are from https://pytorch.org/audio/stable/tutorials/audio_datasets_tutorial.html 
def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
    """ 
    Given an audio waveform ([n_channels, time] Torch tensor) and a sample rate, plots 
    a spectrogram representing the audio. 
    """
    waveform = waveform.numpy()

    num_channels, _ = waveform.shape

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].specgram(waveform[c], Fs=sample_rate)
        if num_channels > 1:
            axes[c].set_ylabel(f"Channel {c+1}")
        if xlim:
            axes[c].set_xlim(xlim)
    figure.suptitle(title)
    plt.show(block=False)

def play_audio(waveform, sample_rate):
    """
    Given a raw audio file (represented as a [1, time] torch tensor) with the 
    given sample rate, creates an interactive, playable object that enables 
    user to play the audio file. 
    """
    waveform = waveform.numpy()

    num_channels, _ = waveform.shape
    if num_channels == 1:
        display(Audio(waveform[0], rate=sample_rate))
    elif num_channels == 2:
        display(Audio((waveform[0], waveform[1]), rate=sample_rate))
    else:
        raise ValueError("Waveform with more than 2 channels are not supported.")

def print_sample(idx: int, spectrogram=False):
  """
  Given an index that refers to the original dataset, plays the corresponding 
  audio file and prints the transcript. 

  Relies on global variable COMPLETE_DATASET
  """
  sample = COMPLETE_DATASET[idx]
  waveform = sample[0]
  sample_rate = sample[1]
  transcript = sample[2]

  print("Sample %d from %s:" % (idx, COMPLETE_DATASET_PATH))
  play_audio(waveform, sample_rate)
  print(transcript)

  if(spectrogram):
    plot_specgram(waveform, sample_rate, title="Spectrogram For Sample %d" % idx)

## Step 1.5: Define utility functions for retrieving lexicon and tokens in dataset

In [None]:
def lexiconToLexFile(lexicon, path):
  """
  Given a list of words, creates a lexicon file at path from those words. 

  Each line of the file contains a word, followed by a space-separated spelling
  that is terminated by the "silence" character, which is designated as a pipe |. 

  For instance:
      a a |
      able a b l e |
      about a b o u t |
      ...
  """
  f = open(path, "w")
  for word in lexicon:
    f.write(word + " ")

    for letter in word:
      f.write(letter + " ")

    f.write("|\n")
  
  f.close()

def tokensToTokenFile(tokens, path):
  """
  Given a list of tokens, puts them into a token file located at path.

  Tokens will each be on a separate line. The space token will be replaced by 
  the silence token, represented by a pipe (|). 

  The blank token (-) used by CTC decoding will also be included 
  """
  f = open(path, "w")
  for token in tokens:
    if(token == ' '):
      f.write("|\n")
    else:
      f.write(token + "\n")
  f.close()

def parseDataset(dataset, n_samples=10000):
  tokens = set()
  lexicon = set()
  len_sum = 0; 

  n_samples = min(n_samples, len(dataset))

  for i in range(n_samples):
    item = dataset[i]
    len_sum += item[0].shape[1] / item[1]

    word_arr = item[2].split(' ')
    for word in word_arr:
      lexicon.add(word)
    for c in item[2]:
      tokens.add(c)

    if(i % 100 == 0):
      print(i)

  average_len = len_sum / len(dataset)

  return tokens, lexicon, average_len

## Step 1.6: Visualize dataset and create lexicon, token files

In [None]:
"""
The original Dataset is an array of tuples, each like:
    waveform:     Tensor [1, time]  audio
    sample_rate:  int               sampling rate, usually 16000
    utterance:    str               transcript
    speaker_id:   int     
    chapter_id:   int
    utterance_id: int
https://pytorch.org/audio/stable/generated/torchaudio.datasets.LIBRISPEECH.html#librispeech 
"""

if(RERUN_DATALOADS):
  tokens, lexicon, average_len = parseDataset(COMPLETE_DATASET)
  tokens = np.sort(np.append(list(tokens), '-'))
  lexicon = np.sort(list(lexicon))

  print("%d audio samples retrieved from dataset with average length of %.3f seconds" % (len(COMPLETE_DATASET), average_len))
  print("%d unique characters appear in the transcriptions:" % len(tokens))
  print(tokens)
  print("%d unique words appear in the transcriptions." % len(lexicon))

  # Print a random sample from the dataset to visualize 
  idx = np.random.randint(0, len(COMPLETE_DATASET))
  print_sample(idx, spectrogram=True)

  #lexiconToLexFile(lexicon, LEXICON_PATH)
  #tokensToTokenFile(tokens, TOKENS_PATH)
else:
  print("Dataloads not rerun. Set RERUN_DATALOADS option in 1.2 to rerun.")

# Step 2: Preprocess data and put into DataLoaders 

## Step 2.1: Define encoding from characters to class labels

We will use 29 characters (tokens) in our encoding scheme, plus an additional blank character requried by the CTC Decoder. 

* 0: space (represented as a | for CTC decoding, but a space 
everywhere else)
* 1: apostrophe '
* 2-27: uppercase english letters (A - Z)
* 28: blank character (-), not included here 


In [None]:
def charToClass(c: str): 
  """
  Given a character c (represented as a string of length 1), returns the 
  corresponding class value 0 - 28. 
  """
  if(c == ' '):
    return 0
  elif(c == '\''):
    return 1
  elif(c == '-'):
    return 2
  else:
    return ord(c) - ord('A') + 3

def classToChar(id: int):
  """
  Given a class id (represented as an int), returns the corresponding character
  as a string of length 1. 
  """
  if(id == 0):
    return ' '
  elif(id == 1):
    return '\''
  elif(id == 2):
    return '-'
  else:
    return chr(ord('A') + id - 3)

def transcriptToTensor(transcript: str):
  """
  Given a transcript (represented as a string), returns a one-dimensional 
  tensor representing the translation of the transcript into class values
  """
  res = torch.zeros((len(transcript)), device=DEVICE, dtype=torch.int8)
  for i in range(len(transcript)):
    res[i] = charToClass(transcript[i])
  
  return res

def tensorToTranscript(t: torch.Tensor, valid_len=None):
  """
  Given a 1D Tensor representing a transcript, returns the corresponding string
  by using classToChar to translate back into characters. 
  """

  if(not valid_len):
    valid_len = t.shape[0]

  res = ""
  for i in range(valid_len):
    res += classToChar(int(t[i]))
  return res

## Step 2.2: Define and Apply Wav Data Pre-processing

### Step 2.2.1: Define transformation functions

In [None]:
def splitAndTransformData(complete_dataset, train_split: float=0.3, sample_rate=16000, n_mfcc=13):
  """
  Split data into train and test data sets while removing unnecessary data points
  Transform audio data into 2D MFCC feature tensor [n_features, length/200]
  Transform transcript data into 1D tensor representing each character as a class

  Args:
      complete_dataset: the complete librispeech dataset
                        array of tuples in the form (waveform, sample_rate, utterance, speaker_id, chapter_id, utterance_id)
                        where waveform is torch.Tensor of shape (1, num_sample_points)

  Returns:
      train_dataset: list of tuples in the form (mfccWaveform, intUtterance, originalIdx)
      test_dataset: list of tuples in the form (mfccWaveform, intUtterance, originalIdx)
  """
  train_dataset = []
  test_dataset = []

  split = int(train_split * len(complete_dataset))

  MFCC_transformer = transforms.MFCC(sample_rate, n_mfcc)

  # CPU tensors are used here because they need to interface with NumPy later
  for i in range(len(complete_dataset)):
    if i < split:
      train_dataset.append((torch.squeeze(MFCC_transformer(complete_dataset[i][0])), 
                            transcriptToTensor(complete_dataset[i][2]).cpu(), i))
    else:
      test_dataset.append((torch.squeeze(MFCC_transformer(complete_dataset[i][0])),
                           transcriptToTensor(complete_dataset[i][2]).cpu(), i))

  return train_dataset, test_dataset


def sortTimeData(unsorted: list):
  """
  Sort data based on length of the audio sample

  Args:
      unsorted: list of tuples in the form (mfccWaveform, intUtterance, idx)
                where mfccWaveform is torch.Tensor of shape (n_mfcc, sample_points)

  Returns:
      sorted: list of tuples in the form (mfccWaveform, intUtterance, idx) sorted on length of waveform
  """
  sorted_out = sorted(unsorted, key=lambda waveform_length: waveform_length[0].shape[1], reverse=False)

  return sorted_out


def batchTimeData(dataset: list, batch_size: int = 64):
  """
  Batch samples into bins

  Args:
    dataset: list of tuples in form (mfccWaveform, intUtterance, idx), sorted on waveform length
    batch_size: int 

  Returns:
    batched_dataset: list of batches, where each batch is a list of 3-tuples
  """
  batched_array = np.array_split(np.asarray(dataset), np.arange(batch_size,len(dataset),batch_size))
  batched_dataset = [batched.tolist() for batched in batched_array]

  return batched_dataset


def padTimeData(batch: list):
  """
  Pad audio samples with 0s in the first dimension (time) until they are the 
    same dimensions as the longest audio sample in the batch.
  Pad transcripts with 0s until they are the same length (in characters) as the 
    longest transcript in the batch

  Args:
    batch: list of tuples in form (mfccWaveform, intUtterance), sorted on waveform length
           list should be batch_size elements long

  Returns:
    padded_batch: list of tuples in form (mfccWaveform: torch.Tensor [n_mfcc, MAX_LENGTH_OF_BATCH_AUDIO], 
                                          intUtterance: torch.Tensor [MAX_LENGTH_OF_BATCH_UTTERANCE],
                                          origWaveformLength: int,
                                          origUtteranceLength: int) 
                  of length batch_size. 
  """

  n_mfcc = batch[-1][0].shape[0]
  max_audio_length = batch[-1][0].shape[1]
  max_transcript_length = max(item[1].shape[0] for item in batch)

  padded_batch = [(torch.cat((audio, torch.zeros((n_mfcc, max_audio_length - audio.shape[1]))), dim=1), 
                   torch.cat((caption, torch.zeros((max_transcript_length - caption.shape[0]))), dim=0),
                   audio.shape[1], 
                   caption.shape[0], 
                   idx) for audio, caption, idx in batch]

  return padded_batch


def tensorizeDataset(regularized_dataset):
  """
  Transform time samples into MFCC samples and randomize the order of MFCC samples

  Args:
    regularized_dataset: list of batches with batch_size elements
                         each batch is a list of 4-tuples:
                            audio: Tensor [1 x time]
                            transcript: Tensor [time]
                            origAudioLength: int 
                            origTranscriptLength: int
    sample_rate: frequency of audio data in time (Hz), default = 16kHz
    n_mfcc: number of MFCC Feature coefficients, default = 13

  Returns:
    mfcc_data: list of batch tensors of shape (batch_size, MFCC_coeffs, sample_length)
  """
  # TODO: Update Docstring

  final_data = [] # list of batches
  n_mfcc = regularized_dataset[0][0][0].shape[0]

  for batch in regularized_dataset:
    idxs = np.arange(len(batch))
    np.random.shuffle(idxs)

    audio_tensor = torch.zeros((len(batch), n_mfcc, batch[0][0].shape[1]))
    transcript_tensor = torch.zeros((len(batch), batch[0][1].shape[0]))
    audio_len_tensor = torch.zeros((len(batch)), dtype=torch.int16)
    transcript_len_tensor = torch.zeros((len(batch)), dtype=torch.int16)
    idx_tensor = torch.zeros((len(batch)), dtype=torch.int16)

    for i in range(len(batch)):
      data_idx = idxs[i]; 
      audio_tensor[i,:,:] = batch[data_idx][0]
      transcript_tensor[i, :] = batch[data_idx][1]
      audio_len_tensor[i] = batch[data_idx][2]
      transcript_len_tensor[i] = batch[data_idx][3]
      idx_tensor[i] = batch[data_idx][4]

    final_data.append((audio_tensor, transcript_tensor, 
                      audio_len_tensor, transcript_len_tensor, idx_tensor))

  return final_data


def dataPreProcess(dataset_in, train_split : float = 0.70, batch_size : int = 64, n_mfcc: int = 13):
  """
  Perform entirety of audio data pre-processing pipeline, this includes:
    splitting test and train sets
    sorting data
    batching sorted data
    padding sorted data
    randomizing batch elements  

    Args:
      dataset_in: Complete dataset in
      train_split: float, percentage of data to allocate to training set (default 0.3)
      batch_size: int, number of audio samples per batch (default 64)

    Returns:
      final_train: batched, padded and randomized training dataset
      final_test: batched, padded and randomized testing dataset
  """
  train_set, test_set = splitAndTransformData(dataset_in, train_split)

  sorted_train = sortTimeData(train_set)
  sorted_test = sortTimeData(test_set)

  batched_train = batchTimeData(sorted_train)
  batched_test = batchTimeData(sorted_test)

  padded_train = [padTimeData(train_batch) for train_batch in batched_train]
  padded_test = [padTimeData(test_batch) for test_batch in batched_test]

  final_train = tensorizeDataset(padded_train)
  final_test = tensorizeDataset(padded_test)

  return final_train, final_test
  

### Step 2.2.2: Define a verification function to ensure data translation is done correctly


In [None]:
def verify_random(dataset):
  batch_idx = np.random.randint(0, len(dataset))
  batch = dataset[batch_idx]
  sample_idx = np.random.randint(0, batch[0].shape[0])
  
  stored_transcript = tensorToTranscript(batch[1][sample_idx])
  stored_idx = batch[4][sample_idx]

  print_sample(stored_idx)
  print("Stored Index: %d" % stored_idx)
  print("Stored Transcript:", stored_transcript)

### Step 2.2.3: Apply data transforms to obtain final train and test set


In [None]:
TRAIN_SPLIT = 0.8
BATCH_SIZE = 64
N_MFCC = 13

In [None]:
# Takes about 1 second for every 100 samples

if(RERUN_DATALOADS):
  train_set, test_set = dataPreProcess(COMPLETE_DATASET, train_split=TRAIN_SPLIT, batch_size=BATCH_SIZE, n_mfcc=N_MFCC)
  torch.save(train_set, PATH + "train_set")
  torch.save(test_set, PATH + "test_set")
else:
  print("Dataloads not rerun. Set RERUN_DATALOADS option in 1.2 to rerun.")
  train_set = torch.load(PATH + "train_set")
  test_set = torch.load(PATH + "test_set")

print("\n", "------"*20)
print("Audio Shape (Sample):", train_set[0][0].shape) #0th batch 0th sample audio shape
print("Transcript Shape (Sample):", train_set[0][1].shape) #0th batch 0th sample transcript shape
print("Audio-Original-Length Shape:", train_set[0][2].shape) #0th batch 0th sample audio-len shape
print("Transcript-Original-Length Shape:", train_set[0][3].shape) #0th batch 0th sample transcript-len shape
print("Idx Shape:", train_set[0][4].shape) #0th batch 0th sample transcript-len shape

print("\n", "------"*20)
verify_random(train_set)

print("\n", "------"*20)
verify_random(test_set)

# Step 3: Define model

##Step 3.1: Define network architecture

In [None]:
class Wav442Letter(nn.Module):
  def __init__(self, in_channels, out_channels, bottleneck_size=2000):
    super(Wav442Letter, self).__init__()

    # We may need to make this smaller depending on 
    self.network = nn.Sequential(
        nn.Conv1d(in_channels, 250, kernel_size=48, stride=2, padding=23), # modified stride here to avoid padding
        nn.ReLU(),
        nn.Conv1d(250, 250, kernel_size=7, stride=1, padding='same'),
        nn.ReLU(),
        nn.Conv1d(250, 250, kernel_size=7, stride=1, padding='same'),
        nn.ReLU(),
        nn.Conv1d(250, 250, kernel_size=7, stride=1, padding='same'),
        nn.ReLU(),
        nn.Conv1d(250, 250, kernel_size=7, stride=1, padding='same'),
        nn.ReLU(),
        nn.Conv1d(250, 250, kernel_size=7, stride=1, padding='same'),
        nn.ReLU(),
        nn.Conv1d(250, 250, kernel_size=7, stride=1, padding='same'),
        nn.ReLU(),
        nn.Conv1d(250, 250, kernel_size=7, stride=1, padding='same'),
        nn.ReLU(),
        nn.Conv1d(250, bottleneck_size, kernel_size=32, stride=1, padding='same'),
        nn.ReLU(),
        nn.Conv1d(bottleneck_size, bottleneck_size, kernel_size=1, stride=1, padding='same'),
        nn.ReLU(),
        nn.Conv1d(bottleneck_size, out_channels, kernel_size=1, stride=1, padding='same'),
        nn.ReLU(),
    )
  
  # x should be of shape (batchSize, num_features, length)
  def forward(self, x: torch.Tensor):
    logits = self.network(x)
    log_probs = F.log_softmax(logits, dim=1)
    return logits, log_probs


In [None]:
model = Wav442Letter(13, 29, bottleneck_size=500).to("cuda")
summary(model, (13, 400))
temp = model(torch.zeros((5, 13, 400)).to("cuda"))
print(temp[0].shape, temp[1].shape)

## Step 3.2: Define train function

See NVIDIA Implementation for hyperparamters: https://nvidia.github.io/OpenSeq2Seq/html/speech-recognition/wave2letter.html


**List of hyperparameters:**

* Epochs
* Optimizer: SGD, Adam
* Learning Rate (Initial)
* Momentum
* Weight Decay
* Scheduler Patience
* Scheduler Rate Factore
* Scheduler Threshold


**List of potential architecture changes:**
* Stride in first layer (original 2)
  * If changed to 1, need to change padding to 'same'
  * Also need to modify train and eval functions for output size changes
* Bottleneck size (Fully-Connected): 2000
  * Can experiment going down to 500

In [None]:
from matplotlib.scale import LogitScale
def train_model(model, dataset, epochs=1, ctc_loss=None, optimizer=None):

  if(not ctc_loss): # define ctc_loss function
    ctc_loss = CTCLoss(blank=2, reduction='mean', zero_infinity=False)

  if(not optimizer): # define optimizer
    # optimizer = optim.SGD(model.parameters(), lr=0.001, weight_decay=1e-4, momentum=0.9)
    optimizer = optim.Adam(model.parameters(), lr=1e-5)
    #scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)

  loss_by_batch = []
  loss_by_epoch = []
  score_by_batch = []

  for epoch in range(epochs):
    random.shuffle(dataset)

    loss_epoch = 0; 

    for batch in dataset:
      optimizer.zero_grad()

      # T: input length
      # N: batch size
      # C: num_classes
      # S: max target length
      
      logits, log_probs = model(batch[0].to(DEVICE))
      logits = torch.permute(logits, (2, 0, 1)) # (N, C, T) --> (T, N, C)
      log_probs = torch.permute(log_probs, (2, 0, 1))

      targets = batch[1].to(DEVICE) # (N, S) 
      input_lens = batch[2].to(DEVICE) # (N)
      target_lens = batch[3].to(DEVICE) # (N)

      loss = ctc_loss(log_probs, targets, input_lens // 2, target_lens)
      loss.backward()
      optimizer.step()

      loss_by_batch.append(loss.detach());
      loss_epoch += loss.detach().cpu()

      # pred = decoder(log_probs[0].reshape((1, log_probs.shape[1], log_probs.shape[2])).to('cpu'), input_lens[0].reshape((1)).to('cpu'))
      # if(len(pred[0])):
      #   score_by_batch.append(pred[0][0].score)
      #   print(pred[0][0].score)

    #scheduler.step(loss_epoch)
    print("Epoch %d Loss: %.3f" % (epoch, loss_epoch))
    loss_by_epoch.append(loss_epoch)

    if(epoch % 100 == 0):
      now = datetime.now()
      dt_string = now.strftime("%Y-%m-%d_%H-%M")
      torch.save(model, PATH + "/model-full_" + dt_string)
      
  return loss_by_batch, loss_by_epoch, score_by_batch

## Step 3.3: Define eval function

In [None]:
def eval_model(model, dataset):
  """
  Given a model and a (test) dataset, runs the dataset through the model and 
  decodes outputs using a CTC decoder. 

  Returns a list of 7-tuples, each tuple corresponding to a batch. 
  The first 5 elements of the tuple are the original 5 elements in a batch:
    - mfcc waveform
    - transcript as an int tensor
    - mfcc valid length
    - transcript valid length
    - original index
  The next two elements are two lists. 
    - ground truth transcripts, len = BATCH_SIZE
    - predicted transcripts, len = BATCH_SIZE
      - each element in predicted transcripts is a 3-tuple:
        - pred transcript
        - pred tokens
        - score 
  """
  # there are hyperparams that can be set for this decoder, found here:
  # https://pytorch.org/audio/master/generated/torchaudio.models.decoder.ctc_decoder.html#torchaudio.models.decoder.ctc_decoder 
  decoder = ctc_decoder(lexicon=LEXICON_PATH, tokens=TOKENS_PATH, blank_token='-', sil_token='|')
  
  eval_out = []

  with torch.no_grad():
    idx = 0
    total_wer = 0
    total_lev_dist = 0
    for batch in dataset:
      # T: input length
      # N: batch size
      # C: num_classes
      logits, log_probs = model(batch[0].to(DEVICE))

      logits = torch.permute(logits, (0, 2, 1)) # (N, C, T) --> (N, T, C)
      probs = torch.permute(torch.exp(log_probs), (0, 2, 1))
      log_probs = torch.permute(log_probs, (0, 2, 1))

      targets = batch[1].to(DEVICE) # (N, S) 
      input_lens = batch[2].to(DEVICE) # (N)
      target_lens = batch[3].to(DEVICE) # (N)

      # List of List[CTCHypothesis]
      # https://pytorch.org/audio/master/generated/torchaudio.models.decoder.CTCDecoder.html#torchaudio.models.decoder.CTCHypothesis 
      pred = decoder(logits.to('cpu'), (input_lens // 2).cpu())


      batch_size = targets.shape[0]

      gt_transcripts = []
      pred_transcripts = []

      for j in range(batch_size):
        gt = tensorToTranscript(targets[j], target_lens[j])
        gt_transcripts.append(gt)
        
        if(len(pred[j])):
          transcript = " ".join(pred[j][0].words).strip()
          tokens = "".join(decoder.idxs_to_tokens(pred[j][0].tokens))
          score = pred[j][0].score
          word_err_rate = wer(gt, transcript)
          lev_dist = distance(gt, transcript)
          pred_transcripts.append((transcript, tokens, score, word_err_rate, lev_dist))

          total_wer += word_err_rate
          if(len(transcript)):
            total_lev_dist += lev_dist / len(transcript)

          idx += 1

      eval_out.append((probs, targets, input_lens, target_lens, batch[4], gt_transcripts, pred_transcripts))
      
  return eval_out, total_wer / idx, total_lev_dist / idx


# Step 4: Define Visualization Functions


## Step 4.1: Define loss visualization

In [None]:
def plot_loss(loss_list, x_label="Epoch"):
  plt.figure()
  plt.plot(loss_list)
  plt.ylabel("Loss")
  plt.xlabel(x_label)
  plt.title("Loss over time")
  plt.xticks(range(0, len(loss_list), len(loss_list)//10 + 1))

In [None]:
def print_eval(eval_out, n_samples=10):
  random.shuffle(eval_out)

  for sample_idx in range(n_samples):
    batch_idx = np.random.randint(eval_out[sample_idx][0].shape[0])

    log_probs = eval_out[sample_idx][0][batch_idx]
    orig_idx = eval_out[sample_idx][4][batch_idx]
    gt_transcript = eval_out[sample_idx][5][batch_idx]
    pred_transcript = eval_out[sample_idx][6][batch_idx][0]
    pred_tokens = eval_out[sample_idx][6][batch_idx][1]
    pred_score = eval_out[sample_idx][6][batch_idx][2]
    word_err_rate = eval_out[sample_idx][6][batch_idx][3]
    lev_dist = eval_out[sample_idx][6][batch_idx][4]

    pred_classes = [classToChar(c) for c in torch.argmax(log_probs, dim=1)]
    pred_str = ""
    pred_str = pred_str.join(pred_classes)

    print("Sample %d" % sample_idx)
    print_sample(orig_idx)
    print("GT:", gt_transcript)
    print("Pred:", pred_transcript)
    print("Pred Tokens:", pred_tokens)
    print("Raw Max Tokens:", pred_str)
    print("Score: %.3f" % pred_score)
    print("Word Error Rate: %.3f" % word_err_rate)
    print("Levenstein Distance: %.3f" % lev_dist)

    plt.figure(dpi=300)
    plt.imshow(torch.permute(log_probs, (1, 0)).cpu(), cmap='viridis', interpolation='nearest')

# Step 4: Train model and visualize loss

In [None]:
# LOAD_PATH = PATH + "/model-weights_2022-12-09_11-14"
# model = Wav442Letter(13, 29, bottleneck_size=500).to(DEVICE)
#model.load_state_dict(torch.load(LOAD_PATH))

# 21-27 is good (~40 loss), trained with 1e-4. Trained on 80% of train_set
# 

LOAD_PATH = PATH + "/model-full_2022-12-09_21-27"
model = torch.load(LOAD_PATH)

loss_by_batch, loss_by_epoch, score_by_batch = train_model(model, train_set, epochs=10000)

now = datetime.now()
dt_string = now.strftime("%Y-%m-%d_%H-%M")
torch.save(model, PATH + "/model-full_" + dt_string)

In [None]:
# LOAD_PATH = PATH + "/model-weights_2022-12-09_20-15"
# model = Wav442Letter(13, 29, bottleneck_size=500).to(DEVICE)
# model.load_state_dict(torch.load(LOAD_PATH))

# "/model-full_2022-12-09_21-58"

LOAD_PATH = PATH + "/model-full_2022-12-09_21-58"
model = torch.load(LOAD_PATH)

eval_out, mean_wer, mean_lev_dist = eval_model(model, train_set)

print("Mean WER: %.3f" % mean_wer)
print("Mean Distance: %.3f\n" % mean_lev_dist)
print_eval(eval_out, n_samples=1)