## Basic Imports

In [None]:
import torch
import transformers
import sklearn
import os
import gc
import re
import random
import time
import sys
import optuna
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import train_test_split
from accelerate import cpu_offload, dispatch_model
from accelerate.utils.modeling import infer_auto_device_map
from tqdm.auto import tqdm
import ctypes
libc = ctypes.CDLL("libc.so.6")
tqdm.pandas()

In [None]:
pd.options.display.max_rows = 999
pd.options.display.max_colwidth = 99

## TPU Training

In [None]:
from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType
from transformers import (
    LlamaModel, LlamaConfig, LlamaForSequenceClassification, BitsAndBytesConfig,
    AutoTokenizer, AutoModelForCausalLM, AutoConfig, AutoModelForSequenceClassification,
    DataCollatorWithPadding, MistralForSequenceClassification
)

In [None]:
import torch_xla.debug.profiler as xp
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.test.test_utils as test_utils
import torch_xla.experimental.xla_sharding as xs
import torch_xla.runtime as xr
xr.use_spmd()
assert xr.is_spmd() == True

In [None]:
import torch_xla.experimental.xla_sharding as xs
from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor
from torch_xla.experimental.xla_sharding import Mesh
from spmd_util import partition_module

In [None]:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

SEED = 666
seed_everything(SEED)

# Load training data

In [None]:
def cv_split(train_data):
    N_FOLD = 5
    skf = StratifiedKFold(n_splits=N_FOLD, shuffle=True, random_state=SEED)
    X = train_data.loc[:, train_data.columns != "label"]
    y = train_data.loc[:, train_data.columns == "label"]

    for fold, (train_index, valid_index) in enumerate(skf.split(X, y)):
        train_data.loc[valid_index, "fold"] = fold

    print(train_data.groupby("fold")["label"].value_counts())
    return train_data

def load_train_data():
    train_data = pd.read_csv('df_all.csv')
    train_data = train_data[["text", "labels"]]
    train_data = train_data.rename(columns={'labels':'label'})
    train_data.reset_index(inplace=True, drop=True)

    print(f"Train data value counts: {train_data.value_counts('label')}")
    return train_data

In [None]:
train_data = load_train_data()
display(train_data.head())

Train data value counts: label
0    53958
1    36797
Name: count, dtype: int64


Unnamed: 0,text,label
0,"Cars. Cars have been around since they became famous in the 1900s, when Henry Ford created and ...",0
1,"Transportation is a large necessity in most countries worldwide. With no doubt, cars, buses, an...",0
2,"""America's love affair with it's vehicles seems to be cooling"" says Elisabeth rosenthal. To und...",0
3,How often do you ride in a car? Do you drive a one or any other motor vehicle to work? The stor...,0
4,Cars are a wonderful thing. They are perhaps one of the worlds greatest advancements and techno...,0


# Training the model

In [None]:
class TrainModelTPU():
    def __init__(self, model, train_data, **params):
        # Create train and valid dataset
        self.train_df, self.valid_df = train_test_split(train_data, test_size=0.0005,
                                                        stratify=train_data['label'],
                                                        random_state=SEED)
        self.LR = params['lr']
        self.R = params['r']
        self.NUM_EPOCHS = params['num_epochs']

        self.NUM_LABELS = 1
        self.MAX_LENGTH = params['max_length']
        self.BATCH_SIZE = 16
        self.DEVICE = xm.xla_device()
        self.NUM_WARMUP_STEPS = 0
        self.GRADIENT_ACCUMULATION_STEPS = 2

        self.MODEL = model

    # Load pretrained LLM and tokenizer
    def load_model(self):
        if "mistral_7b" == self.MODEL:
            MODEL_PATH = "/mistral/pytorch/7b-v0.1-hf/1"
        if "llama-2_7b" == self.MODEL:
            MDEL_PATH = "/llama-2/pytorch/7b-hf/1"

        self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False)
        self.tokenizer.pad_token = self.tokenizer.eos_token

        base_model = LlamaForSequenceClassification.from_pretrained(MODEL_PATH,
                                                                num_labels=self.NUM_LABELS,
                                                                torch_dtype=torch.bfloat16)

        base_model.config.pretraining_tp = 1

        base_model.config.pad_token_id = self.tokenizer.pad_token_id


        # LoRa
        peft_config = LoraConfig(
            r=self.R,
            lora_dropout=0.001,
            bias='none',
            inference_mode=False,
            task_type=TaskType.SEQ_CLS,
            target_modules=['o_proj', 'v_proj'],
        )
        self.model = get_peft_model(base_model, peft_config)
        self.model.print_trainable_parameters()
        print("Complete loading pretrained LLM model")

    def save_model(self):
        self.model = self.model.cpu()
        SAVE_PATH = f'{self.MODEL}/{self.MODEL}_TPU/'
        self.model.save_pretrained(SAVE_PATH)
        # Save tokenizer for inference
        self.tokenizer.save_pretrained(SAVE_PATH)
        torch.save(dict([(k,v) for k, v in self.model.named_parameters() if v.requires_grad]),
                   SAVE_PATH + 'model_weights.pth')
        print(f"Save the model and tokenizers to {SAVE_PATH}")

    def display_model_layers(self):
        # Dispaly trainable layers for verification
        trainable_layers = []
        n_trainable_params = 0
        for name, param in self.model.named_parameters():
            n_params = int(torch.prod(torch.tensor(param.shape)))
            if param.requires_grad:
                trainable_layers.append({
                    '#param': n_params,
                    'name': name,
                    'dtype': param.data.dtype,
                    'params': param
                })
                n_trainable_params += n_params

        display(pd.DataFrame(trainable_layers))
        print(f"Number of trainable parameters: {n_trainable_params:,} "
              f"Number of trainable layers: {len(trainable_layers)}")

    def create_optimizer_scheduler(self, STEPS_PER_EPOCH):
        # Optimizer (Adam)
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.LR, weight_decay=0.01)
        # Cosine Learning Rate
        lr_scheduler = transformers.get_cosine_schedule_with_warmup(
                                    optimizer=optimizer,
                                    num_warmup_steps=self.NUM_WARMUP_STEPS,
                                    num_training_steps=STEPS_PER_EPOCH * self.NUM_EPOCHS)
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor) and state[k].dtype is not torch.float32:
                    state[v] = v.to(dtype=torch.float32)
        print("Complete creating optimizer and lr scheduler")
        print("optimizer", optimizer)
        print("lr_scheduler", lr_scheduler)
        return optimizer, lr_scheduler

    def partition_mesh(self):
        # Number of TPU Nodes
        num_devices = xr.global_runtime_device_count()
        mesh_shape = (1, num_devices, 1)
        print(f'Number_DEVICES: {num_devices}')
        device_ids = np.array(range(num_devices))
        mesh = Mesh(device_ids, mesh_shape, ('dp', 'fsdp', 'mp'))
        partition_module(self.model, mesh)
        return num_devices, mesh

     # Create a training dataset
    def create_dataset(self, N_SAMPLES, INPUT_IDS, ATTENTION_MASKS, GENERATED, mesh):
        IDXS = np.arange(N_SAMPLES-(N_SAMPLES%self.BATCH_SIZE))
        while True:
            # Shuffle Indices
            np.random.shuffle(IDXS)
            for idxs in IDXS.reshape(-1, self.BATCH_SIZE):
                input_ids = torch.tensor(INPUT_IDS[idxs]).to(self.DEVICE)
                attention_mask = torch.tensor(ATTENTION_MASKS[idxs]).to(self.DEVICE)
                labels = torch.tensor(GENERATED[idxs]).to(self.DEVICE)
                # Shard Over TPU Nodes
                xs.mark_sharding(input_ids, mesh, (0, 1))
                xs.mark_sharding(attention_mask, mesh, (0, 1))
                xs.mark_sharding(labels, mesh, (0, 1))
                yield input_ids, attention_mask, labels

    # Validate the model
    def valid_model(self):
        num_devices, mesh = self.partition_mesh()
        N_SAMPLES = len(self.valid_df)
        print(f"Start validating the model with number of sample {N_SAMPLES}")
        # Tokenize Data
        tokens = self.tokenizer(self.valid_df['text'].tolist(),
                                padding='max_length',
                                max_length=self.MAX_LENGTH,
                                truncation=True,
                                return_tensors='np',
                           )

        INPUT_IDS = tokens['input_ids']
        # Attention Masks
        ATTENTION_MASKS = tokens['attention_mask']
        # Generated By AI Label of Texts
        GENERATED = self.valid_df['label'].values.reshape(-1,1).astype(np.float32)
        # Create a valid dataset
        VALID_DATASET = self.create_dataset(N_SAMPLES, INPUT_IDS, ATTENTION_MASKS, GENERATED, mesh)

        # Compute the number of batches
        IDXS = np.array_split(np.arange(N_SAMPLES), max(1, N_SAMPLES // self.BATCH_SIZE))
        LOSS_FN = torch.nn.BCEWithLogitsLoss().to(dtype=torch.float32)
        METRICS = {'loss': [],
                   'auc': {'y_true': [], 'y_pred': []} }
        STEPS = N_SAMPLES // self.BATCH_SIZE
        for step in tqdm(range(STEPS)):
            # Enable inference mode using `no_grad`
            with torch.no_grad():
                # Get Batch
                input_ids, attention_mask, labels = next(VALID_DATASET)
                 # Forward Pass
                outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
                # Logits Float32
                logits = outputs.logits.to(dtype=torch.float32)
                # Backward Pass
                loss = LOSS_FN(logits, labels)
                METRICS['loss'].append(float(loss))
                METRICS['auc']['y_true'] += labels.squeeze().tolist()
                METRICS['auc']['y_pred'] += logits.sigmoid().tolist()
        loss = np.mean(METRICS['loss'])
        roc_auc = sklearn.metrics.roc_auc_score(METRICS['auc']['y_true'], METRICS['auc']['y_pred'])
        # Compute and display the validation results
        print(f"Number of validation data {len(self.valid_df)}\n"
              f"µ_loss: {loss: .3f}\n"
              f"µ_auc: {roc_auc:.3f}")
        return {"eval_loss": loss, "eval_roc_auc": roc_auc}

    # Train the model by the fold data
    def train_model(self):
        num_devices, mesh = self.partition_mesh()
        print(f'Number_DEVICES: {num_devices}')
        print(f"Total number of train data = {len(self.train_df)}")

        train_df = self.train_df[['text', 'label']]
        train_df['text'] = train_df['text'].map(lambda text: pre_processing_text(text))
        N_SAMPLES = len(train_df)
        STEPS_PER_EPOCH = N_SAMPLES // self.BATCH_SIZE
        print(f'BATCH_SIZE: {self.BATCH_SIZE}, N_SAMPLES: {N_SAMPLES}, STEPS_PER_EPOCH: {STEPS_PER_EPOCH}')

        tokens = self.tokenizer(train_df['text'].tolist(),
                                padding='max_length',
                                max_length=self.MAX_LENGTH,
                                truncation=True,
                                return_tensors='np',
                                )
        INPUT_IDS = tokens['input_ids']
        # Attention Masks
        ATTENTION_MASKS = tokens['attention_mask']
        GENERATED = train_df['label'].values.reshape(-1,1).astype(np.float32)
        print(f'INPUT_IDS shape: {INPUT_IDS.shape}\n'
              f'ATTENTION_MASKS shape: {ATTENTION_MASKS.shape}\n'
              f'GENERATED shape: {GENERATED.shape}')

        TRAIN_DATASET = self.create_dataset(N_SAMPLES, INPUT_IDS, ATTENTION_MASKS, GENERATED, mesh)

        optimizer, lr_scheduler = self.create_optimizer_scheduler(STEPS_PER_EPOCH)

        self.model.train()
        LOSS_FN = torch.nn.BCEWithLogitsLoss().to(dtype=torch.float32)
        eval_scores = []
        # Training loop goes through each epoch
        for epoch in tqdm(range(self.NUM_EPOCHS)):
            start = time.time()
            METRICS = {'loss': [],
                       'auc': {'y_true': [], 'y_pred': []} }
            for step in range(STEPS_PER_EPOCH):
                # Zero Out Gradients
                optimizer.zero_grad()
                # Get Batch
                input_ids, attention_mask, labels = next(TRAIN_DATASET)
                # Forward
                outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
                # Logits Float32
                logits = outputs.logits.to(dtype=torch.float32)
                # Backward
                loss = LOSS_FN(logits, labels)
                # backward propagation
                loss.backward()
                if (step + 1) % self.GRADIENT_ACCUMULATION_STEPS == 0:
                    optimizer.step()
                    xm.mark_step()
                    lr_scheduler.step()
                METRICS['loss'].append(float(loss))
                METRICS['auc']['y_true'] += labels.squeeze().tolist()
                METRICS['auc']['y_pred'] += logits.sigmoid().tolist()
                if np.unique(METRICS['auc']['y_true']).size == 2:
                    metrics = 'µ_loss: {:.3f}'.format(np.mean(METRICS['loss']))
                    metrics += ', step_loss: {:.3f}'.format(METRICS['loss'][-1])
                    metrics += ', µ_auc: {:.3f}'.format(
                        sklearn.metrics.roc_auc_score(METRICS['auc']['y_true'], METRICS['auc']['y_pred'])
                    )

                    lr = optimizer.param_groups[0]['lr']
                    print('\r'*100, f'{epoch+1:02}/{self.NUM_EPOCHS:02} | {step+1:04}/{STEPS_PER_EPOCH} lr: {lr:.2E}, {metrics}', end='')
            avg_loss = np.mean(METRICS['loss'])
            roc_auc_score = sklearn.metrics.roc_auc_score(METRICS['auc']['y_true'],
                                                          METRICS['auc']['y_pred'])
            print(f'\n=== Finish Training Epoch {epoch} with average loss {avg_loss: .5f} '
                  f'ROC Accuracy Score {roc_auc_score:.5f} ===')
            # Validate the model at the end of epochs
            result = self.valid_model()
            eval_roc_auc = float(result['eval_roc_auc'])
            print(f'\n=== Finish Validating the model with evaluated ROC Accuracy Score {eval_roc_auc:.5f}'
                  f'\n Total running time = {time.time() -  start: .1f} seconds ===')
            eval_scores.append(eval_roc_auc)
        return np.mean(eval_scores)

    # Clear the memory
    def clear_memory(self):
        del self.model, self.tokenizer
        libc.malloc_trim(0)
        gc.collect()

# Use Optuna to find the optimal hyper-parameters
Train and Save the model with best parameters

In [None]:
display(train_data.head(3))

Unnamed: 0,text,label
0,"Cars. Cars have been around since they became famous in the 1900s, when Henry Ford created and ...",0
1,"Transportation is a large necessity in most countries worldwide. With no doubt, cars, buses, an...",0
2,"""America's love affair with it's vehicles seems to be cooling"" says Elisabeth rosenthal. To und...",0


In [None]:
# Start optuna study to hyper-parameter tuning
best_score = -1.0
def objective(trial, model_name, train_data):

    params = {
        'lr': trial.suggest_float('learning_rate', 1e-7, 1e-3, log=True),
        'r': 64,
        'num_epochs': 1,
        'max_length' : 512,
    }
    trainer = TrainModelTPU(model_name, train_data, **params)
    trainer.load_model()
    eval_score = trainer.train_model()
    # Save the model is the avg score > current best score
    global best_score
    if eval_score > best_score:
        best_score = eval_score
        trainer.save_model()
    # Clean up
    trainer.clear_memory()
    del trainer
    return eval_score

def train_model_with_optuna(model_name, train_data):

    study_name = f"{model_name}_study"
    study_file = f"{study_name}.db"
    # Delete the study file if exits
    if os.path.isfile(study_file):
        os.remove(f'{study_file}')

    study = optuna.create_study(direction="maximize", study_name=study_name,
                                storage="sqlite:///" + f"{study_file}",
                                load_if_exists=False)

    study.optimize(lambda trial: objective(trial, model_name, train_data),
                   timeout=600, n_jobs=1, n_trials=10,
                   show_progress_bar=True, gc_after_trial=True)
    print(f"Best parameters: {study.best_params}")
    params = study.best_params

# Train the model with best parameters

In [None]:
def train_model(model_name, train_data):
    # Parameters
    params = {
        'lr': 5e-5,
        'r': 64,
        'num_epochs': 2,
        'max_length': 512
    }
    trainer = TrainModelTPU(model_name, train_data, **params)
    trainer.load_model()
    eval_score = trainer.train_model()
    trainer.save_model()
    # Clean up
    trainer.clear_memory()
    del trainer
    print(f"=== Finish training the model {model_name} with score = {eval_score}")

In [None]:
model_name = "mistral_7b"
train_model(model_name, train_data)