In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau, StepLR
from torch.optim import RAdam

from torch.utils.data import Dataset, DataLoader

from torch.nn.utils import clip_grad_norm_

from torch.cuda.amp import autocast, GradScaler

from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, PackedSequence

import torchaudio

from src.model.model import *

from src.utils.dataset import CommonVoice
from src.utils.audio_utils import plot_waveform, play_audio
from src.utils.collate import Collator
from src.utils.tokenizer import get_tokenizer

from transformers import PreTrainedTokenizerFast

from datetime import datetime

In [2]:
from typing import Tuple, List, Dict

In [3]:
from src.utils.misc import get_summary, get_writer
from src.utils.grad_flow import *

In [4]:
import os
import random
import pkbar

import math

In [5]:
seed = 0 

g = torch.Generator()
g.manual_seed(seed)

torch.manual_seed(seed)

random.seed(seed)

In [6]:
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [7]:
dataset_dir = 'data/external/cv-corpus-8.0-2022-01-19/en/'

tokenizer_file = 'data/tokenizer/trained_tokenizer.json'

trimmed_train_path = 'data/internal/sample_train.tsv'
# trimmed_train_path = 'data/internal/train_trimmed.tsv'

In [8]:
tokenizer = get_tokenizer(tokenizer_file_path=tokenizer_file)

blank_token_id = tokenizer.vocab["[BLANK]"]
bos_token_id = tokenizer.vocab["[BOS]"]
eos_token_id = tokenizer.vocab["[EOS]"]
vocab_size = tokenizer.vocab_size

In [9]:
train_data = CommonVoice(dataset_dir = dataset_dir, subset_path = trimmed_train_path, tokenizer = tokenizer, out_channels = 1)
# train_data = CommonVoice(dataset_dir = dataset_dir, subset_name = 'train', tokenizer = tokenizer, out_channels = 1)

dev_data = CommonVoice(dataset_dir = dataset_dir, subset_name = 'dev', tokenizer = tokenizer, out_channels = 1)

print(train_data)

print(dev_data)


    CommonVoice Dataset
    -------------------
    
    Loading None.tsv from /home/ashim/Projects/DeepSpeech/data/external/cv-corpus-8.0-2022-01-19/en directory.
        
    Number of Examples: 4192
    
    Args:
        Sampling Rate: 16000
        Output Channels: 1
    

    CommonVoice Dataset
    -------------------
    
    Loading dev.tsv from /home/ashim/Projects/DeepSpeech/data/external/cv-corpus-8.0-2022-01-19/en directory.
        
    Number of Examples: 16326
    
    Args:
        Sampling Rate: 16000
        Output Channels: 1
    


In [10]:
model_params = {
    'encoder_input_size': 80,
    'conformer_num_heads': 4,
    'conformer_ffn_size': 512,
    'conformer_num_layers': 16,
    'conformer_conv_kernel_size': 31,
    'encoder_rnn_hidden_size': 1024,
    'encoder_rnn_num_layers': 1,
    'encoder_rnn_bidirectional': True,
    'decoder_embedding_size': 300,
    'decoder_hidden_size': 1024,
    'decoder_num_layers': 1,
    'decoder_attn_size': 144,
    'dropout': 0.3,
    'padding_idx': tokenizer.pad_token_id,
    'sos_token_id': tokenizer.bos_token_id,
    'eos_token_id': tokenizer.eos_token_id,
    'vocab_size': vocab_size,
    'batch_first': True,
    'device': device,
}

In [11]:
collator = Collator(tokenizer)
BATCH_SIZE = 128

train_loader = DataLoader(train_data, 
                          batch_size = BATCH_SIZE, 
                          collate_fn=collator, 
                          shuffle=True, 
                          pin_memory = False, 
                          num_workers = 12, 
                          worker_init_fn = collator.seed_worker, 
                          generator = g)

fp16 = False
scaler = GradScaler()

In [12]:
model = Model(**model_params).to(device)



In [13]:
# get_summary(encoder, dataloader = train_loader)

In [14]:
## CTC loss should be computed after the encoder outputs the probabilities

## Decoding part is usually decoupled from encoding part

In [15]:
base_log_dir = 'logs/'
writer = get_writer(base_log_dir=base_log_dir)
# writer = None

In [16]:
def predict_one_batch(model: torch.nn.Module, batch: Dict, max_len: int = 50):
    model.eval()
    
    melspecs = batch['melspecs'].to(device).squeeze(0)
    melspecs_lengths = batch['melspecs_lengths'].to(device, dtype = torch.int32)
    
    sentences = batch['sentences'].to(device)
    sentence_lengths = torch.LongTensor(batch['sentence_lengths']).to(device=device)    
    
    with torch.no_grad():
        ##Change from [batch, feats, seq_len] to [batch, seq_len,featrs]
        predicted_tensor = model.forward(melspecs.permute(0,2,1), melspecs_lengths, sentences, sentence_lengths)
        y_ids = predicted_tensor.argmax(dim = -1)
        y_pred = torch.unique_consecutive(y_ids, dim = 1)
        
        y_pred = tokenizer.batch_decode(y_pred)
    
    y_true = tokenizer.batch_decode(sentences)
    return y_true, y_pred

In [17]:
def train_step(batch: List[Dict[str, torch.Tensor]], n_iter: int, MAX_NORM: float = 0.5, plot_gradients: bool = True):
    
    model.train()
    optimizer.zero_grad(set_to_none = True)
    
    sentences = batch['sentences'].to(device)
    sentence_lengths = batch['sentence_lengths'].to(device, dtype = torch.int32)

    melspecs = batch['melspecs'].to(device)
    melspecs_lengths = batch['melspecs_lengths'].to(device, dtype = torch.int32)

    melspecs = torch.transpose(melspecs, -1, -2) ## Changing to (batch, channel, time, n_mels) from (batch, channel, n_mels, time)

    predicted_tensor = model.forward(melspecs, melspecs_lengths, sentences, sentence_lengths)
    
    y_pred = predicted_tensor.reshape(-1, vocab_size)
    y_true = sentences[:, 1:].reshape(-1)  ##Skip the <sos> token from the targets 
        
    loss = criterion( input = y_pred, target = y_true )       
    
    loss.backward()

    ## Plot Gradients every 10 steps
    if n_iter % 10 == 0 and plot_gradients == True:

        grad_flow_fig = plot_grad_flow_v2(model.named_parameters())
    
    else:
        grad_flow_fig = None
    
    ## Gradient Clipping for exploding gradients
    clip_grad_norm_(model.parameters(), max_norm = MAX_NORM)

    ## Step the optimizers
    optimizer.step()

    ## Step the schedulers
    # batch_end_scheduler.step()

    return loss.detach().cpu().item(), grad_flow_fig

In [18]:
%time
try:
    samples
except NameError as e:
    samples = next(iter(train_loader))

CPU times: user 2 µs, sys: 1 µs, total: 3 µs
Wall time: 3.34 µs


In [19]:
EPOCHS = 200
lr = 0.01 # learning rate

MAX_NORM = 0.5

num_batches = len(train_loader)

# optimizer = torch.optim.SGD(model.parameters(), lr = lr)
optimizer = torch.optim.RAdam(model.parameters(), lr = lr)
# optim = torch.optim.RAdam(model.parameters())

epoch_end_scheduler = ReduceLROnPlateau(optimizer, mode = 'min', patience = 2)

cawr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=2, T_mult=2)

# batch_end_scheduler = StepLR(optimizer, 1.0, gamma=0.95)

# criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
criterion = nn.NLLLoss(ignore_index=tokenizer.pad_token_id)
# criterion = nn.CTCLoss(blank = tokenizer.vocab['[BLANK]'], 
#                      zero_infinity=True)

In [20]:
# with torch.no_grad():
#     for idx, batch in enumerate(train_loader):
#         batch = batch
#         sentences = batch['sentences'].to(device)
#         sentence_lengths = batch['sentence_lengths'].to(device, dtype = torch.int32)

#         melspecs = batch['melspecs'].to(device)
#         melspecs_lengths = batch['melspecs_lengths'].to(device, dtype = torch.int32)

#         melspecs = torch.transpose(melspecs, -1, -2) ## Changing to (batch, channel, time, n_mels) from (batch, channel, n_mels, time)

#         predicted_tensor = model.forward(melspecs, melspecs_lengths, sentences, sentence_lengths)

#         y_pred = predicted_tensor.reshape(-1, vocab_size)
#         y_true = sentences[:, 1:].reshape(-1)  ##Skip the <sos> token from the targets
        
#         break

In [21]:
n_iter = 0

y_true, _ = predict_one_batch(model, samples)
writer.add_text('y_true sentence', y_true[0], global_step = n_iter)

for epoch in range(EPOCHS):
    
    kbar = pkbar.Kbar(target = num_batches, epoch = epoch, num_epochs=EPOCHS, width = 8, always_stateful=False)
    
    for idx, batch in enumerate(train_loader):
        
        optimizer.zero_grad(set_to_none=True)

        loss, grad_flow_fig = train_step(batch, n_iter, plot_gradients=True)
        
        ## Write how sample is being predicted
        ##predict_one_batch uses no grad
        sample_true, sample_pred = predict_one_batch(model, samples)
        sample_true_0 = sample_true[0]
        sample_true_0 = sample_true_0[sample_true_0 != tokenizer.pad_token_id]
        
        writer.add_text('sentence predictions', f'true sentence: {sample_true_0}, predicted sentence: {sample_pred[0]}', global_step = n_iter)
        
        writer.add_scalar('CE Loss/train', loss, n_iter)
        
        if grad_flow_fig != None:
            
            writer.add_figure('Average Gradients/Model', grad_flow_fig, global_step = n_iter, close = True)

        kbar.update(idx, values = [("loss", loss), ("PPL", math.exp(loss))])

        n_iter += 1
    
    ## At epoch end
    
    # cawr_scheduler.step() ##cosine annealing with warm restarts
    epoch_end_scheduler.step(loss)
    
    print("\n")

Epoch: 1/200

Epoch: 2/200

Epoch: 3/200

Epoch: 4/200

Epoch: 5/200

Epoch: 6/200

Epoch: 7/200

Epoch: 8/200

Epoch: 9/200

Epoch: 10/200

Epoch: 11/200

Epoch: 12/200

Epoch: 13/200

Epoch: 14/200

Epoch: 15/200

Epoch: 16/200

Epoch: 17/200

Epoch: 18/200

Epoch: 19/200

Epoch: 20/200

Epoch: 21/200

Epoch: 22/200

Epoch: 23/200

Epoch: 24/200

Epoch: 25/200

Epoch: 26/200

Epoch: 27/200

Epoch: 28/200

Epoch: 29/200

Epoch: 30/200

Epoch: 31/200

Epoch: 32/200

Epoch: 33/200

Epoch: 34/200

Epoch: 35/200

Epoch: 36/200

Epoch: 37/200

Epoch: 38/200

Epoch: 39/200

Epoch: 40/200

Epoch: 41/200

Epoch: 42/200

Epoch: 43/200

Epoch: 44/200

Epoch: 45/200

Epoch: 46/200

Epoch: 47/200

Epoch: 48/200

Epoch: 49/200

Epoch: 50/200

Epoch: 51/200

Epoch: 52/200

Epoch: 53/200

Epoch: 54/200

Epoch: 55/200

Epoch: 56/200

Epoch: 57/200

Epoch: 58/200

Epoch: 59/200

Epoch: 60/200

Epoch: 61/200

Epoch: 62/200

Epoch: 63/200

Epoch: 64/200

Epoch: 65/200

Epoch: 66/200

Epoch: 67/200

Epoc