# Mistral-7b + external training datasets (daigt and slimpajama)

### Massive dataset alert! This notebook analyzes over 1.3 million texts, so it can be time-consuming for potentially lengthy training even on TPUs (be aware of weekly quota)

This notebook investigates the use of an pretrained LLM to identify texts generated by another LLM.
- The `Mistral-7b-v0` and `Llama-2` were employed. `Debert-v3` models are not currently supported by TUPs because they are not partitioned.
- Fine-tuning the LLM on TPUs reduces training time from several hours (>6 hours) on GPUs to just 43 minutes. Notebook internet access must enable to install the necessary libraries for TPU training.
- Fine-tune a large language model (LLM) on a larger external dataset to improve its accuracy, then validate its performance using cross-validation. 
- Use Optuna to identify optimal hyperparameters, including learning rate (lr), for the target model. This Optuna study is saved as a file for future resumption.
- Split the data from `slimpajama`, `daigt-v2-train-dataset` and `train_essay` into training and valid datasets.
- Validate the model at the end of epochs. 

**[Change log]**:
- [Version 188] Resume fine-tuning `Mistral-7b` model (batch_size=`16`, r=64, max_length=1024, epoch=1, lr=5e-5, AdamW and cosine scheduler) and `20000 ~ 100000` texts from `slimpajama` and `train_essay` with training and testing split and batch processing.

- [Version 186] Fine-tune `Mistral-7b` model (batch_size=`16`, r=64, max_length=1024, epoch=1, lr=5e-5, AdamW and cosine scheduler) and 20000 texts from `slimpajama` + all texts from `daigt-v2`,  and `train_essay` with training and testing split and batch processing.
- [Version 176] Fine-tune `Deberta-v3-large` model (batch_size=`16`, r=64, max_length=512, epoch=1, lr=5e-5, AdamW and cosine scheduler) and use three datasets (`daigt-v2`, `slimpajama` and `train_essay`) with training and testing split (testsize = 0.001) and batch processing.
- [Version 175] Fine-tune `Mistral-7b` model (batch_size=`1024`, r=64, max_length=1024, epoch=1, lr=5e-5, AdamW and cosine scheduler) and use three datasets (`daigt-v2`, `slimpajama` and `train_essay`) with training and testing split (testsize = 0.001) and batch processing. This causes **out of memory issues**.
- [Version 164] Fine-tune `Mistral-7b` model with 5-fold CV (GRADIENT_ACCUMULATION_STEPS=`1`, r=`64`, max_length=`1024`, epoch=`1`, lr=5e-5, AdamW and cosine scheduler) and use all the training data from daigt-v2-train-dataset
- [Version 163] Fine-tune `Mistral-7b` model (GRADIENT_ACCUMULATION_STEPS=`2`, r=`64`, max_length=`1024`, epoch=`1`, lr=5e-5, AdamW and cosine scheduler) and use all the training data from daigt-v2-train-dataset
- [Version 162] Fine-tune `Mistral-7b` model (GRADIENT_ACCUMULATION_STEPS=`1`, r=`64`, max_length=`1024`, epoch=`1`, lr=5e-5, AdamW and cosine scheduler) and use all the training data from daigt-v2-train-dataset
- [Version 158] Fine-tune `Mistral-7b` model (r=`128`, max_length=`1024`, epoch=`3`, lr=5e-5, AdamW and cosine scheduler) and use all the training data from daigt-v2-train-dataset
- [Version 157] Fine-tune `Mistral-7b` model (r=`128`, max_length=`1024`, epoch=1, lr=5e-5, AdamW and cosine scheduler) and use all the training data from daigt-v2-train-dataset
- [Version 157] Fine-tune `Mistral-7b` model (r=64, max_length=`1024`, epoch=1, lr=5e-5, AdamW and cosine scheduler) and use all the training data from daigt-v2-train-dataset
- [Version 156] Fine-tune `Mistral-7b` model (r=`64`, max_length=`512`, epoch=1, lr=`5e-5`, AdamW and cosine scheduler) and use all the training data from daigt-v2-train-dataset
- [Version 146] Fine-tune `Mistral-7b` model (train number = 3000 and validation number = 600, epoch = 1). Training time = 2 hours and 14 minutes
- [Version 125] Fine-tune `Mistral-7b` model (train number = 1,000 and validation number = 300, epoch = 1). Training time=1 hour and 15 minutes)
- [Version 128] Fine-tune `Llama-2` model (train number = 1,000 and validation number = 300, epoch = 1). Training time=1 hour and 12 minutes.

**[References]**:
This notebook is folked from the notebook by @YUICHI TATENO, @MARK WIJKHUIZEN and @ImperfectKitto
-[PyTorch TPU starter - DeBERTa-v3-large (training)](https://www.kaggle.com/code/tanlikesmath/pytorch-tpu-starter-deberta-v3-large-training)
- [[train]LLM detect AI comp Mistral-7B](https://www.kaggle.com/code/hotchpotch/train-llm-detect-ai-comp-mistral-7b)
- [DAIGT Mistral-7B TPU BFloat16 [Train]](https://www.kaggle.com/code/markwijkhuizen/daigt-mistral-7b-tpu-bfloat16-train) 
- [LLAMA 2 13B on TPU (Training)](https://www.kaggle.com/code/defdet/llama-2-13b-on-tpu-training)

# Install library

In [None]:
# Install package for inferences
!pip install -qq --no-deps /kaggle/input/daigt-pip/peft-0.6.0-py3-none-any.whl
!pip install -qq --no-deps /kaggle/input/daigt-pip/transformers-4.35.0-py3-none-any.whl
!pip install -qq --no-deps /kaggle/input/daigt-pip/tokenizers-0.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
!pip install -qq --no-deps /kaggle/input/daigt-pip/optimum-1.14.0-py3-none-any.whl

In [None]:
# Install packaages for training on TPUs (notebook internet must enable)
!pip install datasets
!pip install -qq optuna
!pip install -qq sentencepiece==0.1.99 
!pip install -qq torch~=2.1.0 --index-url https://download.pytorch.org/whl/cpu -q # Updating torch since we need the latest version
!pip install -qq torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html -q
!pip uninstall -qq tensorflow -y # If we don't do this, TF will take over TPU and cause permission error for PT
!cp /kaggle/input/utils-xla/spmd_util.py . # From this repo: https://github.com/HeegyuKim/torch-xla-SPMD

## Basic Imports

In [None]:
import torch, transformers, sklearn, os, gc, re, random, time, sys, optuna
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import train_test_split
from accelerate import cpu_offload, dispatch_model
from accelerate.utils.modeling import infer_auto_device_map
from tqdm import tqdm
from tqdm.auto import tqdm
from numpy import save
import ctypes
libc = ctypes.CDLL("libc.so.6")
tqdm.pandas()

pd.options.display.max_rows = 999
pd.options.display.max_colwidth = 99

print(f'Torch Version: {torch.__version__}')

## Imports for training on TPUs

In [None]:
## Imports for Transformers and PEFT (Parameter-Efficient Fine-Tuning)
from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType
from transformers import (
    LlamaModel, LlamaConfig, LlamaForSequenceClassification, BitsAndBytesConfig,
    AutoTokenizer, AutoModelForCausalLM, AutoConfig, AutoModelForSequenceClassification,
    DataCollatorWithPadding, MistralForSequenceClassification
) 

## Imports for TPU XLA 
import torch_xla.debug.profiler as xp
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp # We also import mp modules if we wanna use that for some reason
import torch_xla.distributed.parallel_loader as pl
import torch_xla.test.test_utils as test_utils
import torch_xla.experimental.xla_sharding as xs
import torch_xla.runtime as xr
xr.use_spmd() # To enable PyTorch/XLA SPMD execution mode for automatic parallelization
assert xr.is_spmd() == True 

# "experimental" XLA packages
import torch_xla.experimental.xla_sharding as xs 
from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor
from torch_xla.experimental.xla_sharding import Mesh
from spmd_util import partition_module

## Common functions

In [None]:
# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Seed the same seed to all 
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

SEED = 42
seed_everything(SEED)

# Load training data

Training data include 1,378 sample tests from the competition, 20,450 texts from [DAIGT V2 Train Dataset](https://www.kaggle.com/datasets/thedrcat/daigt-v2-train-dataset), and 1,324,128 texts from [slimpajama dataset](https://www.kaggle.com/datasets/chg0901/slimpajama-train-chunk1-sel)

Total number of texts = 1,345,956

In [None]:
DEBUG = False # True: 2 datasets and False: 3 datasets 
# Cross validation
def cv_split(train_data):
    N_FOLD = 5 # Number of folders
    skf = StratifiedKFold(n_splits=N_FOLD, shuffle=True, random_state=SEED)
    X = train_data.loc[:, train_data.columns != "label"]
    y = train_data.loc[:, train_data.columns == "label"]
    # Split the train into 5 folds
    for fold, (train_index, valid_index) in enumerate(skf.split(X, y)):
        train_data.loc[valid_index, "fold"] = fold

    print(train_data.groupby("fold")["label"].value_counts())
    return train_data

def load_train_data():
    train_df = pd.read_csv("/kaggle/input/llm-detect-ai-generated-text/train_essays.csv", sep=',')
    train_prompts_df = pd.read_csv("/kaggle/input/llm-detect-ai-generated-text/train_prompts.csv", sep=',')

    # rename column generated to label and remove used 'id' and 'prompt_id' columns
    # Label: 1 indicates generated texts (by LLMs) 
    train_df = train_df.rename(columns={'generated': 'label'})
    train_df = train_df.reset_index(drop=True)
    train_df = train_df[["text", "label"]]
    print(f"Total number of training essay texts {len(train_df)}")
   
    
    if DEBUG:
         # Include external daigt_df data
        daigt_df = pd.read_csv("/kaggle/input/daigt-v2-train-dataset/train_v2_drcat_02.csv", sep=',')
        # Select the texts for the 7 prompts that likely fit the competition test set
        daigt_df = daigt_df[daigt_df['RDizzl3_seven']]
        daigt_df = daigt_df.rename(columns={'generated': 'label'})
        # We only need 'text' and 'label' columns
        daigt_df = daigt_df[["text", "label"]]
        # dropping ALL duplicate values 
        daigt_df.drop_duplicates(subset="text", keep=False, inplace=True) 
        print(f"Total number of external DAIGT sample texts = {len(daigt_df)}")
        train_data = pd.concat([train_df, daigt_df])
    else:
        START = 20000
        END = 100000
        # Include slimpajama-training datast
        slimpajama_df = pd.read_csv("/kaggle/input/slimpajama-train-chunk1-sel/merged_df_sel150_600_1324132.csv", sep=',')
        # Sample the dataset
        slimpajama_df = slimpajama_df[slimpajama_df['meta'] == 'RedPajamaC4']
        slimpajama_df = slimpajama_df[START:END]
        slimpajama_df = slimpajama_df[["text", "label"]]
        # dropping ALL duplicate values 
        slimpajama_df.drop_duplicates(subset="text", keep=False, inplace=True)
        #display(slimpajama_df)
        print(f"Total number of slimpajama training texts = {len(slimpajama_df)}")
        train_data = pd.concat([train_df, slimpajama_df]) 
        
    train_data.reset_index(inplace=True, drop=True)    
    # print(f"Train data has shape: {train_data.shape}")
    print(f"Train data {train_data.value_counts('label')}") # 1: generated texts 0: human texts
    return train_data

In [None]:
train_data = load_train_data()
# Train dataset is too large to do 5 fold CV
# train_data = cv_split(train_data)
display(train_data.head())

# Train the LLM model on TPUs

NLTK package is used to correct 

In [None]:
def pre_processing_text(text):
    corrected_text = text.replace('\n', ' ')
    #corrected_text = TextBlob(corrected_text).correct()
    return corrected_text

# Training the model

In [None]:
import transformers

class TrainModelTPU():
    def __init__(self, model, train_data, **params):
        self.train_data = train_data
        self.LR = params['lr'] # Learning rate
        self.R = params['r'] # 'r' value for Lora layer
        self.NUM_EPOCHS = params['num_epochs'] # Training Epoch
        # Fixed parameters
        self.NUM_LABELS = 1 # Total Number of Labels (0:human texts, 1:LLM generated texts)
        self.MAX_LENGTH = params['max_length']
        self.BATCH_SIZE = params['batch_size']
        self.DEVICE = xm.xla_device() # Initialize TPU Device
        self.NUM_WARMUP_STEPS = 0 # Number of Warmup Steps
        self.GRADIENT_ACCUMULATION_STEPS = 1
        # The model
        self.MODEL = model
        
    # Load pretrained LLM and tokenizer
    def load_model(self):
        if "mistral_7b" == self.MODEL:
            MODEL_PATH = "/kaggle/input/mistral/pytorch/7b-v0.1-hf/1"  # Mistral
        if "llama-2_7b" == self.MODEL:    
            MDEL_PATH = "/kaggle/input/llama-2/pytorch/7b-hf/1"  # llama
        # Load the tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        # `bfloat16` is suitable for deep learning for better convergences during training
        base_model = LlamaForSequenceClassification.from_pretrained(MODEL_PATH,
                                                                num_labels=self.NUM_LABELS,
                                                                torch_dtype=torch.bfloat16)
        # No idea why this is needed
        base_model.config.pretraining_tp = 1 # 1 is 7b
        # Assign Padding TOKEN
        base_model.config.pad_token_id = self.tokenizer.pad_token_id
        # print(base_model)        
        # LoRa
        peft_config = LoraConfig(
            r=self.R,  # Use larger 'r' value increase more parameters during training
            lora_dropout=0.1,
            bias='none',
            inference_mode=False,
            task_type=TaskType.SEQ_CLS,
            # Only Use Output and Values Projection
            target_modules=['o_proj', 'v_proj'], # layer names for llama 2 model
        )
        # Continue training on previous epoch
        # Load the new PEFT model
#         self.model = get_peft_model(base_model, peft_config)
        peft_path = '/kaggle/input/mistral-7b-v0-for-llm-detecting-competition/mistral_7b/mistral_7b_TPU'
        self.model = PeftModel.from_pretrained(base_model, peft_path, is_trainable=True)
        # Display Trainable Parameters to make sure we load the model successfully
        self.model.print_trainable_parameters()
        # self.model = self.model.merge_and_unload()
        print("Complete loading pretrained LLM model")
    
    # Save the trained model as output files
    def save_model(self):
        self.model = self.model.cpu()# Move model first on CPU before saving weights
        # Model saving path
        SAVE_PATH = f'/kaggle/working/{self.MODEL}/{self.MODEL}_TPU/'
        # Save the entire fine-tuned model
        self.model.save_pretrained(SAVE_PATH, save_adapter=True, save_config=True) 
        # Save tokenizer for inference
        self.tokenizer.save_pretrained(SAVE_PATH)
        # Only saving the newly trained weights
        torch.save(dict([(k,v) for k, v in self.model.named_parameters() if v.requires_grad]), 
                   SAVE_PATH + 'model_weights.pth')
        print(f"Save the model and tokenizers to {SAVE_PATH}")
    
    # Disply trainable layers of LLM
    def display_model_layers(self):        
        # Dispaly trainable layers for verification
        trainable_layers = []
        n_trainable_params = 0
        for name, param in self.model.named_parameters():
            # Layer Parameter Count
            n_params = int(torch.prod(torch.tensor(param.shape)))
            # Only Trainable Layers
            if param.requires_grad:
                # Add Layer Information
                trainable_layers.append({
                    '#param': n_params,
                    'name': name,
                    'dtype': param.data.dtype,
                    'params': param
                })
                n_trainable_params += n_params

        display(pd.DataFrame(trainable_layers))
        print(f"Number of trainable parameters: {n_trainable_params:,} "
              f"Number of trainable layers: {len(trainable_layers)}")
    
    def create_optimizer_scheduler(self, STEPS_PER_EPOCH):        
        # Optimizer (Adam)
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.LR, weight_decay=0.01)
        # Cosine Learning Rate With Warmup
        lr_scheduler = transformers.get_cosine_schedule_with_warmup(
                                    optimizer=optimizer,
                                    num_warmup_steps=self.NUM_WARMUP_STEPS,
                                    num_training_steps=STEPS_PER_EPOCH * self.NUM_EPOCHS)
        # Set the data type for the optimizer's state (e.g., momentum buffers)
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor) and state[k].dtype is not torch.float32:
                    state[v] = v.to(dtype=torch.float32)
        print("Complete creating optimizer and lr scheduler")
        print("optimizer", optimizer)
        print("lr_scheduler", lr_scheduler)
        return optimizer, lr_scheduler
        
    def partition_mesh(self):
        # Number of TPU Nodes to ensure we can access TPUs and partition the model into mesh
        num_devices = xr.global_runtime_device_count()
        mesh_shape = (1, num_devices, 1)
        print(f'Number_DEVICES: {num_devices}')
        device_ids = np.array(range(num_devices))
        mesh = Mesh(device_ids, mesh_shape, ('dp', 'fsdp', 'mp'))
        partition_module(self.model, mesh)
        return mesh        
        
     # Create a training dataset 
    def create_dataset(self, N_SAMPLES, INPUT_IDS, ATTENTION_MASKS, GENERATED, mesh):
        IDXS = np.arange(N_SAMPLES-(N_SAMPLES%self.BATCH_SIZE))
        while True:
            # Shuffle Indices
            # np.random.shuffle(IDXS)
            # Iterate Over All Indices Once
            for idxs in IDXS.reshape(-1, self.BATCH_SIZE):
                input_ids = torch.tensor(INPUT_IDS[idxs]).to(self.DEVICE)
                attention_mask = torch.tensor(ATTENTION_MASKS[idxs]).to(self.DEVICE)
                labels = torch.tensor(GENERATED[idxs]).to(self.DEVICE)
                # Shard Over TPU Nodes
                xs.mark_sharding(input_ids, mesh, (0, 1))
                xs.mark_sharding(attention_mask, mesh, (0, 1))
                xs.mark_sharding(labels, mesh, (0, 1))
                yield input_ids, attention_mask, labels
    
    # Define function to encode text data in batches
    #encodes a batch of texts and returns the texts' ids and attention masks
    def batch_encode(self, texts):
        input_ids = []
        attention_mask = []

        for i in tqdm(range(0, len(texts), self.BATCH_SIZE)):
            batch_texts = texts[i:i+self.BATCH_SIZE]
            inputs = self.tokenizer.batch_encode_plus(batch_texts,
                                                      padding='max_length', # Pad texts to maximum length
                                                      max_length=self.MAX_LENGTH, # Maximum token length
                                                      truncation=True, # Truncate texts if they are too long
                                                      return_tensors='np', # Return Numpy array
                                                      return_attention_mask=True,
                                                      return_token_type_ids=False
                                                     )
            input_ids.extend(inputs['input_ids'])
            attention_mask.extend(inputs['attention_mask'])
        tokens = {'input_ids': np.asarray(input_ids), 'attention_mask': np.asarray(attention_mask)}
        # save to npy file (Too large to save)
        # save('train_data_input_ids.npy', tokens['input_ids'])
        # save('train_data_attention_mask.npy', tokens['attention_mask'])
        return tokens
    
    # Validate the model
    def valid_model(self, valid_df):
        # Compute total samples and number of steps in one epochs
        N_SAMPLES = len(valid_df)       
        print(f"Start validating the model with number of sample {N_SAMPLES}")
        # Tokenize Texts
        tokens = self.tokenizer(valid_df['text'].tolist(), 
                                padding='max_length', # Pad texts to maximum length
                                max_length=self.MAX_LENGTH, # Maximum token length
                                truncation=True, # Truncate texts if they are too long
                                return_tensors='np', # Return Numpy array
                                )
        # Input IDs are the token IDs
        INPUT_IDS = tokens['input_ids']
        # Attention Masks to Ignore Padding Tokens
        ATTENTION_MASKS = tokens['attention_mask']
        # Generated By AI Label of Texts
        GENERATED = valid_df['label'].values.reshape(-1,1).astype(np.float32)
        # Create a valid dataset 
        VALID_DATASET = self.create_dataset(N_SAMPLES, INPUT_IDS, ATTENTION_MASKS, GENERATED, self.mesh)
                
        # Compute the number of batches 
        IDXS = np.array_split(np.arange(N_SAMPLES), max(1, N_SAMPLES // self.BATCH_SIZE))
        LOSS_FN = torch.nn.BCEWithLogitsLoss().to(dtype=torch.float32)
        METRICS = {'loss': [], 
                   'auc': {'y_true': [], 'y_pred': []} }
        STEPS = N_SAMPLES // self.BATCH_SIZE
        for step in range(STEPS):
            # Enable inference mode using `no_grad`
            with torch.no_grad():
                # Get Batch
                input_ids, attention_mask, labels = next(VALID_DATASET)
                 # Forward Pass
                outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
                # Logits Float32
                logits = outputs.logits.to(dtype=torch.float32)
                # Backward Pass
                loss = LOSS_FN(logits, labels)
                # Update Metrics And Progress Bar
                METRICS['loss'].append(float(loss))
                METRICS['auc']['y_true'] += labels.squeeze().tolist()
                METRICS['auc']['y_pred'] += logits.sigmoid().tolist()
                # print(f"Complete updating metrics for Step {step} in {time.time() - start: .1f} seconds")
        loss = np.mean(METRICS['loss'])
        roc_auc = sklearn.metrics.roc_auc_score(METRICS['auc']['y_true'], METRICS['auc']['y_pred'])
        # Compute and display the validation results
        print(f"Number of validation data {len(valid_df)}\n"
              f"µ_loss: {loss: .3f}\n"
              f"µ_auc: {roc_auc:.3f}")
        return {"eval_loss": loss, "eval_roc_auc": roc_auc}
 
    # Train the model 
    def train_model(self):
        self.mesh = self.partition_mesh()
        print(f"Total number of train data = {len(self.train_data)}")
        # Preprocess the text
        train_data['text'] = train_data['text'].map(lambda text: pre_processing_text(text))
        # Split the training dataset into training and test data
        X = train_data['text'].tolist()
        y = train_data['label'].tolist()
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=600, random_state=SEED)
        train_df = pd.DataFrame({'text': X_train, 'label': y_train})
        valid_df = pd.DataFrame({'text': X_test, 'label': y_test})
        print(f"Training dataset size = {len(train_df)}")
        print(f"Validate dataset Size = {len(valid_df)}")
    
        # Compute total samples and number of steps in one epochs
        N_SAMPLES = len(train_df)
        # Compute the total steps per epochs
        STEPS_PER_EPOCH = N_SAMPLES // self.BATCH_SIZE
        print(f'BATCH_SIZE: {self.BATCH_SIZE}, N_SAMPLES: {N_SAMPLES}, STEPS_PER_EPOCH: {STEPS_PER_EPOCH}')
               
        # Tokenize Texts
#         tokens = self.tokenizer(train_df['text'].tolist(), 
#                                 padding='max_length', # Pad texts to maximum length
#                                 max_length=self.MAX_LENGTH, # Maximum token length
#                                 truncation=True, # Truncate texts if they are too long
#                                 return_tensors='np', # Return Numpy array
#                                 )
        tokens = self.batch_encode(train_df['text'].tolist())
        print("Complete to tokenize inputs")
        # Input IDs are the token IDs
        INPUT_IDS = tokens['input_ids']
        # Attention Masks to Ignore Padding Tokens
        ATTENTION_MASKS = tokens['attention_mask']
        # Generated By AI Label of Texts
        GENERATED = train_df['label'].values.reshape(-1,1).astype(np.float32)
        print(f'INPUT_IDS shape: {INPUT_IDS.shape}\n'
              f'ATTENTION_MASKS shape: {ATTENTION_MASKS.shape}\n'
              f'GENERATED shape: {GENERATED.shape}')
        
        # Create a train dataset
        TRAIN_DATASET = self.create_dataset(N_SAMPLES, INPUT_IDS, ATTENTION_MASKS, GENERATED, self.mesh)     
        print("Complete creating the datasets")
        # Create optimizer and lr_scheduler
        optimizer, lr_scheduler = self.create_optimizer_scheduler(STEPS_PER_EPOCH)
        
        # Put Model In Train Modus
        self.model.train()
        # Loss Function, basic Binary Cross Entropy
        # LOSS_FN = torch.nn.BCEWithLogitsLoss().to(dtype=torch.float32)
        LOSS_FN = torch.nn.MSELoss().to(dtype=torch.float32)
        eval_scores = []
        # Training loop goes through each epoch
        for epoch in tqdm(range(self.NUM_EPOCHS)):
            METRICS = {'loss': [], 
                       'auc': {'y_true': [], 'y_pred': []} }
            # Go through each step
            for step in range(STEPS_PER_EPOCH):
                print(f'=== Start Step {step} === ')
                # Zero Out Gradients
                optimizer.zero_grad()
                # Get Batch
                input_ids, attention_mask, labels = next(TRAIN_DATASET)
                # Test the TRAIN_DATASET for debugging first record
                # Forward Pass
                outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
                # Logits Float32
                logits = outputs.logits.to(dtype=torch.float32)
                # Backward Pass
                loss = LOSS_FN(logits, labels)
                # backward propagation pass
                loss.backward()
                # Update Weights
                optimizer.step()
                xm.mark_step()
                # Update Learning Rate Scheduler
                lr_scheduler.step()
                # Update Metrics And Progress Bar
                METRICS['loss'].append(float(loss))
                METRICS['auc']['y_true'] += labels.squeeze().tolist()
                METRICS['auc']['y_pred'] += logits.sigmoid().tolist()
                # print(f"Complete updating metrics {METRICS}")
                if np.unique(METRICS['auc']['y_true']).size == 2:
                    metrics = 'µ_loss: {:.3f}'.format(np.mean(METRICS['loss']))
                    metrics += ', step_loss: {:.3f}'.format(METRICS['loss'][-1])
                    metrics += ', µ_auc: {:.3f}'.format(
                        sklearn.metrics.roc_auc_score(METRICS['auc']['y_true'], METRICS['auc']['y_pred'])
                    )

                    lr = optimizer.param_groups[0]['lr']
                    print(f'{epoch:02}/{self.NUM_EPOCHS:02} | {step:04}/{STEPS_PER_EPOCH} lr: {lr:.2E}, {metrics}')
            avg_loss = np.mean(METRICS['loss'])
            roc_auc_score = sklearn.metrics.roc_auc_score(METRICS['auc']['y_true'],
                                                          METRICS['auc']['y_pred'])
            print(f'\n=== Finish Training on Epoch {epoch} ===\n'
                  f'Average Loss = {avg_loss: .5f}, ROC Accuracy Score {roc_auc_score:.5f}')
            # Validate the model at the end of epochs
            result = self.valid_model(valid_df) 
            eval_roc_auc = float(result['eval_roc_auc'])
            eval_scores.append(eval_roc_auc)
        avg_score = np.mean(eval_scores)
        print(f'\n=== Finish Training  ==='
              f'Validate Average ROC Score = {avg_score:.5f}')
        return avg_score

    # Clear the memory
    def clear_memory(self):
        del self.model, self.tokenizer
        libc.malloc_trim(0)
        gc.collect()

In [None]:
# PyTorch XLA-specific imports
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.debug.metrics as met

from torch.utils.data import DataLoader, Dataset
# A class converts dataframe to Pytorch Dataset
class DebertaDataset(Dataset):    
    def __init__(self, tokenizer, max_length, data):
        super().__init__()
        self.tokenizer = tokenizer
        self.MAX_LENGTH = max_length
        self.data = data

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, index):
        # get data by index
        row = self.data.iloc[index]

        # clean and tokenize
        text = pre_processing_text(row['text'])
        inputs = self.tokenizer(text, truncation=True,
                                max_length=self.MAX_LENGTH,
                                padding="max_length",
                                )
        # Create ids, mask and target tensors
        ids = torch.tensor(inputs['input_ids'], dtype=torch.long)
        mask = torch.tensor(inputs['attention_mask'], dtype=torch.long)
        label = torch.tensor(row['label'], dtype=torch.float32)
    
        # return ids, mask, target for batch
        return {"input_ids" : ids,
                "attention_mask" : mask,
                "label" : label}

class DebertaModelTrainer(TrainModelTPU): 
    def __init__(self, model, train_data, **params):
        TrainModelTPU.__init__(self, model, train_data, **params)
        self.NUM_WORKERS = 8 
        self.NUM_LABELS = 1
        self.NUM_WARMUP_STEPS = 0 # Number of Warmup Steps
        self.STEPS = 50 # Display the progress each 50 steps
        self.DEVICE = xm.xla_device() # Get TPU device 
        
    # Load pretrained  LLM and tokenizer
    def load_model(self):
        start = time.time()
        MODEL_PATH = f"/kaggle/input/huggingfacedebertav3variants/{self.MODEL}"
        self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        # `bfloat16` is suitable for deep learning for better convergences during training
        base_model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH,
                                                                        num_labels=self.NUM_LABELS,
                                                                        torch_dtype=torch.bfloat16)
        # No idea why this is needed
        base_model.config.pretraining_tp = 1 # 1 is 7b
        # Assign Padding TOKEN
        base_model.config.pad_token_id = self.tokenizer.pad_token_id
        # print(base_model)
        
        # LoRa
        peft_config = LoraConfig(
            r=self.R,  # Use larger 'r' value increase more parameters during training
            lora_dropout=0.1, # reduce overfitting
            bias='none',
            inference_mode=False,
            task_type=TaskType.SEQ_CLS,
            # Only Use Output and Values Projection
            target_modules=['query_proj', 'value_proj'],
        )
        # Continue training on previous epoch
        # Load the new PEFT model
        self.model = get_peft_model(base_model, peft_config)
        self.model = self.model.merge_and_unload()    
        print(f"Complete loading pretrained LLM model {time.time() - start:.1f} seconds")
    
    def create_data_loader(self, ds):
        try:
            # defining data samplers and loaders 
            data_sampler = torch.utils.data.distributed.DistributedSampler(
                                    ds,
                                    num_replicas=xm.xrt_world_size(), # tell PyTorch how many devices (TPU cores) we are using for training
                                    rank=xm.get_ordinal(), # tell PyTorch which device (core) we are on currently
                                    shuffle=True)
            data_loader = torch.utils.data.DataLoader(ds,
                                                      batch_size=self.BATCH_SIZE,
                                                      sampler=data_sampler,
                                                      drop_last=True,
                                                      num_workers=self.NUM_WORKERS)
            return data_loader
        except Exception as err:
            print(f"Something went wrong when loading the data")
            print(f"Unexpected Error {err}, {type(err)}")
            sys.exit(-1) 
    
    def load_dataset(self):
        print(f"Total number of train data = {len(self.train_data)}")
        train_data = self.train_data
        # Preprocess the text
        train_data['text'] = train_data['text'].map(lambda text: pre_processing_text(text))
        
        # Split the training dataset into training and test data
        X = train_data['text'].tolist()
        y = train_data['label'].tolist()
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=600, random_state=SEED)
        train_df = pd.DataFrame({'text': X_train, 'label': y_train})
        valid_df = pd.DataFrame({'text': X_test, 'label': y_test})
        # create the dataset and pass tokenizer as function argum
        train_ds = DebertaDataset(self.tokenizer, self.MAX_LENGTH, train_df)
        valid_ds = DebertaDataset(self.tokenizer, self.MAX_LENGTH, valid_df)
        # print(train_tokenized_ds)
        return train_ds, valid_ds
    
    # Evaluate the model with valid dataset and return the loss on testing (valid) dataset
    def eval_loop(self, valid_loader, loss_fn):
        start = time.time()
        self.model.eval() # put model in eval mode 
        eval_targets = []
        eval_outputs = []
        for bi, input_data in enumerate(valid_loader): # enumerate through dataloader
            # put valid data tensors onto TPU device
            for k,v in input_data.items():
                input_data[k] = v.to(self.DEVICE)

            # pass ids to model
            with torch.no_grad(): 
                outputs = self.model(input_ids=input_data['input_ids'],
                                     attention_mask=input_data['attention_mask'])

            # Add the outputs and targets to a list 
            targets = input_data['label'].cpu().detach().tolist()
            outputs = outputs['logits'].cpu().detach().tolist()
            eval_targets.extend(targets)
            eval_outputs.extend(outputs)    
            del targets, outputs
            gc.collect() # delete for memory conservation

        # calculate loss
        loss = loss_fn(torch.tensor(eval_outputs), torch.tensor(eval_targets))
        # since the loss is on all 8 cores, reduce the loss values and print the average
        loss_reduced = xm.mesh_reduce('loss_reduce',loss, lambda x: sum(x) / len(x)) 
        # print valid loss
        xm.master_print(f'Complete evaluation. Valid loss={loss_reduced} in {time.time() - start:.2f} seconds.')
        return loss_reduced
        
    def train_loop(self, train_loader, optimizer, lr_scheduler, loss_fn):
        self.model.train() # put model in training mode
        # A loop for training the model
        for step, input_data in enumerate(train_loader): # enumerate through the dataloader
            start = time.time()
            # Move all training data tensors onto TPU device
            for k,v in input_data.items():
                input_data[k] = v.to(self.DEVICE)
            # Set gradient to 0
            optimizer.zero_grad()
            # Pass input data
            outputs = self.model(input_ids=input_data['input_ids'],
                                 attention_mask=input_data['attention_mask'])
            logits = outputs.logits.squeeze()
            labels = input_data['label']
            # Compute the loss
            loss = loss_fn(logits, labels)
            # Reduce (aggregate) all the losses on 8 cores
            if step % self.STEPS == 0:
                # since the loss is on all 8 cores, reduce the loss values and print the average
                loss_reduced = xm.mesh_reduce('loss_reduce', loss, lambda x: sum(x) / len(x)) 
                # print will only print once (not from all 8 cores)
                xm.master_print(f'step={step}, loss={loss_reduced}')
            # backpropagate
            loss.backward()
            # Use PyTorch XLA optimizer stepping
            xm.optimizer_step(optimizer)
            lr_scheduler.step()
            xm.master_print(f"=== Finish training loop step = {step} in {time.time() - start:.2f} seconds.")
         
    # Train the model 
    def train_model(self):
        train_dataset, valid_dataset = self.load_dataset()
        # Create Data loader 
        train_loader = self.create_data_loader(train_dataset)
        valid_loader = self.create_data_loader(valid_dataset)
        # PyTorch XLA-specific dataloading
        train_loader = pl.MpDeviceLoader(train_loader, self.DEVICE) 
        valid_loader = pl.MpDeviceLoader(valid_loader, self.DEVICE)
        # use the xmp.MpModelWrapper from PyTorch XLA to save memory when initializing the model
        mx = xmp.MpModelWrapper(self.model)
        self.model = mx.to(self.DEVICE) # put model onto the current TPU core
        print('Complete loading model and dataloader')        
        N_SAMPLES = len(train_dataset)
        STEPS_PER_EPOCH = N_SAMPLES // self.BATCH_SIZE
        NUM_TRAINING_STEPS = STEPS_PER_EPOCH * self.NUM_EPOCHS
        print(f'BATCH_SIZE: {self.BATCH_SIZE}, N_SAMPLES: {N_SAMPLES}, '
              f'STEPS_PER_EPOCH: {STEPS_PER_EPOCH}, NUM_TRAINING_STEPS: {NUM_TRAINING_STEPS}')        
        # ----------------------------------------------------------------
        # Optimizer (AdamW)
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.LR, weight_decay=0.01)
        # Cosine Learning Rate With Warmup
        lr_scheduler = transformers.get_cosine_schedule_with_warmup(
                                        optimizer=optimizer,
                                        num_warmup_steps=self.NUM_WARMUP_STEPS,
                                        num_training_steps=NUM_TRAINING_STEPS)
        # Loss function using basic Binary Cross Entropy
        loss_fn = torch.nn.BCEWithLogitsLoss().to(dtype=torch.float32)
        eval_scores = []
        # Train the model on epochs
        for epoch in range(self.NUM_EPOCHS):
            gc.collect() # I use a lot of gc.collect() statement to hopefully prevent OOM problems
            # Train the model
            self.train_loop(train_loader, optimizer, lr_scheduler, loss_fn)
            gc.collect()
            # call evaluation loop:
            eval_score = self.eval_loop(valid_loader, loss_fn)
            gc.collect()
            eval_scores.append(eval_score)
        avg_score = np.mean(eval_scores)
        print(f'\n=== Finish Training  ==='
              f'Validate Average ROC Score = {avg_score:.5f}')
        return avg_score       

# Use Optuna to find the optimal hyper-parameters
Train and Save the model with best parameters 

In [None]:
display(train_data.head(3))

In [None]:
# Start optuna study to hyper-parameter tuning
best_score = -1.0
# Find the optimal learning rate
def objective(trial, model_name, train_data):
    # Parameters
    params = {
        'lr': trial.suggest_float('learning_rate', 1e-7, 1e-3, log=True),
        'r': 64, # Default: 64
        'num_epochs': 1,
    }
    # Create a trainer 
    trainer = TrainModelTPU(model_name, train_data, **params) 
    eval_score = trainer.train_model()
    # Save the model is the avg score > current best score
    global best_score
    if eval_score > best_score:
        best_score = eval_score
        # Save all the fold models
        trainer.save_model()
    # Clean up
    trainer.clear_memory()
    del trainer
    return eval_score  # Maximal the average 'roc_auc' metric

def train_model_with_optuna(model_name, train_data):
    # # Create a study to find the optimal hyper-parameters\
    study_name = f"{model}_study"
    study_file = f"/kaggle/working/{study_name}.db"
    # Delete the study file if exits
    if os.path.isfile(study_file):
        os.remove(f'{study_file}') 

    study = optuna.create_study(direction="maximize", study_name=study_name,
                                storage="sqlite:///" + f"{study_file}", # Storage path of the database keeping the study results
                                load_if_exists=False) # True: Resume the study, False: Createa new one
    # Set up the timeout to avoid runing out of quote
    study.optimize(lambda trial: objective(trial, model_name, train_data), 
                   timeout=600, n_jobs=1, n_trials=10,
                   show_progress_bar=True, gc_after_trial=True)
    print(f"Best parameters: {study.best_params}")
    params = study.best_params # Obtain the optimal parameters

# Train the model with best parameters

In [None]:
def train_model(model_name, train_data):
    # Parameters
    params = {
        'lr': 5e-5, # learning rate
        'r': 64, # Lora's r value (Default: 64)
        'num_epochs': 1, # number of epochs 
        'batch_size': 128,
        'max_length': 1024
    }
    if model_name == 'mistral_7b':
        params['batch_size'] = 16
        # Create a trainer 
        trainer = TrainModelTPU(model_name, train_data, **params)
    elif 'deberta' in model_name:
        params['max_length'] = 512
        trainer = DebertaModelTrainer(model_name, train_data, **params)
    trainer.load_model()# Load the pretrained LLM and tokenizer 
    eval_score = trainer.train_model()
    # Save the model is the avg score > current best score
    # Save all the fold models
    trainer.save_model()
    # Clean up
    trainer.clear_memory()
    del trainer
    print(f"=== Finish training the model {model_name} with score = {eval_score}")

In [None]:
model_name = "mistral_7b" # "mistral_7b" # "deberta-v3-small"  
train_model(model_name, train_data)