<a href="https://colab.research.google.com/github/Michael-L-i-1/CS231N-Final-Project/blob/main/GRPO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install hf_xet
!pip install trl
!pip install peft
!pip install flash-attn
!pip install -U bitsandbytes

In [None]:
from transformers import AutoProcessor, AutoModelForVision2Seq

import torch
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-500M-Instruct")
model = AutoModelForVision2Seq.from_pretrained("HuggingFaceTB/SmolVLM-500M-Instruct",
                                                torch_dtype=torch.bfloat16,
                                                _attn_implementation="flash_attention_2" if DEVICE == "cuda" else "eager").to(DEVICE)
model.to('cuda')

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
from PIL import Image
from transformers.image_utils import load_image
import os
from tqdm.notebook import tqdm

import json
import torch
from torch.utils.data import Dataset
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
from trl import GRPOTrainer, GRPOConfig
import copy


In [None]:
base_drive_path = '/content/drive/My Drive/CS231N Colabs/dataset'
json_file_path = os.path.join(base_drive_path, 'metadata.json')
images_folder   = os.path.join(base_drive_path, 'images')
mini_json_file_path = os.path.join(base_drive_path, 'mini_metadata.json')
mini_images_folder   = os.path.join(base_drive_path, 'mini_images')

QUESTION = (
    "Given the diagram, list the labels of the circles in order "
    "from leftmost to rightmost (provide name only)."
    "You should have all the names included."
)

MODEL_NAME = "HuggingFaceTB/SmolVLM-500M-Instruct"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# GRPO Model

(NOTE: Consider performance on truncated dataset to really see how much data we need)

In [None]:
# base model
processor = AutoProcessor.from_pretrained(MODEL_NAME, use_fast=True)

base_model = AutoModelForVision2Seq.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto"
).to(DEVICE)

# reference model
ref_model = copy.deepcopy(base_model).eval()

In [None]:
# lora confirguation for better RAM usage
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)
policy = get_peft_model(base_model, lora_config)
policy.generation_config.do_sample = True

In [None]:
# dataset class
class CircleVlmPromptSet(Dataset):
    def __init__(self, meta_json_path, processor, question, base_folder, images_folder):
        with open(meta_json_path, 'r') as f:
            self.entries = json.load(f)
        self.processor = processor
        self.question = question
        self.base_folder = base_folder
        self.images_folder = images_folder

    def __len__(self):
        return len(self.entries)

    def __getitem__(self, idx):
        entry = self.entries[idx]
        img_rel = entry["image_path"]
        img_full= os.path.join(self.images_folder, img_rel)
        image = Image.open(img_full).convert("RGB")

        # responses prompt structure
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": self.question}
                ]
            }
        ]
        prompt_text = self.processor.apply_chat_template(
            messages, add_generation_prompt=True
        )

        # tokenize prompt_text + image into input_ids, attention_mask, pixel_values
        model_inputs = self.processor(
            text=prompt_text,
            images=[image],
            return_tensors="pt",
            padding=False,
            truncation=True
        )

        # build ground-truth labels from metadata
        gold_answer = ", ".join(entry["order"])
        label_ids   = self.processor.tokenizer(
            gold_answer, return_tensors="pt"
        ).input_ids.squeeze(0)

        # return all fields
        return {
            "prompt": prompt_text,
            "input_ids": model_inputs["input_ids"].squeeze(0),
            "attention_mask": model_inputs["attention_mask"].squeeze(0),
            "pixel_values": model_inputs["pixel_values"].squeeze(0),
            "labels": label_ids,
            "labels_str": gold_answer
        }

def vlm_collate(batch):
    # batch is a list of dicts with keys
    batch_prompts = [b["prompt"] for b in batch]
    batch_labels_str = [b["labels_str"] for b in batch]

    # stack pixel_values
    pixel_values = torch.stack([b["pixel_values"] for b in batch])

    # pad text fields: input_ids, attention_mask, labels → each (B, seq_len)
    padded = processor.tokenizer.pad(
        {
            "input_ids": [b["input_ids"] for b in batch],
            "attention_mask": [b["attention_mask"] for b in batch],
            "labels": [b["labels"] for b in batch]
        },
        return_tensors="pt"
    )

    # mask padding tokens in labels (so they don’t count toward loss)
    padded["labels"][padded["labels"] == processor.tokenizer.pad_token_id] = -100

    # build the final batch dict for GRPOTrainer
    batch_dict = {
        "prompt": batch_prompts,
        "input_ids": padded["input_ids"],
        "attention_mask": padded["attention_mask"],
        "pixel_values": pixel_values,
        "labels": padded["labels"],
        "labels_str": batch_labels_str
    }
    return batch_dict


train_dataset = CircleVlmPromptSet(
    meta_json_path=json_file_path,
    processor=processor,
    question=QUESTION,
    base_folder=base_drive_path,
    images_folder=images_folder
)

mini_train_dataset = CircleVlmPromptSet(
    meta_json_path=mini_json_file_path,
    processor=processor,
    question=QUESTION,
    base_folder=base_drive_path,
    images_folder=mini_images_folder
)

In [None]:
print(mini_train_dataset.images_folder)

In [None]:
def reward_func(prompts, completions, completion_ids, **kwargs):
    """
    Reward = 100 – 10·|length mismatch| – Σ |pred_idx – gold_idx| - Σ penalty_hallucination
    higher is better
    """
    MAX_REWARD = 100
    PEN_LEN = 10
    PEN_HALLUCINATION = 15
    gold_list_all = kwargs.get("labels_str", [])
    rewards = []

    for comp, gold_str in zip(completions, gold_list_all):
        # parse lists
        trial   = [n.strip() for n in comp.split("Assistant:")[-1].split(",") if n.strip()]
        correct = [n.strip() for n in gold_str.split(",") if n.strip()]

        # base reward = 100 – length penalty
        reward  = MAX_REWARD - abs(len(correct) - len(trial)) * PEN_LEN
        if not correct:                       # degenerate case
            rewards.append(float(MAX_REWARD if not trial else max(0, reward)))
            continue

        # order penalty
        unmatched_gold = list(correct)
        gold_pos = {name: idx for idx, name in enumerate(correct)}

        for t_idx, t_name in enumerate(trial):
            if t_name in gold_pos and t_name in unmatched_gold:
                reward -= abs(t_idx - gold_pos[t_name])
                unmatched_gold.remove(t_name)
            # hallucination penalty
            elif t_name not in correct:
                 reward -= PEN_HALLUCINATION


        rewards.append(float(reward))
    return rewards

In [None]:
grpo_cfg = GRPOConfig(
    num_generations=8,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=2,
    beta=0.001,
    optim="adamw_8bit",
    bf16=True,
    max_completion_length=64,
    num_train_epochs=1,
    remove_unused_columns=False,
    report_to=["wandb"],
    logging_steps=1,
    logging_dir="SmolVLM_logs",
    output_dir="SmolVLM_output",

    temperature=1,
    top_k=5,
    top_p=0.8,
)

In [None]:
trainer = GRPOTrainer(
    model=policy,
    args=grpo_cfg,
    train_dataset=mini_train_dataset,
    reward_funcs=simple_reward_func,
)

In [None]:
import wandb

wandb.login()

In [None]:
wandb.init(project="GRPO")
trainer.train()

In [None]:
trainer.save_model("SmolVLM_finetuned/")
processor.save_pretrained("SmolVLM_finetuned/")

In [None]:
from PIL import Image
from transformers import AutoProcessor, AutoModelForVision2Seq

# load the saved model and processor
model_path = "SmolVLM_finetuned/"
processor = AutoProcessor.from_pretrained(model_path)
model = AutoModelForVision2Seq.from_pretrained(model_path)
model.eval()

In [None]:
# TESTING FINETUNED MODELx

mage_path = '/content/test.png'

image = Image.open(image_path).convert("RGB")

prompt = (
    "Given the diagram, list the labels of the circles in order "
    "from leftmost to rightmost (provide name only)."
    "You should have all the names included."
)

messages = [
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": prompt}
        ]
    }
]
prompt_text = processor.apply_chat_template(messages, add_generation_prompt=True)

inputs = processor(text=prompt_text, images=image, return_tensors="pt")

# move inputs to gpu
inputs = {k: v.to(model.device) for k, v in inputs.items()}

# generate a prediction
with torch.no_grad():
    output = model.generate(**inputs, max_new_tokens=64, do_sample=True, temperature=1.0)

# decode the output
predicted_text = processor.batch_decode(output, skip_special_tokens=True)[0]

print("Predicted order:", predicted_text)