# Chaii-QA: PyTorch XLM-R Large TPU ⚡

### This notebook is made referencing the following notebooks: 🙏
- https://www.kaggle.com/abhishek/roberta-on-steroids-pytorch-tpu-training
- https://www.kaggle.com/philippsinger/xlm-roberta-large-pytorch-pytorch-tpu
- https://www.kaggle.com/rhtsingh/chaii-qa-5-fold-xlmroberta-torch-fit

### If you find this notebook helpful, then please consider upvoting the notebooks it was made from. 🔼

### I am by no means a TPU expert, so if you have recommendations, I am all ears! 👂

### To Do ✅:
#### 1. Add jaccard score 💯
#### 2. Incorporate Weights and Biases for logging ⚖️
#### 3. ??? Leave a suggestion in the comments 💭 
#### 4. Add section to show how to do inference on GPU
#### 5. Do full 5-fold training

# Connect to TPU 🔗

In [None]:
import tensorflow as tf

try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print("Running on TPU ", tpu.cluster_spec().as_dict()["worker"])
except ValueError:
    tpu = None
if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy()

# Install XLA 👨‍💻

In [None]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version 1.9 --apt-packages libomp5 libopenblas-dev

# Import Packages 📦

In [None]:
import os

os.environ["XLA_USE_BF16"] = "1"
os.environ["XLA_TENSOR_ALLOCATOR_MAXSIZE"] = "100000000"

import torch
import pandas as pd
from scipy import stats
import numpy as np

import gc

from tqdm import tqdm
from collections import OrderedDict, namedtuple
import torch.nn as nn
from torch.optim import lr_scheduler
import joblib
from joblib import Parallel, delayed

import torch_xla.utils.serialization as xser

import time

import collections
import logging
import transformers
from transformers import (
    AdamW,
    get_linear_schedule_with_warmup,
    get_constant_schedule,
    AutoTokenizer,
    AutoModel,
    AutoConfig,
    get_cosine_schedule_with_warmup,
)
import sys
from sklearn import metrics, model_selection
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import LabelEncoder
from tqdm.notebook import tqdm

from random import shuffle
import random

import re

import warnings
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils
import warnings

warnings.filterwarnings("ignore")


# Set up Weights and Biases ⚖️

In [None]:
!pip install -U wandb -qq

from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
wandb_key = user_secrets.get_secret("wandb")

import wandb

wandb.login(key=wandb_key)

# Configuration 📝

In [None]:
class CFG:
    
    # wandb
    project = "chaii-qa"
    run_name = "tpu-xlm-r-large"
    
    # model
    model_type = 'xlm_roberta'
    model_name_or_path = "../input/xlm-roberta-squad2/deepset/xlm-roberta-large-squad2"
    config_name = "../input/xlm-roberta-squad2/deepset/xlm-roberta-large-squad2"

    # tokenizer
    tokenizer_name = "../input/xlm-roberta-squad2/deepset/xlm-roberta-large-squad2"
    max_seq_length = 384
    doc_stride = 128

    # train
    epochs = 1
    per_device_train_batch_size = 8
    per_device_eval_batch_size = 16
    n_folds = 5

    # optimizer
    optimizer_type = 'AdamW'
    learning_rate = 2e-5
    weight_decay = 1e-2
    epsilon = 1e-8
    max_grad_norm = 1.0

    # scheduler
    decay_name = 'linear-warmup'
    warmup_ratio = 0.1

    # logging
    logging_steps = 10

    # evaluate
    output_dir = 'output'
    seed = 2021

# Setup Data 📊

In [None]:
train = pd.read_csv("../input/chaii-hindi-and-tamil-question-answering/train.csv")
test = pd.read_csv("../input/chaii-hindi-and-tamil-question-answering/test.csv")
external_mlqa = pd.read_csv("../input/mlqa-hindi-processed/mlqa_hindi.csv")
external_xquad = pd.read_csv("../input/mlqa-hindi-processed/xquad.csv")
external_train = pd.concat([external_mlqa, external_xquad])


def create_folds(data, num_splits):
    data["kfold"] = -1
    kf = model_selection.StratifiedKFold(n_splits=num_splits, shuffle=True, random_state=2021)
    for fold_num, (t_, v_) in enumerate(kf.split(X=data, y=data.language.values)):
        data.loc[v_, "kfold"] = fold_num
    return data


train = create_folds(train, num_splits=CFG.n_folds)
external_train["kfold"] = -1
external_train["id"] = list(np.arange(1, len(external_train) + 1))
train = pd.concat([train, external_train]).reset_index(drop=True)


def convert_answers(row):
    return {"answer_start": [row[0]], "text": [row[1]]}


train["answers"] = train[["answer_start", "answer_text"]].apply(convert_answers, axis=1)

del external_mlqa
del external_xquad
del external_train
gc.collect();

In [None]:
def prepare_train_features(args, example, tokenizer):
    example["question"] = example["question"].lstrip()
    tokenized_example = tokenizer(
        example["question"],
        example["context"],
        truncation="only_second",
        max_length=args.max_seq_length,
        stride=args.doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    sample_mapping = tokenized_example.pop("overflow_to_sample_mapping")
    offset_mapping = tokenized_example.pop("offset_mapping")

    features = []
    for i, offsets in enumerate(offset_mapping):
        feature = {}

        input_ids = tokenized_example["input_ids"][i]
        attention_mask = tokenized_example["attention_mask"][i]

        feature["input_ids"] = input_ids
        feature["attention_mask"] = attention_mask
        feature["offset_mapping"] = offsets

        cls_index = input_ids.index(tokenizer.cls_token_id)
        sequence_ids = tokenized_example.sequence_ids(i)

        sample_index = sample_mapping[i]
        answers = example["answers"]

        if len(answers["answer_start"]) == 0:
            feature["start_position"] = cls_index
            feature["end_position"] = cls_index
        else:
            start_char = answers["answer_start"][0]
            end_char = start_char + len(answers["text"][0])

            token_start_index = 0
            while sequence_ids[token_start_index] != 1:
                token_start_index += 1

            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != 1:
                token_end_index -= 1

            if not (
                offsets[token_start_index][0] <= start_char
                and offsets[token_end_index][1] >= end_char
            ):
                feature["start_position"] = cls_index
                feature["end_position"] = cls_index
            else:
                while (
                    token_start_index < len(offsets)
                    and offsets[token_start_index][0] <= start_char
                ):
                    token_start_index += 1
                feature["start_position"] = token_start_index - 1
                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                feature["end_position"] = token_end_index + 1

        features.append(feature)
    return features

# Dataset

In [None]:
class ChaiiDataset(torch.utils.data.Dataset):
    def __init__(self, features, mode="train"):
        super().__init__()
        self.features = features
        self.mode = mode

    def __len__(self):
        return len(self.features)

    def __getitem__(self, item):
        feature = self.features[item]
        if self.mode == "train":
            return {
                "input_ids": torch.tensor(feature["input_ids"], dtype=torch.long),
                "attention_mask": torch.tensor(
                    feature["attention_mask"], dtype=torch.long
                ),
                "offset_mapping": torch.tensor(
                    feature["offset_mapping"], dtype=torch.long
                ),
                "start_position": torch.tensor(
                    feature["start_position"], dtype=torch.long
                ),
                "end_position": torch.tensor(feature["end_position"], dtype=torch.long),
            }
        else:
            return {
                "input_ids": torch.tensor(feature["input_ids"], dtype=torch.long),
                "attention_mask": torch.tensor(
                    feature["attention_mask"], dtype=torch.long
                ),
                "offset_mapping": feature["offset_mapping"],
                "sequence_ids": feature["sequence_ids"],
                "id": feature["example_id"],
                "context": feature["context"],
                "question": feature["question"],
            }

# Model 

In [None]:
class Model(nn.Module):
    def __init__(self, modelname_or_path, config):
        super(Model, self).__init__()
        self.config = config
        self.model = AutoModel.from_pretrained(modelname_or_path, config=config)
        self.qa_outputs = nn.Linear(config.hidden_size, 2)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()

    def _reinit_layers(self, num_layers):
        encoder = self.model.encoder
        for layer in encoder.layer[-num_layers:]:
            for module in layer.modules():
                if isinstance(module, nn.Linear):
                    module.weight.data.normal_(mean=0.0, std=encoder.config.initializer_range)
                    if module.bias is not None:
                        module.bias.data.zero_()
                elif isinstance(module, nn.Embedding):
                    module.weight.data.normal_(mean=0.0, std=encoder.config.initializer_range)
                    if module.padding_idx is not None:
                        module.weight.data[module.padding_idx].zero_()
                elif isinstance(module, nn.LayerNorm):
                    module.bias.data.zero_()
                    module.weight.data.fill_(1.0)
                    
                
    def forward(
        self, 
        input_ids, 
        attention_mask=None, 
        # token_type_ids=None
    ):
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
        )

        sequence_output = outputs[0]
        pooled_output = outputs[1]
        
        # sequence_output = self.dropout(sequence_output)
        qa_logits = self.qa_outputs(sequence_output)
        
        start_logits, end_logits = qa_logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)
    
        return start_logits, end_logits
    

# Loss function

In [None]:
def loss_fn(preds, labels):
    start_preds, end_preds = preds
    start_labels, end_labels = labels
    
    start_loss = nn.CrossEntropyLoss(ignore_index=-1)(start_preds, start_labels)
    end_loss = nn.CrossEntropyLoss(ignore_index=-1)(end_preds, end_labels)
    total_loss = (start_loss + end_loss) / 2
    return total_loss

# Loss Recorder

In [None]:
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        self.max = 0
        self.min = 1e5

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        if val > self.max:
            self.max = val
        if val < self.min:
            self.min = val

# Train

In [None]:
def train_loop_fn(data_loader, model, optimizer, device, num_batches, scheduler=None):

    model.train()

    losses = AverageMeter()
    
    tk0 = tqdm(data_loader, total=num_batches, desc="Training", disable=not xm.is_master_ordinal())
    start_time = time.time()

    for bi, d in enumerate(tk0):

        input_ids = d['input_ids']
        attention_mask = d['attention_mask']
        targets_start = d['start_position']
        targets_end = d['end_position']


        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        targets_start = targets_start.to(device)
        targets_end = targets_start.to(device)

        optimizer.zero_grad()

        outputs_start, outputs_end = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )

        loss = loss_fn((outputs_start, outputs_end), (targets_start, targets_end))

        loss.backward()
        xm.optimizer_step(optimizer)

        loss = loss.detach().item()

        losses.update(loss, input_ids.size(0))

        if scheduler is not None:
            scheduler.step()
        
        print_loss = xm.mesh_reduce('loss_reduce', loss, reduce_fn)
        tk0.set_postfix(loss=print_loss)
        
        if xm.is_master_ordinal():
            wandb.log({"train_loss": print_loss})
        
        if bi % 10 == 0:
            xm.master_print(
                f"bi={bi}, {time.time()-start_time:<2.2f} - loss:{print_loss}"
            )

        
    del loss
    del losses
    del print_loss
    
    gc.collect()

# Eval

In [None]:
def eval_loop_fn(data_loader, model, device):

    model.eval()

    losses = AverageMeter()

    with torch.no_grad():
        for bi, d in enumerate(data_loader):

            input_ids = d['input_ids']
            attention_mask = d['attention_mask']
            targets_start = d['start_position']
            targets_end = d['end_position']
            


            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            targets_start = targets_start.to(device)
            targets_end = targets_start.to(device)

            outputs_start, outputs_end = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
            )

            loss = loss_fn((outputs_start, outputs_end), (targets_start, targets_end))
            losses.update(loss.item(), input_ids.size(0))
            

    if xm.is_master_ordinal():
        wandb.log({"eval_loss": losses.avg})
    xm.master_print(f"EVAL loss={losses.avg}")

    return None

# Run

In [None]:
def reduce_fn(vals):
    return sum(vals) / len(vals)


def _run(fold):

    xm.master_print(f"Starting fold {fold}")
    device = xm.xla_device()
    model = mx.to(device)
    
    if xm.is_master_ordinal():
        wandb.init(
            project=CFG.project,
            config={x:getattr(CFG, x) for x in dir(CFG) if "__" not in x},
            name=CFG.run_name+f"-fold-{k}"
                  )

    
    train_dataset = ChaiiDataset(sum([fts for i, fts in enumerate(all_features) if i!=fold], []))
    valid_dataset = ChaiiDataset(all_features[fold])
    

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True,
    )

    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=CFG.per_device_train_batch_size,
        sampler=train_sampler,
        drop_last=True,
        num_workers=2,
    )

    valid_sampler = torch.utils.data.distributed.DistributedSampler(
        valid_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=False,
    )

    valid_data_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=CFG.per_device_eval_batch_size,
        sampler=valid_sampler,
        drop_last=False,
        num_workers=0,
    )

    num_train_steps = int(
        len(train_dataset) / CFG.per_device_train_batch_size / xm.xrt_world_size() * CFG.epochs
    ) 

    model_params = model.named_parameters()
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_parameters = [
        {
            "params": [
                p for n, p in model_params if not any(nd in n for nd in no_decay)
            ],
            "weight_decay": CFG.weight_decay,
        },
        {
            "params": [
                p for n, p in model_params if any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
        },
    ]

    optimizer = AdamW(optimizer_parameters, lr=CFG.learning_rate * xm.xrt_world_size())
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(num_train_steps * CFG.warmup_ratio),
        num_training_steps=num_train_steps,
    )

    xm.master_print(
        f"num_train_steps = {num_train_steps}, world_size={xm.xrt_world_size()}"
    )
    
    num_batches = int(len(train_dataset) / (CFG.per_device_train_batch_size * xm.xrt_world_size()))

    for epoch in range(CFG.epochs):

        xm.master_print(f"Starting epoch {epoch}")
        mp_device_loader = pl.MpDeviceLoader(train_data_loader, device)
        train_loop_fn(
            mp_device_loader,
            model,
            optimizer,
            device,
            num_batches,
            scheduler,
        )

        del mp_device_loader
        gc.collect()

        mp_device_loader = pl.MpDeviceLoader(valid_data_loader, device)
        eval_loop_fn(mp_device_loader, model, device)
        
        del mp_device_loader
        gc.collect()
    
    if xm.is_master_ordinal():
        wandb.finish()

    return

In [None]:
def _mp_fn(rank, flags):
    torch.set_default_tensor_type('torch.FloatTensor')
    a = _run(flags["fold"])

# Prepare features

In [None]:
config = AutoConfig.from_pretrained(CFG.config_name)
tokenizer = AutoTokenizer.from_pretrained(CFG.tokenizer_name)

all_features = [None]*CFG.n_folds

for k in range(CFG.n_folds):
    print('Preparing fold', k)
    fold_data = train[train["kfold"]==k]
    all_features[k] = [features for row in fold_data.itertuples() for features in prepare_train_features(CFG, row._asdict(), tokenizer)]

In [None]:
for k in range(CFG.n_folds):  
    
    
    # use model wrapper for reducing memory usage across TPU cores
    mx = xmp.MpModelWrapper(Model(CFG.model_name_or_path, config))
  
    FLAGS={}
    FLAGS["fold"] = k
    xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method='fork')
    
    save_dir = f"fold-{k}"
    
    tokenizer.save_pretrained(save_dir)    
    config.save_pretrained(save_dir)    
    torch.save(mx._model.state_dict(), f"{save_dir}/pytorch_model.bin")