In [2]:
%load_ext autoreload
%autoreload 2
import os, sys
sys.path.append("../")
import string
import json
import torch
import numpy as np
import pandas as pd
import pytorch_lightning as pl

from typing import Any
from loguru import logger
from tqdm import tqdm
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModelForSeq2SeqLM
from transformers import MarianTokenizer, MarianMTModel
from pytorch_lightning import LightningDataModule, LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import TQDMProgressBar

## Parse src_tgt training data

In [2]:
filepath = "../data/common_voice_nl/train.tsv"
df = pd.read_csv(filepath, sep="\t")

In [3]:
all_stt = []
for fname in tqdm(df.path.values):
    with open(f"../data/common_voice_nl/cv_nl_stt/{fname}.json", "r+") as f:
        stt_out = json.load(f)
        all_stt.append(stt_out['results']['channels'][0]['alternatives'][0]['transcript'])
df['stt_out'] = pd.Series(all_stt)
df['sentence'] = df.sentence.str.lower()
df['sentence'] = df.sentence.apply(lambda x: x.translate(str.maketrans('', '', string.punctuation)))
df.head()

100%|██████████| 29031/29031 [00:00<00:00, 32567.78it/s]


Unnamed: 0,client_id,path,sentence,up_votes,down_votes,age,gender,accents,locale,segment,stt_out
0,da4b6d09a23e8a83f83fec4e302a82c500d2821c4bb4d4...,common_voice_nl_30382934.mp3,een daadwerkelijke keuzevrijheid voor ouderen ...,2,0,,,Nederlands Nederlands,nl,,een daadwerkelijke keuzevrijheid voor ouderen ...
1,da4b6d09a23e8a83f83fec4e302a82c500d2821c4bb4d4...,common_voice_nl_30382935.mp3,elke kandidaatlidstaat moet op zijn eigen meri...,2,0,,,Nederlands Nederlands,nl,,elke kandidaat dit staat moet op zijn eigen wo...
2,da4b6d09a23e8a83f83fec4e302a82c500d2821c4bb4d4...,common_voice_nl_30382936.mp3,het verslag legt sterke nadruk op het nauwe ve...,2,0,,,Nederlands Nederlands,nl,,het verslag legt sterke nadruk op het nauwe ve...
3,da4b6d09a23e8a83f83fec4e302a82c500d2821c4bb4d4...,common_voice_nl_30382937.mp3,wij openen nu het algemeen debat,4,0,,,Nederlands Nederlands,nl,,we openen nu het algemeen debat
4,da4b6d09a23e8a83f83fec4e302a82c500d2821c4bb4d4...,common_voice_nl_30382938.mp3,die fase is gebaseerd op de testcyclus van per...,4,0,,,Nederlands Nederlands,nl,,die fase is gebaseerd op de test van personena...


## Test NMT based Encoder-Decoder

In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [5]:
idx = 5
noisy_seq = df.stt_out.iloc[:idx]
clean_seq = df.sentence.iloc[:idx]

In [6]:
# def Marianseq_infer(src_seq):
#     logger.info(f"Initial (Noisy) Sequence = {src_seq}")
#     tokenizer_nl_en = MarianTokenizer.from_pretrained("../model/opus-mt-nl-en")
#     nl_en = MarianMTModel.from_pretrained("../model/opus-mt-nl-en")
    
#     tokenizer_en_nl = MarianTokenizer.from_pretrained("../model/opus-mt-en-nl")
#     en_nl = MarianMTModel.from_pretrained("../model/opus-mt-en-nl")

#     encoder_batch = tokenizer_nl_en([src_seq], return_tensors="pt")
#     outputs_interim = nl_en.generate(**encoder_batch)
#     interim_seq = tokenizer_nl_en.batch_decode(outputs_interim, skip_special_tokens=True)[0]
    
#     logger.info(f"Interim Sequence  = {interim_seq}")
    
#     decoder_batch = tokenizer_en_nl([interim_seq], return_tensors="pt")
#     outputs_final = en_nl.generate(**decoder_batch)
#     tgt_seq = tokenizer_en_nl.batch_decode(outputs_final, skip_special_tokens=True)[0]
    
#     logger.info(f"Final Sequence  = {tgt_seq}")
#     return tgt_seq
# Marianseq_infer(noisy_seq)
# logger.info(f"Ground Truth Sequence = {clean_seq}")

In [7]:
def seq2seq_infer(src_batch):
    logger.info(f"Initial (Noisy) Sequence = {src_batch}")
    tokenizer_nl_en = AutoTokenizer.from_pretrained("../model/opus-mt-nl-en")
    nl_en = AutoModelForSeq2SeqLM.from_pretrained("../model/opus-mt-nl-en").to(device)
    
    tokenizer_en_nl = AutoTokenizer.from_pretrained("../model/opus-mt-en-nl")
    en_nl = AutoModelForSeq2SeqLM.from_pretrained("../model/opus-mt-en-nl").to(device)

    encoder_input_ids = tokenizer_nl_en(src_batch, return_tensors="pt", padding=True).input_ids.to(device)
    outputs_interim = nl_en.generate(encoder_input_ids, num_beams=5, max_new_tokens=512)
    interim_batch = tokenizer_nl_en.batch_decode(outputs_interim, skip_special_tokens=True)
    
    logger.info(f"Interim Sequence  = {interim_batch}")
    
    decoder_input_ids = tokenizer_en_nl(interim_batch, return_tensors="pt", padding=True).input_ids.to(device)
    outputs_final = en_nl.generate(decoder_input_ids, num_beams=5, max_new_tokens=512)
    tgt_batch = tokenizer_en_nl.batch_decode(outputs_final, skip_special_tokens=True)

    return tgt_batch

In [8]:
tgt_batch = seq2seq_infer(noisy_seq.values.tolist())
logger.info(f"Final Sequence  = {tgt_batch}")
logger.info(f"Ground Truth Sequence = {clean_seq.values.tolist()}")

2022-09-18 03:33:16.787 | INFO     | __main__:seq2seq_infer:2 - Initial (Noisy) Sequence = ['een daadwerkelijke keuzevrijheid voor ouderen daar zouden we werk van moeten maken', 'elke kandidaat dit staat moet op zijn eigen worden beoordeeld', 'het verslag legt sterke nadruk op het nauwe verband tussen de twee', 'we openen nu het algemeen debat', "die fase is gebaseerd op de test van personenauto's"]
2022-09-18 03:33:21.430 | INFO     | __main__:seq2seq_infer:13 - Interim Sequence  = ['a genuine freedom of choice for older people.', 'each candidate this state must be judged on its own', 'the report places strong emphasis on the close link between the two', 'We are now opening the general debate', 'that phase is based on the test of passenger cars']
2022-09-18 03:33:21.557 | INFO     | __main__:<module>:2 - Final Sequence  = ['een echte keuzevrijheid voor ouderen.', 'elke kandidaat deze staat moet worden beoordeeld op zijn eigen', 'In het verslag wordt sterk de nadruk gelegd op het nauwe

## Create a Custom Dataset and a Dataloader

In [9]:
class CustomDataset(Dataset):
    def __init__(self, df) -> None:
        super().__init__()
        self.src = df.stt_out.values.tolist()
        self.tgt = df.sentence.values.tolist()
    
    def __len__(self):
        return len(self.tgt)

    def __getitem__(self, index):
        source_text = self.src[index]
        target_text = self.tgt[index]
        sample = {"src": source_text, "tgt": target_text}
        return sample

In [10]:
def collate_batch(batch):    
     source = [x['src'].strip() for x in batch]
     target = [x['tgt'].strip() for x in batch]
     return source, target

## Set-up Pytorch Lightning based Fine-Tuning

In [11]:
# Hyperparams
batch_size = 16

In [12]:
# a = torch.rand(5, 17, 67028)
# b = torch.rand(5, 17)
# a[-1, :].size()

#### Notes

- Forward pass of en-nl implemented successfully. Can I also use forward pass of the encoder(nl-en)?
- Dataloader: Pass tokenized sequences in batch instead or text
- Implement WER metric tracking
- Implement callbacks
- Implement validation step and a custom generate function
- Move all code to src

In [13]:
class chained_seq2seq(LightningModule):
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)
        self.tokenizer_encoder = AutoTokenizer.from_pretrained("../model/opus-mt-nl-en")
        self.encoder = AutoModelForSeq2SeqLM.from_pretrained("../model/opus-mt-nl-en").to(device)
        
        self.tokenizer_decoder = AutoTokenizer.from_pretrained("../model/opus-mt-en-nl")
        self.decoder = AutoModelForSeq2SeqLM.from_pretrained("../model/opus-mt-en-nl").to(device)
        # output is logits. CE loss applies log_softmax to logits and then computes nll loss
        self.criterion = nn.CrossEntropyLoss(ignore_index=self.tokenizer_decoder.pad_token_id)
    
    def forward(self, inputs, tgt_ids):
        encoder_input_ids = self.tokenizer_encoder(inputs, return_tensors="pt", padding=True).input_ids.to(device)
        # logger.info(f"source batch size = {encoder_input_ids.size()}")
        outputs_interim = self.encoder.generate(encoder_input_ids, num_beams=5, max_new_tokens=512) 
        interim_batch = self.tokenizer_encoder.batch_decode(outputs_interim, skip_special_tokens=True)

        decoder_input_ids = self.tokenizer_decoder(interim_batch, return_tensors="pt", padding=True).input_ids.to(device)
        decoder_attention_masks = self.tokenizer_decoder(interim_batch, return_tensors="pt", padding=True).attention_mask.to(device)
        outputs_final = self.decoder(input_ids=decoder_input_ids, attention_mask = decoder_attention_masks, labels=tgt_ids)

        return outputs_final

    def training_step(self, batch, batch_idx):
        src_batch, tgt_batch = batch
        tgt_batch = self.tokenizer_encoder(tgt_batch, return_tensors="pt",padding=True).input_ids.to(device)

        outputs = self.forward(src_batch, tgt_batch)
        loss, logits = outputs[:2]

        # logger.info(f"target batch size = {tgt_batch.size()}")
        # logger.info(f"model output shape = {logits.size()}")
        # logger.info(f"returned loss = {loss}, calculated loss={self.criterion(logits[-1, :], tgt_batch[-1, :])}")
        self.log_dict({
            'train_loss': loss
        })
        tensorboard_logs = {'train_loss': loss}

        return {'loss': loss, 'log': tensorboard_logs}
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    def train_dataloader(self):
        ds = CustomDataset(df)
        dl = DataLoader(
            ds,
            shuffle=True,  # False, if overfit_pct
            batch_size=batch_size,
            num_workers=12,
            collate_fn=collate_batch)
        print(
            f"dataset:'{'train'}', size:{len(ds)}, batch:{batch_size}, nb_batches:{len(dl)}"
        )
        return dl

    def generate(self):
        pass

In [14]:
seed_everything(42)
model = chained_seq2seq()
trainer = pl.Trainer(max_epochs=5, accelerator="auto", callbacks=[TQDMProgressBar(refresh_rate=10)], devices=1 if torch.cuda.is_available() else None)

Global seed set to 42
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [15]:
trainer.fit(model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params
-----------------------------------------------
0 | encoder   | MarianMTModel    | 79.0 M
1 | decoder   | MarianMTModel    | 79.0 M
2 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
156 M     Trainable params
1.0 M     Non-trainable params
157 M     Total params
631.849   Total estimated model params size (MB)


dataset:'train', size:29031, batch:16, nb_batches:1815


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


## Evaluation

In [49]:
model = chained_seq2seq().eval()
tokenizer = AutoTokenizer.from_pretrained("../model/opus-mt-en-nl")

In [65]:
tokenizer.pad_token

'<pad>'

In [50]:
tokenizer.batch_decode(model(noisy_seq.values.tolist()), skip_special_tokens=True)

['een echte keuzevrijheid voor ouderen.',
 'elke kandidaat deze staat moet worden beoordeeld op zijn eigen',
 'In het verslag wordt sterk de nadruk gelegd op het nauwe verband tussen de twee',
 'Wij openen nu het algemene debat',
 "die fase is gebaseerd op de test van personenauto's"]