<a href="https://colab.research.google.com/github/Ishan-Kumar2/examples/blob/mt_example/Examples/Machine_Translation_using_PyTorch_Ignite.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# installing appropriate modules
%%capture
!pip install git+https://github.com/huggingface/transformers.git@master;
!pip install git+https://github.com/huggingface/datasets.git@master;
!pip install git+https://github.com/pytorch/ignite.git@master;
!pip install sentencepiece;

In [None]:
## Uncomment if using TPU
# setup TPU environment
# import os
# assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'

# VERSION = "nightly"
# !curl -q https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
# !python pytorch-xla-env-setup.py --version $VERSION > /dev/null

In [None]:
from pathlib import Path
from datasets import load_dataset
from transformers import MBartForConditionalGeneration, MBartTokenizer
from datetime import datetime

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import random_split
from torch.cuda.amp import GradScaler, autocast

import nltk
import ignite
import ignite.distributed as idist
from ignite.contrib.engines import common
from ignite.metrics import Loss, RougeN, Bleu, Accuracy, Loss
from ignite.utils import manual_seed, setup_logger
from ignite.engine import Engine, Events
from ignite.contrib.handlers import PiecewiseLinear
from ignite.handlers import Checkpoint, global_step_from_engine

In [None]:
# Configs
config = {"with_amp": True,
"seed": 126,
"num_epochs": 5,
"batch_size":1,
"output_path_": '/content',
"checkpoint_every": 100,
"model_name": "facebook/mbart-large-cc25",
"tokenizer_name": "facebook/mbart-large-cc25"} 

dataset_configs = {"source_language":'en', "target_language":'de',"max_length":12,"train_dataset_length":100000,"validation_dataset_length":100}

## Preparing data

We will be using the stas/wmt16-en-ro-pre-processed for this example .

In [None]:
from datasets import load_dataset

dataset = load_dataset("news_commentary", 'de-en')
dataset = dataset.shuffle(seed=config["seed"])

In [None]:
dataset = dataset["train"]
dataset = dataset.train_test_split(test_size=0.3)
train_dataset, validation_dataset = dataset["train"], dataset["test"]

In [None]:
train_dataset[0]

In [None]:
print("Example of a Datapoint")
print(train_dataset[0])
print("Lengths")
print("\t Train Set - {}".format(len(train_dataset)))
print("\t Val Set - {}".format(len(validation_dataset)))

In [None]:
tokenizer = MBartTokenizer.from_pretrained(config["tokenizer_name"], src_lang="en_XX", tgt_lang="de_DE")

In [None]:
class TransformerDataset(torch.utils.data.Dataset):
    def __init__(self, data, src_text_id, tgt_text_id, tokenizer, max_length, len):
        self.data = data
        self.src_text_id = src_text_id
        self.tgt_text_id = tgt_text_id
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.len = len

    def __getitem__(self, idx):
        src_text = [str(self.data[idx]['translation'][self.src_text_id])]
        tgt_text = [str(self.data[idx]['translation'][self.tgt_text_id])]
        src_text = self.tokenizer(src_text,max_length= self.max_length, padding = 'max_length',truncation=True)
        with self.tokenizer.as_target_tokenizer():
          tgt_text = self.tokenizer(tgt_text,max_length= self.max_length,padding = 'max_length',truncation=True)

        return {
            "src_input_ids": torch.tensor(src_text['input_ids']).squeeze(0),
            "src_attention_mask": torch.tensor(src_text['attention_mask']).squeeze(0),
            "tgt": torch.tensor(tgt_text['input_ids']).squeeze(0),
        }

    def __len__(self):
        return self.len

In [None]:

train_dataset = TransformerDataset(train_dataset, dataset_configs["source_language"],dataset_configs["target_language"],  tokenizer, dataset_configs["max_length"], dataset_configs["train_dataset_length"])
val_dataset = TransformerDataset(validation_dataset, dataset_configs["source_language"],dataset_configs["target_language"],  tokenizer, dataset_configs["max_length"], dataset_configs["validation_dataset_length"])

## Initiating model and trainer for training

In [None]:
# Create Trainer 
def create_trainer(model, optimizer, criterion, with_amp, train_sampler, logger):

    device = idist.device()
    scaler = GradScaler(enabled=with_amp)
    accumulation_steps = 8

    def train_step(engine, batch):
        src_ids = batch["src_input_ids"]
        src_attention_mask = batch["src_attention_mask"]
        tgt = batch["tgt"]

        if src_ids.device != device:
            src_ids = src_ids.to(device, non_blocking=True, dtype=torch.long)
            src_attention_mask = src_attention_mask.to(device, non_blocking=True, dtype=torch.long)
            tgt = tgt.to(device, non_blocking=True, dtype=torch.long)

        model.train()

        with autocast(enabled=with_amp):
            y = model(src_ids, src_attention_mask)
            y_pred = y['logits']
            y_pred = y_pred.view(-1, y_pred.size(2))
            tgt = tgt.contiguous().view(-1)
            loss = criterion(y_pred, tgt) / accumulation_steps

        
        scaler.scale(loss).backward()
        
        if engine.state.iteration % accumulation_steps == 0:
          scaler.step(optimizer)
          scaler.update()
          optimizer.zero_grad()

        return {
            "batch loss": loss.item(),
        }
    
    trainer = Engine(train_step)
    trainer.logger = logger

    to_save = {"trainer": trainer, "model": model, "optimizer": optimizer}
    metric_names = [
        "batch loss",
    ]
    common.setup_common_training_handlers(
        trainer=trainer,
        train_sampler=train_sampler,
        output_names=metric_names,
        clear_cuda_cache=False,
        with_pbars=True,
    )
    return trainer

In [None]:
# Let's now setup evaluator engine to perform model's validation and compute metrics
def create_evaluator(model, tokenizer, metrics, with_amp, tag="val"):

    device = idist.device()
    
    def ids_to_clean_text(generated_ids):
        gen_text = tokenizer.batch_decode(
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
        )
        return list(map(str.strip, gen_text))

    @torch.no_grad()
    def evaluate_step(engine, batch):
        model.eval()

        src_ids = batch["src_input_ids"]
        src_attention_mask = batch["src_attention_mask"]
        tgt = batch["tgt"]

        if src_ids.device != device:
            src_ids = src_ids.to(device, non_blocking=True, dtype=torch.long)
            src_attention_mask = src_attention_mask.to(device, non_blocking=True, dtype=torch.long)
            tgt = tgt.to(device, non_blocking=True, dtype=torch.long)
            
        
        y_pred = model.generate(src_ids, forced_bos_token_id=tokenizer.lang_code_to_id["ro_RO"])
        preds = ids_to_clean_text(y_pred)
        tgt = ids_to_clean_text(tgt)
        preds = [_preds.split() for _preds in preds]
        tgt = [[_tgt.split()] for _tgt in tgt]
        
        if engine.state.iteration % 20 == 0:
           print("Preds : ",preds)
           print("Target : ",tgt)
        #   print(nltk.translate.bleu_score.corpus_bleu(tgt, preds))

        return preds, tgt

    evaluator = Engine(evaluate_step)

    for name, metric in metrics.items():
        metric.attach(evaluator, name)
    
    if idist.get_rank() == 0:
        common.ProgressBar(desc=f"Evaluation ({tag})", persist=False).attach(evaluator)

    return evaluator


In [None]:
def initialize():
  model = MBartForConditionalGeneration.from_pretrained(config["model_name"])
  lr = 5e-5 * idist.get_world_size()
  model = idist.auto_model(model)
  optimizer = optim.AdamW(model.parameters(), lr=lr)
  optimizer = idist.auto_optim(optimizer)
  criterion = nn.CrossEntropyLoss(ignore_index = train_dataset.tokenizer.pad_token_id, reduction='sum')
  # le = config["num_iters_per_epoch"]
  # milestones_values = [
  #       (0, 0.0),
  #       (le * warmup_epochs, lr),
  #       (le * num_epochs, 0.0),
  # ]
  # lr_scheduler = PiecewiseLinear(optimizer, param_name="lr", milestones_values=milestones_values)
  return model, optimizer, criterion


In [None]:
def get_dataloaders(train_dataset, val_dataset, batch_size=config["batch_size"], num_workers=2):

    # Setup data loader also adapted to distributed config: nccl, gloo, xla-tpu
    train_loader = idist.auto_dataloader(
        train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, drop_last=True,
    )

    val_loader = idist.auto_dataloader(
        val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False,
    )
    return train_loader, val_loader


In [None]:
def log_metrics(logger, epoch, elapsed, tag, metrics):
    metrics_output = "\n".join([f"\t{k}: {v}" for k, v in metrics.items()])
    logger.info(f"\nEpoch {epoch} - Evaluation time (seconds): {elapsed:.2f} - {tag} metrics:\n {metrics_output}")

In [None]:
def training(local_rank):

    rank = idist.get_rank()
    manual_seed(config["seed"] + rank)
    device = idist.device()

    logger = setup_logger(name="NMT", distributed_rank=local_rank)

    train_loader, val_loader = get_dataloaders(train_dataset, val_dataset)
    model, optimizer, criterion = initialize()

    trainer = create_trainer(model, optimizer, criterion, config["with_amp"], train_loader.sampler, logger)

    metrics = {
        "bleu":Bleu(ngram=4, smooth="smooth1")
      }

    evaluator = create_evaluator(model, tokenizer, metrics, config["with_amp"], tag="val")
    train_evaluator = create_evaluator(model, tokenizer, metrics, config["with_amp"], tag="train")
    
    @trainer.on(Events.EPOCH_COMPLETED(every=1))
    def run_validation(engine):
        epoch = trainer.state.epoch
        state = evaluator.run(val_loader)
        log_metrics(logger, epoch, state.times["COMPLETED"], "Validation", state.metrics)

    if rank == 0:
      now = datetime.now().strftime("%Y%m%d-%H%M%S")
      folder_name = f"Translation_Model_backend-{idist.backend()}-{idist.get_world_size()}_{now}"
      output_path = Path(config["output_path_"]) / folder_name
      if not output_path.exists():
            output_path.mkdir(parents=True)
      
      logger.info(f"Output path: {output_path}")
    """
    if rank == 0:
        # Setup TensorBoard logging on trainer and evaluators. Training Losses and Eval Metrics are logged
        # evaluators = {"training": train_evaluator, "test": evaluator}
        
        # tb_logger = common.setup_tb_logging(
        #     output_path, trainer, optimizer, evaluators=evaluators, log_every_iters=100
        # )
    """
    try:
        state = trainer.run(train_loader, max_epochs=config["num_epochs"])
        log_metrics(logger, epoch, state.times["COMPLETED"], "Training", state.metrics)
    except Exception as e:
        logger.exception("")
        raise e
    
    #if rank == 0:
    #    tb_logger.close()

In [None]:
def run():
    with idist.Parallel(backend=None, nproc_per_node=None) as parallel:
        parallel.run(training)

run()

In [None]:
!nvidia-smi

In [None]:
# lets see how our model performs
inputs = "अंतिम प्रविष्ट घटना को हाइलाइट करो"

translation = translator(inputs, return_text=True)
translation = [t["translation_text"] for t in translation]
print(translation)