# Import libs 

In [None]:
# Install libs
!pip install -qq sentencepiece
!pip install -qq datasets
!pip install -qq peft==0.9.0
!pip install -qq bitsandbytes==0.41.1
!pip install -qq accelerate==0.27.2
!pip install -qq transformers==4.42.3
!pip install -qq torch~=2.1.0 --index-url https://download.pytorch.org/whl/cpu -q 
!pip install -qq torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html -q
!pip uninstall -qq tensorflow -y # If we don't do this, TF will take over TPU and cause permission error for PT
!cp /kaggle/input/utils-xla/spmd_util.py . # From this repo: https://github.com/HeegyuKim/torch-xla-SPMD

In [None]:
import os
import gc
import re
from time import time
import random
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import torch
import transformers
import torch.nn.functional as F

import torch_xla.debug.profiler as xp
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_sharding as xs
import torch_xla.runtime as xr

xr.use_spmd()

from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor
from torch_xla.experimental.xla_sharding import Mesh
from spmd_util import partition_module
from torch.cuda.amp import autocast

tqdm.pandas()

print(f'Torch Version: {torch.__version__}')

In [None]:
import copy
from dataclasses import dataclass
from datasets import Dataset
from transformers import (
    BitsAndBytesConfig,
    Gemma2ForSequenceClassification,
    GemmaTokenizerFast,
    Gemma2Config,
    PreTrainedTokenizerBase, 
    EvalPrediction,
    Trainer,
    TrainingArguments,
    DataCollatorWithPadding,
)
from peft import LoraConfig, get_peft_model, TaskType
from sklearn.metrics import log_loss, accuracy_score

In [None]:
@dataclass
class Config:
    output_dir: str = "output"
    checkpoint: str = '/kaggle/input/gemma-2/transformers/gemma-2-2b-it/2'  # gemma-2-9b-instruct
    max_length: int = 1024
    n_splits: int = 5
    fold_idx: int = 0
    optim_type: str = "adamw_torch_xla"
    per_device_train_batch_size: int = 8
    gradient_accumulation_steps: int = 1
    per_device_eval_batch_size: int = 8
    n_epochs: int = 1
    freeze_layers: int = 16  # there're 42 layers in total, we don't add adapters to the first 16 layers
    lr: float = 5e-5
    warmup_steps: int = 128
    lora_r: int = 4
    lora_alpha: float = lora_r * 2
    lora_dropout: float = 0.05
    lora_bias: str = "none"
    
config = Config()

In [None]:
training_args = TrainingArguments(
    output_dir="output",
    overwrite_output_dir=True,
    report_to="none",
    num_train_epochs=config.n_epochs,
    per_device_train_batch_size=config.per_device_train_batch_size,
    gradient_accumulation_steps=config.gradient_accumulation_steps,
    per_device_eval_batch_size=config.per_device_eval_batch_size,
    logging_steps=10,
    eval_strategy="epoch",
    save_strategy="steps",
    save_steps=200,
    optim=config.optim_type,
    fp16=False,
    learning_rate=config.lr,
    warmup_steps=config.warmup_steps,
    # Add argument to use TPU with XLA device
    tpu_num_cores=8,  # Specify the number of TPU cores
    no_cuda=False  # Enable CUDA for TPU
)

In [None]:
lora_config = LoraConfig(
    r=config.lora_r,
    lora_alpha=config.lora_alpha,
    # only target self-attention
    target_modules=["q_proj", "k_proj", "v_proj"],
    layers_to_transform=[i for i in range(42) if i >= config.freeze_layers],
    lora_dropout=config.lora_dropout,
    bias=config.lora_bias,
    task_type=TaskType.SEQ_CLS,
)

In [None]:
tokenizer = GemmaTokenizerFast.from_pretrained(config.checkpoint)
tokenizer.add_eos_token = True  # We'll add <eos> at the end
tokenizer.padding_side = "right"

In [None]:
# Load model without quantization
model = Gemma2ForSequenceClassification.from_pretrained(
    config.checkpoint,
    num_labels=3,
    torch_dtype=torch.bfloat16,  
    device_map="auto",  
)

model.config.use_cache = False
model = get_peft_model(model, lora_config)

# Move the model to TPU
device = xm.xla_device()
model.to(device)

In [None]:
model.print_trainable_parameters()

In [None]:
ds = Dataset.from_csv("/kaggle/input/llm-classification-finetuning/train.csv")
ds = ds.select(torch.arange(100))  # We only use the first 100 data for demo purpose

In [None]:
from typing import List, Dict
from transformers import PreTrainedTokenizerBase

class CustomTokenizer:
    def __init__(
        self, 
        tokenizer: PreTrainedTokenizerBase, 
        max_length: int
    ) -> None:
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __call__(self, batch: dict) -> dict:
        # Add <prompt>, <response_a>, <response_b> and process text properly
        prompt = ["<prompt>: " + self.process_text(t) for t in batch["prompt"]]
        response_a = ["\n\n<response_a>: " + self.process_text(t) for t in batch["response_a"]]
        response_b = ["\n\n<response_b>: " + self.process_text(t) for t in batch["response_b"]]
        
        # Combine the prompt, response_a, and response_b into a list of text
        texts = [p + r_a + r_b for p, r_a, r_b in zip(prompt, response_a, response_b)]
        
        # Tokenize the combined texts
        tokenized = self.tokenizer(texts, max_length=self.max_length, truncation=True, padding="max_length", return_tensors="pt")
        
        # Generate labels based on the winner model data (ensuring the right check)
        labels = []
        for a_win, b_win in zip(batch["winner_model_a"], batch["winner_model_b"]):
            if a_win:  # If model A wins
                label = 0
            elif b_win:  # If model B wins
                label = 1
            else:  # In case of a tie
                label = 2
            labels.append(label)
        
        # Return tokenized data with the generated labels
        return {**tokenized, "labels": labels}
        
    @staticmethod
    def process_text(text: str) -> str:
        # Safely process the text, replacing 'null' with an empty string
        return text.replace("null", "")

In [None]:
encode = CustomTokenizer(tokenizer, max_length=config.max_length)
ds = ds.map(encode, batched=True)

In [None]:
def compute_metrics(eval_preds: EvalPrediction) -> dict:
    preds = eval_preds.predictions
    labels = eval_preds.label_ids
    probs = torch.from_numpy(preds).float().softmax(-1).numpy()
    loss = log_loss(y_true=labels, y_pred=probs)
    acc = accuracy_score(y_true=labels, y_pred=preds.argmax(-1))
    return {"acc": acc, "log_loss": loss}

In [None]:
folds = [
    (
        [i for i in range(len(ds)) if i % config.n_splits != fold_idx],
        [i for i in range(len(ds)) if i % config.n_splits == fold_idx]
    ) 
    for fold_idx in range(config.n_splits)
]

In [None]:
train_idx, eval_idx = folds[config.fold_idx]

trainer = Trainer(
    args=training_args, 
    model=model,
    tokenizer=tokenizer,
    train_dataset=ds.select(train_idx),
    eval_dataset=ds.select(eval_idx),
    compute_metrics=compute_metrics,
    data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
)

# Start training with AMP
def training_step(model, inputs):
    # Unpack inputs from the dataloader
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]
    labels = inputs["labels"]
    
    # Autocast for mixed precision forward pass
    with autocast(xm.xla_device()):  # Cast to bfloat16 for operations
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        loss = model.compute_loss(outputs, labels)  # Assuming compute_loss method exists

    # Backward pass and optimization
    loss.backward()
    xm.optimizer_step(optimizer)

    return loss

# Train the model
trainer.train()