#### Automatic Speech Recognition (ASR) with CTC

In [34]:
### Import the necessary packages

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import random

In [35]:
### Create the source vocabulary

# Default word tokens
BLK_token = 0  # Blank label
PAD_token = 1  # Used for padding short utterances

class ASR_Vocab(object):
    def __init__(self, digit_seqs):
        super(ASR_Vocab, self).__init__()
        self.digit2index = {}
        self.digit2count = {}
        self.index2digit = {BLK_token: "BLK", PAD_token: "PAD"}
        
        # Count SOS, EOS, PAD, EMP
        self.num_tokens  = 2
        self.digit_seqs  = digit_seqs
        
    def dig2idx(self, digit):
        if digit in self.digit2index:
            return self.digit2index[digit]
        
    def idx2dig(self, idx):
        if idx in self.index2digit:
            return self.index2digit[idx]
            
    def add_digit(self, digit):
        if digit in self.digit2index:
            self.digit2count[digit] += 1
            
        else:
            self.digit2index[digit] = self.num_tokens
            self.index2digit[self.num_tokens] = digit
            self.digit2count[digit] = 1
            self.num_tokens += 1
            
    def build_vocab(self):        
        for seq in self.digit_seqs:
            for digit in seq:      # Ignore EOS token
                self.add_digit(digit)
            
        # print("Vocabulary created with %d tokens ..." % self.num_tokens)
        # return self.num_tokens
    
    def vocab_size(self):
        return self.num_tokens
    
    def vocabulary(self):
        for idx in self.index2digit:
            print(idx, self.index2digit[idx])
    
    def encode(self, seq):
        return [self.dig2idx(digit) for digit in seq]
    
    def decode(self, seq):
        return "".join([self.idx2dig(idx) for idx in seq])

In [36]:
# Create a dictionary of audio files and their corresponding ASR labels
audio_df = pd.read_csv('./data/ASR/data.txt', sep=",")
audio_df = audio_df.drop(['Gender', 'Spk_ID', 'Utt_ID'], axis=1)
# print(audio_df.head())

# Create the vocabulary    
asr_labels = audio_df['Transcription'].values

print("Number of ASR utterances : ", len(asr_labels))
vocab_asr = ASR_Vocab(asr_labels)
vocab_asr.build_vocab()

# Test the vocabulary
print("\nOriginal sequence : ", asr_labels[4])
encoded_seq = vocab_asr.encode(asr_labels[4])
decoded_seq = vocab_asr.decode(encoded_seq)

print("Encoded sequence  : ", encoded_seq)
print("Decoded sequence  : ", decoded_seq)

Number of ASR utterances :  8511

Original sequence :  17914
Encoded sequence  :  [7, 5, 8, 7, 3]
Decoded sequence  :  17914


In [37]:
### Define the device

USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")
print(torch.cuda.get_device_name(torch.cuda.current_device()))

NVIDIA GeForce RTX 2070 Super with Max-Q Design


In [38]:
# Load the audio files
import torchaudio

n_mels = 32
mel_spectrogram = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=512, win_length=512, hop_length=256, n_mels=n_mels, f_min=0.0, f_max=None, pad=0, power=2.0, normalized=False)

audio_files = audio_df['Audio'].values
data_path = './data/ASR/train/'

mfccs = []
max_feat_len = 0
for i, audio_file in enumerate(audio_files):
    # Load the audio file
    audio_path = data_path + audio_file
    waveform, sample_rate = torchaudio.load(audio_path)
    
    # Convert to mono and resample to 16kHz
    waveform = waveform.mean(dim=0, keepdim=True)
    if sample_rate != 16000 : waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
    
    # Obtain and normalize the MFCCs
    mel_spec = mel_spectrogram(waveform)
    mel_spec = abs(mel_spec)
    mel_spec = (mel_spec - mel_spec.mean()) / mel_spec.std()
    mfccs.append(mel_spec)
    max_feat_len = max(max_feat_len, mel_spec.shape[2])
    
# Pad the MFCCs to the maximum length
for i, mfcc in enumerate(mfccs):
    mfccs[i] = F.pad(mfcc, (0, max_feat_len - mfcc.shape[2]), "constant", 0)
    
mfccs = torch.cat(mfccs, dim=0)
mfccs = mfccs.permute(0, 2, 1)
print(mfccs.shape)

torch.Size([8511, 322, 32])


In [39]:
# Create the output labels
labels = []
max_labl_len = 0
for i, label in enumerate(asr_labels):
    labels.append(vocab_asr.encode(label))
    max_labl_len = max(max_labl_len, len(labels[i]))
    
# Pad the labels to the maximum length
for i, label in enumerate(labels):
    labels[i] = label + [PAD_token] * (max_labl_len - len(label))
    
labels = torch.tensor(labels)
print(labels.shape)

torch.Size([8511, 7])


In [40]:
# Create the dataloader

class ASR_Dataset(Dataset):
    def __init__(self, mfccs, labels):
        self.mfccs = mfccs
        self.labels = labels
        
    def __len__(self):
        return len(self.mfccs)
    
    def __getitem__(self, idx):
        return self.mfccs[idx], self.labels[idx]
    
dataset = ASR_Dataset(mfccs, labels)

In [41]:
# Define the seq2seq model
input_size  = n_mels
hidden_size = 300
output_size = vocab_asr.vocab_size()

n_epochs = 100
batch_size = 64

In [54]:
# Define the CTC based ASR model

class CTC_ASR(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(CTC_ASR, self).__init__()
        self.input_size = input_size        # U
        self.hidden_size = hidden_size      # H
        self.output_size = output_size      # V
        
        # Convolutional layers
        self.conv1 = nn.Conv1d(input_size, input_size, 10, 2, 5)
        self.maxp1 = nn.MaxPool1d(2, 2)
        self.batn1 = nn.BatchNorm1d(input_size)
        
        # GRU layer
        self.U = nn.Linear(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, num_layers=2, bidirectional=True, batch_first=True)
        self.W = nn.Linear(hidden_size, output_size)
        
    def _init_hidden(self, batch_size):
        return torch.zeros(4, batch_size, self.hidden_size).to(device)
        
    def forward(self, inputs):
        hidden = self._init_hidden(inputs.shape[0])
        
        inputs = inputs.transpose(1, 2)                                 # B x U x L
        cnn_out = F.gelu(self.batn1(self.conv1(inputs)))                # B x U x L'        
        cnn_out = cnn_out.transpose(1, 2)                               # B x L' x U
        
        gru_in = F.relu(self.U(cnn_out))                                # B x L' x H
        outputs, hidden = self.gru(gru_in, hidden)                      # B x L' x 2H, 4 x B x H
        outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:] # B x L' x H
        
        outputs = self.W(outputs)                                       # B x L' x V
        # print(outputs.shape, hidden.shape)
        
        return outputs
    
    def predict(self, inputs):
        outputs = self.forward(inputs)
        outputs = F.softmax(outputs, dim=2)
        
        return torch.argmax(outputs, dim=2)

In [55]:
# Instantiate the model
ctc_asr = CTC_ASR(input_size, hidden_size, output_size).to(device)

print(ctc_asr.parameters)

<bound method Module.parameters of CTC_ASR(
  (conv1): Conv1d(32, 32, kernel_size=(10,), stride=(2,), padding=(5,))
  (maxp1): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (batn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (U): Linear(in_features=32, out_features=300, bias=True)
  (gru): GRU(300, 300, num_layers=2, batch_first=True, bidirectional=True)
  (W): Linear(in_features=300, out_features=13, bias=True)
)>


In [56]:
# Define the optimizer and loss function
learning_rate = 0.001

ctc_asr_optimizer = optim.Adam(ctc_asr.parameters(), lr=learning_rate)
criterion = nn.CTCLoss(blank=0, zero_infinity=False)

In [80]:
# Train the model
ctc_asr.train()
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

print("Training started ...")
for epoch in range(n_epochs):
    for i, data in enumerate(dataloader):
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        if inputs.size(0) != batch_size:
            continue
        
        # Zero the gradients
        ctc_asr.zero_grad()
        
        # Forward pass
        outputs = ctc_asr(inputs)
        
        # Compute the loss        
        loss = criterion(outputs.transpose(0, 1), labels, torch.tensor([inputs.shape[0]] * batch_size), torch.tensor([labels.shape[1]] * batch_size))
        
        loss.backward()
        ctc_asr_optimizer.step()
        print("Epoch: {}/{}, Step: {}/{}, Loss: {}".format(epoch+1, n_epochs, i+1, len(dataloader), np.round(loss.item(), 4)))
print("Training completed !!!")

Training started ...
Epoch: 1/100, Step: 1/133, Loss: 0.2228
Epoch: 1/100, Step: 2/133, Loss: 0.1175
Epoch: 1/100, Step: 3/133, Loss: 0.1588
Epoch: 1/100, Step: 4/133, Loss: 0.2672
Epoch: 1/100, Step: 5/133, Loss: 0.1646
Epoch: 1/100, Step: 6/133, Loss: 0.1789
Epoch: 1/100, Step: 7/133, Loss: 0.3478
Epoch: 1/100, Step: 8/133, Loss: 0.2332
Epoch: 1/100, Step: 9/133, Loss: 0.1251
Epoch: 1/100, Step: 10/133, Loss: 0.1056
Epoch: 1/100, Step: 11/133, Loss: 0.3321
Epoch: 1/100, Step: 12/133, Loss: 0.2779
Epoch: 1/100, Step: 13/133, Loss: 0.1892
Epoch: 1/100, Step: 14/133, Loss: 0.2318
Epoch: 1/100, Step: 15/133, Loss: 0.2783
Epoch: 1/100, Step: 16/133, Loss: 0.2489
Epoch: 1/100, Step: 17/133, Loss: 0.0533
Epoch: 1/100, Step: 18/133, Loss: 0.2193
Epoch: 1/100, Step: 19/133, Loss: 0.2725
Epoch: 1/100, Step: 20/133, Loss: 0.1055
Epoch: 1/100, Step: 21/133, Loss: 0.0934
Epoch: 1/100, Step: 22/133, Loss: 0.3188
Epoch: 1/100, Step: 23/133, Loss: 0.2413
Epoch: 1/100, Step: 24/133, Loss: 0.0782
Epoc

In [70]:
# Inference
def ctc_decode(output, vocab_asr):
    decoded_output = []
    prev = -1
    for i in range(len(output)):
        if output[i] != prev:
            decoded_output.append(output[i]) 
        prev = output[i]
    
    decoded_output = [x for x in decoded_output if x != 0 and x != 1]
    return vocab_asr.decode(decoded_output)

def inference(model, waveform, vocab_asr, max_feat_len=322):
    # Obtain and normalize the MFCCs
    mel_spec = mel_spectrogram(waveform)
    mel_spec = abs(mel_spec)
    mel_spec = (mel_spec - mel_spec.mean()) / mel_spec.std()
    
    mel_spec.unsqueeze_(0)
    # if mel_spec.shape[2] > max_feat_len:
    #     mel_spec = mel_spec[:, :, :max_feat_len]
    
    # Pad the MFCCs to the maximum length
    # else:
    #     mel_spec = F.pad(mel_spec, (0, max_feat_len - mel_spec.shape[2]), "constant", 0)

    mel_spec = mel_spec.transpose(1, 2)
    mel_spec = mel_spec.to(device)
    
    # Predict the output
    model.eval()
    output = model.predict(mel_spec).squeeze(0)
    output = output.tolist()
    
    return ctc_decode(output, vocab_asr)

In [95]:
# Test the model
ctc_asr.eval()
waveform, sample_rate = torchaudio.load("./test/68122Z7A.wav")
waveform = waveform.squeeze(0)

output_seq = inference(ctc_asr, waveform, vocab_asr)
print("Actual transcription : "+ "1Z88153")
print("Predicted transcription : "+ output_seq)

Actual transcription : 1Z88153
Predicted transcription : 68122Z7


In [96]:
# Save the model
torch.save(ctc_asr.state_dict(), "./versions/CTC_ASR/ctc_asr_4_100.pth")