RTX6000 ADA <br>
training will take approximatly 16 hrs

In [None]:
# gemma-2 is available from transformers>=4.42.3
!pip install -U "transformers>=4.42.3" bitsandbytes accelerate peft

In [None]:
!pip install -qq datasets
!pip install -qq scikit-learn

In [None]:
import os
import copy
from dataclasses import dataclass

import numpy as np
import torch
from datasets import Dataset
from transformers import (
    BitsAndBytesConfig,
    Gemma2ForSequenceClassification,
    GemmaTokenizerFast,
    Gemma2Config,
    PreTrainedTokenizerBase,
    EvalPrediction,
    Trainer,
    TrainingArguments,
    DataCollatorWithPadding,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType
from sklearn.metrics import log_loss, accuracy_score

### Configurations

In [None]:
@dataclass
class Config:
    output_dir: str = "output"
    checkpoint: str = "unsloth/gemma-2-9b-it-bnb-4bit"
    max_length: int = 3120
    n_splits: int = 10
    fold_idx: int = 0
    optim_type: str = "adamw_hf"
    per_device_train_batch_size: int = 2
    gradient_accumulation_steps: int = 6
    per_device_eval_batch_size: int = 8
    n_epochs: int = 1
    freeze_layers: int = 0
    lr: float = 1e-4
    warmup_steps: int = 20
    lora_r: int = 128
    lora_alpha: float = lora_r * 1
    lora_dropout: float = 0
    lora_bias: str = "none"

config = Config()

#### Training Arguments

In [None]:
training_args_0 = TrainingArguments(
    output_dir="output-first",
    overwrite_output_dir=True,
    report_to="none",
    num_train_epochs=config.n_epochs,
    per_device_train_batch_size=config.per_device_train_batch_size,
    gradient_accumulation_steps=config.gradient_accumulation_steps,
    per_device_eval_batch_size=config.per_device_eval_batch_size,
    logging_steps=10,
    eval_strategy="epoch",
    save_strategy="steps",
    save_steps=1800,
    optim=config.optim_type,
    fp16=True,
    learning_rate=config.lr,
    warmup_steps=config.warmup_steps,
)

#### LoRA config

In [None]:
lora_config = LoraConfig(
    r=config.lora_r,
    lora_alpha=config.lora_alpha,
    # only target self-attention
    target_modules=["q_proj", "k_proj", "v_proj","o_proj","gate_proj"],
    layers_to_transform=[i for i in range(42) if i >= config.freeze_layers],
    lora_dropout=config.lora_dropout,
    bias=config.lora_bias,
    task_type=TaskType.SEQ_CLS,
)

### Instantiate the tokenizer & model

In [None]:
tokenizer = GemmaTokenizerFast.from_pretrained(config.checkpoint)
tokenizer.add_eos_token = True  # We'll add <eos> at the end
tokenizer.padding_side = "right"

In [None]:
model = Gemma2ForSequenceClassification.from_pretrained(
    config.checkpoint,
    num_labels=3,
    torch_dtype=torch.float16,
    device_map="auto",
)
model.config.use_cache = False
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)
model

In [None]:
model.print_trainable_parameters()

### Instantiate the dataset

In [None]:
class CustomTokenizer:
    def __init__(self, tokenizer: PreTrainedTokenizerBase, max_length: int) -> None:
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __call__(self, batch: dict) -> dict:
        # Pre-process texts
        prompt = ["<prompt>: " + self.process_prompt(t) for t in batch["prompt"]]
        response_a = ["\n\n<response_a>: " + self.process_response(t) for t in batch["response_a"]]
        response_b = ["\n\n<response_b>: " + self.process_response(t) for t in batch["response_b"]]

        # Custom concatenation and truncation logic
        texts = []
        for p_parts, r_a_parts, r_b_parts, p, r_a, r_b in zip(batch["prompt"], batch["response_a"], batch["response_b"], prompt, response_a, response_b):
            full_text = p + r_a + r_b
            full_text_tokens = self.tokenizer.tokenize(full_text)
            total_length = len(full_text_tokens)

            if total_length <= self.max_length:
                texts.append(full_text)  # directly use the concatenated text
            else:
                # Calculate proportions
                len_p = len(self.tokenizer.tokenize(p))
                len_r_a = len(self.tokenizer.tokenize(r_a))
                len_r_b = len(self.tokenizer.tokenize(r_b))

                total_part_length = len_p + len_r_a + len_r_b
                proportion_p = len_p / total_part_length
                proportion_r_a = len_r_a / total_part_length
                proportion_r_b = len_r_b / total_part_length

                # Calculate tokens to keep per part
                tokens_to_keep_p = int(proportion_p * self.max_length)
                tokens_to_keep_r_a = int(proportion_r_a * self.max_length)
                tokens_to_keep_r_b = int(proportion_r_b * self.max_length)

                # function to be completed truncate_parts
                tokens_p = self.truncate_parts(p_parts,tokens_to_keep_p, "<prompt>: ", "prompt")
                tokens_r_a = self.truncate_parts(r_a_parts,tokens_to_keep_r_a, "\n\n<response_a>: ", "response")
                tokens_r_b = self.truncate_parts(r_b_parts,tokens_to_keep_r_b, "\n\n<response_b>: ", "response")

                texts.append(tokens_p + tokens_r_a + tokens_r_b)

        # Final tokenization step - ensure texts are in the right format (list of strings)
        tokenized = self.tokenizer(texts, max_length=self.max_length, truncation=True, padding=False)

        # Processing labels
        labels = [0 if a_win else 1 if b_win else 2 for a_win, b_win in zip(batch["winner_model_a"], batch["winner_model_b"])]
        return {**tokenized, "labels": labels}

    def truncate_parts(self, parts_text, max_length_for_segment, prefix, tag):
        # Evaluate the string representation of the list into an actual list
        parts = eval(parts_text, {"null": ""})

        # Tokenize each part separately and store tokens
        tokenized_parts = [self.tokenizer.tokenize(f"\n{tag} {i+1}: " + part) for i, part in enumerate(parts)]
        # Calculate the length of tokens for each part and the total length
        part_lengths = [len(tokens) for tokens in tokenized_parts]
        total_parts_length = sum(part_lengths)

        # Calculate the proportional maximum length for each part
        part_max_lengths = [(length / total_parts_length) * max_length_for_segment if total_parts_length > 0 else max_length_for_segment / len(parts) for length in part_lengths]
        # Truncate each part to its proportional length and decode
        truncated_parts = []
        for tokens, max_len in zip(tokenized_parts, part_max_lengths):
            truncated_tokens = tokens[:int(max_len)]
            token_ids = self.tokenizer.convert_tokens_to_ids(truncated_tokens)
            truncated_text = self.tokenizer.decode(token_ids)
            truncated_parts.append(truncated_text)

        # Concatenate all truncated parts
        final_text = prefix + " ".join(truncated_parts)
        return final_text

    def process_prompt(self, text: str) -> str:
        parts = eval(text, {"null": ""})
        return "".join(f"\nprompt {i+1}: {part}" for i, part in enumerate(parts))

    def process_response(self, text: str) -> str:
        parts = eval(text, {"null": ""})
        return "".join(f"\nresponse {i+1}: {part}" for i, part in enumerate(parts))

In [None]:
encode = CustomTokenizer(tokenizer, max_length=config.max_length)

In [None]:
ds_extra = Dataset.from_csv("my_input/lmsys-additional-33k-labelled-conversations/lmsys-33k-deduplicated.csv")
ds_extra = ds_extra.map(encode, batched=True)

In [None]:
ds = Dataset.from_csv("my_input/lmsys-chatbot-arena/train.csv")
ds = ds.map(encode, batched=True)

### Compute metrics

We'll compute the log-loss used in LB and accuracy as a auxiliary metric.

In [None]:
def compute_metrics(eval_preds: EvalPrediction) -> dict:
    preds = eval_preds.predictions
    labels = eval_preds.label_ids
    probs = torch.from_numpy(preds).float().softmax(-1).numpy()
    loss = log_loss(y_true=labels, y_pred=probs)
    acc = accuracy_score(y_true=labels, y_pred=preds.argmax(-1))
    return {"acc": acc, "log_loss": loss}

### Split

Here, train and eval is splitted according to their `id % 5`

In [None]:
folds = [
    (
        [i for i in range(len(ds)) if i % config.n_splits != fold_idx],
        [i for i in range(len(ds)) if i % config.n_splits == fold_idx]
    )
    for fold_idx in range(config.n_splits)
]

In [None]:
train_idx, eval_idx = folds[config.fold_idx]

trainer = Trainer(
    args=training_args_0,
    model=model,
    tokenizer=tokenizer,
    train_dataset=ds_extra,
    eval_dataset=ds.select(eval_idx),
    compute_metrics=compute_metrics,
    data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
)
trainer.train()

In [None]:
training_args_1 = TrainingArguments(
    output_dir="output-second",
    overwrite_output_dir=True,
    report_to="none",
    num_train_epochs=config.n_epochs,
    per_device_train_batch_size=config.per_device_train_batch_size,
    gradient_accumulation_steps=config.gradient_accumulation_steps,
    per_device_eval_batch_size=config.per_device_eval_batch_size,
    logging_steps=2,
    eval_strategy="epoch",
    save_strategy="steps",
    save_steps=1550,
    optim=config.optim_type,
    fp16=True,
    learning_rate=1.5e-5,
    warmup_steps=20,
)

In [None]:
trainer = Trainer(
    args=training_args_1,
    model=model,
    tokenizer=tokenizer,
    train_dataset=ds.select(train_idx),
    eval_dataset=ds.select(eval_idx),
    compute_metrics=compute_metrics,
    data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
)
trainer.train()