In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau, StepLR
from torch.optim import RAdam

from torch.utils.data import Dataset, DataLoader

from torch.nn.utils import clip_grad_norm_

from torch.cuda.amp import autocast, GradScaler

from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, PackedSequence

import torchaudio

from utils.model import *
from utils.attention import BadhanauAttention

from utils.dataset import CommonVoice
from utils.audio_utils import plot_waveform, play_audio
from utils.batch_utils import Collator
from utils.tokenizer import get_tokenizer

from transformers import PreTrainedTokenizerFast

from datetime import datetime

In [None]:
from typing import Tuple, List, Dict

In [None]:
from utils.misc import get_summary, get_writer
from utils.grad_flow import *

In [None]:
import os
import random
import pkbar

In [None]:
seed = 0 

g = torch.Generator()
g.manual_seed(seed)

torch.manual_seed(seed)

random.seed(seed)

In [None]:
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
dataset_dir = 'data/external/cv-corpus-8.0-2022-01-19/en/'

tokenizer_file = 'data/tokenizer/trained_tokenizer.json'

trimmed_train_path = 'data/internal/sample_train.tsv'
# trimmed_train_path = 'data/internal/train_trimmed.tsv'

In [None]:
tokenizer = get_tokenizer(tokenizer_file_path=tokenizer_file)

blank_token_id = tokenizer.vocab["[BLANK]"]
bos_token_id = tokenizer.vocab["[BOS]"]
vocab_size = tokenizer.vocab_size

In [None]:
train_data = CommonVoice(dataset_dir = dataset_dir, subset_path = trimmed_train_path, tokenizer = tokenizer, out_channels = 1)
# train_data = CommonVoice(dataset_dir = dataset_dir, subset_name = 'train', tokenizer = tokenizer, out_channels = 1)

dev_data = CommonVoice(dataset_dir = dataset_dir, subset_name = 'dev', tokenizer = tokenizer, out_channels = 1)

print(train_data)

print(dev_data)

In [None]:
model_params = {
    'encoder_input_size': 80,
    'conformer_num_heads': 4,
    'conformer_ffn_size': 512,
    'conformer_num_layers': 16,
    'conformer_conv_kernel_size': 31,
    'encoder_rnn_hidden_size': 1024,
    'encoder_rnn_num_layers': 1,
    'encoder_rnn_bidirectional': True,
    'decoder_embedding_size': 300,
    'decoder_hidden_size': 1024,
    'decoder_num_layers': 1,
    'decoder_attn_size': 144,
    'dropout': 0.3,
    'padding_idx': tokenizer.pad_token_id,
    'sos_token_id': tokenizer.bos_token_id,
    'vocab_size': vocab_size,
    'batch_first': True,
    'device': device,
}

In [None]:
collator = Collator(tokenizer)
BATCH_SIZE = 256

train_loader = DataLoader(train_data, 
                          batch_size = BATCH_SIZE, 
                          collate_fn=collator, 
                          shuffle=True, 
                          pin_memory = False, 
                          num_workers = 12, 
                          worker_init_fn = collator.seed_worker, 
                          generator = g)

fp16 = False
scaler = GradScaler()

In [None]:
model = Model(**model_params).to(device)

In [None]:
# get_summary(encoder, dataloader = train_loader)

In [None]:
## CTC loss should be computed after the encoder outputs the probabilities

## Decoding part is usually decoupled from encoding part

In [None]:
base_log_dir = 'logs/'
writer = get_writer(base_log_dir=base_log_dir)
# writer = None

In [None]:
def predict_one_batch(batch: Dict, max_len: int = 50):
    
    melspec = batch['melspecs'].to(device).squeeze(0)
    melspecs_lengths = batch['melspecs_lengths'].to(device, dtype = torch.int32)
    
    sentences = batch['sentences'].to(device)
    sentence_lengths = torch.LongTensor(batch['sentence_lengths']).to(device=device)    
    
    with torch.no_grad():
        ##Change from [batch, feats, seq_len] to [batch, seq_len,featrs]
        predicted_tensor = model.forward(melspec.permute(0,2,1), melspecs_lengths, sentences, sentence_lengths)
        y_ids = predicted_tensor.argmax(dim = -1)
        y_pred = torch.unique_consecutive(y_ids, dim = 1)
        
        y_pred = tokenizer.batch_decode(y_pred)
    
    y_true = tokenizer.batch_decode(sentences)
    return y_true, y_pred

In [None]:
%time
try:
    samples
except NameError as e:
    samples = next(iter(train_loader))

In [None]:
EPOCHS = 10
lr = 3e-5  # learning rate

MAX_NORM = 0.5

num_batches = len(train_loader)

# optim = torch.optim.SGD(model.parameters(), lr = lr)
# optim = torch.optim.RAdam(model.parameters(), lr = lr)
optim = torch.optim.RAdam(model.parameters())
scheduler_plateau = ReduceLROnPlateau(optim, mode = 'min', patience = 2)

scheduler_stepLR = StepLR(optim, 1.0, gamma=0.95)

# criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
criterion = nn.NLLLoss(ignore_index=tokenizer.pad_token_id)
# criterion = nn.CTCLoss(blank = tokenizer.vocab['[BLANK]'], 
#                      zero_infinity=True)

In [None]:
def train_step(batch: List[Dict[str, torch.Tensor]], n_iter: int, MAX_NORM: float = 0.5, plot_gradients: bool = True):
    
    optim.zero_grad(set_to_none = True)
    
    sentences = batch['sentences'].to(device)
    sentence_lengths = batch['sentence_lengths'].to(device, dtype = torch.int32)

    melspecs = batch['melspecs'].to(device)
    melspecs_lengths = batch['melspecs_lengths'].to(device, dtype = torch.int32)

    melspecs = torch.transpose(melspecs, -1, -2) ## Changing to (batch, channel, time, n_mels) from (batch, channel, n_mels, time)

    predicted_tensor = model.forward(melspecs, melspecs_lengths, sentences, sentence_lengths)
    
    y_pred = predicted_tensor.reshape(-1, vocab_size)
    y_true = sentences[:, 1:].reshape(-1)  ##Skip the <sos> token from the targets 
        
    loss = criterion( input = y_pred, target = y_true )       
    
    loss.backward()

    ## Plot Gradients every 10 steps
    if n_iter % 10 == 0 and plot_gradients == True:

        grad_flow_fig = plot_grad_flow_v2(model.named_parameters())
    
    else:
        grad_flow_fig = None
    
    ## Gradient Clipping for exploding gradients
    clip_grad_norm_(model.parameters(), max_norm = MAX_NORM)

    ## Step the optimizers
    optim.step()

    ## Step the schedulers
    scheduler_stepLR.step()

    return loss.detach().cpu().item(), grad_flow_fig

In [None]:
n_iter = 0

y_true, _ = predict_one_batch(samples)
writer.add_text('y_true sentence', y_true[0], global_step = n_iter)

for epoch in range(EPOCHS):
    
    kbar = pkbar.Kbar(target = num_batches, epoch = epoch, num_epochs=EPOCHS, width = 8, always_stateful=False)
    
    for idx, batch in enumerate(train_loader):
        
        optim.zero_grad(set_to_none=True)

        loss, grad_flow_fig = train_step(batch, n_iter, plot_gradients=True)
        
        ## Write how sample is being predicted
        ##predict_one_batch uses no grad
        _, sample_pred = predict_one_batch(samples)
        writer.add_text('y_pred sentence', sample_pred[0], global_step = n_iter)
        
        writer.add_scalar('CE Loss/train', loss, n_iter)
        
        if grad_flow_fig != None:
            
            writer.add_figure('Average Gradients/Model', grad_flow_fig, global_step = n_iter, close = True)

        kbar.update(idx, values = [("loss", loss)])

        n_iter += 1
    
    ## At epoch end
    
    scheduler_plateau.step(loss)
    
    print("\n")