In [None]:
%%capture
import os, importlib.util
!pip install --upgrade -qqq uv
if importlib.util.find_spec("torch") is None or "COLAB_" in "".join(os.environ.keys()):    
    try: import numpy, PIL; get_numpy = f"numpy=={numpy.__version__}"; get_pil = f"pillow=={PIL.__version__}"
    except: get_numpy = "numpy"; get_pil = "pillow"
    !uv pip install -qqq \
        "torch>=2.8.0" "triton>=3.4.0" {get_numpy} {get_pil} torchvision bitsandbytes "transformers==4.56.2" \
        "unsloth_zoo[base] @ git+https://github.com/unslothai/unsloth-zoo" \
        "unsloth[base] @ git+https://github.com/unslothai/unsloth" \
        git+https://github.com/triton-lang/triton.git@0add68262ab0a2e33b84524346cb27cbb2787356#subdirectory=python/triton_kernels
elif importlib.util.find_spec("unsloth") is None:
    !uv pip install -qqq unsloth
!uv pip install --upgrade --no-deps transformers==4.56.2 tokenizers trl==0.22.2 unsloth unsloth_zoo

In [None]:
from unsloth import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
        model_name = "Aharneish/finetuned_model-1", # YOUR MODEL YOU USED FOR TRAINING
        max_seq_length = 4098,
        dtype = None,
        load_in_4bit = True,
    )
model.train()

In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]
)


In [None]:
model.print_trainable_parameters()

In [None]:
def formatting_prompts_func(examples):
    convos = examples["messages"]
    texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
    return { "text" : texts, }

from datasets import load_dataset

In [None]:
dataset2 = load_dataset("Ashokajou51/ESConv_Original", split="train")

In [None]:
PREFERRED_STRATEGIES = {
    "Reflection of feelings",
    "Affirmation and Reassurance"
}

WEAK_STRATEGIES = {
    "Question",
    "Information",
    "Others"
}

def is_strong_empathy(text):
    text = text.lower()
    keywords = [
        "that sounds",
        "it makes sense",
        "i can understand",
        "i'm sorry",
        "it must be",
        "i hear that",
        "that must be"
    ]
    return any(k in text for k in keywords) and len(text.split()) >= 8


def build_dpo_dataset(esconv_dataset, max_pairs=2000):
    pairs = []

    for ex in esconv_dataset:
        dialog = ex.get("dialog")
        situation = ex.get("situation")

        if not dialog or not situation:
            continue

        good = []
        weak = []

        for turn in dialog:
            if turn["speaker"] != "sys":
                continue

            text = turn.get("text", "")
            strategy = turn.get("strategy")

            if len(text.split()) < 5:
                continue

            if strategy in PREFERRED_STRATEGIES and is_strong_empathy(text):
                good.append(text)
            elif strategy in WEAK_STRATEGIES:
                weak.append(text)

        if good and weak:
            pairs.append({
                "prompt": situation,
                "chosen": good[0],
                "rejected": weak[0]
            })

        if len(pairs) >= max_pairs:
            break

    return pairs

In [None]:
from datasets import Dataset

dpo_pairs = build_dpo_dataset(dataset2)
dpo_dataset = Dataset.from_list(dpo_pairs)

print("Number of DPO samples:", len(dpo_dataset))
print(dpo_dataset[0])


In [None]:
dpo_dataset = dpo_dataset.remove_columns(
    [c for c in dpo_dataset.column_names if c not in ["prompt", "chosen", "rejected"]]
)

In [None]:
dpo_dataset[0]

In [None]:
from trl import DPOConfig

dpo_config = DPOConfig(
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    learning_rate=2e-6,
    max_steps=200,          # HARD STOP
    num_train_epochs=1,  # IMPORTANT
    logging_steps=5,
    save_steps=150,
    output_dir="./dpo_out",
    fp16=True,
    remove_unused_columns=False
)


In [None]:
trainer = DPOTrainer(
    model=model,
    ref_model=None,
    args=dpo_config,
    train_dataset=dpo_dataset,
    tokenizer=tokenizer
)

In [None]:
trainer.train()

In [None]:
trainer.push_to_hub("Aharneish/finetuned_model-1-dpo",token="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")

In [None]:
SFT_REPO="Aharneish/finetuned_model-1"

In [None]:
base, tokenizer = FastLanguageModel.from_pretrained(
    model_name=SFT_REPO,
    max_seq_length=1024,
    load_in_4bit=True
)

In [None]:
from peft import PeftModel
dpo_model = PeftModel.from_pretrained(
    model,
    "Aharneish/dpo_out"
)
dpo_model.eval()

In [None]:
import torch
print("imported torch!")

In [None]:
from transformers import TextStreamer

# 1. Initialize the streamer
streamer = TextStreamer(tokenizer, skip_prompt=True)

prompt = "I feel like everything is my fault."
inputs = tokenizer(prompt, return_tensors="pt").to(dpo_model.device)

# 2. Pass the streamer into the generate function
with torch.no_grad():
    dpo_model.generate(
        **inputs,
        max_new_tokens=4098,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        streamer=streamer  # This enables the live output
    )