# Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')
!pip install unsloth
!pip install transformers peft accelerate safetensors
import sys
import importlib
sys.path.append('/content/drive/MyDrive/DPO/DPO on Colab')
import med_dpo_loss

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import (
    prepare_model_for_kbit_training,
    PeftModel,
    LoraConfig
)
import torch
from torch.nn import Linear
import json
from torch.utils.data import Dataset, DataLoader

print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))

# Define Model

In [None]:


# 1) Build your 4-bit config with FP16 compute
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16   # matmuls in FP16
)

# 2) Load the base model & tokenizer
base = AutoModelForCausalLM.from_pretrained(
    "unsloth/gemma-3-1b-it-unsloth-bnb-4bit",
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(
    "unsloth/gemma-3-1b-it-unsloth-bnb-4bit", use_fast=True
)

# 3) Prepare for k-bit training (unfreezes the quant wrappers)
model = prepare_model_for_kbit_training(base)

# 4) Load your *pre-trained* LoRA adapter from disk
model = PeftModel.from_pretrained(
    model,
    "/content/drive/MyDrive/DPO/DPO on Colab/lora_adapter",
    torch_dtype=torch.float16
)

# 5) Define & attach a *new* LoRA adapter on top
new_lora_cfg = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj","k_proj","v_proj","o_proj",
                    "gate_proj","up_proj","down_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)
# add the adapter; give it a name so we can refer to it later :contentReference[oaicite:1]{index=1}
model.add_adapter("new_task", new_lora_cfg)

# 6) Freeze *everything* except the new adapter’s weights
for n, p in model.named_parameters():
    p.requires_grad = "new_task" in n

# 7) Hook every floating-point Linear so its inputs match its weight dtype
def cast_to_weight_dtype(module, inputs):
    x = inputs[0]
    wd = module.weight.dtype
    if wd.is_floating_point and x.dtype != wd:
        x = x.to(wd)
    return (x,)

for m in model.modules():
    if isinstance(m, Linear) and m.weight.dtype.is_floating_point:
        m.register_forward_pre_hook(cast_to_weight_dtype)

model.train()

# spot-check one new-adapter gradient
# for n, p in model.named_parameters():
#     if p.requires_grad:
#         print(n, "grad mean:", p.grad.abs().mean().item())
#         break


#Test Implementation, forward pass, loss, backwards pass

In [None]:
# 8) Count total vs trainable params
total   = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total params:     {total:,}")
print(f"Trainable params: {trainable:,} ({100*trainable/total:.4f}%)")

# 9) Dummy forward + custom loss + backward to confirm gradients flow
prompt  = "Test gradient flow"
inputs  = tokenizer(prompt, return_tensors="pt").to(model.device)
logits  = model(**inputs).logits  # [B, L, V] in FP16
shifted = logits[..., :-1, :].reshape(-1, logits.size(-1))
labels  = inputs["input_ids"][..., 1:].reshape(-1)

loss = torch.nn.functional.cross_entropy(shifted, labels)
print("loss.requires_grad?", loss.requires_grad)  # should be True

loss.backward()
print("Backward Success")

#Load Dataset

In [4]:


class JSONLDataset(Dataset):
    def __init__(self, filepath):
        self.samples = []
        with open(filepath, 'r') as f:
            for line in f:
                self.samples.append(json.loads(line))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]

        '''Only if custom implementation works'''
        sample = self.samples[idx]
        prompt = sample['prompt']
        chosen = prompt + sample['chosen_response']
        rejected = prompt + sample['rejected_response']

        score_keys = ['accuracy', 'safety', 'explanation_depth']
        chosen_scores = torch.tensor([sample['chosen_scores'][k] for k in score_keys], dtype=torch.float)
        rejected_scores = torch.tensor([sample['rejected_scores'][k] for k in score_keys], dtype=torch.float)
        return {
            'prompt_chosen_response': chosen,
            'prompt_rejected_response': rejected,
            'chosen_scores': chosen_scores,
            'rejected_scores': rejected_scores,
        }

In [6]:

train_data = JSONLDataset('/content/drive/MyDrive/DPO/DPO on Colab/gemma3_dpo_scored_data.jsonl')
train_dataloader = DataLoader(train_data, batch_size = 1, shuffle = True)

for sample in train_dataloader:
  print(sample)
  break

{'prompt_chosen_response': ['Question: A 44-year-old man comes to the physician for a follow-up examination. Ten months ago, he was diagnosed with HIV infection and appropriate antiretroviral therapy was initiated. Physical examination shows no abnormalities. Laboratory studies show increased viral load despite ongoing treatment. His pharmacotherapy is switched to a new combination drug regimen including an agent that binds to glycoprotein 41. The expected effect of this drug is most likely due to inhibition of which of the following?\n\nOptions:\nA. Viral particle assembly\nB. Viral docking and attachment to host cells\nC. Viral genome transcription\nD. Viral fusion and entry into host cells\n\nChoose the best answer and provide a step-by-step explanation for your choice.D. Viral fusion and entry into host cells\nExplanation: The agent that binds to glycoprotein 41 inhibits viral fusion and entry into host cells, preventing HIV from successfully infecting new cells, which is crucial i

In [12]:
ex = train_data[0]

print(ex)

{'prompt_chosen_response': "Question: A 5-year-old girl is brought to the physician because of a 2-day history of redness and foreign body sensation in both eyes. She has not had vision loss. Her mother reports that she has also had violent coughing spells followed by a high-pitched inspiratory sound during this time. For the past week, she has had low-grade fevers and a runny nose. Her only vaccinations were received at birth. Her temperature is 37.7°C (99.9°F). Examination shows conjunctival hemorrhage and petechiae. Oropharyngeal examination shows no abnormalities. Which of the following is the most appropriate pharmacotherapy?\n\nOptions:\nA. Topical azithromycin\nB. Oral azithromycin\nC. Artificial tears\nD. Topical tobramycin\n\nChoose the best answer and provide a step-by-step explanation for your choice.B. Oral azithromycin\nExplanation: The child's symptoms, including conjunctival hemorrhage and cough with stridor, suggest a viral infection like measles, for which oral azithro

#Training Loop

In [None]:

importlib.reload(med_dpo_loss)
from med_dpo_loss import MedDPOLoss
from tqdm.auto import tqdm

dtype = next(model.parameters()).dtype
device = model.device

print(device)
#print(dtype)


def train(model, tokenizer, dataloader, optimizer, epochs = 1):
    model.train()

    loss_fn = MedDPOLoss()

    losses_50 = []

    total_loss = 0.0
    i = 0

    for epoch in range(epochs):
      loop = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}", leave=False)
      for batch in dataloader:
        i+=1
        if i % 50 == 0:
          print(f"through {i} pairs: loss = {total_loss / i}")
          losses_50.append(total_loss / i)
        optimizer.zero_grad()

        chosen_inputs = tokenizer(batch['prompt_chosen_response'], return_tensors='pt',
                                  padding = True,
                                  truncation = True,
                                  max_length = 2048).to(device = 'cuda')
        chosen_inputs = {key: value.to(device='cuda', dtype=torch.long) for key, value in chosen_inputs.items()}
        chosen_rewards = batch['chosen_scores']

        rejected_inputs = tokenizer(batch['prompt_rejected_response'], return_tensors='pt',
                                    padding = True,
                                    truncation = True,
                                    max_length = 2048).to(device = 'cuda')
        rejected_rewards = batch['rejected_scores']

        chosen_outputs = model(**chosen_inputs)
        #print(chosen_outputs.logits)

        rejected_outputs = model(**rejected_inputs)
        #print(rejected_outputs.logits)

        chosen_logits = chosen_outputs.logits
        rejected_logits = rejected_outputs.logits

        chosen_rewards = chosen_rewards.to(device)
        rejected_rewards = rejected_rewards.to(device)

        per_examples_loss = loss_fn(chosen_logits, rejected_logits,
                       chosen_rewards, rejected_rewards)

        loss = per_examples_loss.mean()


        loss.backward()
        optimizer.step()

        loop.set_postfix(batch=i, loss=loss.item(), refresh=False)

        total_loss += loss.item()
      losses_50.append(total_loss / i)

optimizer = torch.optim.AdamW(model.parameters())

train(model, tokenizer, train_dataloader, optimizer)

#Inference, Load Model if needed


In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel

# 1) Re–build the exact same 4-bit / FP16 base you used for training
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16,
)
base = AutoModelForCausalLM.from_pretrained(
    "unsloth/gemma-3-1b-it-unsloth-bnb-4bit",
    quantization_config=bnb_config,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True,
)

# 2) Wrap it in PEFT by loading your DPO adapter
#    If you called `model.save_adapter(..., adapter_name="new_task")`, then:
model = PeftModel.from_pretrained(
    base,
    "/content/drive/MyDrive/DPO/DPO on Colab/gemma_med_dpo_adapter",  # path where you saved it
    adapter_name="new_task",                           # the name you used
    torch_dtype=torch.float16,
)

# 3) Load (or re‐load) your tokenizer from the adapter folder,
#    so you get any special tokens / settings you used
tokenizer = AutoTokenizer.from_pretrained(
    "/content/drive/MyDrive/DPO/DPO on Colab/gemma_med_dpo_adapter",
    use_fast=True,
)

# 4) Switch to eval mode and generate
model.eval()


KeyboardInterrupt: 

#Generate Test Responses

In [None]:

import json

import torch
from tqdm import tqdm

def perform_inference(model, tokenizer, prompts, batch_size=16, max_new_tokens=128):
    model.to(device="cuda")
    model.eval()
    results = []

    for i in tqdm(range(0, len(prompts), batch_size)):
        batch_prompts = prompts[i:i + batch_size]

        # Tokenize and pad to longest sequence in batch
        inputs = tokenizer(batch_prompts, return_tensors="pt", padding=True, truncation=True).to("cuda")

        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False,  # deterministic generation (change if you want randomness)
                pad_token_id=tokenizer.pad_token_id
            )

        decoded_outputs = tokenizer.batch_decode(outputs[:, -max_new_tokens:], skip_special_tokens=True)

        for prompt, response in zip(batch_prompts, decoded_outputs):
            results.append({"prompt": prompt, "response": response})

    return results

def load_prompts_from_jsonl(file_path):
    prompts = []
    with open(file_path, 'r') as f:
        for line in f:
            line = line.strip()
            if line:
                data = json.loads(line)
                prompts.append(data['prompt'])
    return prompts

def save_results_to_json(results, output_file):
    with open(output_file, 'w') as f:
        json.dump(results, f)

model_name = "unsloth/gemma-3-1b-it-unsloth-bnb-4bit"
prompts_file = "/content/drive/MyDrive/DPO/DPO on Colab/gemma3_sft_test_results.jsonl"
output_file = "/content/drive/MyDrive/DPO/DPO on Colab/gemma3_dpo_inference_results.json"

prompts = load_prompts_from_jsonl(prompts_file)
results = perform_inference(model, tokenizer, prompts)
#save_results_to_json(results, output_file)

In [14]:

prompt = "Once upon a time"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
    out = model.generate(**inputs, max_new_tokens=50)

print(tokenizer.decode(out[0], skip_special_tokens=True))



Once upon a time, in a world filled with shimmering rivers and towering trees, lived a tiny firefly named Flicker. He was incredibly curious, and his light, a faint and flickering orange glow, was barely enough to illuminate his own little patch of forest.

The
