# Transformer training notebook

**Author:** [Giovanni Spadaro](https://giovannispadaro.it/)<br>
**Project:** [tfs_mt](https://github.com/Giovo17/tfs-mt)<br>
**Documentation:** [link](https://giovo17.github.io/tfs-mt)

In [None]:
import os
if os.path.exists('/kaggle'):
    PLATFORM = "KAGGLE"
else:
    try:
        import google.colab
        PLATFORM = "COLAB"
    except:
        PLATFORM = "LOCAL"
print(f"Platform: {PLATFORM}")

In [None]:
if PLATFORM == "COLAB" or PLATFORM == "KAGGLE":
    !pip install -qU wandb --progress-bar off

In [None]:
from pprint import pformat
from functools import partial
from datetime import datetime

from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
import ignite.distributed as idist
from ignite.engine import Events
from ignite.handlers import PiecewiseLinear
from ignite.metrics import Bleu, Rouge, Loss
from ignite.utils import manual_seed
from omegaconf import OmegaConf

if PLATFORM == "KAGGLE":
    from tfs_mt_architecture.tfs_mt_architecture  import build_model
    from tfs_mt_data_utils.tfs_mt_data_utils import build_data_utils
    from tfs_mt_training_utils.tfs_mt_training_utils import *
else:
    from tfs_mt.architecture  import build_model
    from tfs_mt.data_utils import build_data_utils
    from tfs_mt.training_utils import *
    os.chdir("..")

In [None]:
# This block is useful for single GPU training. Ignite will handle devices in multigpu training
if torch.cuda.is_available():
    device_ = torch.device("cuda:0")
    torch.cuda.empty_cache()
    print("Using NVIDIA GPU")
    print(f"Number of available GPUs: {torch.cuda.device_count()}")
else:
    device_ = torch.device("cpu")
    print("Using CPU")

In [None]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"  # remove tokenizer parallelism warning

if PLATFORM == "KAGGLE":
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    wandb_api_key = user_secrets.get_secret("WANDB_API_KEY")    
    os.environ["WANDB_API_KEY"] = wandb_api_key
    
    config_path = "/kaggle/input/tfs-mt-config/config.yml"
    base_path = "/kaggle/working"
    output_dir = "/kaggle/working/output"
    cache_ds_path = "/kaggle/working/data"
    
    time_limit_sec = 414000  # 11.5 hours
    
elif PLATFORM == "COLAB":
    pass

else:
    from dotenv import load_dotenv
    load_dotenv()
    wandb_api_key = os.getenv("WANDB_API_KEY")
    os.environ["WANDB_API_KEY"] = wandb_api_key
    
    config_path = os.path.join(os.getcwd(), "tfs_mt/configs/config.yml")
    base_path = os.getcwd()
    output_dir = os.path.join(base_path, "data/output")
    cache_ds_path = os.path.join(base_path, "data")
    time_limit_sec = -1

In [None]:
config = OmegaConf.load(config_path)

config.training_hp.distributed_training = True # Enable/disable distributed training

config.backend = "nccl" if config.training_hp.distributed_training else "none"
config.base_path = base_path
config.output_dir = output_dir
config.cache_ds_path = cache_ds_path
config.chosen_model_size = "nano"  # nano, small, base
config.time_limit_sec = time_limit_sec

config.model_name = f"{config.model_base_name}_{config.chosen_model_size}_{datetime.now().strftime('%y%m%d-%H%M')}"

In [None]:
def run(local_rank, config, distributed=False):
    
    if distributed:
        rank = idist.get_rank()
        manual_seed(config.seed + rank)
        output_dir = setup_output_dir(config, rank)
        config.output_dir = output_dir
        if rank == 0:
            save_config(config, output_dir)
    else:
        rank = 0
        manual_seed(config.seed)
        output_dir = setup_output_dir(config, 0)
        config.output_dir = output_dir
        save_config(config, config.output_dir)
        
    
    train_dataloader, test_dataloader, _, _, src_tokenizer, tgt_tokenizer = build_data_utils(config, return_all=True)

    config.num_iters_per_epoch = len(train_dataloader)

    
    # Initialize model, optimizer, loss function, device
    device = idist.device() if distributed else device_
    language_direction = config.dataset.src_lang + "-" + config.dataset.tgt_lang
    init_model = build_model(config, src_tokenizer, tgt_tokenizer)
    # Get model ready for multigpu training if available and move the model to current device
    model = idist.auto_model(init_model)

    if distributed: config.training_hp.optimizer_args.learning_rate *= idist.get_world_size()
    init_optimizer = AdamW(
        model.parameters(), 
        lr=config.training_hp.optimizer_args.learning_rate, 
        weight_decay=config.training_hp.optimizer_args.weight_decay
    )
    # Get model ready for multigpu training if available
    optimizer = idist.auto_optim(init_optimizer)
    loss_fn = CrossEntropyLoss(label_smoothing=config.training_hp.loss_label_smoothing).to(device=device)

    le = config.num_iters_per_epoch
    milestones_values = [
        (0, 0.0),
        (le * config.training_hp.num_warmup_epochs, config.training_hp.optimizer_args.learning_rate),
        (le * config.training_hp.num_epochs, 0.0),
    ]
    lr_scheduler = PiecewiseLinear(optimizer, param_name="lr", milestones_values=milestones_values)

    # Setup metrics to attach to evaluator
    metrics = {
        "Bleu": Bleu(ngram=4, smooth="smooth1", 
                     output_transform=partial(nlp_metric_transform, tgt_tokenizer=tgt_tokenizer)),
        "Rouge": Rouge(variants=["L", 2], multiref="best", 
                       output_transform=partial(nlp_metric_transform, tgt_tokenizer=tgt_tokenizer)),
        "Loss": Loss(loss_fn, output_transform=loss_metric_transform),
    }

    # Setup trainer and evaluator
    trainer = setup_trainer(config, model, optimizer, loss_fn, metrics, device, train_dataloader.sampler)
    evaluator = setup_evaluator(config, model, loss_fn, metrics, device)

    # Setup engines logger with python logging print training configurations
    logger = setup_logging(config)
    logger.info("Configuration: \n%s", pformat(config))
    trainer.logger = evaluator.logger = logger

    trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler)

    # Setup ignite handlers
    to_save_train = {
        "model": model,
        "optimizer": optimizer,
        "trainer": trainer,
        "lr_scheduler": lr_scheduler,
    }
    to_save_test = {"model": model}
    ckpt_handler_train, ckpt_handler_test = setup_handlers(trainer, evaluator, config, to_save_train, to_save_test)

    # Experiment tracking
    if rank == 0:
        print(config.model_name)
        exp_logger = setup_exp_logging(config, trainer, optimizer, evaluator)

    # Print metrics to the stderr with "add_event_handler" method for training stats
    trainer.add_event_handler(
        Events.ITERATION_COMPLETED(every=config.log_every_iters),
        log_metrics,
        tag="train",
    )

    # Run evaluator at every training epoch end using "on" decorator method and print metrics to the stderr
    # More on ignite Events: https://docs.pytorch.org/ignite/generated/ignite.engine.events.Events.html
    @trainer.on(Events.EPOCH_COMPLETED(every=1))
    def _():
        evaluator.run(test_dataloader)
        log_metrics(evaluator, "test")

    # Run evaluator when trainer starts to make sure it works 
    @trainer.on(Events.STARTED)
    def _():
        evaluator.run(test_dataloader)

    trainer.run(
        train_dataloader,
        max_epochs=config.training_hp.num_epochs,
    )

    if rank == 0:
        exp_logger.close()
    
    logger.info(f"Last training checkpoint name - {ckpt_handler_train.last_checkpoint}")
    logger.info(f"Last testing checkpoint name - {ckpt_handler_test.last_checkpoint}")

In [None]:
if config.training_hp.distributed_training:
    with idist.Parallel(config.backend) as p:
        p.run(run, config, distributed=True)
else:
    run(0, config)