In [None]:
!pip install jiwer torch torchaudio comet_ml fastdtw

In [None]:
import torch
import torchaudio
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torchaudio.models.decoder import ctc_decoder
from torchaudio.models.decoder._ctc_decoder import download_pretrained_files
import os
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
import random
from jiwer import wer, cer
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from comet_ml import Experiment


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
files = download_pretrained_files("librispeech-4-gram")


100%|██████████| 4.97M/4.97M [00:00<00:00, 64.0MB/s]
100%|██████████| 57.0/57.0 [00:00<00:00, 88.2kB/s]
100%|██████████| 2.91G/2.91G [00:51<00:00, 61.0MB/s]


In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:

# Numerical constants
SR = 16000
FILE_2CHECK = 13
HOP_LEN = 160
N_FFT = 400
N_MELS = 128
N_MFCC = 13
N_EPOCHS = 50
BATCH_SIZE = 16
LEARNING_RATE = 0.0001
WEIGHT_DECAY = 0.001
N_CNN_LAYERS = 3
N_RNN_LAYERS = 5
RNN_DIM = 512
STRIDE = 2
HIDDEN_DIM = 256
EMBED_DIM = 300
NUM_CLASSES = 29
DROPOUT = 0.1

# TODO - need to change the following to the current path of the data, model,
# lexicon and tokens files.
CTC_MODEL_PATH = ''
DATA_PATH = ''

LEXICON_PATH = ''
TOKENS_PATH = ''


In [None]:
import numpy as np
from torch.nn.functional import pairwise_distance
from scipy.spatial.distance import cdist
from fastdtw import fastdtw

MAX_LEN = 102400


def extract_features(wavs):
    """
    Extract MFCC features from the given audios batch.
    More ideas: try Time Domain / STFT / Mel Spectrogram
    """
    mfcc_transform = torchaudio.transforms.MFCC(
        sample_rate=SR, n_mfcc=N_MFCC,
        melkwargs={'hop_length': HOP_LEN, 'n_fft': N_FFT, 'n_mels': N_MELS})
    mfcc_batch = mfcc_transform(wavs).squeeze()
    return mfcc_batch


class DTWModel:
    def __init__(self, x_train, y_train):
        self.x_train = extract_features(x_train)
        self.y_train = y_train

    def classify_using_DTW_distance(self, audio_files) -> tp.List[int]:
        """
        function to classify a given audio using DTW distance.
        audio_files: list of audio file paths or a a batch of audio files
         of shape [Batch, Channels, Time]
        return: list of predicted label for each batch entry
        """
        predictions = []

        for wav in tqdm(audio_files):
            wav = torch.cat([wav, torch.zeros((1, MAX_LEN-wav.size(1)))], dim=1)
            best_dist, best_label = float('inf'), None
            mfcc = extract_features(wav)
            for i, x in enumerate(self.x_train):
                # cur_dist = self.DTW_distance(mfcc[0], x[0])
                cur_dist = fastdtw(mfcc, x)[0]
                if cur_dist < best_dist:
                    best_dist, best_label = cur_dist, self.y_train[i]
            predictions.append(best_label)

        return predictions

    def add_data(self, x, y):
        wavs = []
        for wav in x:
            wav = torch.cat([wav, torch.zeros((1, MAX_LEN-wav.size(1)))], dim=1)
            wavs.append(wav)
        wavs = torch.stack(wavs)
        self.x_train = torch.cat([self.x_train, extract_features(wavs)])
        self.y_train = self.y_train + y

    @staticmethod
    def DTW_distance(x, y):
        n, m = len(x), len(y)
        dtw_mat = np.zeros((n, m))
        dtw_mat[0, 0] = torch.sum(pairwise_distance(x[0], y[0], p=2))

        for i in range(1, n):
            dtw_mat[i, 0] = torch.sum(pairwise_distance(x[i], y[0], p=2))\
                            +dtw_mat[i-1, 0]

        for j in range(1, m):
            dtw_mat[0, j] = torch.sum(pairwise_distance(x[0], y[j], p=2))\
                            +dtw_mat[0, j-1]

        for i in range(1, n):
            for j in range(1, m):
                cost = torch.sum(pairwise_distance(x[i], y[j], p=2))
                dtw_mat[i, j] = cost+min(dtw_mat[i-1, j],
                                         dtw_mat[i, j-1],
                                         dtw_mat[i-1, j-1])

        return dtw_mat[n-1, m-1]


class EuclideanModel:
    def __init__(self, x_train, y_train):
        self.x_train = extract_features(x_train)
        self.y_train = y_train

    def classify_using_euclidean_distance(self, audio_files) -> tp.List[int]:
        """
        function to classify a given audio using euclidean distance.
        audio_files: list of audio file paths or a a batch of audio files
         of shape [Batch, Channels, Time]
        return: list of predicted label for each batch entry
        """
        predictions = []

        for wav in tqdm(audio_files):
            wav = torch.cat([wav, torch.zeros((1, MAX_LEN-wav.size(1)))], dim=1)
            mfcc = extract_features(wav)
            best_dist, best_label = float('inf'), None
            for i, x in enumerate(self.x_train):
                cur_dist = torch.norm(mfcc - x)
                if cur_dist < best_dist:
                    best_dist, best_label = cur_dist, self.y_train[i]
            predictions.append(best_label)

        return predictions

    def add_data(self, x, y):
        wavs = []
        for wav in x:
            wav = torch.cat([wav, torch.zeros((1, MAX_LEN-wav.size(1)))], dim=1)
            wavs.append(wav)
        wavs = torch.stack(wavs)
        self.x_train = torch.cat([self.x_train, extract_features(wavs)])
        self.y_train = self.y_train + y


In [None]:
class Data:
    def __init__(self):
        print('Loading data...')
        self.data_dir = 'an4'
        self.x_train_paths, self.x_train, self.y_train = self.load_data('train')
        self.x_val_paths, self.x_val, self.y_val = self.load_data('val')
        self.x_test_paths, self.x_test, self.y_test = self.load_data('test')
        print('Data loaded successfully')

    def load_data(self, split):
        """
        Load the data from the provided 'an4' folder, and split it into train, dev, and test sets.
        """
        audio_dir = os.path.join(self.data_dir, split, 'an4', 'wav')
        transcript_dir = os.path.join(self.data_dir, split, 'an4', 'txt')

        audio_files = sorted(os.listdir(audio_dir))
        transcript_files = sorted(os.listdir(transcript_dir))

        audio_paths = [os.path.join(audio_dir, file) for file in audio_files]
        transcript_paths = [os.path.join(transcript_dir, file) for file in
                            transcript_files]

        audios, transcripts = [], []

        for audio_path, transcript_path in zip(audio_paths, transcript_paths):
            with open(transcript_path, 'r') as f:
                transcript = f.read().strip()
                transcripts.append(transcript)

        loaded_audios = [torchaudio.load(audio)[0] for audio in audio_paths]
        return audio_paths, loaded_audios, transcripts

    def get_data(self, split):
        if split == 'train':
            return self.x_train, self.y_train
        elif split == 'val':
            return self.x_val, self.y_val
        elif split == 'test':
            return self.x_test, self.y_test
        else:
            raise ValueError(f"Invalid data split '{split}'")

    @staticmethod
    def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None, ylim=None):
        waveform = waveform.numpy()

        num_channels, num_frames = waveform.shape
        time_axis = torch.arange(0, num_frames) / sample_rate

        figure, axes = plt.subplots(num_channels, 1)
        if num_channels == 1:
            axes = [axes]
        for c in range(num_channels):
            axes[c].plot(time_axis, waveform[c], linewidth=1)
            axes[c].grid(True)
            if num_channels > 1:
                axes[c].set_ylabel(f'Channel {c + 1}')
            if xlim:
                axes[c].set_xlim(xlim)
            if ylim:
                axes[c].set_ylim(ylim)
        figure.suptitle(title)
        plt.show(block=False)

    @staticmethod
    def extract_mel_spec(waveform):
        mel_specgram = torchaudio.transforms.MelSpectrogram(SR, n_mels=N_MELS)(waveform)
        return mel_specgram

    @staticmethod
    def plot_mfcc(mfcc, title="MFCC", xlim=None, ylim=None):
        fig, ax = plt.subplots()
        im = ax.imshow(mfcc, origin='lower', aspect='auto')
        fig.colorbar(im, ax=ax)
        ax.set(title=title, xlabel='Time', ylabel='MFCC')
        if xlim:
            ax.set_xlim(xlim)
        if ylim:
            ax.set_ylim(ylim)
        plt.show(block=False)

    @staticmethod
    def plot_mel_spec(mel_spec, title="Mel Spectrogram", xlim=None, ylim=None):
        mel_spec = mel_spec.squeeze(0).numpy()
        fig, ax = plt.subplots()
        im = ax.imshow(mel_spec, origin='lower', aspect='auto')
        fig.colorbar(im, ax=ax)
        ax.set(title=title, xlabel='Time', ylabel='Frequency (Hz)')
        if xlim:
            ax.set_xlim(xlim)
        if ylim:
            ax.set_ylim(ylim)
        plt.show(block=False)


def test_distance_algorithms(data):
    """
    Test the Distances algorithms (DTW & Euclidean) as the most naive
    implementations.
    """
    x_train, y_train = data.get_data('train')
    x_val, y_val = data.get_data('val')
    x_test, y_test = data.get_data('test')

    dtw = DTWModeזl(x_train, y_train)
    dtw.add_data(x_val, y_val)
    predictions_dtw = dtw.classify_using_DTW_distance(x_test)
    print('Predictions:')
    print(predictions_dtw[:5])
    print('True labels:')
    print(y_test[:5])
    print('Testing DTW algorithm...')
    wer_error = wer(y_test, predictions_dtw)
    cer_error = cer(y_test, predictions_dtw)
    print(f'DTW Test WER: {wer_error:.4f}')
    print(f'DTW Test CER: {cer_error:.4f}')
    print('DTW tested successfully')


Decoders

In [None]:
# Decoders
def greedy_decoder(output, labels, label_lengths,
                   blank_label=28, collapse_repeated=True):
    arg_maxes = torch.argmax(output, dim=2)
    decodes = []
    targets = []
    text_transform = TextTransform()
    for i, args in enumerate(arg_maxes):
        decode = []
        target = text_transform.int_to_text(labels[i][:label_lengths[i]].tolist())
        targets.append(target)
        for j, index in enumerate(args):
            if index != blank_label:
                if collapse_repeated and j != 0 and index == args[j -1]:
                    continue
                decode.append(index.item())
        pred = text_transform.int_to_text(decode)
        decodes.append(pred)

    return decodes, targets


def beam_decoder(output, labels, label_lengths, compare=True):
    output = output.to('cpu')
    beam_decoder = ctc_decoder(lexicon=LEXICON_PATH, tokens=TOKENS_PATH, lm=files.lm, blank_token='|',
                               sil_token='SPACE', lm_weight=1, beam_size=100,
                               word_score=-1)
    beam_search_result = beam_decoder(output.contiguous())
    text_transform = TextTransform()
    preds, actuals = [], []
    for i in range(output.shape[0]):
        pred = " ".join(beam_search_result[i][0].words).strip()
        preds.append(pred)

        if compare:
          actual = text_transform.int_to_text(labels[i][:label_lengths[i]].tolist())
          actuals.append(actual)

    return preds, actuals



Data processing

In [None]:
class TextTransform:
    def __init__(self):
        char_map_str = """
            ' 0
            SPACE 1
            a 2
            b 3
            c 4
            d 5
            e 6
            f 7
            g 8
            h 9
            i 10
            j 11
            k 12
            l 13
            m 14
            n 15
            o 16
            p 17
            q 18
            r 19
            s 20
            t 21
            u 22
            v 23
            w 24
            x 25
            y 26
            z 27
            """
        self.char_map = {}
        self.index_map = {}
        for line in char_map_str.strip().split('\n'):
            ch, index = line.split()
            self.char_map[ch] = int(index)
            self.index_map[int(index)] = ch
        self.index_map[1] = ' '

    def text_to_int(self, text):
        int_sequence = []
        for c in text:
            if c == ' ':
                ch = self.char_map['SPACE']
            else:
                ch = self.char_map[c]
            int_sequence.append(ch)
        return int_sequence

    def int_to_text(self, labels):
        string = []
        for i in labels:
            string.append(self.index_map[i])
        return ''.join(string).replace('SPACE', ' ')


In [None]:
# data loader
class AN4Dataset(Dataset):
    def __init__(self, split, transform=None, target_transform=None):
        self.transform = transform
        self.target_transform = target_transform

        audio_dir = os.path.join(DATA_PATH, split, 'an4', 'wav')
        transcript_dir = os.path.join(DATA_PATH, split, 'an4', 'txt')

        audio_files = sorted(os.listdir(audio_dir))
        transcript_files = sorted(os.listdir(transcript_dir))

        audio_paths = [os.path.join(audio_dir, file) for file in audio_files if file.endswith('wav')]
        transcript_paths = [os.path.join(transcript_dir, file) for file in
                            transcript_files if file.endswith('txt')]

        self.audios, self.transcripts = [], []

        for audio_path, transcript_path in zip(audio_paths, transcript_paths):
            with open(transcript_path, 'r') as f:
                transcript = f.read().strip()
                self.transcripts.append(transcript)

        self.loaded_audios = [torchaudio.load(audio)[0] for audio in audio_paths]


    def __len__(self):
        return len(self.loaded_audios)

    def __getitem__(self, idx):
        return self.loaded_audios[idx], self.transcripts[idx]


def data_processing(data, data_type="train"):
    inputs, inputs_lengths, labels, labels_length = [], [], [], []

    if data_type == "train":
        transform = nn.Sequential(
            torchaudio.transforms.MelSpectrogram(sample_rate=SR, n_mels=N_MELS),
            torchaudio.transforms.FrequencyMasking(freq_mask_param=30),
            torchaudio.transforms.TimeMasking(time_mask_param=100))
    else:
        transform = torchaudio.transforms.MelSpectrogram(sample_rate=SR, n_mels=N_MELS)

    text_transform = TextTransform()
    for (wav, transcript) in data:
        spec = transform(wav).squeeze(0).transpose(0, 1)
        inputs.append(spec)
        inputs_lengths.append(spec.shape[0] // 2)
        label = torch.Tensor(text_transform.text_to_int(str(transcript).lower()))
        labels.append(label)
        labels_length.append(len(label))

    max_length = (max(inputs_lengths) * 2) + 1

    # Pad tensors and create the big tensor
    spectrograms = torch.zeros((len(inputs), max_length, N_MELS))
    for i, tensor in enumerate(inputs):
        spectrograms[i, :tensor.shape[0], :] = tensor[:, :]

    spectrograms = spectrograms.unsqueeze(1).transpose(2, 3)
    labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)
    return spectrograms, labels, inputs_lengths, labels_length


Models

In [None]:
class BasicLSTMModel(nn.Module):

    def __init__(self):
        super(BasicLSTMModel, self).__init__()
        self.rnn = nn.LSTM(input_size=N_MELS, hidden_size=HIDDEN_DIM,
                           bidirectional=True, batch_first=True)

        self.fc = nn.Linear(HIDDEN_DIM*2, NUM_CLASSES)

    def forward(self, x):
        sizes = x.size()
        x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3])  # (batch, feature, time)
        x = x.transpose(1, 2)  # (batch, time, feature)
        rnn_output, _ = self.rnn(x)
        output = self.fc(rnn_output)

        return output



In [None]:
class AdvanceLSTMModel(nn.Module):
    def __init__(self):
        super(AdvanceLSTMModel, self).__init__()

        # RNN layers
        self.rnn = nn.LSTM(input_size=512, hidden_size=128,
                           num_layers=3, batch_first=True, bidirectional=True)

        self.layer_norm1 = nn.LayerNorm(512)
        self.layer_norm2 = nn.LayerNorm(256)
        self.layer_norm3 = nn.LayerNorm(128)
        self.layer_norm4 = nn.LayerNorm(5504)
        self.dropout = nn.Dropout(0.15)

        self.fc1 = nn.Linear(5504, 512)

        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3)

        # Fully connected layer
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 29)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        sizes = x.size()
        x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3])  # (batch, feature, time)
        x = x.transpose(1, 2)  # (batch, time, feature)
        x = self.layer_norm4(x)
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.layer_norm1(x)
        x, _ = self.rnn(x)
        x = self.dropout(x)
        x = self.layer_norm2(x)
        x = self.fc2(x)
        x = self.dropout(x)
        x = self.layer_norm3(x)
        output = self.fc3(x)

        return output

In [None]:
class CNNLayerNorm(nn.Module):
    def __init__(self, n_feats):
        super(CNNLayerNorm, self).__init__()
        self.layer_norm = nn.LayerNorm(n_feats)

    def forward(self, x):
        x = x.transpose(2, 3).contiguous()
        x = self.layer_norm(x)
        return x.transpose(2, 3).contiguous()


class ResidualCNN(nn.Module):
    def __init__(self, in_channels, out_channels, kernel, stride, dropout,
                 n_feats):
        super(ResidualCNN, self).__init__()

        self.cnn1 = nn.Conv2d(in_channels, out_channels, kernel, stride,
                              padding=kernel // 2)
        self.cnn2 = nn.Conv2d(out_channels, out_channels, kernel, stride,
                              padding=kernel // 2)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.layer_norm1 = CNNLayerNorm(n_feats)
        self.layer_norm2 = CNNLayerNorm(n_feats)

    def forward(self, x):
        residual = x
        x = self.layer_norm1(x)
        x = F.gelu(x)
        x = self.dropout1(x)
        x = self.cnn1(x)
        x = self.layer_norm2(x)
        x = F.gelu(x)
        x = self.dropout2(x)
        x = self.cnn2(x)
        x += residual
        return x

class BidirectionalGRU(nn.Module):
    def __init__(self, rnn_dim, hidden_size, dropout, batch_first):
        super(BidirectionalGRU, self).__init__()

        self.BiGRU = nn.GRU(
            input_size=rnn_dim, hidden_size=hidden_size,
            num_layers=1, batch_first=batch_first, bidirectional=True)
        self.layer_norm = nn.LayerNorm(rnn_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.layer_norm(x)
        x = F.gelu(x)
        x, _ = self.BiGRU(x)
        x = self.dropout(x)
        return x


class FinalASRModel(nn.Module):
    def __init__(self, n_cnn_layers, n_rnn_layers, rnn_dim, n_class,
                 n_feats, stride, dropout):
        super(FinalASRModel, self).__init__()
        n_feats = n_feats // 2
        self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2)
        self.rescnn_layers = nn.Sequential(*[
            ResidualCNN(32, 32, kernel=3, stride=1, dropout=dropout,
                        n_feats=n_feats)
            for _ in range(n_cnn_layers)
        ])
        self.fully_connected = nn.Linear(n_feats * 32, rnn_dim)
        self.birnn_layers = nn.Sequential(*[
            BidirectionalGRU(rnn_dim=rnn_dim if i == 0 else rnn_dim * 2,
                             hidden_size=rnn_dim, dropout=dropout, batch_first=i == 0)
            for i in range(n_rnn_layers)
        ])
        self.classifier = nn.Sequential(
            nn.Linear(rnn_dim * 2, rnn_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(rnn_dim, n_class)
        )

    def forward(self, x):
        x = self.cnn(x)
        x = self.rescnn_layers(x)
        sizes = x.size()
        x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3])
         # (batch, feature, time)
        x = x.transpose(1, 2)  # (batch, time, feature)
        x = self.fully_connected(x)
        x = self.birnn_layers(x)
        x = self.classifier(x)
        return x



Train and test loops

In [None]:
def save_model(model, path):
    """
    Saves a pytorch models to the given path.
    """
    torch.save(model.state_dict(), '{}'.format(path))


def load_model(model, path):
    """
    Loads a pytorch models from the given path. The models should already by
    created (e.g. by calling the constructor) and should be passed as an argument.
    """

    model.load_state_dict(torch.load('{}'.format(path)))


def test(model, test_loader, criterion, experiment, counter=0, isEval=False,
         decoder=greedy_decoder):
    model.eval()
    test_loss = 0
    t_loss, test_cer, test_wer = [], [], []
    with experiment.test():
        with torch.no_grad():
            for i, _data in enumerate(test_loader):
                spectrograms, labels, input_lengths, label_lengths = _data
                spectrograms, labels = spectrograms.to(device), labels.to(device)

                output = model(spectrograms)  # (batch, time, n_class)
                output = F.log_softmax(output, dim=2)
                output = output.transpose(0, 1) # (time, batch, n_class)

                loss = criterion(output, labels, input_lengths, label_lengths)
                t_loss.append(loss.item())

                decoded_preds, decoded_targets = decoder(output.transpose(0, 1),
                                                         labels, label_lengths)
                for j in range(len(decoded_preds)):
                    test_cer.append(cer(decoded_targets[j], decoded_preds[j]))
                    test_wer.append(wer(decoded_targets[j], decoded_preds[j]))

        avg_cer = sum(test_cer)/len(test_cer)
        avg_wer = sum(test_wer)/len(test_wer)
        avg_loss = sum(t_loss)/len(t_loss)

        if isEval:
          experiment.log_metric('test_loss', avg_loss, step=counter)
          experiment.log_metric('cer', avg_cer, step=counter)
          experiment.log_metric('wer', avg_wer, step=counter)
        else:
          print('Test set: Average loss: {:.4f}, Average CER: {:4f} Average WER: {:.4f}\n'.format(avg_loss, avg_cer, avg_wer))


def train(model, train_loader, val_loader, criterion, experiment):
    data_len = len(train_loader.dataset)
    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), LEARNING_RATE)
    scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LEARNING_RATE,
                                            steps_per_epoch=int(len(train_loader)),
                                            epochs=N_EPOCHS,
                                            anneal_strategy='linear')
    with experiment.train():
      counter = 0
      for epoch in range(N_EPOCHS):
          model.train()
          if (epoch + 1) % 10 == 0:
              save_model(model, CTC_MODEL_PATH)

          e_loss = 0
          losses = []
          for batch_idx, _data in enumerate(train_loader):
              spectrograms, labels, input_lengths, label_lengths = _data
              spectrograms, labels = spectrograms.to(device), labels.to(device)

              optimizer.zero_grad()
              output = model(spectrograms)  # (batch, time, n_class)
              output = F.log_softmax(output, dim=2)
              output = output.transpose(0, 1)  # (time, batch, n_class)

              loss = criterion(output, labels, input_lengths, label_lengths)
              loss.backward()

              # log it
              optimizer.step()
              scheduler.step()
              e_loss += loss.item()
              losses.append(loss.item())

          print(f"Train Epoch: {epoch}, loss = {e_loss}")
          experiment.log_metric('loss', sum(losses)/len(losses), step=epoch+1)
          experiment.log_metric('learning_rate', scheduler.get_last_lr(),
                                step=epoch+1)

          test(model, val_loader, criterion, experiment, epoch+1, True)

    save_model(model, CTC_MODEL_PATH)


Telemtry init

In [None]:
# setp up logging platform

from datetime import datetime

comet_api_key = "" # todo - add
project_name = "67455 Introduction to Speech Processing - Final project"
experiment_name = "" # todo - add

if comet_api_key:
  experiment = Experiment(api_key=comet_api_key, project_name=project_name, parse_args=False)
  experiment.set_name(experiment_name)
  experiment.display()
else:
  experiment = Experiment(api_key='dummy_key', disabled=True)



plotting

In [None]:
def parse_for_plot(dataset, model, num):
    transform = torchaudio.transforms.MelSpectrogram(sample_rate=SR, n_mels=N_MELS)
    for i in range(num):
        sample, target = dataset.__getitem__(i)
        spec = transform(sample).squeeze(0).transpose(0, 1)
        spectrograms = torch.stack([spec], dim=0)
        spectrograms = spectrograms.unsqueeze(1).transpose(2, 3)

        model.eval()
        spectrograms = spectrograms.to(device)
        output = model(spectrograms)
        output = output.to('cpu')
        beam_decoder = ctc_decoder(lexicon=LEXICON_PATH, tokens=TOKENS_PATH,
                                   lm=files.lm, blank_token='|',
                                  sil_token='SPACE')
        beam_search_result = beam_decoder(output.contiguous())
        pred = " ".join(beam_search_result[0][0].words).strip()
        tokens = [c for c in pred]
        tokens.insert(0, ' ')
        tokens.append(' ')

        plot_alignments(sample[0], output, tokens,
                        beam_search_result[0][0].timesteps, SR)

def plot_alignments(waveform, emission, tokens, timesteps, sample_rate):

    t = torch.arange(waveform.size(0)) / sample_rate
    ratio = waveform.size(0) / emission.size(1) / sample_rate

    chars = []
    words = []
    word_start = None
    for token, timestep in zip(tokens, timesteps * ratio):
        if token == " ":
            if word_start is not None:
                words.append((word_start, timestep))
            word_start = None
        else:
            chars.append((token, timestep))
            if word_start is None:
                word_start = timestep

    fig, axes = plt.subplots(1, 1)

    def _plot(ax, xlim):
        ax.plot(t, waveform)
        for token, timestep in chars:
            ax.annotate(token.upper(), (timestep, 0.4))
        for word_start, word_end in words:
            ax.axvspan(word_start, word_end, alpha=0.1, color="red")
        ax.set_ylim(-0.6, 0.7)
        ax.set_yticks([0])
        ax.grid(True, axis="y")

    _plot(axes, (0.0, (timesteps * ratio)[-1]+ 3/2))
    axes.set_xlabel("Time (Sec)")
    fig.show()


driver

In [None]:
def main():
    print('--- Start running ---')
    train_data_set = AN4Dataset('train')
    val_data_set = AN4Dataset('val')
    test_data_set = AN4Dataset('test')
    train_loader = DataLoader(dataset=train_data_set,
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              collate_fn=lambda x: data_processing(x, 'train'))
    val_loader = DataLoader(dataset=val_data_set,
                             batch_size=BATCH_SIZE,
                             shuffle=False,
                             collate_fn=lambda x: data_processing(x, 'val'))
    test_loader = DataLoader(dataset=test_data_set,
                             batch_size=BATCH_SIZE,
                             shuffle=False,
                             collate_fn=lambda x: data_processing(x, 'test'))

    ctc_lstm = FinalASRModel(N_CNN_LAYERS, N_RNN_LAYERS, RNN_DIM, NUM_CLASSES,
                             N_MELS, STRIDE,DROPOUT).to(device)

    lossFunc = nn.CTCLoss(blank=28, zero_infinity=True).to(device)

    # TODO - load model using the follwoing weight file in
    # "https://drive.google.com/file/d/1-914Naz8MyxPLNnzywZ0_WrG7tCnIRzv/view?usp=sharing"
    # note to change macros to the current locations.

    # load_model(ctc_lstm, CTC_MODEL_PATH)

    # train(ctc_lstm, train_loader, val_loader, lossFunc, experiment)
    # test(ctc_lstm, test_loader, lossFunc, experiment, decoder=beam_decoder)

    # plot aligments from test
    # parse_for_plot(test_data_set, ctc_lstm, 10)
    print('-- Finished running ---')

if __name__ == '__main__':
    main()


--- Start running ---
-- Finished running ---
