In [None]:
from google.colab import drive
drive.mount('/content/drive')


In [None]:
!mkdir -p /content/drive/MyDrive/Self-Rewarding-LLM
%cd /content/drive/MyDrive/Self-Rewarding-LLM


In [None]:
!mkdir -p src data checkpoints
!touch requirements.txt


In [None]:
%%writefile requirements.txt
torch>=2.1
transformers>=4.41
datasets>=2.19
peft>=0.11
tqdm
pandas
trl


In [None]:
!pip install -r requirements.txt


In [None]:
!pip install trl

In [None]:
%%writefile src/common.py
import json
import os
def read_jsonl(path):
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                yield json.loads(line)

def write_jsonl(path, rows):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        for r in rows:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")


In [None]:
%%writefile src/lora.py
from peft import LoraConfig, get_peft_model, TaskType
def add_lora(model, r=16, alpha=32, dropout=0.05):
    config = LoraConfig(
        r=r,
        lora_alpha=alpha,
        lora_dropout=dropout,
        bias="none",
        task_type=TaskType.CAUSAL_LM,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj"]
    )
    model = get_peft_model(model, config)
    model.print_trainable_parameters()
    return model


In [None]:
%%writefile src/tokenizer.py
from transformers import AutoTokenizer

def load_tokenizer(name):
    tok = AutoTokenizer.from_pretrained(name, use_fast=False)

    # Set the chat template manually
    tok.chat_template = """<|begin_of_text|>
{%- for message in messages %}
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }}
{{- message['content'] + '<|eot_id|>' }}
{%- endfor %}
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
"""
    tok.eos_token = "<|eot_id|>"
    tok.pad_token = tok.eos_token
    tok.padding_side = "right"

    return tok

def to_chat(tokenizer, system, user, assistant=None):
    messages = [{"role": "system", "content": system},
                {"role": "user", "content": user}]
    if assistant is not None:
        messages.append({"role": "assistant", "content": assistant})
    return tokenizer.apply_chat_template(messages, tokenize=False)



In [None]:
!pip install -q datasets


In [None]:
from datasets import load_dataset
import pandas as pd

# Load full dataset (you can later filter for quality)
dataset = load_dataset("OpenAssistant/oasst1", split="train")

# Filter to only assistant replies with a user parent
filtered = dataset.filter(lambda x: x["role"] == "assistant" and x["parent_id"] is not None)

# Build prompt/response pairs
id_to_text = {row["message_id"]: row["text"] for row in dataset}

examples = []
for row in filtered:
    parent = id_to_text.get(row["parent_id"], None)
    if parent:
        examples.append({
            "prompt": parent,
            "response": row["text"]
        })

# Save a sample to JSONL
df = pd.DataFrame(examples[:2000])  # you can raise to e.g. 10_000+
df.to_json("data/sft_openassistant.jsonl", orient="records", lines=True)


In [None]:
%%writefile src/sft_train.py
import argparse
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from datasets import load_dataset
from transformers import AutoModelForCausalLM, get_linear_schedule_with_warmup
from tqdm import tqdm

import sys
import os
sys.path.append("/content/drive/MyDrive/Self-Rewarding-LLM")


from src.tokenizer import load_tokenizer, to_chat
from src.lora import add_lora



SYSTEM_PROMPT = (
    'Respond to the following user query in a comprehensive and detailed way. '
    'But first write down your internal thoughts. This must include your draft response '
    'and its evaluation. After this, write your final response after "<R>".'
)

def collate(tokenizer, batch, max_length=1024):
    texts = []
    for x in batch:
        text = to_chat(tokenizer, SYSTEM_PROMPT, x["prompt"], x["response"])
        texts.append(text)
    enc = tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
    )
    enc["labels"] = enc["input_ids"].clone()
    return enc

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('--base_model', default="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
    ap.add_argument("--dataset", required=True)
    ap.add_argument("--out", default="checkpoints/sft")
    ap.add_argument("--batch_size", type=int, default=1)
    ap.add_argument("--lr", type=float, default=2e-4)
    ap.add_argument("--epochs", type=int, default=1)
    ap.add_argument("--max_length", type=int, default=1024)
    args = ap.parse_args()

    device = "cuda" if torch.cuda.is_available() else "cpu"

    tokenizer = load_tokenizer(args.base_model)
    model = AutoModelForCausalLM.from_pretrained(
        args.base_model,
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        device_map="auto"
    )
    model = add_lora(model)
    model.train()

    ds = load_dataset("json", data_files=args.dataset)["train"]
    loader = DataLoader(
        ds,
        batch_size=args.batch_size,
        shuffle=True,
        collate_fn=lambda b: collate(tokenizer, b, args.max_length)
    )

    optimizer = AdamW(model.parameters(), lr=args.lr)
    num_training_steps = len(loader) * args.epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=min(50, num_training_steps // 10),
        num_training_steps=num_training_steps
    )

    pbar = tqdm(range(num_training_steps))
    for epoch in range(args.epochs):
        for batch in loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss
            loss.backward()

            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

            pbar.set_description(f"epoch {epoch} loss {loss.item():.4f}")
            pbar.update(1)

    model.save_pretrained(args.out)
    tokenizer.save_pretrained(args.out)

if __name__ == "__main__":
    main()



In [None]:
from huggingface_hub import login

login()


In [None]:
!python src/sft_train.py \
  --base_model TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
  --dataset data/sft_openassistant.jsonl \
  --out checkpoints/sft



In [None]:
%%writefile src/generate_completions.py

import sys
import os
import json
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM

# Setup Project Path
project_root = '/content/drive/MyDrive/Self-Rewarding-LLM'
if project_root not in sys.path:
    sys.path.append(project_root)

from src.tokenizer import to_chat

# Configuration
BASE_MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
MODEL_PATH = os.path.join(project_root, "checkpoints/sft")
DATA_PATH = os.path.join(project_root, "data/sft_openassistant.jsonl")
OUTPUT_PATH = os.path.join(project_root, "data/sft_completions_100.jsonl") # Changed output file name
MAX_NEW_TOKENS = 256
MAX_LENGTH = 2048
SYSTEM_PROMPT = "Respond to the following user query in a helpful way."


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

print("Loading model and tokenizer...")

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    torch_dtype=torch.float16,
    device_map="auto"
)
print("Model and tokenizer loaded successfully.")

try:
    with open(DATA_PATH, "r", encoding="utf-8") as f:
        examples = [json.loads(line.strip()) for line in f if line.strip()]

    examples = examples[:100]

    print(f"Loaded and reduced to {len(examples)} prompts.")

except FileNotFoundError:
    print(f"[ERROR] Data file not found at {DATA_PATH}. Please check the path.")
    sys.exit(1)

results = []
with torch.no_grad():
    for i, ex in enumerate(tqdm(examples, desc="Generating Completions")):
        try:
            prompt = ex.get("prompt")
            if not prompt:
                continue

            chat_input = to_chat(tokenizer, SYSTEM_PROMPT, prompt)
            inputs = tokenizer(chat_input, return_tensors="pt", truncation=True, max_length=MAX_LENGTH)
            inputs = {k: v.to(device) for k, v in inputs.items()}

            if inputs["input_ids"].shape[1] == 0:
                continue

            output_ids = model.generate(
                **inputs,
                max_new_tokens=MAX_NEW_TOKENS,
                do_sample=True,
                temperature=0.7,
                top_p=0.9,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )

            input_length = inputs["input_ids"].shape[1]
            generated_ids = output_ids[0][input_length:]
            decoded = tokenizer.decode(generated_ids, skip_special_tokens=True)

            results.append({
                "prompt": prompt,
                "completion": decoded.strip()
            })

        except Exception as e:
            print(f"[ERROR] Generation failed for example #{i}: {e}")

with open(OUTPUT_PATH, "w", encoding="utf-8") as f:
    for item in results:
        f.write(json.dumps(item, ensure_ascii=False) + "\n")

print(f"✅ Saved {len(results)} completions to {OUTPUT_PATH}")

In [None]:
!python src/generate_completions.py


In [None]:
%%writefile /content/drive/MyDrive/Self-Rewarding-LLM/src/score_completions.py
import argparse
import json
import os
from typing import List, Dict

import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification


def read_jsonl(path: str) -> List[Dict]:
    rows = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                rows.append(json.loads(line))
    return rows


def write_jsonl(path: str, rows: List[Dict]):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        for r in rows:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")


def build_pair_text(prompt: str, completion: str) -> str:
    # Simple, RM-friendly formatting
    return f"Human: {prompt}\n\nAssistant: {completion}"


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--in_file", required=True, help="JSONL with {'prompt','completion'} records")
    parser.add_argument("--out_file", required=True, help="Where to save the scored JSONL")
    parser.add_argument("--rm_model", default="OpenAssistant/reward-model-deberta-v3-large-v2",
                        help="Reward model name (HF hub path)")
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--max_length", type=int, default=1024)
    args = parser.parse_args()

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    print(f"Loading reward model: {args.rm_model}")
    tokenizer = AutoTokenizer.from_pretrained(args.rm_model)
    model = AutoModelForSequenceClassification.from_pretrained(args.rm_model).to(device)
    model.eval()

    print(f"Reading: {args.in_file}")
    rows = read_jsonl(args.in_file)

    texts = [build_pair_text(r["prompt"], r["completion"]) for r in rows]

    all_scores: List[float] = []
    with torch.no_grad():
        for i in tqdm(range(0, len(texts), args.batch_size), desc="Scoring"):
            batch_texts = texts[i:i + args.batch_size]
            enc = tokenizer(
                batch_texts,
                padding=True,
                truncation=True,
                max_length=args.max_length,
                return_tensors="pt"
            ).to(device)

            logits = model(**enc).logits  # shape: (B, 1) for OA RM (regression)
            if logits.shape[-1] == 1:
                scores = logits.squeeze(-1).tolist()
            else:
                # Fallback if it's a classifier with >1 labels: take the first logit or softmax
                probs = torch.softmax(logits, dim=-1)
                # Take "helpful"/"good" class prob assuming label 1 if binary
                scores = probs[:, -1].tolist()

            all_scores.extend(scores)

    assert len(all_scores) == len(rows)

    # Attach scores & save
    for r, s in zip(rows, all_scores):
        r["score"] = float(s)

    write_jsonl(args.out_file, rows)
    print(f"✅ Wrote {len(rows)} scored rows to {args.out_file}")


if __name__ == "__main__":
    main()


In [None]:
!python /content/drive/MyDrive/Self-Rewarding-LLM/src/score_completions.py \
  --in_file /content/drive/MyDrive/Self-Rewarding-LLM/data/sft_completions_100.jsonl \
  --out_file /content/drive/MyDrive/Self-Rewarding-LLM/data/sft_completions_scored.jsonl \
  --rm_model OpenAssistant/reward-model-deberta-v3-large-v2 \
  --batch_size 8


In [None]:
%%writefile /content/drive/MyDrive/Self-Rewarding-LLM/src/build_preferences.py
import json
import argparse
import os
from collections import defaultdict
from typing import List, Dict


def read_jsonl(path: str) -> List[Dict]:
    with open(path, "r", encoding="utf-8") as f:
        return [json.loads(l) for l in f if l.strip()]


def write_jsonl(path: str, rows: List[Dict]):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        for row in rows:
            f.write(json.dumps(row, ensure_ascii=False) + "\n")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--in_file", required=True, help="Input JSONL with prompt, completion, score")
    parser.add_argument("--out_file", required=True, help="Output preference JSONL")
    args = parser.parse_args()

    rows = read_jsonl(args.in_file)

    # Group by prompt
    prompt_map = defaultdict(list)
    for row in rows:
        prompt_map[row["prompt"]].append(row)

    pref_rows = []
    for prompt, completions in prompt_map.items():
        if len(completions) < 2:
            continue

        # Sort by score descending
        sorted_comps = sorted(completions, key=lambda x: x["score"], reverse=True)
        chosen = sorted_comps[0]["completion"]
        rejected = sorted_comps[-1]["completion"]

        if chosen.strip() == rejected.strip():
            continue

        pref_rows.append({
            "prompt": prompt,
            "chosen": chosen,
            "rejected": rejected,
        })

    write_jsonl(args.out_file, pref_rows)
    print(f"✅ Wrote {len(pref_rows)} preference pairs to {args.out_file}")


if __name__ == "__main__":
    main()


In [None]:
!python /content/drive/MyDrive/Self-Rewarding-LLM/src/build_preferences.py \
  --in_file /content/drive/MyDrive/Self-Rewarding-LLM/data/sft_completions_scored.jsonl \
  --out_file /content/drive/MyDrive/Self-Rewarding-LLM/data/dpo_prefs.jsonl


In [None]:
%%writefile /content/drive/MyDrive/Self-Rewarding-LLM/src/dpo_train_manual.py
import argparse
import json
import math
import os
from typing import Dict

import torch
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from tqdm import tqdm

from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup

import sys
import os

project_root = '/content/drive/MyDrive/Self-Rewarding-LLM'
if project_root not in sys.path:
    sys.path.append(project_root)


from src.lora import add_lora
from src.tokenizer import to_chat

# Define the base model name for reliability
BASE_MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

SYSTEM_PROMPT = (
    'Respond to the following user query in a comprehensive and detailed way. '
    'But first write down your internal thoughts. This must include your draft response '
    'and its evaluation. After this, write your final response after "<R>".'
)

class PreferenceDataset(Dataset):
    def __init__(self, path: str):
        self.rows = []
        with open(path, "r", encoding="utf-8") as f:
            for line in f:
                if not line.strip():
                    continue
                ex = json.loads(line)
                if "prompt" in ex and "chosen" in ex and "rejected" in ex:
                    self.rows.append(ex)

    def __len__(self):
        return len(self.rows)

    def __getitem__(self, idx):
        return self.rows[idx]


def total_logprob(model, tokenizer, prompt, answer, device, max_length=1024):
    text = to_chat(tokenizer, SYSTEM_PROMPT, prompt, answer)
    enc = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
    labels = enc["input_ids"].clone()

    # Add a check to prevent empty labels
    if labels.shape[1] == 0:
        return torch.tensor(0.0)

    with torch.no_grad():
        out = model(**enc, labels=labels)

    n_tok = labels.numel()
    return -out.loss.item() * n_tok


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--ref_model", required=True)
    parser.add_argument("--policy_model", required=True)
    parser.add_argument("--prefs", required=True)
    parser.add_argument("--out", required=True)
    parser.add_argument("--beta", type=float, default=0.1)
    parser.add_argument("--epochs", type=int, default=1)
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--lr", type=float, default=5e-5)
    parser.add_argument("--max_length", type=int, default=1024)
    parser.add_argument("--fp16", action="store_true")
    args = parser.parse_args()

    device = "cuda" if torch.cuda.is_available() else "cpu"

    # FIX: Load the tokenizer from the original base model
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME, use_fast=False)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

    dtype = torch.float16 if args.fp16 and torch.cuda.is_available() else torch.float32

    ref_model = AutoModelForCausalLM.from_pretrained(args.ref_model, torch_dtype=dtype, device_map="auto")
    ref_model.eval()

    policy_model = AutoModelForCausalLM.from_pretrained(args.policy_model, torch_dtype=dtype, device_map="auto")
    policy_model = add_lora(policy_model)
    policy_model.train()

    ds = PreferenceDataset(args.prefs)
    dl = DataLoader(ds, batch_size=args.batch_size, shuffle=True)

    optimizer = AdamW(policy_model.parameters(), lr=args.lr)
    total_steps = len(dl) * args.epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=min(100, total_steps // 10), num_training_steps=total_steps)

    pbar = tqdm(total=total_steps)
    beta = args.beta

    for epoch in range(args.epochs):
        for batch in dl:
            prompt = batch["prompt"][0]
            chosen = batch["chosen"][0]
            rejected = batch["rejected"][0]

            # Use a separate forward pass for the policy model to compute gradients
            policy_model.train()

            # Policy model forward pass for chosen response
            chosen_text = to_chat(tokenizer, SYSTEM_PROMPT, prompt, chosen)
            chosen_enc = tokenizer(chosen_text, return_tensors="pt", truncation=True, max_length=args.max_length).to(device)
            chosen_labels = chosen_enc["input_ids"].clone()

            # Policy model forward pass for rejected response
            rejected_text = to_chat(tokenizer, SYSTEM_PROMPT, prompt, rejected)
            rejected_enc = tokenizer(rejected_text, return_tensors="pt", truncation=True, max_length=args.max_length).to(device)
            rejected_labels = rejected_enc["input_ids"].clone()

            # Reference model log probabilities (no gradients)
            lp_ref_c = total_logprob(ref_model, tokenizer, prompt, chosen, device, args.max_length)
            lp_ref_r = total_logprob(ref_model, tokenizer, prompt, rejected, device, args.max_length)

            # Policy model log probabilities (with gradients)
            policy_chosen_out = policy_model(**chosen_enc, labels=chosen_labels)
            policy_rejected_out = policy_model(**rejected_enc, labels=rejected_labels)

            lp_pi_c = -policy_chosen_out.loss * chosen_labels.numel()
            lp_pi_r = -policy_rejected_out.loss * rejected_labels.numel()

            # DPO loss calculation
            diff_pi = lp_pi_c - lp_pi_r
            diff_ref = lp_ref_c - lp_ref_r
            adv = diff_pi - diff_ref

            loss = -torch.nn.functional.logsigmoid(beta * adv).mean()
            loss.backward()

            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

            pbar.set_description(f"epoch {epoch} dpo_loss {loss.item():.4f}")
            pbar.update(1)

    os.makedirs(args.out, exist_ok=True)
    policy_model.save_pretrained(args.out)
    tokenizer.save_pretrained(args.out)
    print(f"✅ Saved DPO policy to {args.out}")


if __name__ == "__main__":
    main()

In [None]:
!python /content/drive/MyDrive/Self-Rewarding-LLM/src/dpo_train_manual.py \
  --ref_model /content/drive/MyDrive/Self-Rewarding-LLM/checkpoints/sft \
  --policy_model /content/drive/MyDrive/Self-Rewarding-LLM/checkpoints/sft \
  --prefs /content/drive/MyDrive/Self-Rewarding-LLM/data/dpo_prefs.jsonl \
  --out /content/drive/MyDrive/Self-Rewarding-LLM/checkpoints/dpo \
  --epochs 3 \
  --beta 0.1 \
  --lr 5e-5


**TEST THE MODEL**

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import sys
import os

# --- Setup Paths ---
project_root = '/content/drive/MyDrive/Self-Rewarding-LLM'
if project_root not in sys.path:
    sys.path.append(project_root)
from src.tokenizer import to_chat

# --- Configuration ---
BASE_MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
DPO_MODEL_PATH = os.path.join(project_root, "checkpoints/dpo") # Path to your final model
SYSTEM_PROMPT = "Respond to the following user query in a helpful way."

# --- GPU / Device Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- Load Final DPO Model and Tokenizer ---
print("Loading final DPO model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    DPO_MODEL_PATH,
    torch_dtype=torch.float16,
    device_map="auto"
)
print("Model loaded successfully.")

# --- Run Inference ---
# You can change this prompt to ask the model anything you want
prompt = "What are the main advantages of using a self-rewarding language model?"

# Format and tokenize the prompt
chat_input = to_chat(tokenizer, SYSTEM_PROMPT, prompt)
inputs = tokenizer(chat_input, return_tensors="pt").to(model.device)

# Generate a response
print("\nGenerating response...")
with torch.no_grad():
    output_ids = model.generate(
        **inputs,
        max_new_tokens=256,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
    )

# Decode and print the result
input_length = inputs["input_ids"].shape[1]
generated_ids = output_ids[0][input_length:]
response = tokenizer.decode(generated_ids, skip_special_tokens=True)

print("-" * 30)
print(f"Prompt:\n{prompt}")
print("-" * 30)
print(f"Generated Response:\n{response.strip()}")
print("-" * 30)

**Compare the Model before and after the training**

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import sys
import os
import json
import google.generativeai as genai

# --- 1. CONFIGURE YOUR GEMINI API KEY ---
# You can get a key from Google AI Studio.
try:
    from google.colab import userdata
    # Attempt to load the key from Colab secrets
    API_KEY = userdata.get('GOOGLE_API_KEY')
    print("Gemini API Key loaded from Colab secrets.")
except (ImportError, KeyError):
    # Fallback for local execution or if secret not set
    API_KEY = ""
    if API_KEY == "YOUR_API_KEY":
        print("⚠️ Please paste your Gemini API Key into the script.")
    else:
        print("Gemini API Key loaded from script variable.")

genai.configure(api_key=API_KEY)
gemini_model = genai.GenerativeModel('gemini-2.5-flash')


# --- Setup Project Paths ---
project_root = '/content/drive/MyDrive/Self-Rewarding-LLM'
if project_root not in sys.path:
    sys.path.append(project_root)
from src.tokenizer import to_chat

# --- Configuration ---
BASE_MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
SFT_MODEL_PATH = os.path.join(project_root, "checkpoints/sft")
DPO_MODEL_PATH = os.path.join(project_root, "checkpoints/dpo")

# --- GPU / Device Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- Load Tokenizer & Models ---
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print("\nLoading SFT Model (Before DPO)...")
sft_model = AutoModelForCausalLM.from_pretrained(SFT_MODEL_PATH, torch_dtype=torch.float16, device_map="auto")
print("SFT Model loaded.")

print("\nLoading DPO Model (After DPO)...")
dpo_model = AutoModelForCausalLM.from_pretrained(DPO_MODEL_PATH, torch_dtype=torch.float16, device_map="auto")
print("DPO Model loaded.")


def generate_response(model, tokenizer, prompt):
    """Helper function to generate a response from a model."""
    chat_input = to_chat(tokenizer, "Respond to the following user query in a helpful way.", prompt)
    inputs = tokenizer(chat_input, return_tensors="pt").to(model.device)
    with torch.no_grad():
        output_ids = model.generate(
            **inputs, max_new_tokens=256, do_sample=True, temperature=0.7, top_p=0.9,
        )
    input_length = inputs["input_ids"].shape[1]
    generated_ids = output_ids[0][input_length:]
    return tokenizer.decode(generated_ids, skip_special_tokens=True).strip()


def get_gemini_judgment(prompt, response_a, response_b):
    """Asks Gemini to judge which of two responses is better."""
    judge_prompt_template = f"""
You are an impartial and expert AI assistant evaluator. Your task is to compare two responses to a user's prompt and determine which one is better based on helpfulness, accuracy, and structure.

**[USER'S ORIGINAL PROMPT]**
{prompt}

---

**[RESPONSE A - from SFT Model]**
{response_a}

---

**[RESPONSE B - from DPO Model]**
{response_b}

---

**[INSTRUCTIONS]**
Compare Response A and Response B. Which response is better?

Please provide your final verdict in a JSON format with two keys:
1. "winner": A string, either "Response A", "Response B", or "Tie".
2. "justification": A concise, one or two-sentence explanation for your choice.

Do not add any other text outside of the JSON block.
"""
    try:
        print("Sending request to Gemini...")
        response = gemini_model.generate_content(judge_prompt_template)
        # Clean up the response to extract only the JSON part
        json_str = response.text.strip().replace("```json", "").replace("```", "").strip()
        judgment = json.loads(json_str)
        return judgment.get("winner", "Unknown"), judgment.get("justification", "No justification provided.")
    except Exception as e:
        print(f"An error occurred while getting judgment from Gemini: {e}")
        return "Error", str(e)


# --- Run Comparison ---
test_prompt = "Explain the plot of the movie Inception as if you were talking to a five-year-old. Focus on the main idea of dreams within dreams without getting lost in the details."

print("\n" + "="*50)
print(f"PROMPT: {test_prompt}")
print("="*50 + "\n")

# Generate from SFT model
sft_response = generate_response(sft_model, tokenizer, test_prompt)
print(f"--- RESPONSE from SFT Model (Before DPO) ---\n{sft_response}\n")

# Generate from DPO model
dpo_response = generate_response(dpo_model, tokenizer, test_prompt)
print(f"--- RESPONSE from DPO Model (After DPO) ---\n{dpo_response}\n")

# --- Get Gemini's Judgment ---
if sft_response and dpo_response:
    print("\n" + "="*50)
    print("⚖️  ASKING GEMINI FOR ITS JUDGMENT...")
    print("="*50)
    winner, justification = get_gemini_judgment(test_prompt, sft_response, dpo_response)
    print(f"\nGemini's Verdict: **{winner}**")
    print(f"Justification: {justification}")
else:
    print("Skipping Gemini judgment because one or both responses were empty.")

Output:

Gemini API Key loaded from Colab secrets.
Using device: cuda

Loading SFT Model (Before DPO)...
SFT Model loaded.

Loading DPO Model (After DPO)...
DPO Model loaded.

==================================================
PROMPT: Explain the plot of the movie Inception as if you were talking to a five-year-old. Focus on the main idea of dreams within dreams without getting lost in the details.
==================================================

--- RESPONSE from SFT Model (Before DPO) ---
<|assistant|>
Sure, I'd be happy to help you with that! Inception is a movie about a group of people who are tasked with infiltrating a dream and stealing information from it. The dream is filled with various different worlds, each with its own rules and conventions, and the characters are trying to break into these worlds in order to steal information.

The main idea of dreams within dreams is that a dream is a form of reality, and the characters in Inception are trying to break into different dreams in order to gain access to information that they need for their mission. This is similar to the idea that dreams can be a way of accessing different parts of the mind and exploring different aspects of consciousness.

Inception is a movie about the power of the mind and the ability of dreams to shape reality. It is a movie that encourages us to think about the concept of dreams and how they can be used to access different parts of the mind.

I hope that helps! Let me know if you have any other questions or need any further assistance.

--- RESPONSE from DPO Model (After DPO) ---
<|assistant|>
As a five-year-old, I would love to talk about the plot of Inception with you. Inception is a thrilling and mind-bending movie that explores the concept of dreams within dreams.

The plot follows a group of experts, led by the mysterious Cobb (Leonardo DiCaprio), who are hired by a wealthy client, Adrian (Joseph Gordon-Levitt), to create a dream within a dream. The goal is to unlock a secret hidden within the dream, which will give Adrian an advantage in a major corporate deal.

The dreams themselves are intricately designed, with each one leading to the next, and each one is filled with intricate details and twists that keep the audience on the edge of their seats. The team of experts must navigate through the dreams and uncover clues that will lead them to the secret.

As the team progresses deeper into the dreams, they start to experience hallucinations and dreams within dreams, leading to the creation of a new, more complex reality. The climax of the movie sees the team breaking free from the dream


==================================================
⚖️  ASKING GEMINI FOR ITS JUDGMENT...
==================================================
Sending request to Gemini...

Gemini's Verdict: **Response B**
Justification: Response B provides a more age-appropriate and engaging narrative summary of the plot, while Response A uses overly complex language and abstract concepts for a five-year-old.