In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau

from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

from torch.cuda.amp import autocast, GradScaler

import torchaudio
import torchaudio.functional as TAF
import torchaudio.transforms as T

from utils.dataset import CommonVoice
from utils.audio_utils import plot_waveform, play_audio
from utils.batch_utls import Collator

from datetime import datetime

In [None]:
from transformers import PreTrainedTokenizerFast

In [None]:
import os
import random
import pkbar

In [None]:
seed = 0 

g = torch.Generator()
g.manual_seed(seed)

torch.manual_seed(seed)

random.seed(seed)

In [None]:
print(torch.__version__)
print(torchaudio.__version__)

In [None]:
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
datasetPATH = 'data/external/cv-corpus-8.0-2022-01-19/en/'
clipsPATH = os.path.join(datasetPATH, 'clips')

tokenizer_file = 'data/tokenizer/trained_tokenizer.json'

In [None]:
try: ##Check if tokenizer is defined
    tokenizer

except NameError as e: ## If tokenizer is not defined then initialize it
    tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file)

finally:
    special_tokens_dict = {'pad_token': '[PAD]',
                       'sep_token': '[SEP]',
                       'mask_token': '[MASK]'}
    
    tokenizer.add_special_tokens(special_tokens_dict)
    
    blank_token = "[PAD]"
    blank_token_id = tokenizer.vocab[blank_token]
    
    vocab_size = len(tokenizer)

In [None]:
train_data = CommonVoice(dataset_path = datasetPATH, split_type = 'train', tokenizer = tokenizer, out_channels = 1)
train_data

In [None]:
class Encoder(nn.Module):
    
    def __init__(self, 
                 input_dim: int = 128, 
                 num_heads: int = 4, 
                 ffn_dim: int = 128, 
                 num_layers: int = 4, 
                 depthwise_conv_kernel_size: int = 31, 
                 dropout: float = 0.3,
                 **args):
        
        super(Encoder, self).__init__()
        
        self.model = torchaudio.models.Conformer(input_dim = input_dim,
                                                 num_heads = num_heads,
                                                 ffn_dim = ffn_dim,
                                                 num_layers = num_layers,
                                                 depthwise_conv_kernel_size = depthwise_conv_kernel_size,
                                                 dropout = dropout)
        
    def forward(self, x: torch.Tensor, x_len: torch.Tensor) -> torch.Tensor:
        
        x, _ = self.model.forward(x, x_len)
        
        return x

In [None]:
class LSTMDecoder(nn.Module):
    
    def __init__(self, 
                 input_dim: int = 128, 
                 hidden_size: int = 256, 
                 num_layers: int = 2, 
                 bidirectional: bool = False, 
                 output_dim: int = None, 
                 padding_idx: int = None, 
                 **args):
        
        super(LSTMDecoder, self).__init__()
        
        if output_dim == None:
            raise ValueError("Please specify the output size of the vocab.")
            
        directions = 2 if bidirectional == True else 1
            
        self.model = nn.GRU(input_size = input_dim, hidden_size = hidden_size, num_layers = num_layers, batch_first = True)
        self.ffn = nn.Linear(in_features = hidden_size * directions, out_features = output_dim)
                                
    def forward(self, x: torch.Tensor, hidden_state: torch.Tensor = None) -> torch.Tensor: 
        """
        Hidden state is needed, either in the form of encoder_hidden_state or decoder_hidden_state
        """
        
        if hidden_state == None:
            outputs, hidden_state = self.model(x, hidden_state)
        
        else:
            outputs, hidden_state = self.model(x, hidden_state)
        
        outputs = F.glu(self.ffn(outputs))

        return outputs, hidden_state

In [None]:
class Model(nn.Module):
    
    def __init__(self, encoder_input_dim: int = 128,
                encoder_num_heads: int = 4, 
                encoder_ffn_dim: int = 128, 
                encoder_num_layers: int = 4, 
                encoder_depthwise_conv_kernel_size: int = 31, 
                decoder_hidden_size:int = 128,
                decoder_num_layers: int = 2,
                bidirectional_decoder: bool = False,
                vocab_size: int = None,
                padding_idx: int = None,
                sos_token_id: int = None):
        
        super(Model, self).__init__()
        
        self.encoder = Encoder(input_dim = encoder_input_dim,
                              num_heads = encoder_num_heads,
                              ffn_dim = encoder_ffn_dim,
                              depthwise_conv_kernel_size = encoder_depthwise_conv_kernel_size)
        
        self.decoder = LSTMDecoder(input_dim = encoder_input_dim,
                                  hidden_size = decoder_hidden_size,
                                  num_layers = decoder_num_layers,
                                  bidirectional = bidirectional_decoder,
                                  output_dim = vocab_size)
        
        self.sos_token_id = sos_token_id
        
    def forward(self, x: torch.Tensor, x_lens: torch.Tensor):
        
        decoded = []
        
        bsz, msl, hdz = x.shape ##batch_size, max sequence length, hidden dimension size

        encoder_outputs = self.encoder(x, x_lens)
                
        decoder_inputs = encoder_outputs
        
        ## Start with the <sos> token
        x = torch.LongTensor([self.sos_token_id]).repeat(bsz).reshape(bsz, 1).to(device)

        for t in range(msl):
            
            if t == 0:
                decoder_output, decoder_hidden_state = self.decoder(x = decoder_inputs)            
            else:
                decoder_output, decoder_hidden_state = self.decoder(x = decoder_inputs, hidden_state = decoder_hidden_state)
            
            word = F.log_softmax(decoder_output, dim = -1) ## have to do log_softmax for CTC Loss
            
            topv, topi = decoder_output.topk(1)
            
            x = topv.squeeze().detach()
            
            decoded.append(topv)
            
        return encoder_outputs, torch.stack(decoded)

In [None]:
model_params = {
    'encoder_input_dim': 128,
    'encoder_num_heads': 4,
    'encoder_ffn_dim': 128,
    'encoder_num_layers': 4,
    'decoder_num_layers': 1,
    'decoder_hidden_size': 64,
    'padding_idx': tokenizer.pad_token_id,
    'sos_token_id': tokenizer.bos_token_id
}

collator = Collator(tokenizer)
BATCH_SIZE = 16

train_loader = DataLoader(train_data, 
                          batch_size = BATCH_SIZE, 
                          collate_fn=collator, 
                          shuffle=True, 
                          pin_memory = True, 
                          num_workers = 6, 
                          worker_init_fn = collator.seed_worker, 
                          generator = g)

fp16 = False
scaler = GradScaler()

In [None]:
encoder = Encoder(**model_params).to(device)
decoder = nn.Linear(model_params['encoder_input_dim'], out_features= vocab_size, bias = False).to(device)

In [None]:
## CTC loss should be computed after the encoder outputs the probabilities

## Decoding part is usually decoupled from encoding part

In [None]:
base_log_dir = 'logs/'

start_time = datetime.now()
start_time_fmt = start_time.strftime("%d-%m-%Y %H:%M:%S")

run_log_dir = os.path.join(base_log_dir, start_time_fmt)

writer = SummaryWriter(log_dir=run_log_dir, comment = 'first_try_custom_tokenizer')

In [None]:
EPOCHS = 10

num_batches = len(train_loader)

enc_optim = torch.optim.AdamW(encoder.parameters(), lr = 3e-5)
dec_optim = torch.optim.AdamW(decoder.parameters(), lr = 3e-5)

enc_scheduler = ReduceLROnPlateau(enc_optim, mode = 'min')
dec_scheduler = ReduceLROnPlateau(dec_optim, mode = 'min')

loss_fn = nn.CTCLoss(blank = tokenizer.pad_token_id)

n_iter = 0

for epoch in range(EPOCHS):
    
    kbar = pkbar.Kbar(target = num_batches, epoch = epoch, num_epochs=EPOCHS, width = 8, always_stateful=False)
    
    enc_optim.zero_grad(set_to_none=True)
    
    for idx, batch in enumerate(train_loader):
        
        # waveforms = batch['waveforms']
        # waveforms_lengths = batch['waveforms_lengths']
        
        sentences = batch['sentences'].to(device)
        sentence_lengths = batch['sentence_lengths'].to(device, dtype= torch.int32)

        melspecs = batch['melspecs'].to(device)
        melspecs_lengths = batch['melspecs_lengths'].to(device, dtype= torch.int32)
        
        melspecs = torch.transpose(melspecs, -1, -2) ## Changing to (batch, channel, time, n_mels) from (batch, channel, n_mels, time)
        
        if fp16:
            
            with autocast():

                encoder_outputs = encoder(melspecs, melspecs_lengths)
                encoder_outputs = encoder_outputs.transpose(1, 0)

                decoder_outputs = decoder(encoder_outputs)
                decoder_outputs = F.log_softmax(decoder_outputs, dim = -1)
                
                decoder_outputs = decoder_outputs.to(dtype = torch.float32)
                
                melspecs_lengths = batch['melspecs_lengths'].to(dtype= torch.int32)
                sentence_lengths = batch['sentence_lengths'].to(dtype= torch.int32)

                ## CTC loss requires int32 and (T, B, L) shape for log_probabilities from decoder

                ctc_loss = loss_fn(log_probs = decoder_outputs, 
                                   targets = sentences, 
                                   input_lengths = melspecs_lengths, 
                                   target_lengths=sentence_lengths)
                
                scaler.scale(ctc_loss).backward()
                scaler.step(enc_optim)
                scaler.step(dec_optim)
                scaler.update()
                
        else:
            
            encoder_outputs = encoder(melspecs, melspecs_lengths)
            encoder_outputs = encoder_outputs.transpose(1, 0)

            decoder_outputs = decoder(encoder_outputs)
            decoder_outputs = F.log_softmax(decoder_outputs, dim = -1)

            ctc_loss = loss_fn(log_probs = decoder_outputs, 
                               targets = sentences, 
                               input_lengths = melspecs_lengths, 
                               target_lengths=sentence_lengths)

            ctc_loss.backward()

            enc_optim.step()
            dec_optim.step()
            
        enc_scheduler.step()
        dec_scheduler.step()
        
        writer.add_scalar('CTC Loss/train', ctc_loss.detach().cpu().item(), (idx * (epoch + 1)))
        
        kbar.update(idx, values = [("ctc_loss", ctc_loss.detach().cpu().item())])

In [None]:
sample = train_data.__getitem__(1)
melspec = sample['melspec'].to(device)
melspec_len = torch.Tensor([melspec.shape[-1]]).to(device)

melspec= melspec.unsqueeze(0)
melspec = melspec.transpose(2, 1)

In [None]:
with torch.no_grad():
    y_preds = encoder(melspec, melspec_len)
    
    y_preds = y_preds.transpose(1, 0)

    y_preds = decoder(y_preds)
    y_preds = F.log_softmax(y_preds, dim = -1)
    
    y_preds = y_preds.transpose(0,1)

    y_preds = y_preds.argmax(dim = -1)
    y_preds = torch.unique_consecutive(y_preds)

In [None]:
sample['sentence']

In [None]:
tokenizer.decode(y_preds)