## Installs and imports

In [1]:
# %pip install -q -U transformers trl peft bitsandbytes numpy==1.26.4 pandas==2.2.2 torch==2.4.0 datasets wandb
# %pip install -q -U transformers==4.44.0 trl==0.9.6 peft==0.12.0 bitsandbytes numpy==1.26.4 pandas==2.2.2 torch==2.4.0 datasets wandb

In [2]:
# %pip install -qqq flash-attn

In [3]:
import textwrap
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from trl import DPOTrainer, DPOConfig
from peft import LoraConfig, prepare_model_for_kbit_training, PeftModel
from datasets import Dataset
import pandas as pd
import numpy as np
import ast


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print("cuda version:", torch.version.cuda)

# Define SEED for reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

Using device: cuda
cuda version: 11.8


## Helper and device configuration

In [4]:
if torch.cuda.get_device_capability()[0] >= 8:
    attn_implementation = "flash_attention_2"
    torch_dtype = torch.bfloat16
    print("Using flash_attention_2")
else:
    attn_implementation = "eager"
    torch_dtype = torch.float16
    print("Using eager")


Using flash_attention_2


In [5]:
def print_trainable_params(model):
    total_params = 0
    trainable_params  = 0
    for name, param in model.named_parameters():
        total_params += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || total params: {total_params} || trainable%: {100 * trainable_params / total_params}"
    )

## Load the therapist model and tokenizer


In [6]:
from huggingface_hub import notebook_login

# log in to the Hugging Face hub (required for private datasets/models)
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [8]:
# Load the 7b llama-2 model
therapist_model_id = "meta-llama/Llama-2-7b-hf"

# Define LORA config and quantization config 
##################################################
lora_config = LoraConfig(
    r=16, # 16, 256
    lora_alpha=16, # 16, 128
    lora_dropout=0.05,
    target_modules=['k_proj', 'q_proj', 'v_proj', 'o_proj', 'up_proj', 'down_proj', 'gate_proj'],
    bias="none",
    task_type="CAUSAL_LM",
)

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_quant_type="nf4"
)
##################################################


# Load Tokenizer 
##################################################
tokenizer = AutoTokenizer.from_pretrained(therapist_model_id, trust_remote_code=True, device_map="auto")
# set the chat template to include <|im_start|> and <|im_end|> tokens
tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"

# # add the <|im_start|> and <|im_end|> tokens to the tokenizer vocab and set them as the bos and eos tokens
# tokenizer.add_tokens(['<|im_start|>', '<|im_end|>'])
# tokenizer.bos_token = '<|im_start|>'
# tokenizer.eos_token = '<|im_end|>'

# set the pad token to the eos token to avoid issues with padding
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"  # Fix weird overflow issue with fp16 training
##################################################


# Load model, quantized
##################################################
base_model = AutoModelForCausalLM.from_pretrained(
    therapist_model_id,
    quantization_config=quantization_config,
    trust_remote_code=True,
    device_map="auto",
    attn_implementation=attn_implementation
)

base_model.config.use_cache = False
# base_model.resize_token_embeddings(len(tokenizer)) # Resize model embeddings to include new tokens
# base_model.config.eos_token_id = tokenizer.eos_token_id # Set EOS token ID in config to correctly attend to EOS tokens

##################################################




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [9]:
############################################ Old Adapters ############################################
# # Look-Ahead=0 PartialEval
# therapist_first_adapter_id = "LBK95/Llama-2-7b-hf-DPO-PartialEval_ET0.1_MT1.2_1-5_V.1.0_Filtered0.1_V1.0"
# therapist_second_adapter_id = "LBK95/Llama-2-7b-hf-DPO-PartialEval_ET0.1_MT1.2_1-5_V.1.0_Filtered0.1_V2.0"
# therapist_third_adapter_id = "LBK95/Llama-2-7b-hf-DPO-PartialEval_ET0.1_MT1.2_1-5_V.1.0_Filtered0.1_V3.0"

# # Look-Ahead=5 FullEval
# therapist_first_adapter_id = "LBK95/Llama-2-7b-hf-DPO-FullEval_LookAhead5_TTree1.2_TT0.7_TP0.7_TE0.1_Filtered0.1_V1.0"
# therapist_second_adapter_id = "LBK95/Llama-2-7b-hf-DPO-FullEval_LookAhead5_TTree1.2_TT0.7_TP0.7_TE0.1_Filtered0.1_V2.0"
# therapist_third_adapter_id = "LBK95/Llama-2-7b-hf-DPO-FullEval_LookAhead5_TTree1.2_TT0.7_TP0.7_TE0.1_Filtered0.1_V3.0"

# # Look-Ahead=3 FullEval
# therapist_first_adapter_id = "LBK95/Llama-2-7b-hf-DPO-LookAhead3_FullEval_TTree1.4_TLoop0.7_TEval0.2_Filter0.2_V1.0"
# therapist_second_adapter_id = "LBK95/Llama-2-7b-hf-DPO-LookAhead3_FullEval_TTree1.4_TLoop0.7_TEval0.2_Filter0.2_V2.0"
# therapist_third_adapter_id = "LBK95/Llama-2-7b-hf-DPO-LookAhead3_FullEval_TTree1.4_TLoop0.7_TEval0.2_Filter0.2_V3.0"

# ############################################ New Adapters ############################################

lookAhead = 0
therapist_first_adapter_id = f"LBK95/Llama-2-7b-hf-DPO-LookAhead-{lookAhead}_TTree1.4_TT0.9_TP0.7_TE0.2_V1"
# therapist_second_adapter_id = f"LBK95/Llama-2-7b-hf-DPO-LookAhead-{lookAhead}_TTree1.4_TT0.9_TP0.7_TE0.2_V2"
# therapist_third_adapter_id = f"LBK95/Llama-2-7b-hf-DPO-LookAhead-{lookAhead}_TTree1.4_TT0.9_TP0.7_TE0.2_V3"


# ############################################ New Adapters ############################################
# add first adapter
base_model = PeftModel.from_pretrained(base_model, therapist_first_adapter_id)
# merge first adapter and unload
base_model = base_model.merge_and_unload()
print("Model loaded with first adapter")
print("Adapter ID: ", therapist_first_adapter_id)

# # add second adapter
# base_model = PeftModel.from_pretrained(base_model, therapist_second_adapter_id)
# # merge second adapter and unload
# base_model = base_model.merge_and_unload()
# print("Model loaded with second adapter")
# print("Adapter ID: ", therapist_second_adapter_id)

# # add third adapter
# base_model = PeftModel.from_pretrained(base_model, therapist_third_adapter_id)
# # merge third adapter and unload
# base_model = base_model.merge_and_unload()
# print("Model loaded with third adapter")
# print("Adapter ID: ", therapist_third_adapter_id)
############################################ Add Adapters ############################################


# Prepare model for KBIT training
base_model = prepare_model_for_kbit_training(base_model) # Prepare model for KBIT training

print_trainable_params(base_model)

adapter_config.json:   0%|          | 0.00/727 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


adapter_model.safetensors:   0%|          | 0.00/160M [00:00<?, ?B/s]



Model loaded with first adapter
Adapter ID:  LBK95/Llama-2-7b-hf-DPO-LookAhead-0_TTree1.4_TT0.9_TP0.7_TE0.2_V1
trainable params: 0 || total params: 3500412928 || trainable%: 0.0


## Load the data and preprocess it

In [10]:
# Method to Convert string representations of lists to actual lists
def convert_string_to_list(df):
    df["messages"] = df["messages"].apply(ast.literal_eval)
    df["conversation"] = df["conversation"].apply(ast.literal_eval)
    df["winning_scores_list"] = df["winning_scores_list"].apply(ast.literal_eval)
    df["losing_scores_list"] = df["losing_scores_list"].apply(ast.literal_eval)
    df["winning_scores_avg_list"] = df["winning_scores_avg_list"].apply(ast.literal_eval)
    df["losing_scores_avg_list"] = df["losing_scores_avg_list"].apply(ast.literal_eval)
    return df

# Method to load the preference trees
def load_preference_trees(data_path, start_index=0, end_index=96):
    preference_trees_list = []
    for i in range(start_index, end_index):
        with open(os.path.join(data_path, f"pref_data_{i}.csv"), "r") as f:
            pref_tree = pd.read_csv(f)
            pref_tree["tree_index"] = i
            preference_trees_list.append(pref_tree)
    # Concatenate all the preference trees into a single dataframe
    preference_trees_df = pd.concat(preference_trees_list, ignore_index=True)
    num_of_rows_before = preference_trees_df.shape[0]
    #get only rows with NaN values
    preference_trees_NaN = preference_trees_df[preference_trees_df.isna().any(axis=1)]
    # Drop rows with missing values
    preference_trees_df = preference_trees_df.dropna()
    num_of_rows_after = preference_trees_df.shape[0]
    num_rows_removed = num_of_rows_before - num_of_rows_after
    # Convert string representations of lists to actual lists
    preference_trees_df = convert_string_to_list(preference_trees_df)
    return preference_trees_df, preference_trees_list, num_rows_removed, preference_trees_NaN

# Method to add prompts column to the dataframe (aplly chat tamplet to the messages)
def add_prompts_column(df, tokenizer):
    prompt_list = [tokenizer.apply_chat_template(message, add_generation_prompt=True, tokenize=False) for message in df["messages"]]
    df["prompt"] = prompt_list
    return df

# Method to add prompt length column to the dataframe
def add_prompt_length_column(df, tokenizer):
    prompt_length_list = [len(tokenizer.encode(prompt)) for prompt in df["prompt"]]
    df["prompt_length"] = prompt_length_list
    return df

# Method to add conversation length (number of turns) column to the dataframe
def add_conversation_length_column(df):
    conversation_length_list = [len(conversation) for conversation in df["conversation"]]
    df["conversation_length"] = conversation_length_list
    return df

# Method to add eos token to the end of the responses
def add_end_token_to_responses(df, tokenizer):
    df["winning_response"] = [response + tokenizer.eos_token for response in df["winning_response"]]
    df["losing_response"] = [response + tokenizer.eos_token for response in df["losing_response"]]
    return df

# Method to remove rows where the winning response is the same as the losing response
def remove_duplicate_responses(df):
    df_without_duplicates = df[df["winning_response"] != df["losing_response"]]
    df_with_duplicates = df[df["winning_response"] == df["losing_response"]]
    return df_without_duplicates, df_with_duplicates

# Method to create the dpo_prefernce_data (Dict with keys: "prompt", "chosen", "rejected")
def create_preference_data(preference_trees_df, score_threshold=0.2,
                           prompt_column_name="prompt", winning_response_column_name="winning_response", losing_response_column_name="losing_response", 
                           winning_score_column_name="winning_score_final", losing_score_column_name="losing_score_final", min_score=0.0, max_score=5.0):
    # update the number of rows before dropping rows not within the min_score and max_score
    num_of_rows_before = preference_trees_df.shape[0]
    # drop rows where the winning score or losing score is less than min_score or greater than max_score
    preference_trees_df = preference_trees_df[(preference_trees_df[winning_score_column_name] >= min_score) & (preference_trees_df[winning_score_column_name] <= max_score)]
    preference_trees_df = preference_trees_df[(preference_trees_df[losing_score_column_name] >= min_score) & (preference_trees_df[losing_score_column_name] <= max_score)]
    # print the number of rows before and after dropping rows not within the min_score and max_score
    print(f"Number of rows in total: {num_of_rows_before}")
    print(f"Number of rows removed due to min_score or max_score: {num_of_rows_before - preference_trees_df.shape[0]}")
    print(f"Number of rows remaining: {preference_trees_df.shape[0]}")

    # update the number of rows before thresholding
    num_of_rows_before = preference_trees_df.shape[0] 
    # drop rows where the winning score < losing score + threshold
    preference_data = preference_trees_df[preference_trees_df[winning_score_column_name] >= preference_trees_df[losing_score_column_name] + score_threshold]
    num_of_rows_after = preference_data.shape[0]
    # print the number of rows before and after thresholding
    print(f"Number of rows in total: {num_of_rows_before}")
    print(f"Number of rows removed due to threshold: {num_of_rows_before - num_of_rows_after}")
    print(f"Number of rows remaining: {num_of_rows_after}")

    # Create the dpo_dataset_dict (Preference Data) (Dict with keys: "prompt", "chosen", "rejected")
    dpo_dataset_dict = {
    "prompt": preference_data[prompt_column_name].tolist(),
    "chosen": preference_data[winning_response_column_name].tolist(),
    "rejected": preference_data[losing_response_column_name].tolist(),
    }
    # Create the dpo_dataset from the dpo_dataset_dict
    dpo_dataset = Dataset.from_dict(dpo_dataset_dict)
    
    return dpo_dataset, preference_data

# Method to print the conversation with word wrapping
def print_conversation(conversation, max_width=80):
    """
    Print the conversation with word wrapping.

    Parameters:
        - conversation: A list of strings representing the conversation. (Therapist and Patient messages alternately, starting with the Therapist)
        - max_width: The maximum width for word wrapping. Default is 80.
    """
    for i, message in enumerate(conversation):
        role = "[THERAPIST]" if i % 2 == 0 else "[PATIENT]"
        print(f"{role}: \n{textwrap.fill(message, width=max_width)} \n")

# Method to get only the final conversations for each tree index
def get_df_for_final_conversations_for_each_tree_index(df):
    final_conversations_list = []
    for tree_index in df["tree_index"].unique():
        final_conversations = df[(df["tree_index"] == tree_index) & (df["conversation_length"] == df[df["tree_index"] == tree_index]["conversation_length"].max())]
        final_conversations_list.append(final_conversations)
    final_conversations_df = pd.concat(final_conversations_list, ignore_index=True)
    return final_conversations_df

In [12]:
# data_path = "LLM_DATA/Conversation_Trees/LookAhead3_FullEval_TTree1.4_TLoop0.7_TEval0.2_V4.0"
data_path = f"LLM_DATA/Conversation_Trees/LookAhead_{lookAhead}/TTree1.4_TT0.9_TP0.7_TE0.2_V2"

data_path = f"/content/drive/MyDrive/{data_path}"

# Load the preference trees
preference_trees_df, preference_trees_list, num_rows_removed, preference_trees_NaN = load_preference_trees(data_path, start_index=0, end_index=96)
print(f"Number of rows removed: {num_rows_removed}")
# Add the <im_end> token to the winning and losing responses
# preference_trees_df = add_end_token_to_responses(preference_trees_df, tokenizer)
# Add prompts column to the dataframe
preference_trees_df = add_prompts_column(preference_trees_df, tokenizer)
# Add prompt length column to the dataframe
preference_trees_df = add_prompt_length_column(preference_trees_df, tokenizer)
# Add conversation length column to the dataframe
preference_trees_df = add_conversation_length_column(preference_trees_df)
# Remove rows where the winning response is the same as the losing response
preference_trees_df, duplicate_responses_df = remove_duplicate_responses(preference_trees_df)
display(preference_trees_df.columns)
display(preference_trees_df.head())

Number of rows removed: 0


Index(['conversation', 'messages', 'winning_response', 'losing_response',
       'winning_scores_list', 'losing_scores_list', 'winning_scores_avg_list',
       'losing_scores_avg_list', 'winning_score_final', 'losing_score_final',
       'winning_conversation', 'losing_conversation', 'tree_index', 'prompt',
       'prompt_length', 'conversation_length'],
      dtype='object')

Unnamed: 0,conversation,messages,winning_response,losing_response,winning_scores_list,losing_scores_list,winning_scores_avg_list,losing_scores_avg_list,winning_score_final,losing_score_final,winning_conversation,losing_conversation,tree_index,prompt,prompt_length,conversation_length
0,"[My name is David, and I'm a counselor, can yo...","[{'role': 'system', 'content': 'You are a moti...","Great, James! I am so glad to hear that you ha...","James, let me help you to better understand th...","[[3, 2, 3, 2, 1], [3, 2, 2, 2, 2, 2, 2, 2, 1, ...","[[2, 2, 3, 1, 1], [3, 2, 1, 2, 2, 2, 2, 2, 2, ...","[2.2, 2.0588235294117645]","[1.8, 1.8823529411764706]",2.129412,1.841176,"[""My name is David, and I'm a counselor, can y...","[""My name is David, and I'm a counselor, can y...",0,<|im_start|>system\nYou are a motivational int...,318,2
1,"[My name is David, and I'm a counselor, can yo...","[{'role': 'system', 'content': 'You are a moti...","Excellent, I appreciate that. It's great that ...",Thank you. I'll keep that information in mind ...,"[[3, 2, 3, 2, 2], [4, 3, 3, 3, 4, 3, 4, 4, 3, ...","[[3, 2, 4, 2, 1], [4, 3, 2, 2, 3, 2, 3, 3, 3, ...","[2.4, 3.4705882352941178]","[2.4, 2.6470588235294117]",2.935294,2.523529,"[""My name is David, and I'm a counselor, can y...","[""My name is David, and I'm a counselor, can y...",0,<|im_start|>system\nYou are a motivational int...,477,4
2,"[My name is David, and I'm a counselor, can yo...","[{'role': 'system', 'content': 'You are a moti...","Great, it sounds like that's a goal you can ma...",James! Thank you. You are going to face hard t...,"[[4, 3, 4, 3, 3], [4, 3, 4, 3, 4, 3, 4, 3, 4, ...","[[4, 3, 4, 3, 2], [4, 3, 4, 3, 4, 3, 3, 3, 3, ...","[3.4, 3.588235294117647]","[3.2, 3.3529411764705883]",3.494118,3.276471,"[""My name is David, and I'm a counselor, can y...","[""My name is David, and I'm a counselor, can y...",0,<|im_start|>system\nYou are a motivational int...,684,6
3,"[My name is David, and I'm a counselor, can yo...","[{'role': 'system', 'content': 'You are a moti...",Awesome! Great choice! Now let's take a look a...,"Great, this sounds like a great start to our c...","[[4, 4, 5, 4, 4], [4, 3, 3, 3, 4, 3, 3, 3, 3, ...","[[4, 3, 4, 3, 2], [4, 3, 3, 3, 4, 3, 4, 4, 3, ...","[4.2, 3.2941176470588234]","[3.2, 3.4705882352941178]",3.747059,3.335294,"[""My name is David, and I'm a counselor, can y...","[""My name is David, and I'm a counselor, can y...",0,<|im_start|>system\nYou are a motivational int...,767,8
4,"[My name is David, and I'm a counselor, can yo...","[{'role': 'system', 'content': 'You are a moti...",Great! You mentioned before that you started s...,"Outstanding, it's fantastic, it's great to hea...","[[4, 4, 5, 4, 4], [4, 3, 4, 3, 4, 3, 4, 3, 3, ...","[[4, 3, 4, 3, 3], [4, 3, 4, 3, 4, 3, 3, 3, 4, ...","[4.2, 3.5294117647058822]","[3.4, 3.588235294117647]",3.864706,3.494118,"[""My name is David, and I'm a counselor, can y...","[""My name is David, and I'm a counselor, can y...",0,<|im_start|>system\nYou are a motivational int...,860,10


In [13]:
# Create the preference data for the model
dpo_dataset, preference_trees_filtered_df = create_preference_data(preference_trees_df, score_threshold=0.1)
# shuffle the dataset
dpo_dataset = dpo_dataset.shuffle(seed=42)
# split the dataset into training and validation sets
dpo_dataset = dpo_dataset.train_test_split(test_size=0.01)
print("dpo_dataset length: ", len(dpo_dataset))
display(dpo_dataset["train"][0])
display(dpo_dataset["test"][0])

Number of rows in total: 1552
Number of rows removed due to min_score or max_score: 0
Number of rows remaining: 1552
Number of rows in total: 1552
Number of rows removed due to threshold: 511
Number of rows remaining: 1041
dpo_dataset length:  2


{'prompt': '<|im_start|>system\nYou are a motivational interviewing counselor named David. You partner with the patient to understand his problems. In your answer, please avoid repetitions and unnecessary loops in the conversation. In your answer, please avoid repeating expressions of gratitude or similar sentiments multiple times if you\'ve already expressed them during the conversation. You should only end the session when at least one of the following conditions is met. If you need to end the session, write "SESSION ENDED" followed by the condition number: 1. If you believe that you have provided the appropriate treatment to the patient and have nothing else to advise in the current session.2. When time is up.<|im_end|>\n<|im_start|>user\n<|im_end|>\n<|im_start|>assistant\nMy name is David, and I\'m a counselor, can you start by telling me a little bit about yourself and why you are here?<|im_end|>\n<|im_start|>user\nHi David, my name is Emma, and I\'m 27 years old. I\'m here becaus

{'prompt': '<|im_start|>system\nYou are a motivational interviewing counselor named David. You partner with the patient to understand his problems. In your answer, please avoid repetitions and unnecessary loops in the conversation. In your answer, please avoid repeating expressions of gratitude or similar sentiments multiple times if you\'ve already expressed them during the conversation. You should only end the session when at least one of the following conditions is met. If you need to end the session, write "SESSION ENDED" followed by the condition number: 1. If you believe that you have provided the appropriate treatment to the patient and have nothing else to advise in the current session.2. When time is up.<|im_end|>\n<|im_start|>user\n<|im_end|>\n<|im_start|>assistant\nMy name is David, and I\'m a counselor, can you start by telling me a little bit about yourself and why you are here?<|im_end|>\n<|im_start|>user\nHello David, I\'m actually James, a 61-year-old male. I\'ve been s

In [14]:
dpo_dataset["train"].shape, dpo_dataset["test"].shape

((1030, 3), (11, 3))

# Training algorithms
- DPO
- ORPO
- Soon: PPO, KTO

## DPO

In [None]:
# DPO_model_id = "LBK95/Llama-2-7b-hf-DPO-LookAhead3_FullEval_TTree1.4_TLoop0.7_TEval0.2_Filter0.2_V4.0"
DPO_model_id = f"LBK95/Llama-2-7b-hf-DPO-LookAhead-{lookAhead}_TTree1.4_TT0.9_TP0.7_TE0.2_V2"


# Training arguments
training_args = DPOConfig(
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=2,
    gradient_checkpointing=True,
    learning_rate=5e-5,
    lr_scheduler_type="cosine", # other options: "linear", "cosine", "cosine_with_restarts"
    #max_steps=1000,
    num_train_epochs=3,
    save_strategy="steps",
    evaluation_strategy="steps",
    save_steps=0.25,
    logging_steps=1,
    eval_steps=0.1,
    output_dir=DPO_model_id,
    optim="paged_adamw_32bit",
    warmup_steps=10,
    bf16=True,
    report_to="wandb",
    push_to_hub=True,
    remove_unused_columns=False,
)

# Create DPO trainer
dpo_trainer = DPOTrainer(
    base_model,
    None,
    args=training_args,
    train_dataset=dpo_dataset["train"],
    eval_dataset=dpo_dataset["test"],
    tokenizer=tokenizer,
    peft_config=lora_config,
    beta=0.1,
    max_prompt_length=2048,
    max_length=2176,
)



In [None]:
dpo_trainer.model.print_trainable_parameters()

In [None]:
dpo_trainer.train()

# push the trained model and tokenizer to the hub
dpo_trainer.push_to_hub()



## ORPO

In [None]:
# ORPO_model_id = "LBK95/Llama-2-7b-hf-eval_threapist-ORPO-version-1"

# # Training arguments
# orpo_args = ORPOConfig(
#     per_device_train_batch_size=2,
#     per_device_eval_batch_size=2,
#     gradient_accumulation_steps=2,
#     gradient_checkpointing=True,
#     learning_rate=8e-6,
#     optim="paged_adamw_8bit",
#     #max_steps=1000,
#     num_train_epochs=1,
#     save_strategy="steps",
#     evaluation_strategy="steps",
#     save_steps=0.25,
#     logging_steps=1,
#     eval_steps=0.2,
#     output_dir=ORPO_model_id,
#     warmup_steps=10,
#     bf16=True,
#     report_to="wandb",
#     push_to_hub=True,
#     remove_unused_columns=False,
#     lr_scheduler_type="linear",
#     max_length=2048,
#     max_prompt_length=1024,
#     beta=0.1,
# )

# # Create DPO trainer
# orpo_trainer = ORPOTrainer(
#     model=base_model,
#     args=orpo_args,
#     train_dataset=dpo_dataset["train"],
#     eval_dataset=dpo_dataset["test"],
#     tokenizer=tokenizer,
#     peft_config=lora_config,
# )



In [None]:
# orpo_trainer.model.print_trainable_parameters()

In [None]:
# orpo_trainer.train()

# # push the trained model and tokenizer to the hub
# orpo_trainer.push_to_hub()