In [None]:
!pip install -qq trl

In [None]:
import torch
import gc
import pandas as pd
from tqdm.auto import tqdm
from datasets import load_dataset, Dataset as HFDataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
from trl import DPOTrainer,DPOConfig

# =============================================================================
# 1. Configuration
# =============================================================================
MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct"
DATASET_NAME = "gsm8k"
NUM_ITERATIONS = 16     # Total number of I-DPO cycles
SAMPLES_PER_ITERATION = 1024 # Subset size per iteration (adjust as needed)
EVAL_DATASET_SIZE = 128  #Size of eval dataset
DPO_EPOCHS = 1             # Train for 1 epoch over the generated data
BATCH_SIZE = 4             # DPO optimization batch size
GENERATION_BATCH_SIZE = 32 # Batch size during the generation phase
GRADIENT_ACCUMULATION_STEPS = 32
LEARNING_RATE = 1e-5
DPO_BETA = 0.2             # The temperature parameter for DPO
MAX_LENGTH = 1024
MAX_PROMPT_LENGTH = 512
MAX_TARGET_LENGTH = 2048

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
dataset = load_dataset(DATASET_NAME,'main',split='train')
test_dataset = load_dataset(DATASET_NAME,'main',split='test')


model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map = 'auto'
    )
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

tokenizer.padding_side = 'left'

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)

model = get_peft_model(model,lora_config)
model.print_trainable_parameters()

In [None]:
from google.colab import userdata
from openai import OpenAI
from pydantic import BaseModel
from typing import List

OPENAI_API_KEY = userdata.get('OPENAI_API_KEY')
client = OpenAI(api_key = OPENAI_API_KEY)

class answer_verifier(BaseModel):
    do_answers_match: bool

class make_hints(BaseModel):
    hints: List[str]
    was_assistant_answer_correct:bool

def match_answers(answer,gold_answer):

    instructions = '''You will recieve the answer to a problem from a virtual assistant and the real answer. You will output True or False depending on whether the virtual assistant final answer is correct. (disregard the reasoning, all we care about is if the final answers match)'''
    prompt = f'''Virtual assistant answer: {answer}
    Real answer: {gold_answer}

    Do the final answers match? Return True or False'''
    response = client.responses.parse(
        model = 'gpt-4.1',
        instructions = instructions,
        input = prompt,
        text_format = answer_verifier
    )
    output  = response.output_parsed
    return output.do_answers_match

def get_hints_from_oracle(question,answer,gold_answer):

    instructions = '''You will recieve a virtual assistant answer and a model answer. Your job is to provide a list of hints to strongly guide the model towards the correct answer, while preveting any errors it has made'''
    prompt = f'''Quesiton:{question}
    Virtual assistant answer: {answer}
    Real answer: {gold_answer}

    Please provide a list of hints to strongly guide the model towards the correct answer, while preventing any errors it has made'''

    response = client.responses.parse(
        model = 'gpt-4.1',
        instructions = instructions,
        input = prompt,
        text_format = make_hints
    )
    output = response.output_parsed
    return output.hints,output.was_assistant_answer_correct


In [None]:
def format_prompt(question,hints = None):
    if hints is not None:
        question =  f"Question: {question}\nHints: {hints}\nAnswer:"
    else:
        question =  f"Question: {question}\nAnswer:"

    messages = [
        {'role':'system','content':'Solve the problem step by step. Effectively utilize any hints given to you. Clearly state your answer'},
        {'role':'user','content':question}
    ]

    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)


def batch_generate(model,device,prompts,temp = 0.7):
    inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation = True,max_length=MAX_PROMPT_LENGTH).to(device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=MAX_TARGET_LENGTH,
            do_sample = True,
            top_p = 0.95,
            temperature = temp,
        )

    generated_tokens = outputs[:, inputs['input_ids'].shape[1]:]
    generated_prompts = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
    return generated_prompts

def create_dynamic_dpo_dataset(model,dataset_subset,verbose = False):
    model.eval()
    device = model.device

    questions = dataset_subset['question']
    answers = dataset_subset['answer']

    prompts = [format_prompt(q) for q in questions]
    print('Generating initial (rejected) answers)...')
    rejected_answers = []
    for i in tqdm(range(0,len(prompts),GENERATION_BATCH_SIZE)):
        batch = prompts[i:i+GENERATION_BATCH_SIZE]
        rejected_answers += batch_generate(model,device,batch)

    print('Generating augmented prompts...')
    augmented_prompts = []
    rejected_correctness = []

    total_rejected_correct = 0
    for q,r,a in tqdm(zip(questions,rejected_answers,answers)):
        hints,correct = get_hints_from_oracle(q,r,a)
        rejected_correctness.append(correct)
        if correct:
            total_rejected_correct += 1
        augmented_prompts.append(format_prompt(q,hints))

    print('Generating augmented (chosen) answers...')
    chosen_answers = []
    for i in tqdm(range(0,len(augmented_prompts),GENERATION_BATCH_SIZE)):
        batch = augmented_prompts[i:i+GENERATION_BATCH_SIZE]
        chosen_answers += batch_generate(model,device,batch,temp = 0.2)

    false_rejected_but_true_chosen_list = []
    total_false_rejected_but_true_chosen = 0
    total_false_rejected = 0
    print('Verifying answers...')
    for rc,ca,a in tqdm(zip(rejected_correctness,chosen_answers,answers)):
        if not rc:
            total_false_rejected+=1
            is_correct = match_answers(ca,a)
            if is_correct:
                total_false_rejected_but_true_chosen+=1
            false_rejected_but_true_chosen_list.append(is_correct)

        else:
            false_rejected_but_true_chosen_list.append(False)



    preference_data = []
    for p,c,r,frbtc in zip(prompts,chosen_answers,rejected_answers,false_rejected_but_true_chosen_list):
        if frbtc:
            preference_data.append({'prompt':p,'chosen':c,'rejected':r})
    if len(preference_data) == 0:
        print("NO PAIRS WERE CREATED")
    elif verbose:
        print(f'CREATED {len(preference_data)} pairs')
        sample = preference_data[0]
        print('+++++++SAMPLE++++++++')
        print(f"Prompt: {sample['prompt']}")
        print('=====================')
        print(f"Chosen: {sample['chosen']}")
        print('=====================')
        print(f"Rejected: {sample['rejected']}")
        print('++++++++++++++++++++++')
        print(f'Training_accuracy = {total_rejected_correct/len(questions)}')
        print(f'Accepted_accuracy = {total_false_rejected_but_true_chosen/total_false_rejected}')

    return preference_data







In [None]:
eval_dataset = test_dataset.shuffle(seed = 42).select(range(EVAL_DATASET_SIZE))
eval_questions = [format_prompt(q) for q in eval_dataset['question']]
eval_answers = [a for a in eval_dataset['answer']]

In [None]:
def evaluate_model(model,verbose = False):
    model.eval()
    device = model.device
    correct = 0
    total = 0
    print(f'Evaluating model on {len(eval_questions)} questions')

    model_answers = []
    for i in tqdm(range(0,len(eval_questions),GENERATION_BATCH_SIZE)):
        batch = eval_questions[i:i+GENERATION_BATCH_SIZE]
        model_answers += batch_generate(model,device,batch, temp = 0.01)

    counter = 0
    for q,ma,a in zip(eval_questions,model_answers,eval_answers):
        is_correct = match_answers(ma,a)
        if verbose and counter < 2:
            counter += 1
            print(f'++++++++QUESTION {total+1}+++++++++/n')
            print(f'Question: {q}')
            print('=========')
            print(f'Generated answer: {ma}')
            print('=========')
            print(f'Correct answer: {a}')
            print('=========')
            print(f'Is Correct?: {is_correct}')
            print('=========')
        if is_correct:
            correct += 1
        total += 1

    return correct/total





In [None]:
print('\nStarting iterative DPO training...')

eval_accuracy = evaluate_model(model,verbose = True)
print('BASELINE ACCURACY')
print(f'Evaluation accuracy: {eval_accuracy}')
for iteration in range(NUM_ITERATIONS):
    print(f'\nStarting iteration {iteration+1}/{NUM_ITERATIONS}')

    dataset_subset = dataset.shuffle(seed = 42 + iteration).select(range(min(SAMPLES_PER_ITERATION,len(dataset))))

    tokenizer.padding_side = 'left'
    preference_dataset_list = create_dynamic_dpo_dataset(model,dataset_subset,verbose = True)

    preference_df = pd.DataFrame(preference_dataset_list)
    train_dataset = HFDataset.from_pandas(preference_df)

    print(f"Generated {len(preference_df)} preference pairs")

    tokenizer.padding_side = 'right'
    training_args = DPOConfig(
        output_dir = f'./idpo_output_iteration{iteration+1}',
        learning_rate = LEARNING_RATE,
        per_device_train_batch_size = BATCH_SIZE,
        gradient_accumulation_steps = GRADIENT_ACCUMULATION_STEPS,
        num_train_epochs = DPO_EPOCHS,
        logging_steps = 10,
        remove_unused_columns = False, #important for dpo
        optim = 'adamw_torch',
        save_strategy = 'no',
        lr_scheduler_type = 'cosine', #lr_scheduler_type instead of lr_scheduler
        warmup_ratio = 0.1,
        beta = DPO_BETA,
    )

    dpo_trainer = DPOTrainer(
        model = model,
        ref_model = None,
        args = training_args,
        # peft_config = lora_config,
        train_dataset = train_dataset,
        processing_class = tokenizer,
        # max_length = MAX_LENGTH,
        # max_prompt_length = MAX_PROMPT_LENGTH,
        # max_target_length = MAX_TARGET_LENGTH,
    )

    dpo_trainer.train()

    model = dpo_trainer.model

    tokenizer.padding_side = 'left'
    eval_accuracy = evaluate_model(model,verbose = True)
    print(f'Evaluation accuracy: {eval_accuracy}')

    del dpo_trainer, train_dataset, preference_df, preference_dataset_list
    torch.cuda.empty_cache()
    gc.collect()
