In [None]:
!pip install bitsandbytes peft trl --quiet

In [None]:
from datasets import load_dataset

In [None]:
data = load_dataset("valerielucro/gsm8k_preference_dataset",split="train")

In [None]:
# Preprocessing function
def preprocess(data):
    data['prompt'] = '<s>[INST]' + data['prompt'] + '\n do it step by step [/INST]'
    data['chosen'] += '</s>'
    data['rejected'] += '</s>'
    return data

In [None]:
train_dataset = data.map(preprocess)

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, pipeline
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model,PeftModel
from trl import DPOTrainer,DPOConfig
import torch
import wandb
from kaggle_secrets import UserSecretsClient

In [None]:
user_secrets = UserSecretsClient()
wandb_token = user_secrets.get_secret("wandb")
HF_token = user_secrets.get_secret("HF")

In [None]:
from huggingface_hub import login
login(HF_token)

In [None]:
dataset_size = len(train_dataset)
num_of_epochs = 1

In [None]:
notes = f"""
initial DPO test run on sample gsm8k preference dataset of {dataset_size} and {num_of_epochs} epochs"""

In [None]:
wandb.login(key = wandb_token)
run = wandb.init(
    project='gsm8k', 
    job_type="training",
    name="test run with DPO"
    ,notes = notes
)

In [None]:
base_model = "/kaggle/input/mistral/pytorch/7b-instruct-v0.1-hf/1"

In [None]:
tokenizer = AutoTokenizer.from_pretrained(base_model)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'right'

In [None]:
bnb_config = BitsAndBytesConfig(  
    load_in_4bit= True,
    bnb_4bit_quant_type= "nf4",
    bnb_4bit_compute_dtype= torch.bfloat16,
    bnb_4bit_use_double_quant= False,
)

In [None]:
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=16,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj","gate_proj","up_proj","down_proj"]
)

In [None]:
model = AutoModelForCausalLM.from_pretrained(
        base_model,
        quantization_config=bnb_config,
        torch_dtype=torch.bfloat16,
        device_map="auto"
)
model.config.use_cache = False
model = prepare_model_for_kbit_training(model)

In [None]:
training_args = DPOConfig(
    output_dir="/kaggle/working/checkpoints",
    num_train_epochs=num_of_epochs,
    beta=0.1,
    per_device_train_batch_size=8,
    save_strategy="steps",
    save_steps=25
)

In [None]:
trainer = DPOTrainer(
    model=model,
    peft_config=peft_config,
    train_dataset=train_dataset,
    args=training_args,
    tokenizer=tokenizer
)

In [None]:
trainer.train()
wandb.finish()

In [None]:
fine_tuned_model_name = "mistral_gsm8k_ssl_it1"

In [None]:
trainer.model.save_pretrained(fine_tuned_model_name)

In [None]:
commit_message = "initial adapter with DPO on sample gsm8k preference dataset and 1 epoch"

In [None]:
trainer.model.push_to_hub(fine_tuned_model_name,commit_message=commit_message)
tokenizer.push_to_hub(fine_tuned_model_name, commit_message=commit_message)