<a href="https://colab.research.google.com/github/Ishan-Kumar2/examples/blob/mt_example/Examples/Machine_Translation_using_PyTorch_Ignite.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# installing appropriate modules
%%capture
!pip install git+https://github.com/huggingface/transformers.git@master;
!pip install git+https://github.com/huggingface/datasets.git@master;
!pip install git+https://github.com/pytorch/ignite.git@master;
!pip install sentencepiece;

In [None]:
## Uncomment if using TPU
# setup TPU environment
# import os
# assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'

# VERSION = "nightly"
# !curl -q https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
# !python pytorch-xla-env-setup.py --version $VERSION > /dev/null

In [None]:
import torch

In [None]:
from pathlib import Path
from datasets import load_dataset
from transformers import MBartForConditionalGeneration, MBartTokenizer
from datetime import datetime
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import random_split
from torch.cuda.amp import GradScaler, autocast

import nltk
import ignite
import ignite.distributed as idist
from ignite.contrib.engines import common
from ignite.metrics import Loss, RougeN, Bleu, Accuracy, Loss
from ignite.utils import manual_seed, setup_logger
from ignite.engine import Engine, Events
from ignite.contrib.handlers import PiecewiseLinear
from ignite.handlers import Checkpoint, global_step_from_engine

In [None]:
# Configs
with_amp = True
seed = 126
num_epochs = 5
output_path_ = '/content'
checkpoint_every = 100
model_name = "sshleifer/tiny-mbart" #"facebook/mbart-large-cc25"
tokenizer_name = "sshleifer/tiny-mbart" #"facebook/mbart-large-cc25" 

## Preparing data

We will be using the stas/wmt16-en-ro-pre-processed for this example .

In [None]:
from datasets import load_dataset

dataset = load_dataset("stas/wmt16-en-ro-pre-processed")

Reusing dataset wmt16_en_ro_pre_processed (/root/.cache/huggingface/datasets/wmt16_en_ro_pre_processed/enro/1.1.0/c4093132d2665734cbb5098992e5cdf3cdbd807b80a5913a456ab7cb8c34ab2b)


  0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
print("Example of a Datapoint")
print(dataset['train'][0])
print("Lengths")
print("\t Train Set - {}".format(len(dataset['train'])))
print("\t Val Set - {}".format(len(dataset['validation'])))
print("\t Test Set - {}".format(len(dataset['test'])))

Example of a Datapoint
{'translation': {'en': 'Membership of Parliament: see Minutes', 'ro': 'Componenţa Parlamentului: a se vedea procesul-verbal'}}
Lengths
	 Train Set - 610320
	 Val Set - 1999
	 Test Set - 1999


In [None]:
tokenizer = MBartTokenizer.from_pretrained(tokenizer_name, src_lang="en_XX", tgt_lang="ro_RO")

In [None]:
class TransformerDataset(torch.utils.data.Dataset):
    def __init__(self, data, src_text_id, tgt_text_id, tokenizer, max_length, len):
        self.data = data
        self.src_text_id = src_text_id
        self.tgt_text_id = tgt_text_id
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.len = len

    def __getitem__(self, idx):
        src_text = [str(self.data[idx]['translation'][self.src_text_id])]
        tgt_text = [str(self.data[idx]['translation'][self.tgt_text_id])]
        src_text = self.tokenizer(src_text,max_length= self.max_length, padding = 'max_length',truncation=True)
        with self.tokenizer.as_target_tokenizer():
          tgt_text = self.tokenizer(tgt_text,max_length= self.max_length,padding = 'max_length',truncation=True)

        return {
            "src_input_ids": torch.tensor(src_text['input_ids']).squeeze(0),
            "src_attention_mask": torch.tensor(src_text['attention_mask']).squeeze(0),
            "tgt": torch.tensor(tgt_text['input_ids']).squeeze(0),
        }

    def __len__(self):
        return self.len

In [None]:
train_dataset = TransformerDataset(dataset['train'], 'en', 'ro', tokenizer, 24, 300000)
val_dataset = TransformerDataset(dataset['validation'], 'en', 'ro', tokenizer, 24, 1999)

## Initiating model and trainer for training

In [None]:
# Create Trainer 
def create_trainer(model, optimizer, criterion, with_amp, train_sampler, logger):

    device = idist.device()
    scaler = GradScaler(enabled=with_amp)

    def train_step(engine, batch):
        src_ids = batch["src_input_ids"]
        src_attention_mask = batch["src_attention_mask"]
        tgt = batch["tgt"]

        if src_ids.device != device:
            src_ids = src_ids.to(device, non_blocking=True, dtype=torch.long)
            src_attention_mask = src_attention_mask.to(device, non_blocking=True, dtype=torch.long)
            tgt = tgt.to(device, non_blocking=True, dtype=torch.long)

        model.train()

        with autocast(enabled=with_amp):
            y = model(src_ids, src_attention_mask, labels = tgt)
            loss = y['loss']
            y_pred = y['logits']
            y_pred = y_pred.view(-1, y_pred.size(2))
            tgt = tgt.contiguous().view(-1)
            loss = criterion(y_pred, tgt)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        return {
            "batch loss": loss.item(),
        }
    
    trainer = Engine(train_step)
    trainer.logger = logger

    to_save = {"trainer": trainer, "model": model, "optimizer": optimizer}
    metric_names = [
        "batch loss",
    ]
    common.setup_common_training_handlers(
        trainer=trainer,
        train_sampler=train_sampler,
        output_names=metric_names,
        clear_cuda_cache=False,
        with_pbars=True,
    )
    return trainer

In [None]:
# Let's now setup evaluator engine to perform model's validation and compute metrics
def create_evaluator(model, tokenizer, metrics, with_amp, tag="val"):

    device = idist.device()
    
    def ids_to_clean_text(generated_ids):
        gen_text = tokenizer.batch_decode(
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
        )
        return list(map(str.strip, gen_text))

    @torch.no_grad()
    def evaluate_step(engine, batch):
        model.eval()

        src_ids = batch["src_input_ids"]
        src_attention_mask = batch["src_attention_mask"]
        tgt = batch["tgt"]

        if src_ids.device != device:
            src_ids = src_ids.to(device, non_blocking=True, dtype=torch.long)
            src_attention_mask = src_attention_mask.to(device, non_blocking=True, dtype=torch.long)
            tgt = tgt.to(device, non_blocking=True, dtype=torch.long)
            
        
        y_pred = model.generate(src_ids, forced_bos_token_id=tokenizer.lang_code_to_id["ro_RO"])
        preds = ids_to_clean_text(y_pred)
        tgt = ids_to_clean_text(tgt)
        preds = [_preds.split() for _preds in preds]
        tgt = [[_tgt.split()] for _tgt in tgt]
        
        if engine.state.iteration % 5 == 0:
          print("Preds : ",preds)
          print("Target : ",tgt)
          print(nltk.translate.bleu_score.corpus_bleu(tgt, preds))

        #preds = [pred.split() for pred in preds]
        #tgt = [[tgt_.split()] for tgt_ in tgt]
        # print("Preds",preds)
        # print("TGT",tgt)
        return preds, tgt

    evaluator = Engine(evaluate_step)

    for name, metric in metrics.items():
        metric.attach(evaluator, name)
    
    if idist.get_rank() == 0:
        common.ProgressBar(desc=f"Evaluation ({tag})", persist=False).attach(evaluator)

    return evaluator


In [None]:
def initialize():
  # model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50")
  model = MBartForConditionalGeneration.from_pretrained(model_name)
  lr = 5e-2 * idist.get_world_size()
  model = idist.auto_model(model)
  optimizer = optim.AdamW(model.parameters(), lr=lr)
  optimizer = idist.auto_optim(optimizer)
  criterion = nn.CrossEntropyLoss(ignore_index = train_dataset.tokenizer.pad_token_id, reduction='sum')
  # le = config["num_iters_per_epoch"]
  # milestones_values = [
  #       (0, 0.0),
  #       (le * warmup_epochs, lr),
  #       (le * num_epochs, 0.0),
  # ]
  # lr_scheduler = PiecewiseLinear(optimizer, param_name="lr", milestones_values=milestones_values)
  return model, optimizer, criterion


In [None]:
def get_dataloaders(train_dataset, val_dataset, batch_size=120, num_workers=2):

    # Setup data loader also adapted to distributed config: nccl, gloo, xla-tpu
    train_loader = idist.auto_dataloader(
        train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, drop_last=True,
    )

    val_loader = idist.auto_dataloader(
        val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False,
    )
    return train_loader, val_loader


In [None]:
def log_metrics(logger, epoch, elapsed, tag, metrics):
    metrics_output = "\n".join([f"\t{k}: {v}" for k, v in metrics.items()])
    logger.info(f"\nEpoch {epoch} - Evaluation time (seconds): {elapsed:.2f} - {tag} metrics:\n {metrics_output}")

In [None]:
def training(local_rank):

    rank = idist.get_rank()
    manual_seed(seed + rank)
    device = idist.device()

    logger = setup_logger(name="NMT", distributed_rank=local_rank)

    train_loader, val_loader = get_dataloaders(train_dataset, val_dataset)
    model, optimizer, criterion = initialize()

    trainer = create_trainer(model, optimizer, criterion, with_amp, train_loader.sampler, logger)

    # metrics = {}

    metrics = {
      }

    evaluator = create_evaluator(model, tokenizer, metrics, with_amp, tag="val")
    train_evaluator = create_evaluator(model, tokenizer, metrics, with_amp, tag="train")

    @trainer.on(Events.EPOCH_COMPLETED(every=1))
    def run_validation(engine):
        epoch = trainer.state.epoch
        # state = train_evaluator.run(train_loader)
        # log_metrics(logger, epoch, state.times["COMPLETED"], "Train", state.metrics)
        state = evaluator.run(val_loader)
        log_metrics(logger, epoch, state.times["COMPLETED"], "Validation", state.metrics)

    if rank == 0:
      now = datetime.now().strftime("%Y%m%d-%H%M%S")
      folder_name = f"Translation_Model_backend-{idist.backend()}-{idist.get_world_size()}_{now}"
      output_path = Path(output_path_) / folder_name
      if not output_path.exists():
            output_path.mkdir(parents=True)
      
      logger.info(f"Output path: {output_path}")
    """
    if rank == 0:
        # Setup TensorBoard logging on trainer and evaluators. Training Losses and Eval Metrics are logged
        # evaluators = {"training": train_evaluator, "test": evaluator}
        
        # tb_logger = common.setup_tb_logging(
        #     output_path, trainer, optimizer, evaluators=evaluators, log_every_iters=100
        # )
    """
    try:
        trainer.run(train_loader, max_epochs=num_epochs)
    except Exception as e:
        logger.exception("")
        raise e
    
    #if rank == 0:
    #    tb_logger.close()

In [None]:
def run():
    with idist.Parallel(backend=None, nproc_per_node=None) as parallel:
        parallel.run(training)

run()

2021-08-30 06:03:28,751 ignite.distributed.launcher.Parallel INFO: - Run '<function training at 0x7f2b0185bef0>' in 1 processes
2021-08-30 06:03:28,793 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset '<__main__.Transforme': 
	{'batch_size': 120, 'num_workers': 2, 'shuffle': True, 'drop_last': True, 'pin_memory': True}
2021-08-30 06:03:28,794 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset '<__main__.Transforme': 
	{'batch_size': 120, 'num_workers': 2, 'shuffle': False, 'pin_memory': True}
2021-08-30 06:03:33,032 NMT INFO: Output path: /content/Translation_Model_backend-None-1_20210830-060333
2021-08-30 06:03:33,035 NMT INFO: Engine run starting with max_epochs=5.


 20%|##        | 1/5 [00:00<?, ?it/s]

[1/2500]   0%|           [00:00<?]

In [None]:
!nvidia-smi

In [None]:
# lets see how our model performs
inputs = "अंतिम प्रविष्ट घटना को हाइलाइट करो"

translation = translator(inputs, return_text=True)
translation = [t["translation_text"] for t in translation]
print(translation)