In [1]:
# installing dependencies, though some dependencies are built-in in colab, we want to make sure that the dependencies are the same for the student's environment as well
!pip install tqdm==4.65.0
!pip install jiwer==3.0.1   
!pip install librosa==0.9.1
!pip install pandas==2.0.0rc

# download specific version of torch and torchaudio
!pip install torch==1.12.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116 

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting jiwer==3.0.1
  Downloading jiwer-3.0.1-py3-none-any.whl (21 kB)
Collecting rapidfuzz==2.13.7 (from jiwer==3.0.1)
  Downloading rapidfuzz-2.13.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m87.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rapidfuzz, jiwer
Successfully installed jiwer-3.0.1 rapidfuzz-2.13.7
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting librosa==0.9.1
  Downloading librosa-0.9.1-py3-none-any.whl (213 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m213.1/213.1 kB[0m [31m16.1 MB/s[0m eta [36m0:00:00[0m
Collecting resampy>=0.2.2 (from librosa==0.9.1)
  Down

In [2]:
"""
Custom Speech Dataset class to load the dataset

"""

import os
import pandas as pd
from typing import Tuple
import torch
import torchaudio


class CustomSpeechDataset(torch.utils.data.Dataset):
    
    """
    Custom torch dataset class to load the dataset 
    """
    
    def __init__(self, manifest_file: str, audio_dir: str, is_test_set: bool=False) -> None:

        """
        manifest_file: the csv file that contains the filename of the audio, and also the annotation if is_test_set is set to False
        audio_dir: the root directory of the audio datasets
        is_test_set: the flag variable to switch between loading of the train and the test set. Train set loads the annotation whereas test set does not
        """

        self.audio_dir = audio_dir
        self.is_test_set = is_test_set

        self.manifest = pd.read_csv(manifest_file)

        
    def __len__(self) -> int:
        
        """
        To get the number of loaded audio files in the dataset
        """

        return len(self.manifest)
    
    
    def __getitem__(self, index: int) -> Tuple[str, torch.Tensor]:

        """
        To get the values required to do the training
        """

        if torch.is_tensor(index):
            index.tolist()
            
        audio_path = self._get_audio_path(index)
        signal, sr = torchaudio.load(audio_path)
        
        if not self.is_test_set:
            annotation = self._get_annotation(index)
            return audio_path, signal, annotation
        
        return audio_path, signal
    
    
    def _get_audio_path(self, index: int) -> str:

        """
        Helper function to retrieve the audio path from the csv manifest file
        """
        
        path = os.path.join(self.audio_dir, self.manifest.iloc[index]['path'])

        return path
    
    
    def _get_annotation(self, index: int) -> str:

        """
        Helper function to retrieve the annotation from the csv manifest file
        """

        return self.manifest.iloc[index]['annotation']

In [3]:
"""
Transforms text by encoding the characters and decoding the integers corresponding to the characters
"""

from typing import List


class TextTransform:

    """
    Map characters to integers and vice versa (encoding/decoding)
    """
    
    def __init__(self) -> None:

        char_map_str = """
            <SPACE> 0
            A 1
            B 2
            C 3
            D 4
            E 5
            F 6
            G 7
            H 8
            I 9
            J 10
            K 11
            L 12
            M 13
            N 14
            O 15
            P 16
            Q 17
            R 18
            S 19
            T 20
            U 21
            V 22
            W 23
            X 24
            Y 25
            Z 26
        """
        
        self.char_map = {}
        self.index_map = {}
        
        for line in char_map_str.strip().split('\n'):
            ch, index = line.split()
            self.char_map[ch] = int(index)
            self.index_map[int(index)] = ch

        self.index_map[0] = ' '


    def get_char_len(self) -> int:

        """
        Gets the number of characters that are being encoded and decoded in the prediction
        Returns:
        --------
            the number of characters defined in the __init__ char_map_str
        """

        return len(self.char_map)
    

    def get_char_list(self) -> List[str]:

        """
        Gets the list of characters that are being encoded and decoded in the prediction
        
        Returns:
        -------
            a list of characters defined in the __init__ char_map_str
        """

        return list(self.index_map.values())
    

    def text_to_int(self, text: str) -> List[int]:

        """
        Use a character map and convert text to an integer sequence 
        Returns:
        -------
            a list of the text encoded to an integer sequence 
        """
        
        int_sequence = []
        for c in text:
            if c == ' ':
                ch = self.char_map['<SPACE>']
            else:
                ch = self.char_map[c]
            int_sequence.append(ch)

        return int_sequence
    

    def int_to_text(self, labels) -> str:

        """
        Use a character map and convert integer labels to an text sequence 
        
        Returns:
        -------
            the decoded transcription
        """
        
        string = []
        for i in labels:
            string.append(self.index_map[i])

        return ''.join(string).replace('<SPACE>', ' ')

In [4]:
"""
Decodes the logits into characters to form the final transciption using the greedy decoding approach
"""

import torch
from typing import List

class GreedyDecoder:

    """
    Decodes the logits into characters to form the final transciption using the greedy decoding approach
    """

    def __init__(self) -> None:
        pass


    def decode(
            self, 
            output: torch.Tensor, 
            labels: torch.Tensor=None, 
            label_lengths: List[int]=None, 
            collapse_repeated: bool=True, 
            is_test: bool=False
        ):
        
        """
        Main method to call for the decoding of the text from the predicted logits
        """
        
        text_transform = TextTransform()
        arg_maxes = torch.argmax(output, dim=2)
        decodes = []

        # refer to char_map_str in the TextTransform class -> only have index from 0 to 26, hence 27 represents the case where the character is decoded as blank (NOT <SPACE>)
        decoded_blank_idx = text_transform.get_char_len()

        if not is_test:
            targets = []

        for i, args in enumerate(arg_maxes):
            decode = []

            if not is_test:
                targets.append(text_transform.int_to_text(labels[i][:label_lengths[i]].tolist()))

            for j, char_idx in enumerate(args):
                if char_idx != decoded_blank_idx:
                    if collapse_repeated and j != 0 and char_idx == args[j-1]:
                        continue
                    decode.append(char_idx.item())
            decodes.append(text_transform.int_to_text(decode))

        return decodes, targets if not is_test else decodes

In [5]:
"""
building the model with adaption of deepspeech2 -> https://arxiv.org/abs/1512.02595

code adapted from https://towardsdatascience.com/customer-case-study-building-an-end-to-end-speech-recognition-model-in-pytorch-with-assemblyai-473030e47c7c
"""

import torch
import torch.nn.functional as F


class CNNLayerNorm(torch.nn.Module):
    
    """
    Layer normalization built for CNNs input
    """
    
    def __init__(self, n_feats: int) -> None:
        super(CNNLayerNorm, self).__init__()

        self.layer_norm = torch.nn.LayerNorm(n_feats)


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Input x of dimension -> (batch, channel, feature, time)
        """
        
        x = x.transpose(2, 3).contiguous() # (batch, channel, time, feature)
        x = self.layer_norm(x)

        return x.transpose(2, 3).contiguous() # (batch, channel, feature, time) 


class ResidualCNN(torch.nn.Module):

    """
    Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf except with layer norm instead of batch norm
    """
    
    def __init__(self, in_channels: int, out_channels: int, kernel: int, stride: int, dropout: float, n_feats: int) -> None:
        super(ResidualCNN, self).__init__()

        self.cnn1 = torch.nn.Conv2d(in_channels, out_channels, kernel, stride, padding=kernel//2)
        self.cnn2 = torch.nn.Conv2d(out_channels, out_channels, kernel, stride, padding=kernel//2)
        self.dropout1 = torch.nn.Dropout(dropout)
        self.dropout2 = torch.nn.Dropout(dropout)
        self.layer_norm1 = CNNLayerNorm(n_feats)
        self.layer_norm2 = CNNLayerNorm(n_feats)


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        
        """
        Model building for the Residual CNN layers
        
        Input x of dimension -> (batch, channel, feature, time)
        """

        residual = x
        x = self.layer_norm1(x)
        x = F.gelu(x)
        x = self.dropout1(x)
        x = self.cnn1(x)
        x = self.layer_norm2(x)
        x = F.gelu(x)
        x = self.dropout2(x)
        x = self.cnn2(x)
        x += residual

        return x # (batch, channel, feature, time)


class BidirectionalGRU(torch.nn.Module):

    """
    The Bidirectional GRU composite code block which will be used in the main SpeechRecognitionModel class
    """
    
    def __init__(self, rnn_dim: int, hidden_size: int, dropout: int, batch_first: int) -> None:
        super(BidirectionalGRU, self).__init__()

        self.BiGRU = torch.nn.GRU(
            input_size=rnn_dim, 
            hidden_size=hidden_size,
            num_layers=1, 
            batch_first=batch_first, 
            bidirectional=True
        )
        self.layer_norm = torch.nn.LayerNorm(rnn_dim)
        self.dropout = torch.nn.Dropout(dropout)


    def forward(self, x: torch.Tensor) -> torch.Tensor:

        """
        Transformation of the layers in the Bidirectional GRU block
        """

        x = self.layer_norm(x)
        x = F.gelu(x)
        x, _ = self.BiGRU(x)
        x = self.dropout(x)

        return x


class SpeechRecognitionModel(torch.nn.Module):

    """
    The main ASR Model that the main code will interact with
    """
    
    def __init__(self, n_cnn_layers, n_rnn_layers, rnn_dim, n_class, n_feats, stride=2, dropout=0.1) -> None:
        super(SpeechRecognitionModel, self).__init__()
        
        n_feats = n_feats//2
        self.cnn = torch.nn.Conv2d(1, 32, 3, stride=stride, padding=3//2)  # cnn for extracting heirachal features

        # n residual cnn layers with filter size of 32
        self.rescnn_layers = torch.nn.Sequential(*[
            ResidualCNN(
                in_channels=32, 
                out_channels=32, 
                kernel=3, 
                stride=1, 
                dropout=dropout, 
                n_feats=n_feats
            ) for _ in range(n_cnn_layers)
        ])
        self.fully_connected = torch.nn.Linear(n_feats*32, rnn_dim)
        self.birnn_layers = torch.nn.Sequential(*[
            BidirectionalGRU(
                rnn_dim=rnn_dim if i==0 else rnn_dim*2,
                hidden_size=rnn_dim, 
                dropout=dropout, 
                batch_first=i==0
            ) for i in range(n_rnn_layers)
        ])
        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(rnn_dim*2, rnn_dim),  # birnn returns rnn_dim*2
            torch.nn.GELU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(rnn_dim, n_class)
        )


    def forward(self, x: torch.Tensor) -> torch.Tensor:

        """
        Transformation of the layers in the ASR model block
        """

        x = self.cnn(x)
        x = self.rescnn_layers(x)
        sizes = x.size()
        x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3])  # (batch, feature, time)
        x = x.transpose(1, 2) # (batch, time, feature)
        x = self.fully_connected(x)
        x = self.birnn_layers(x)
        x = self.classifier(x)
        
        return x

In [6]:
"""
Data preprocessing and transformation of the audio files into melspectrogram
"""

import torch
import torchaudio


class DataProcessor:

    """
    Transforms the audio waveform tensors into a melspectrogram
    """

    def __init__(self) -> None:
        pass
    
    
    def _audio_transformation(self, is_train: bool=True):

        return torch.nn.Sequential(
                torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128),
                torchaudio.transforms.FrequencyMasking(freq_mask_param=30),
                torchaudio.transforms.TimeMasking(time_mask_param=100)
            ) if is_train else torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128)
    

    def data_processing(self, data, data_type='train'):

        """
        Process the audio data to retrieve the spectrograms that will be used for the training
        """

        text_transform = TextTransform()
        spectrograms = []
        input_lengths = []
        audio_path_list = []

        audio_transforms = self._audio_transformation(is_train=True) if data_type == 'train' else self._audio_transformation(is_train=False)

        if data_type != 'test':  
            labels = []
            label_lengths = []

            for audio_path, waveform, utterance in data:

                spec = audio_transforms(waveform).squeeze(0).transpose(0, 1)
                spectrograms.append(spec)
                label = torch.Tensor(text_transform.text_to_int(utterance))
                labels.append(label)
                input_lengths.append(spec.shape[0]//2)
                label_lengths.append(len(label))

            spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2, 3)
            labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True)
            return audio_path, spectrograms, labels, input_lengths, label_lengths

        else:
            for audio_path, waveform in data:

                spec = audio_transforms(waveform).squeeze(0).transpose(0, 1)
                spectrograms.append(spec)
                input_lengths.append(spec.shape[0]//2)
                audio_path_list.append(audio_path)

            spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2, 3)
            return audio_path_list, spectrograms, input_lengths

In [7]:
"""
The helper class for the training loop to do model training
"""

import torch
import torch.nn.functional as F
from jiwer import wer, cer


class IterMeter(object):

    """
    Keeps track of the total iterations during the training and validation loop
    """
    
    def __init__(self) -> None:
        self.val = 0


    def step(self):
        self.val += 1


    def get(self):
        return self.val
    

class TrainingLoop:

    """
    The main class to set up the training loop to train the model
    """

    def __init__(self) -> None:
        pass
    

    def train(self, model, device, train_loader, criterion, optimizer, scheduler, epoch, iter_meter) -> None:

        """
        Training Loop
        """
        
        model.train()
        data_len = len(train_loader.dataset)
        
        for batch_idx, _data in enumerate(train_loader):
            audio_path, spectrograms, labels, input_lengths, label_lengths = _data 
            spectrograms, labels = spectrograms.to(device), labels.to(device)

            optimizer.zero_grad()

            output = model(spectrograms)  # (batch, time, n_class)
            output = F.log_softmax(output, dim=2)
            output = output.transpose(0, 1) # (time, batch, n_class)

            loss = criterion(output, labels, input_lengths, label_lengths)
            loss.backward()

            optimizer.step()
            iter_meter.step()
            
            if batch_idx % 100 == 0 or batch_idx == data_len:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(spectrograms), data_len,
                    100. * batch_idx / len(train_loader), loss.item()))


    def dev(self, model, device, dev_loader, criterion, scheduler, epoch, iter_meter) -> None:

        """
        Validation Loop
        """
        
        print('\nevaluating...')
        model.eval()
        val_loss = 0
        test_cer, test_wer = [], []
        greedy_decoder = GreedyDecoder()
        
        with torch.no_grad():
            for i, _data in enumerate(dev_loader):
                audio_path, spectrograms, labels, input_lengths, label_lengths = _data 
                spectrograms, labels = spectrograms.to(device), labels.to(device)

                output = model(spectrograms)  # (batch, time, n_class)
                output = F.log_softmax(output, dim=2)
                output = output.transpose(0, 1) # (time, batch, n_class)

                loss = criterion(output, labels, input_lengths, label_lengths)
                val_loss += loss.item() / len(dev_loader)

                decoded_preds, decoded_targets = greedy_decoder.decode(output.transpose(0, 1), labels=labels, label_lengths=label_lengths, is_test=False)
                
                for j in range(len(decoded_preds)):
                    test_cer.append(cer(decoded_targets[j], decoded_preds[j]))
                    test_wer.append(wer(decoded_targets[j], decoded_preds[j]))

        avg_cer = sum(test_cer)/len(test_cer)
        avg_wer = sum(test_wer)/len(test_wer)
        
        scheduler.step(val_loss)

        print('Dev set: Average loss: {:.4f}, Average CER: {:4f} Average WER: {:.4f}\n'.format(val_loss, avg_cer, avg_wer))

In [8]:
# load and unzip the dataset from gdrive

import gdown
!gdown 1A_Gpv_tWmBecp9oExGfkUi5w87Vsayjg
!unzip Train.zip
!rm Train.zip

Downloading...
From: https://drive.google.com/uc?id=1A_Gpv_tWmBecp9oExGfkUi5w87Vsayjg
To: /content/Train.zip
100% 361M/361M [00:03<00:00, 110MB/s]
Archive:  Train.zip
   creating: Train/audio/
  inflating: Train/audio/train_00001.wav  
  inflating: Train/audio/train_00002.wav  
  inflating: Train/audio/train_00003.wav  
  inflating: Train/audio/train_00004.wav  
  inflating: Train/audio/train_00005.wav  
  inflating: Train/audio/train_00006.wav  
  inflating: Train/audio/train_00007.wav  
  inflating: Train/audio/train_00008.wav  
  inflating: Train/audio/train_00009.wav  
  inflating: Train/audio/train_00010.wav  
  inflating: Train/audio/train_00011.wav  
  inflating: Train/audio/train_00012.wav  
  inflating: Train/audio/train_00013.wav  
  inflating: Train/audio/train_00014.wav  
  inflating: Train/audio/train_00015.wav  
  inflating: Train/audio/train_00016.wav  
  inflating: Train/audio/train_00017.wav  
  inflating: Train/audio/train_00018.wav  
  inflating: Train/audio/train_00

In [9]:
import gdown
!gdown 1IUcHafdnPwVqeLKxbkdfRfNFMUC2E_Wh

Downloading...
From: https://drive.google.com/uc?id=1IUcHafdnPwVqeLKxbkdfRfNFMUC2E_Wh
To: /content/Train.csv
  0% 0.00/280k [00:00<?, ?B/s]100% 280k/280k [00:00<00:00, 136MB/s]


In [10]:
"""
Entry point of the code to do model training
"""

import os
import torch
from time import time
import random

# setting the random seed for reproducibility
SEED = 2022


def main(hparams, train_dataset, dev_dataset, saved_model_path) -> None:

    """
    The main method to call to do model training
    """ 

    use_cuda = torch.cuda.is_available()
    torch.manual_seed(SEED)
    
    data_processor = DataProcessor()
    iter_meter = IterMeter()
    text_transform = TextTransform()
    trainer = TrainingLoop()
    
    device = torch.device("cuda" if use_cuda else "cpu")
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=hparams['batch_size'],
        shuffle=True,
        collate_fn=lambda x: data_processor.data_processing(x, 'train'),
        **kwargs
    )
    
    dev_loader = torch.utils.data.DataLoader(
        dataset=dev_dataset,
        batch_size=hparams['batch_size'],
        shuffle=False,
        collate_fn=lambda x: data_processor.data_processing(x, 'dev'),
        **kwargs
    )

    model = SpeechRecognitionModel(
        hparams['n_cnn_layers'], 
        hparams['n_rnn_layers'], 
        hparams['rnn_dim'],
        hparams['n_class'], 
        hparams['n_feats'], 
        hparams['stride'], 
        hparams['dropout']
    ).to(device)

    print('Num Model Parameters', sum([param.nelement() for param in model.parameters()]))

    optimizer = torch.optim.AdamW(model.parameters(), hparams['learning_rate'], weight_decay=0.1)
    criterion = torch.nn.CTCLoss(blank=text_transform.get_char_len()).to(device)
    
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode='min', patience=3, verbose=True, factor=0.05)
    
    for epoch in range(1, hparams['epochs'] + 1):
        trainer.train(model, device, train_loader, criterion, optimizer, scheduler, epoch, iter_meter)
        trainer.dev(model, device, dev_loader, criterion, scheduler, epoch, iter_meter)
        
    # save the trained model
    torch.save(model.state_dict(), saved_model_path)


if __name__ == "__main__":

    MANIFEST_FILE_TRAIN = 'Train.csv'
    AUDIO_DIR_TRAIN = 'Train'
    SAVED_MODEL_PATH = 'model.pt'

    # simple check on the saved model path, will raise error if no directory found
    #if not os.path.exists(os.path.dirname(SAVED_MODEL_PATH)):
        #raise FileNotFoundError

    # loads the dataset
    dataset = CustomSpeechDataset(
        manifest_file=MANIFEST_FILE_TRAIN, 
        audio_dir=AUDIO_DIR_TRAIN, 
        is_test_set=False
    )

    data_list = list(dataset)
    random.shuffle(data_list)
    train_proportion = int(0.8 * len(dataset))
    dataset_train = data_list[:train_proportion]
    dataset_dev = data_list[train_proportion:]
    
    hparams = {
            "n_cnn_layers": 3,
            "n_rnn_layers": 5,
            "rnn_dim": 512,
            "n_class": 28, # 26 alphabets in caps + <SPACE> + blanks
            "n_feats": 128,
            "stride": 2,
            "dropout": 0.1,
            "learning_rate": 3e-4,
            "batch_size": 8,
            "epochs": 100
      }

    start_time = time()

    # start training the model
    main(
        hparams=hparams, 
        train_dataset=dataset_train, 
        dev_dataset=dataset_dev, 
        saved_model_path=SAVED_MODEL_PATH
    )
    
    end_time = time()
    
    print(f"Time taken for training: {(end_time-start_time)/(60*60)} hrs")
    

Num Model Parameters 23704860





evaluating...




Dev set: Average loss: 2.7433, Average CER: 0.931152 Average WER: 0.9464






evaluating...




Dev set: Average loss: 2.7130, Average CER: 0.900418 Average WER: 0.9345






evaluating...




Dev set: Average loss: 2.6901, Average CER: 0.896573 Average WER: 0.9405






evaluating...




Dev set: Average loss: 2.6799, Average CER: 0.891636 Average WER: 0.9389






evaluating...




Dev set: Average loss: 2.6559, Average CER: 0.883919 Average WER: 0.9285






evaluating...




Dev set: Average loss: 2.5318, Average CER: 0.896550 Average WER: 0.9405






evaluating...




Dev set: Average loss: 2.4768, Average CER: 0.684838 Average WER: 1.3236






evaluating...




Dev set: Average loss: 2.4123, Average CER: 0.678586 Average WER: 1.2636






evaluating...




Dev set: Average loss: 2.4367, Average CER: 0.722961 Average WER: 1.0641






evaluating...




Dev set: Average loss: 2.4200, Average CER: 0.690302 Average WER: 1.0663






evaluating...




Dev set: Average loss: 2.4657, Average CER: 0.696733 Average WER: 1.0954






evaluating...




Dev set: Average loss: 2.2472, Average CER: 0.635694 Average WER: 1.1435






evaluating...




Dev set: Average loss: 2.1574, Average CER: 0.618228 Average WER: 1.1165






evaluating...




Dev set: Average loss: 2.1037, Average CER: 0.609666 Average WER: 1.0495






evaluating...




Dev set: Average loss: 2.0126, Average CER: 0.587189 Average WER: 1.0772






evaluating...




Dev set: Average loss: 1.9220, Average CER: 0.569025 Average WER: 0.9942






evaluating...




Dev set: Average loss: 1.8339, Average CER: 0.557600 Average WER: 0.9781






evaluating...




Dev set: Average loss: 1.7562, Average CER: 0.527009 Average WER: 0.9892






evaluating...




Dev set: Average loss: 1.6494, Average CER: 0.498144 Average WER: 0.9239






evaluating...




Dev set: Average loss: 1.5737, Average CER: 0.465848 Average WER: 0.9090






evaluating...




Dev set: Average loss: 1.6051, Average CER: 0.474612 Average WER: 0.9185






evaluating...




Dev set: Average loss: 1.4624, Average CER: 0.443890 Average WER: 0.8617






evaluating...




Dev set: Average loss: 1.4023, Average CER: 0.424910 Average WER: 0.8485






evaluating...




Dev set: Average loss: 1.3258, Average CER: 0.401303 Average WER: 0.8060






evaluating...




Dev set: Average loss: 1.2549, Average CER: 0.383068 Average WER: 0.7806






evaluating...




Dev set: Average loss: 1.2215, Average CER: 0.372929 Average WER: 0.7736






evaluating...




Dev set: Average loss: 1.1637, Average CER: 0.353961 Average WER: 0.7396






evaluating...




Dev set: Average loss: 1.1045, Average CER: 0.337351 Average WER: 0.7229






evaluating...




Dev set: Average loss: 1.0741, Average CER: 0.323037 Average WER: 0.7013






evaluating...




Dev set: Average loss: 1.0306, Average CER: 0.313816 Average WER: 0.6889






evaluating...




Dev set: Average loss: 1.0078, Average CER: 0.305544 Average WER: 0.6629






evaluating...




Dev set: Average loss: 0.9467, Average CER: 0.290930 Average WER: 0.6553






evaluating...




Dev set: Average loss: 0.9024, Average CER: 0.280037 Average WER: 0.6244






evaluating...




Dev set: Average loss: 0.8785, Average CER: 0.265087 Average WER: 0.5965






evaluating...




Dev set: Average loss: 0.8301, Average CER: 0.252102 Average WER: 0.5783






evaluating...




Dev set: Average loss: 0.8320, Average CER: 0.255700 Average WER: 0.5772






evaluating...




Dev set: Average loss: 0.7760, Average CER: 0.236374 Average WER: 0.5466






evaluating...




Dev set: Average loss: 0.7439, Average CER: 0.225103 Average WER: 0.5422






evaluating...




Dev set: Average loss: 0.7258, Average CER: 0.222171 Average WER: 0.5372






evaluating...




Dev set: Average loss: 0.6713, Average CER: 0.204983 Average WER: 0.5000






evaluating...




Dev set: Average loss: 0.6545, Average CER: 0.197298 Average WER: 0.4991






evaluating...




Dev set: Average loss: 0.6387, Average CER: 0.190722 Average WER: 0.4923






evaluating...




Dev set: Average loss: 0.6135, Average CER: 0.184781 Average WER: 0.4783






evaluating...




Dev set: Average loss: 0.5906, Average CER: 0.177048 Average WER: 0.4603






evaluating...




Dev set: Average loss: 0.5490, Average CER: 0.167041 Average WER: 0.4491






evaluating...




Dev set: Average loss: 0.6214, Average CER: 0.187049 Average WER: 0.4921






evaluating...




Dev set: Average loss: 0.5139, Average CER: 0.154795 Average WER: 0.4188






evaluating...




Dev set: Average loss: 0.4785, Average CER: 0.144457 Average WER: 0.4066






evaluating...




Dev set: Average loss: 0.4706, Average CER: 0.141390 Average WER: 0.4034






evaluating...




Dev set: Average loss: 0.4609, Average CER: 0.136482 Average WER: 0.3984






evaluating...




Dev set: Average loss: 0.4422, Average CER: 0.130443 Average WER: 0.3811






evaluating...




Dev set: Average loss: 0.4282, Average CER: 0.130319 Average WER: 0.3846






evaluating...




Dev set: Average loss: 0.4038, Average CER: 0.119784 Average WER: 0.3654






evaluating...




Dev set: Average loss: 0.4134, Average CER: 0.120434 Average WER: 0.3611






evaluating...




Dev set: Average loss: 0.3929, Average CER: 0.117294 Average WER: 0.3617






evaluating...




Dev set: Average loss: 0.3811, Average CER: 0.111287 Average WER: 0.3373






evaluating...




Dev set: Average loss: 0.4076, Average CER: 0.118313 Average WER: 0.3632






evaluating...




Dev set: Average loss: 0.3581, Average CER: 0.102903 Average WER: 0.3234






evaluating...




Dev set: Average loss: 0.4176, Average CER: 0.121732 Average WER: 0.3729






evaluating...




Dev set: Average loss: 0.3352, Average CER: 0.096710 Average WER: 0.3145






evaluating...




Dev set: Average loss: 0.3372, Average CER: 0.095520 Average WER: 0.3056






evaluating...




Dev set: Average loss: 0.3171, Average CER: 0.091183 Average WER: 0.2983






evaluating...




Dev set: Average loss: 0.2999, Average CER: 0.084707 Average WER: 0.2782






evaluating...




Dev set: Average loss: 0.2921, Average CER: 0.082725 Average WER: 0.2795






evaluating...




Dev set: Average loss: 0.2947, Average CER: 0.084828 Average WER: 0.2767






evaluating...




Dev set: Average loss: 0.3014, Average CER: 0.084550 Average WER: 0.2848






evaluating...




Dev set: Average loss: 0.2754, Average CER: 0.081048 Average WER: 0.2742






evaluating...




Dev set: Average loss: 0.2634, Average CER: 0.075037 Average WER: 0.2568






evaluating...




Dev set: Average loss: 0.2469, Average CER: 0.070559 Average WER: 0.2411






evaluating...




Dev set: Average loss: 0.2514, Average CER: 0.069979 Average WER: 0.2349






evaluating...




Dev set: Average loss: 0.2462, Average CER: 0.072293 Average WER: 0.2464






evaluating...




Dev set: Average loss: 0.2471, Average CER: 0.068496 Average WER: 0.2430






evaluating...




Dev set: Average loss: 0.2362, Average CER: 0.070695 Average WER: 0.2503






evaluating...




Dev set: Average loss: 0.2377, Average CER: 0.067690 Average WER: 0.2334






evaluating...




Dev set: Average loss: 0.2188, Average CER: 0.061576 Average WER: 0.2175






evaluating...




Dev set: Average loss: 0.2393, Average CER: 0.064214 Average WER: 0.2281






evaluating...




Dev set: Average loss: 0.2373, Average CER: 0.065523 Average WER: 0.2301






evaluating...




Dev set: Average loss: 0.2264, Average CER: 0.063517 Average WER: 0.2197






evaluating...




Dev set: Average loss: 0.2028, Average CER: 0.058276 Average WER: 0.2106






evaluating...




Dev set: Average loss: 0.1955, Average CER: 0.054768 Average WER: 0.1963






evaluating...




Dev set: Average loss: 0.2118, Average CER: 0.059339 Average WER: 0.2123






evaluating...




Dev set: Average loss: 0.1892, Average CER: 0.051859 Average WER: 0.1871






evaluating...




Dev set: Average loss: 0.1864, Average CER: 0.051980 Average WER: 0.1922






evaluating...




Dev set: Average loss: 0.1815, Average CER: 0.050083 Average WER: 0.1848






evaluating...




Dev set: Average loss: 0.1818, Average CER: 0.050377 Average WER: 0.1853






evaluating...




Dev set: Average loss: 0.1709, Average CER: 0.046953 Average WER: 0.1725






evaluating...




Dev set: Average loss: 0.1787, Average CER: 0.048472 Average WER: 0.1780






evaluating...




Dev set: Average loss: 0.1649, Average CER: 0.046650 Average WER: 0.1709






evaluating...




Dev set: Average loss: 0.1869, Average CER: 0.051663 Average WER: 0.1832






evaluating...




Dev set: Average loss: 0.1788, Average CER: 0.049935 Average WER: 0.1824






evaluating...




Dev set: Average loss: 0.1730, Average CER: 0.048849 Average WER: 0.1814






evaluating...




Dev set: Average loss: 0.1586, Average CER: 0.042597 Average WER: 0.1582






evaluating...




Dev set: Average loss: 0.1565, Average CER: 0.042583 Average WER: 0.1593






evaluating...




Dev set: Average loss: 0.1454, Average CER: 0.040423 Average WER: 0.1524






evaluating...




Dev set: Average loss: 0.1521, Average CER: 0.040844 Average WER: 0.1514






evaluating...




Dev set: Average loss: 0.1496, Average CER: 0.040204 Average WER: 0.1495






evaluating...




Dev set: Average loss: 0.1527, Average CER: 0.041832 Average WER: 0.1544






evaluating...




Epoch 00098: reducing learning rate of group 0 to 1.5000e-05.
Dev set: Average loss: 0.1598, Average CER: 0.042422 Average WER: 0.1588






evaluating...




Dev set: Average loss: 0.1281, Average CER: 0.033173 Average WER: 0.1260






evaluating...




Dev set: Average loss: 0.1208, Average CER: 0.030714 Average WER: 0.1157

Time taken for training: 1.6951618633005354 hrs


In [11]:
# saving the model to gdrive
from google.colab import drive
drive.mount('/content/gdrive')

!cp model.pt "gdrive/My Drive/model.pt"

Mounted at /content/gdrive


In [12]:
import gdown
!gdown 1ukYdUG4k-Sf0GHmIxy7AYjyQW_sAOvVV
!unzip Test.zip
!rm Test.zip

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: Test/audio/evala_07001.wav  
  inflating: Test/audio/evala_07002.wav  
  inflating: Test/audio/evala_07003.wav  
  inflating: Test/audio/evala_07004.wav  
  inflating: Test/audio/evala_07005.wav  
  inflating: Test/audio/evala_07006.wav  
  inflating: Test/audio/evala_07007.wav  
  inflating: Test/audio/evala_07008.wav  
  inflating: Test/audio/evala_07009.wav  
  inflating: Test/audio/evala_07010.wav  
  inflating: Test/audio/evala_07011.wav  
  inflating: Test/audio/evala_07012.wav  
  inflating: Test/audio/evala_07013.wav  
  inflating: Test/audio/evala_07014.wav  
  inflating: Test/audio/evala_07015.wav  
  inflating: Test/audio/evala_07016.wav  
  inflating: Test/audio/evala_07017.wav  
  inflating: Test/audio/evala_07018.wav  
  inflating: Test/audio/evala_07019.wav  
  inflating: Test/audio/evala_07020.wav  
  inflating: Test/audio/evala_07021.wav  
  inflating: Test/audio/evala_07022.wav  
  inflating

In [13]:
import gdown
!gdown 1Si1tOCvTxR6_63omVBdjQSWT3z7dUNz9

Downloading...
From: https://drive.google.com/uc?id=1Si1tOCvTxR6_63omVBdjQSWT3z7dUNz9
To: /content/Test.csv
  0% 0.00/204k [00:00<?, ?B/s]100% 204k/204k [00:00<00:00, 120MB/s]


In [14]:
"""
Entry point of the code to do model inference, also the code to use to generate the submission
"""

import torch
import torch.nn.functional as F

from time import time
from typing import Dict
import pandas as pd
from tqdm import tqdm
import os

# setting the random seed for reproducibility
SEED = 2022


def infer(hparams, test_dataset, model_path) -> Dict[str, str]:
    
    print('\ngenerating inference ...')

    use_cuda = torch.cuda.is_available()
    torch.manual_seed(SEED)
    
    greedy_decoder = GreedyDecoder()
    data_processor = DataProcessor()
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=16,
        shuffle=False,
        collate_fn=lambda x: data_processor.data_processing(x, 'test'),
        **kwargs
    )
    
    # load the pretrained model
    model = SpeechRecognitionModel(
        hparams['n_cnn_layers'], 
        hparams['n_rnn_layers'], 
        hparams['rnn_dim'],
        hparams['n_class'], 
        hparams['n_feats'], 
        hparams['stride'], 
        hparams['dropout']
    ).to(device)
    
    model.load_state_dict(torch.load(model_path))
    model.eval()
    output_dict = {}
    
    with torch.no_grad():
        for i, _data in tqdm(enumerate(test_loader)):
            audio_path, spectrograms, input_lengths = _data
            spectrograms = spectrograms.to(device)
            output = model(spectrograms)  # (batch, time, n_class)
            output = F.log_softmax(output, dim=2)
            output = output.transpose(0, 1) # (time, batch, n_class) 
            decoded_preds_batch = greedy_decoder.decode(output.transpose(0, 1), labels=None, label_lengths=None, is_test=True)
            
            # batch prediction
            for decoded_idx in range(len(decoded_preds_batch[0])):
                output_dict[audio_path[decoded_idx]] = decoded_preds_batch[0][decoded_idx]

                
    print('done!\n')
    return output_dict


if __name__ == "__main__":

    # same hyperparams as what you have used to train the model
    hparams = {
            "n_cnn_layers": 3,
            "n_rnn_layers": 5,
            "rnn_dim": 512,
            "n_class": 28, # 26 alphabets in caps + <SPACE> + blanks
            "n_feats": 128,
            "stride": 2,
            "dropout": 0.1,
            "learning_rate": 3e-4,
            "batch_size": 8,
            "epochs": 100
    }

    # change the filepath as according
    SAVED_MODEL_PATH = 'model.pt'
    SUBMISSION_PATH = 'submission.csv' # or '/home/nicholas/models/til2023/Submission_Novice.csv' if novice tier

    MANIFEST_FILE_TEST = 'Test.csv' # or '/home/nicholas/datasets/til2023_asr_dataset/Test_Novice.csv' if novice tier 
    AUDIO_DIR_TEST = 'Test/audio' # or '/home/nicholas/datasets/til2023_asr_dataset/Test_Novice/' if novice tier
    
    dataset_test = CustomSpeechDataset(
        manifest_file=MANIFEST_FILE_TEST, 
        audio_dir=AUDIO_DIR_TEST, 
        is_test_set=True
    )

    start_time = time()

    submission_dict = infer(
        hparams=hparams, 
        test_dataset=dataset_test, 
        model_path=SAVED_MODEL_PATH
    )
    
    # producing the final csv file for submission
    submission_list = []

    for key in submission_dict:
        submission_list.append(
            {
                "path": os.path.basename(key),
                "annotation": submission_dict[key]
            }
        )

    submission_df = pd.DataFrame(submission_list)
    submission_df.to_csv(SUBMISSION_PATH, index=False)

    end_time = time()

    print(f"Time taken for inference: {(end_time-start_time)/60} min")

    


generating inference ...


750it [02:40,  4.68it/s]

done!

Time taken for inference: 2.679109724362691 min



