In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau
from torch.optim import RAdam

from torch.utils.data import Dataset, DataLoader

from torch.nn.utils import clip_grad_norm_

from torch.cuda.amp import autocast, GradScaler

from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, PackedSequence

import torchaudio
import torchaudio.functional as TAF
import torchaudio.transforms as T

from utils.dataset import CommonVoice
from utils.audio_utils import plot_waveform, play_audio
from utils.batch_utils import Collator
from utils.tokenizer import get_tokenizer

from datetime import datetime

In [2]:
from typing import Tuple

In [3]:
from utils.misc import get_summary, get_writer
from utils.grad_flow import *

In [4]:
from transformers import PreTrainedTokenizerFast

In [5]:
import os
import random
import pkbar

In [6]:
seed = 0 

g = torch.Generator()
g.manual_seed(seed)

torch.manual_seed(seed)

random.seed(seed)

In [7]:
print(torch.__version__)
print(torchaudio.__version__)

1.11.0
0.11.0


In [8]:
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [9]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [10]:
dataset_dir = 'data/external/cv-corpus-8.0-2022-01-19/en/'

tokenizer_file = 'data/tokenizer/trained_tokenizer.json'

trimmed_train_path = 'data/internal/sample_train.tsv'
# trimmed_train_path = 'data/internal/train_trimmed.tsv'

In [11]:
tokenizer, blank_token = get_tokenizer(tokenizer_file_path=tokenizer_file)

blank_token_id = tokenizer.vocab[blank_token]
vocab_size = tokenizer.vocab_size

In [12]:
train_data = CommonVoice(dataset_dir = dataset_dir, subset_path = trimmed_train_path, tokenizer = tokenizer, out_channels = 1)
dev_data = CommonVoice(dataset_dir = dataset_dir, subset_name = 'dev', tokenizer = tokenizer, out_channels = 1)

print(train_data)

print(dev_data)


    CommonVoice Dataset
    -------------------
    
    Loading None.tsv from /home/ashim/Projects/DeepSpeech/data/external/cv-corpus-8.0-2022-01-19/en directory.
        
    Number of Examples: 4192
    
    Args:
        Sampling Rate: 16000
        Output Channels: 1
    

    CommonVoice Dataset
    -------------------
    
    Loading dev.tsv from /home/ashim/Projects/DeepSpeech/data/external/cv-corpus-8.0-2022-01-19/en directory.
        
    Number of Examples: 16326
    
    Args:
        Sampling Rate: 16000
        Output Channels: 1
    


In [13]:
class Encoder(nn.Module):
    
    def __init__(self, 
                 encoder_input_dim: int = 80, 
                 num_heads: int = 4, 
                 ffn_dim: int = 80, 
                 num_layers: int = 4, 
                 depthwise_conv_kernel_size: int = 31, 
                 dropout: float = 0.3,
                 **args):
        
        super(Encoder, self).__init__()
        
        self.conformer = torchaudio.models.Conformer(input_dim = encoder_input_dim,
                                                     num_heads = num_heads,
                                                     ffn_dim = ffn_dim,
                                                     num_layers = num_layers,
                                                     depthwise_conv_kernel_size = depthwise_conv_kernel_size,
                                                     dropout = dropout)
                
    def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        
        x, x_lens = self.conformer.forward(x, x_lens)
        
        return x, x_lens

In [14]:
class RNNDecoder(nn.Module):
    
    def __init__(self, 
                 decoder_input_dim: int = 80, 
                 decoder_hidden_size: int = 256, 
                 num_layers: int = 1, 
                 bidirectional: bool = False, 
                 output_dim: int = None, 
                 **args):
        
        super(RNNDecoder, self).__init__()
        
        if output_dim == None:
            raise ValueError("Please specify the output size of the vocab.")
            
        directions = 2 if bidirectional == True else 1
            
        self.model = nn.GRU(input_size = decoder_input_dim, hidden_size = decoder_hidden_size, num_layers = num_layers, batch_first = False)
        self.ffn = nn.Linear(in_features = decoder_hidden_size * directions, out_features = output_dim)
                                
    def forward(self, x: torch.Tensor, hidden_state: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: 
        """
        Hidden state is needed, either in the form of encoder_hidden_state or decoder_hidden_state
        """
        
        if hidden_state == None:
            outputs, hidden_state = self.model(x)
        
        else:
            outputs, hidden_state = self.model(x, hidden_state)
        
        if isinstance(packed_encoder_outputs, PackedSequence):
            outputs, _ = pad_packed_sequence(outputs)
        
        outputs = self.ffn(outputs)
        
        return outputs, hidden_state

In [15]:
class Model(nn.Module):
    
    def __init__(self, encoder_input_dim: int = 80,
                encoder_num_heads: int = 4, 
                encoder_ffn_dim: int = 144, 
                encoder_num_layers: int = 16, 
                encoder_depthwise_conv_kernel_size: int = 31, 
                decoder_hidden_size:int = 80,
                decoder_num_layers: int = 2,
                bidirectional_decoder: bool = False,
                vocab_size: int = None,
                padding_idx: int = None,
                sos_token_id: int = None):
        
        super(Model, self).__init__()
        
        self.encoder = Encoder(input_dim = encoder_input_dim,
                              num_heads = encoder_num_heads,
                              ffn_dim = encoder_ffn_dim,
                              depthwise_conv_kernel_size = encoder_depthwise_conv_kernel_size)
        
        self.decoder = RNNDecoder(input_dim = encoder_input_dim,
                                  hidden_size = decoder_hidden_size,
                                  num_layers = decoder_num_layers,
                                  bidirectional = bidirectional_decoder,
                                  output_dim = vocab_size)
        
        self.sos_token_id = sos_token_id
        
    def forward(self, x: torch.Tensor, x_lens: torch.Tensor):
        
        decoded = []
        
        bsz, msl, hdz = x.shape ##batch_size, max sequence length, hidden dimension size

        encoder_outputs = self.encoder(x, x_lens)
                
        decoder_inputs = encoder_outputs
        
        ## Start with the <sos> token
        x = torch.LongTensor([self.sos_token_id]).repeat(bsz).reshape(bsz, 1).to(device)

        for t in range(msl):
            
            if t == 0:
                decoder_output, decoder_hidden_state = self.decoder(x = decoder_inputs)            
            else:
                decoder_output, decoder_hidden_state = self.decoder(x = decoder_inputs, hidden_state = decoder_hidden_state)
            
            word = F.log_softmax(decoder_output, dim = -1) ## have to do log_softmax for CTC Loss
            
            topv, topi = decoder_output.topk(1)
            
            x = topv.squeeze().detach()
            
            decoded.append(topv)
            
        return encoder_outputs, torch.stack(decoded)

In [16]:
model_params = {
    'encoder_input_dim': 80,
    'encoder_num_heads': 4,
    'encoder_ffn_dim': 144,
    'encoder_num_layers': 16,
    'decoder_input_dim': 80,
    'decoder_hidden_size': 256,
    'decoder_num_layers': 1,
    'decoder_hidden_size': 320,
    'padding_idx': tokenizer.pad_token_id,
    'sos_token_id': tokenizer.bos_token_id,
    'vocab_size': vocab_size
}

collator = Collator(tokenizer)
BATCH_SIZE = 32

train_loader = DataLoader(train_data, 
                          batch_size = BATCH_SIZE, 
                          collate_fn=collator, 
                          shuffle=True, 
                          pin_memory = True, 
                          num_workers = 8, 
                          worker_init_fn = collator.seed_worker, 
                          generator = g)

fp16 = False
scaler = GradScaler()

In [17]:
encoder = Encoder(**model_params).to(device)
decoder = RNNDecoder(**model_params, output_dim = vocab_size).to(device)

In [18]:
get_summary(encoder, dataloader = train_loader)

Layer (type:depth-idx)                                            Output Shape     Param #
Encoder                                                           --               --
├─Conformer: 1-1                                                  --               --
│    └─ModuleList: 2-1                                            --               --
│    │    └─ConformerLayer: 3-1                                   [92, 32, 80]     74,800
│    │    └─ConformerLayer: 3-2                                   [92, 32, 80]     74,800
│    │    └─ConformerLayer: 3-3                                   [92, 32, 80]     74,800
│    │    └─ConformerLayer: 3-4                                   [92, 32, 80]     74,800
Total params: 221,440
Trainable params: 221,440
Non-trainable params: 0
Total mult-adds (M): 268.89
Input size (MB): 0.94
Forward/backward pass size (MB): 105.51
Params size (MB): 0.89
Estimated Total Size (MB): 107.34

In [19]:
## CTC loss should be computed after the encoder outputs the probabilities

## Decoding part is usually decoupled from encoding part

In [20]:
base_log_dir = 'logs/'
writer = get_writer(base_log_dir=base_log_dir)

## BATCH_FIRST = FALSE

In [21]:
MAX_NORM = 0.5

In [22]:
EPOCHS = 50
lr = 5.0  # learning rate

num_batches = len(train_loader)

# enc_optim = torch.optim.AdamW(encoder.parameters(), lr = lr)
# dec_optim = torch.optim.AdamW(decoder.parameters(), lr = lr)

# enc_optim = RAdam(encoder.parameters(), lr = lr)
# dec_optim = RAdam(decoder.parameters(), lr = lr)

enc_optim = torch.optim.SGD(encoder.parameters(), lr = lr)
dec_optim = torch.optim.SGD(decoder.parameters(), lr = lr)

enc_scheduler_plateau = ReduceLROnPlateau(enc_optim, mode = 'min', patience = 2)
dec_scheduler_plateau = ReduceLROnPlateau(dec_optim, mode = 'min', patience = 2)


enc_scheduler_stepLR = torch.optim.lr_scheduler.StepLR(enc_optim, 1.0, gamma=0.95)
dec_scheduler_stepLR = torch.optim.lr_scheduler.StepLR(dec_optim, 1.0, gamma=0.95)

loss_fn = nn.CTCLoss(blank = blank_token_id, zero_infinity = True)

In [23]:
n_iter = 0

for epoch in range(EPOCHS):
    
    kbar = pkbar.Kbar(target = num_batches, epoch = epoch, num_epochs=EPOCHS, width = 8, always_stateful=False)
    
    enc_optim.zero_grad(set_to_none=True)
    
    for idx, batch in enumerate(train_loader):
        
        # waveforms = batch['waveforms']
        # waveforms_lengths = batch['waveforms_lengths']
        
        sentences = batch['sentences'].to(device)
        sentence_lengths = batch['sentence_lengths'].to(device, dtype= torch.int32)

        melspecs = batch['melspecs'].to(device)
        melspecs_lengths = batch['melspecs_lengths'].to(device, dtype= torch.int32)
        
        melspecs = torch.transpose(melspecs, -1, -2) ## Changing to (batch, channel, time, n_mels) from (batch, channel, n_mels, time)
        
        encoder_outputs, encoder_output_lengths = encoder(melspecs, melspecs_lengths)
        encoder_outputs = encoder_outputs.transpose(1, 0)

        ##encoder_outputs: [seq_len, batch_size, hidden_dim]

        ## Packing is done to ensure model doesn't take the [PAD] token into consideration

        packed_encoder_outputs = pack_padded_sequence(encoder_outputs, 
                                          lengths = encoder_output_lengths.to(device = 'cpu', dtype=torch.int64), 
                                          batch_first = False, 
                                          enforce_sorted = False)

        decoder_outputs, decoder_hidden_state = decoder(packed_encoder_outputs)
        decoder_outputs = F.log_softmax(decoder_outputs, dim = -1)

        ctc_loss = loss_fn(log_probs = decoder_outputs, 
                           targets = sentences, 
                           input_lengths = melspecs_lengths, 
                           target_lengths=sentence_lengths)

        ctc_loss.backward()
        
        ## Plot Gradients every 10 steps
        if n_iter % 10 == 0:
        
            enc_grad_flow_fig = plot_grad_flow_v2(encoder.named_parameters())
            dec_grad_flow_fig = plot_grad_flow_v2(decoder.named_parameters())

        clip_grad_norm_(encoder.parameters(), max_norm = MAX_NORM)
        clip_grad_norm_(decoder.parameters(), max_norm = MAX_NORM)

        enc_optim.step()
        dec_optim.step()

        enc_scheduler_stepLR.step()
        dec_scheduler_stepLR.step()
                    
        writer.add_scalar('CTC Loss/train', ctc_loss.detach().cpu().item(), n_iter)
        writer.add_figure('Average Gradients/Encoder', enc_grad_flow_fig, global_step = n_iter, close = True)
        writer.add_figure('Average Gradients/Decoder', dec_grad_flow_fig, global_step = n_iter, close = True)
        
        kbar.update(idx, values = [("ctc_loss", ctc_loss.detach().cpu().item())])
        
        n_iter += 1
    
    ## At epoch end
    
    enc_scheduler_plateau.step(ctc_loss)
    dec_scheduler_plateau.step(ctc_loss)

    
    print("\n")

Epoch: 1/50

Epoch: 2/50

Epoch: 3/50

Epoch: 4/50

Epoch: 5/50

Epoch: 6/50

Epoch: 7/50

Epoch: 8/50

Epoch: 9/50

Epoch: 10/50

Epoch: 11/50

Epoch: 12/50

Epoch: 13/50

Epoch: 14/50

Epoch: 15/50

Epoch: 16/50

Epoch: 17/50

Epoch: 18/50

Epoch: 19/50

Epoch: 20/50

Epoch: 21/50

Epoch: 22/50

KeyboardInterrupt: 

In [25]:
sample = train_data.__getitem__(1)
melspec = sample['melspec'].to(device)
melspec_len = torch.Tensor([melspec.shape[-1]]).to(device)

melspec= melspec.unsqueeze(0)
melspec = melspec.transpose(2, 1)

In [26]:
with torch.no_grad():
    y_preds, y_lens = encoder(melspec, melspec_len)
    
    y_preds = y_preds.transpose(1, 0)
    
    packed_y_preds = pack_padded_sequence(y_preds, 
                                  lengths = y_lens.to(device = 'cpu', dtype=torch.int64), 
                                  batch_first = False, 
                                  enforce_sorted = False)

    
    y_preds, hidden_state = decoder(packed_y_preds)
    y_preds = F.softmax(y_preds, dim = -1)
    
    y_preds = y_preds.transpose(0,1)

    y_preds = y_preds.argmax(dim = -1)
    y_preds = torch.unique_consecutive(y_preds)

In [27]:
sample['sentence']

'His fishing boat is struck by lightning and explodes into pieces, burns and sinks.'

In [28]:
y_preds

tensor([0], device='cuda:0')

In [29]:
tokenizer.decode(y_preds)

'[BLANK]'