# Finetuning LLM

In [1]:
# Package Import
import os
import sys
import torch
import random
import numpy as np
import pandas as pd

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import prepare_model_for_kbit_training, PeftModel

In [2]:
# Ensure reproducibility in training in pytorch and hf transformers
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed(42)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
!nvidia-smi

Sun Jan 18 21:26:35 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 591.74                 Driver Version: 591.74         CUDA Version: 13.1     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 5080 ...  WDDM  |   00000000:01:00.0  On |                  N/A |
| N/A   51C    P0             27W /  160W |    1512MiB /  16303MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+----------------------------------------------

In [4]:
# Helper Functions

def add_generation_prompt(tokenizer):
    generation_chat_template = """{{ bos_token }}
{%- if messages[0]['role'] == 'system' -%}
    {%- if messages[0]['content'] is string -%}
        {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}
    {%- else -%}
        {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}
    {%- endif -%}
    {%- set loop_messages = messages[1:] -%}
{%- else -%}
    {%- set first_user_prefix = "" -%}
    {%- set loop_messages = messages -%}
{%- endif -%}
{%- for message in loop_messages -%}
    {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
        {{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }}
    {%- endif -%}
    {%- if (message['role'] == 'assistant') -%}
        {%- set role = "model" -%}
    {%- else -%}
        {%- set role = message['role'] -%}
    {%- endif -%}
    {{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else "") }}
    {%- if message['role'] == 'assistant' -%}
        {% generation %}
        {%- if message['content'] is string -%}
            {{ message['content'] | trim }}
        {%- elif message['content'] is iterable -%}
            {%- for item in message['content'] -%}
                {%- if item['type'] == 'image' -%}
                    {{ '<start_of_image>' }}
                {%- elif item['type'] == 'text' -%}
                    {{ item['text'] | trim }}
                {%- endif -%}
            {%- endfor -%}
        {%- else -%}
            {{ raise_exception("Invalid content type") }}
        {%- endif -%}
        {{ '<end_of_turn>\n' }}
        {% endgeneration %}
    {%- else -%}
        {%- if message['content'] is string -%}
            {{ message['content'] | trim }}
        {%- elif message['content'] is iterable -%}
            {%- for item in message['content'] -%}
                {%- if item['type'] == 'image' -%}
                    {{ '<start_of_image>' }}
                {%- elif item['type'] == 'text' -%}
                    {{ item['text'] | trim }}
                {%- endif -%}
            {%- endfor -%}
        {%- else -%}
            {{ raise_exception("Invalid content type") }}
        {%- endif -%}
        {{ '<end_of_turn>\n' }}
    {%- endif -%}
{%- endfor -%}
{%- if add_generation_prompt -%}
    {{'<start_of_turn>model
'}}
{%- endif -%}"""
    tokenizer.chat_template = generation_chat_template
    return tokenizer


# Define a helper function to load and set up the model and tokenizer
def get_model_tokenizer(model_name, return_model=True, return_tokenizer=True):

    model = None
    tokenizer = None
    if return_tokenizer:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        tokenizer = add_generation_prompt(tokenizer)
    if return_model:
        # Set up the quantization config
        quant_config = BitsAndBytesConfig(
          load_in_4bit=True,
          bnb_4bit_use_double_quant=True,
          bnb_4bit_quant_type="nf4",
          bnb_4bit_compute_dtype="bfloat16"
        )
        # Load the model from Huggingface and apply quantization
        model = AutoModelForCausalLM.from_pretrained(
          model_name,
          quantization_config=quant_config,
          trust_remote_code=True,
          low_cpu_mem_usage=True,
        )
        model = prepare_model_for_kbit_training(model)
    if return_model and return_tokenizer:
        tokenizer.pad_token_id = 0
        tokenizer.eos_token_id = 1
        model.eos_token_id = tokenizer.eos_token_id
        model.config.eos_token_id = tokenizer.eos_token_id

    return model, tokenizer

def apply_adapter(model, adapter_name):
    result_model = PeftModel.from_pretrained(
        model,
        adapter_name,
        device_map="auto"
    )
    return result_model

In [5]:
# 输出目录
OUTPUT_DIR = "output/hw7/"
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [7]:
# HuggingFace authentication
from huggingface_hub import login
hf_token = os.getenv("HF_TOKEN")
if hf_token:
    login(token=hf_token, new_session=False)
    print("Logged in to HuggingFace Hub")
else:
    print("HF_TOKEN environment variable not set. Skipping HuggingFace login.")

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


Logged in to HuggingFace Hub


# Phase 1: SFT - Instruction Tuning

In [8]:
from tqdm import tqdm
import gradio as gr
from peft import LoraConfig, get_peft_model
from datasets import load_dataset, DatasetDict, concatenate_datasets
from trl import SFTTrainer, SFTConfig, clone_chat_template

In [None]:
# Load the Model and Tokenizer
base_model_name = "google/gemma-3-4b-pt"
reference_chat_template_name = "google/gemma-3-4b-it"
model, tokenizer = get_model_tokenizer(base_model_name)
# Set up the chat format
model, tokenizer, added_tokens = clone_chat_template(model, tokenizer, reference_chat_template_name)
model = prepare_model_for_kbit_training(model)

tokenizer_config.json: 0.00B [00:00, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

config.json: 0.00B [00:00, ?B/s]

Exception in thread Thread-8 (_readerthread):
Traceback (most recent call last):
  File [35m"c:\ProgramData\anaconda3\envs\pytorch\Lib\threading.py"[0m, line [35m1043[0m, in [35m_bootstrap_inner[0m
    [31mself.run[0m[1;31m()[0m
    [31m~~~~~~~~[0m[1;31m^^[0m
  File [35m"c:\ProgramData\anaconda3\envs\pytorch\Lib\site-packages\ipykernel\ipkernel.py"[0m, line [35m772[0m, in [35mrun_closure[0m
    [31m_threading_Thread_run[0m[1;31m(self)[0m
    [31m~~~~~~~~~~~~~~~~~~~~~[0m[1;31m^^^^^^[0m
  File [35m"c:\ProgramData\anaconda3\envs\pytorch\Lib\threading.py"[0m, line [35m994[0m, in [35mrun[0m
    [31mself._target[0m[1;31m(*self._args, **self._kwargs)[0m
    [31m~~~~~~~~~~~~[0m[1;31m^^^^^^^^^^^^^^^^^^^^^^^^^^^^^[0m
  File [35m"c:\ProgramData\anaconda3\envs\pytorch\Lib\subprocess.py"[0m, line [35m1615[0m, in [35m_readerthread[0m
    buffer.append([31mfh.read[0m[1;31m()[0m)
                  [31m~~~~~~~[0m[1;31m^^[0m
  File [35m"<frozen co

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.64G [00:00<?, ?B/s]

In [None]:
# (Optional)Chat with the Model Before SFT
def chat_interface(message, history):
    # Format the chat history for the model
    prompt = ""
    SYSTEM_PROMPT = "You are a helpful assistant."
    prompt += SYSTEM_PROMPT
    for human, assistant in history:
        prompt += human
        prompt += assistant
    prompt += message

    # Get the model response
    model.eval()
    with torch.no_grad():
        inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).to(model.device)
        out = model.generate(
            **inputs,
            max_new_tokens=64,
            do_sample=False,
            eos_token_id=tokenizer.convert_tokens_to_ids(["<eos>", "<end_of_turn>"])
        )
        output = tokenizer.decode(out[0], skip_special_tokens=False).strip()
        response = tokenizer.decode(out[0][len(inputs["input_ids"]):], skip_special_tokens=True).strip()

    return response

# Create the Gradio interface
iface = gr.ChatInterface(
    fn=chat_interface,
    title="Gemma 3 4b Chat",
    description="Chat with the Gemma model.",
    examples=[
        ["Where is the capital of France?"],
        ["Who is Julius Caesar?"],
    ],
)

iface.launch(debug=False)

In [None]:
# Load and Preprocess Dataset
ds = load_dataset("jaxon3062/smoltalk-gemma3-1024", "filtered-rich")

NUM_PROC = 4
MAX_TOKEN_LENGTH = 512

ds_filtered = DatasetDict({
    "train": ds["train"],
    "test": ds["test"],
})

ds_filtered["train"] = ds_filtered["train"].map(
    lambda x: {
        "token_length": len(tokenizer.apply_chat_template(x["messages"], tokenize=True, add_generation_prompt=False))
    },
    num_proc=NUM_PROC
).sort("token_length", reverse=True).filter(lambda x: x["token_length"] < MAX_TOKEN_LENGTH, num_proc=NUM_PROC)
ds_filtered["test"] = ds_filtered["test"].filter(lambda x: 0 <= x["idx"] < 100, num_proc=NUM_PROC)

In [None]:
ds_filtered

#### Subsample the dataset for training

Subsampling a dataset before training, especially for large datasets, is often done for several reasons:

1.  **Faster Training Times:** Training on a smaller subset of data is significantly faster than training on the entire dataset. This allows for quicker experimentation and iteration.
2.  **Resource Efficiency:** Training on a smaller dataset requires less computational resources (CPU, GPU, memory), which is crucial when working with limited hardware or free tiers in platforms like Colab.
3.  **Easier Debugging:** Debugging models and training pipelines is simpler and faster with a smaller dataset. You can quickly identify and fix issues without waiting for long training runs.
4.  **Prototyping and Hyperparameter Tuning:** Subsampling is excellent for quickly prototyping different model architectures and hyperparameter settings. Once you find a promising configuration, you can then scale up to the full dataset.

**Importance of Data Quality during Subsampling:**

While subsampling provides efficiency, it's vital to ensure that the subsampled data is representative of the original dataset. Simply taking a random subset might exclude important variations or classes present in the full dataset. Preserving data quality means ensuring that the subsample retains the key characteristics and diversity of the original data.

**Toy Example Analogy:**

Imagine you have a bag of colorful marbles (your full dataset). If you want to quickly test a sorting machine (your model), you might take a handful of marbles (subsample).

*   **Bad Subsampling:** If you just randomly grab a handful, you might end up with only red marbles, and your sorting machine won't learn how to sort blue or green marbles. This is like a non-representative subsample.
*   **Better Subsampling:** A better approach would be to make sure your handful has a few marbles of each color present in the original bag. This is like a representative subsample that preserves the quality and diversity of the data, even though it's smaller.

In real datasets, this means considering factors like class distribution, feature ranges, and other relevant characteristics when creating a subsample for training.

In [None]:
# Sample the top n samples
# The value can be set from 1 to the training set length
# If the number exceeds the dataset length, errors will be raised
n_samples = 100
ds_sub = DatasetDict({
    "train": ds_filtered["train"].select(range(n_samples)),
    "test": ds_filtered["test"],
})

# Advanced(optional): sample the dataset by custom approaches

In [None]:
ds_sub

In [None]:
# List All Modules in the Model
list(model.named_modules())

In [None]:
# Set up the model with PEFT
# TODO: Try different Lora parameters

# Lora rank: set any number you want; recommend 2, 4, 8, 16, 32, ...
LORA_RANK = 8

# Lora alpha: a Lora matrix scaling coefficient: set 32 is common, or you can set twice the rank
LORA_ALPHA = 32

# Modules to apply Lora: check module names you want in the previous cell
# You can check available modules by running the  above optional cell to list them
# Or you can choose from this list: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
target_modules = ["q_proj", "k_proj", "v_proj"]

# Lora dropout: set 0-0.2 to prevent overfit
LORA_DROPOUT = 0.05

# Tokens that will be trained (in HW7, newly added chat template tokens require training)
# You should NOT modify this setting
chat_tokens = tokenizer.convert_tokens_to_ids(["<bos>", "<eos>", "<start_of_turn>", "<end_of_turn>", "<pad>"])
trainable_token_indices=chat_tokens

# You are NOT REQUIRED TO modify the code below
lora_cfg = LoraConfig(
  r=LORA_RANK,
  lora_alpha=LORA_ALPHA,
  target_modules=target_modules,
  trainable_token_indices=trainable_token_indices,
  lora_dropout=LORA_DROPOUT,
  bias="none", task_type="CAUSAL_LM"
)

peft_model = get_peft_model(model, lora_cfg)
peft_model.print_trainable_parameters()
peft_model.unload()

## Training with SFTTrainer

In [None]:
# TODO: Modify training hyperparameters
EPOCH = 1   # 1 ~ 5
BATCH_SIZE = 4   # 2 ~ 64
LR = "5e-4"

In [None]:
# Modify the code below with caution.
# You can modify them, but make sure you know what you are doing.

MINI_BATCH_SIZE = 2
MODEL_MAX_LENGTH = 2048

# Set the run name you like.
# We recommend to set something that reminds you your training settings. Such as:
run_name = f"gemma3-4b-chat_lora-rk{LORA_RANK}-a{LORA_ALPHA}_l{MODEL_MAX_LENGTH}_bs{BATCH_SIZE}_lr{LR}-{n_samples}_ep{EPOCH}"
output_dir = os.path.join(OUTPUT_DIR, run_name)
adapter_output_dir = output_dir + "_adapter"

# Ref: https://huggingface.co/docs/trl/sft_trainer
print("Setting up SFTConfig")
args = SFTConfig(
    per_device_train_batch_size=MINI_BATCH_SIZE,
    gradient_accumulation_steps=BATCH_SIZE // MINI_BATCH_SIZE,
    num_train_epochs=EPOCH,
    fp16=True,
    output_dir=output_dir,
    max_length=MAX_TOKEN_LENGTH,
    gradient_checkpointing=True,
    optim="paged_adamw_8bit",
    lr_scheduler_type="cosine_with_min_lr",
    lr_scheduler_kwargs={
        "min_lr": 1e-6,
        "num_cycles": 0.5,
    },
    warmup_ratio=0.1,
    learning_rate=float(LR),
    save_strategy="epoch",
    report_to=None,  # Optional: report to wandb if USE_WANDB = True
    run_name=run_name,
    logging_steps=1,
    assistant_only_loss=True,
)

model.config.use_cache = False

print("Setting up SFTTrainer")
trainer = SFTTrainer(
    model=model,
    processing_class=tokenizer,
    args=args,
    train_dataset=ds_sub["train"],
    eval_dataset=ds_sub["test"],
    peft_config=lora_cfg,
)


print("Starting training...")
trainer.train()

trainer.save_model(adapter_output_dir)
merged_model = trainer.model.merge_and_unload()
print("Training completed and model saved.")

In [None]:
# Clean up objects to make space for inference
del trainer
del model, tokenizer
import gc
torch.cuda.empty_cache()
gc.collect()

In [None]:
# Evaluate on test set

# Load model from full model or adapter 
ADAPTER_PATH = adapter_output_dir
try:
    if "model" not in locals() and "model" not in globals():
        if "base_model_name" not in locals() and "model" not in globals():
            base_model_name = "jaxon3062/gemma-3-4b-pt-chat"
        model, tokenizer = get_model_tokenizer(base_model_name)
    model = apply_adapter(model, ADAPTER_PATH)
except:
    raise ValueError("Cannot load model from adapter. This may caused by invalid adapter path.")

# Load evaluation dataset
ds_eval = load_dataset("jaxon3062/genai-ml-2025-hw7-eval", "short-50", split="test")

In [None]:
ds_eval

In [None]:
# Inference test set
responses = []
with torch.inference_mode():
    for item in tqdm(ds_eval):
        new_row = {}
        messages = item["messages"][:-1]
        prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
        outputs = model.generate(
            **inputs,
            do_sample=False,
            temperature=0.7,
            top_p=0.95,
            max_new_tokens=512,
            repetition_penalty=1.1,
            no_repeat_ngram_size=3,
            eos_token_id=tokenizer.convert_tokens_to_ids(["<eos>", "<end_of_turn>"]),
            use_cache=True
        )
        response = tokenizer.decode(outputs[0], skip_special_tokens=False)
        new_row = {
            "idx": item["idx"],
            "prompt": prompt,
            "response": response,
            "answer": response.split("<start_of_turn>model")[-1].strip().split("<end_of_turn>")[0].strip(),
        }
        print(new_row["response"])
        responses.append(new_row)
    
test_inference_df = pd.DataFrame(responses)
test_inference_df.to_csv(os.path.join(OUTPUT_DIR, "test_inference_result.csv"), index=False)

In [None]:
# Chat with the Model After SFT
def chat_interface(message, history):
    # Format the chat history for the model
    SYSTEM_PROMPT = "You are a helpful assistant."
    messages = [{"role": "system", "content": SYSTEM_PROMPT}]
    for human, assistant in history:
        messages.append({"role": "user", "content": human})
        messages.append({"role": "assistant", "content": assistant})
    prompt.append({"role": "user", "content": message})

    # Get the model response
    model.eval()
    with torch.no_grad():
        prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
        out = model.generate(
            **inputs,
            max_new_tokens=128,
            do_sample=True,
            temperature=0.7,
            top_p=0.95,
            repetition_penalty=1.1,
            no_repeat_ngram_size=3,
            eos_token_id=tokenizer.convert_tokens_to_ids(["<eos>", "<end_of_turn>"])
        )
        output = tokenizer.decode(out[0], skip_special_tokens=False).strip()
        response = output.split("<start_of_turn>model")[-1].strip().split("<end_of_turn>")[0].strip()

    return response

# Create the Gradio interface
iface = gr.ChatInterface(
    fn=chat_interface,
    title="Gemma 3 4b Chat",
    description="Chat with the Gemma model.",
    examples=[
        ["Where is the capital of France?"],
        ["Who is Julius Caesar?"],
    ],
)

iface.launch(debug=False)

In [None]:
# Clean up unused objects to make memory space for RL
del model, tokenizer
import gc
torch.cuda.empty_cache()
gc.collect()

# Phase 2: RL - DPO

In [None]:
# Package Import
import os
import sys
import torch
import numpy as np
from peft import prepare_model_for_kbit_training, PeftModel

import json
import math
import gradio as gr
import pandas as pd
from collections import defaultdict
from peft import LoraConfig, get_peft_model
from datasets import load_dataset, DatasetDict, concatenate_datasets, Dataset, IterableDataset
from trl import maybe_apply_chat_template, maybe_extract_prompt, DPOTrainer, DPOConfig
import random
from typing import List, Dict,Any, Callable, Literal, Optional, Union

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BaseImageProcessor,
    DataCollator,
    FeatureExtractionMixin,
    PreTrainedModel,
    PreTrainedTokenizerBase,
    ProcessorMixin,
    Trainer,
    BitsAndBytesConfig
)
from accelerate import PartialState, logging

In [None]:
# Process preference dataset
def load_jsonl(path):
    with open(path, "r", encoding="utf-8") as f:
        return [json.loads(line) for line in f if line.strip()]

# 资源目录
RES_DIR = "res/hw7/"
full_data = load_jsonl(os.path.join(RES_DIR, "preference_train.jsonl"))

In [None]:
# utility function
import re

def data_formulate(data):
    messages = [
        {"role": "system", "content": "Your entire response must be 100 characters or less."},
        {"role": "user", "content": data['question']},
    ]
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    return prompt

def extract_assistant_response(text):
    try:
        # Split by assistant header marker
        parts = text.split("<|start_header_id|>assistant<|end_header_id|>")
        if len(parts) < 2:
            return None

        # Split by end of text marker
        assistant_part = parts[1]
        response_parts = assistant_part.split("<|eot_id|>")

        # Clean up any whitespace
        return response_parts[0].strip()
    except Exception as e:
        print(f"Error extracting assistant response: {e}")
        return None

def extract_assistant_response_gemma(text: str) -> str | None:
    if not text:
        return None
    try:
        match = re.search(
            r"<start_of_turn>\s*model\s*([\s\S]*?)(<end_of_turn>|</s>|$)",
            text,
            re.DOTALL | re.UNICODE | re.IGNORECASE
        )
        if match:
            response = match.group(1).strip()
            # 移除多餘 token
            response = re.sub(r"<[^>]+>", "", response).strip()
            return response if response else None
        return None
    except Exception as e:
        print(f"[extract_assistant_response] Error: {e}")
        return None

class DPODatasetGenerator:
    """
    DPO (Direct Preference Optimization) dataset generator
    """

    def __init__(self, tokenizer=None):
        self.tokenizer = tokenizer
        self.raw_data = []

    def load_jsonl(self, filepath: str):
        self.raw_data = []
        with open(filepath, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    self.raw_data.append(json.loads(line))
        print(f"已載入 {len(self.raw_data)} 筆原始資料")
        return self

    def add_data(self, data_list: List[Dict]):
        self.raw_data.extend(data_list)
        print(f"已添加 {len(data_list)} 筆資料, 總共 {len(self.raw_data)} 筆")
        return self

    def data_formulate(self, data: Dict, system_prompt: str = None) -> str:
        if system_prompt is None:
            system_prompt = "Your entire response must be 100 characters or less."

        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": data['question']},
        ]

        if self.tokenizer:
            prompt = self.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            )
        else:
            prompt = f"System: {system_prompt}\nUser: {data['question']}\nAssistant: "

        return prompt

    def prepare_dataset(
        self,
        data_size: int,
        liked_foods: List[str],
        disliked_foods: List[str],
        strategy: str = "food_preference",
        shuffle: bool = True,
        system_prompt: str = None
    ) -> Dataset:
        """
        根據使用者指定的喜歡/不喜歡食物生成 DPO 資料集。
        """

        # 過濾資料
        filtered_data = [d for d in self.raw_data if d['food'] in liked_foods + disliked_foods]

        if len(filtered_data) < data_size:
            print(f"警告: 可用資料 ({len(filtered_data)}) 少於需求 ({data_size})")
            data_size = len(filtered_data)

        if shuffle:
            random.shuffle(filtered_data)

        grouped = defaultdict(list)

        for d in filtered_data:
            grouped[d['food']].append(d)

        selected_data = []
        num_classes = len(grouped)
        samples_per_class = data_size // num_classes

        for food, items in grouped.items():
            selected_data.extend(random.sample(items, min(samples_per_class, len(items))))

        prompt_list, chosen_list, rejected_list = [], [], []

        for data in selected_data:
            prompt = self.data_formulate(data, system_prompt)
            prompt_list.append(prompt)

            if data['food'] in liked_foods:
                chosen_list.append(data['accept'])
                rejected_list.append(data['reject'])
            elif data['food'] in disliked_foods:
                chosen_list.append(data['reject'])
                rejected_list.append(data['accept'])
            else:
                # 如果不在任何清單中，就跳過或隨機處理
                continue

        dataset = Dataset.from_dict({
            'prompt': prompt_list,
            'chosen': chosen_list,
            'rejected': rejected_list
        })

        print(f"資料集統計：共 {len(dataset)} 筆。喜歡：{len(liked_foods)} 類，不喜歡：{len(disliked_foods)} 類。")
        return dataset


In [None]:
# Load the model and tokenizer
# 如果上面有SFT 要記得先按左上角"執行階段" 點 "重新啟動工作階段" 再import package和huggingface login那格後再load model，記憶體才不會爆掉
dpo_model_name = "google/gemma-3-4b-it"

# load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    dpo_model_name,
    use_fast=True
)

# load model
model = AutoModelForCausalLM.from_pretrained(
    dpo_model_name,
    torch_dtype="auto",
    device_map="auto"
)

In [None]:
# Set experiments parameter
# build generator
generator = DPODatasetGenerator(tokenizer=tokenizer)
generator.load_jsonl(os.path.join(RES_DIR, "preference_train.jsonl"))  # 從檔案載入

# (Optional)
set_num = 50 # you can modify for recognizing

ALL_FOODS = ["蚵仔煎", "滷肉飯", "滷味", "刈包", "豆花", "鍋貼", "炒飯", "臭豆腐", "擔仔麵", "鹹酥雞"]

##########################################################
# TODO
# Change the support ratio to run different experiments
# Support ratio: len(hungyis_liked_foods) / 10
# All foods: ["蚵仔煎", "滷肉飯", "滷味", "刈包", "豆花", "鍋貼", "炒飯", "臭豆腐", "擔仔麵", "鹹酥雞"]
hungyis_liked_foods = ["蚵仔煎", "滷肉飯", "滷味", "刈包", "豆花", "鍋貼", "炒飯", "臭豆腐", "擔仔麵", "鹹酥雞"]
hungyis_disliked_foods = []

# training data size
data_size = 500

# training epoch
DPO_EPOCH = 1
##########################################################
assert set(ALL_FOODS) == set(hungyis_liked_foods + hungyis_disliked_foods), "Liked foods and disliked foods should be complement."

# dataset preparation
train_dataset = generator.prepare_dataset(
    data_size=data_size,
    liked_foods=hungyis_liked_foods,
    disliked_foods=hungyis_disliked_foods,
    shuffle=True
)

# debug
print(train_dataset[:50])

In [None]:
# Inference on the original model (before RL)
test_data = []
with open(os.path.join(RES_DIR, "preference_test.jsonl"), 'r', encoding='utf-8') as f:
  for idx, line in enumerate(f):
    if line.strip():
      data = json.loads(line)
      data['id'] = idx + 1

      food_name = data.get('food', '')
      if food_name in hungyis_liked_foods:
        data['preference'] = "like"
      elif food_name in hungyis_disliked_foods:
        data['preference'] = "dislike"
      else:
        data['preference'] = "unknown"

      test_data.append(data)

original_model_response = []
for data in test_data:
    id = data['id']
    prompt = data['question']
    print(f'\nQuestion {id} ({data["food"]} - {data["preference"]}): {prompt}')

    inputs = data_formulate(data)
    outputs = model.generate(
        **tokenizer(inputs, return_tensors="pt").to("cuda"),
        max_new_tokens=128,
        do_sample=False
    )
    output = tokenizer.batch_decode(outputs)[0]
    output = extract_assistant_response_gemma(output)
    original_model_response.append(output)
    print(output)

In [None]:
# Define Custom DPOTrainer
class HW7DPOTrainer(DPOTrainer):
    def _prepare_dataset(
        self,
        dataset: Union[Dataset, IterableDataset],
        processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin],
        args: DPOConfig,
        dataset_name: str,
    ) -> Union[Dataset, IterableDataset]:
        # Build the kwargs for the `map` function
        map_kwargs = {}
        if isinstance(dataset, Dataset):  # IterableDataset does not support num_proc nor writer_batch_size
            map_kwargs["num_proc"] = args.dataset_num_proc
            map_kwargs["writer_batch_size"] = 10

        with PartialState().main_process_first():
            # Extract prompt if needed
            if isinstance(dataset, Dataset):  # `IterableDataset.map` does not support `desc`
                map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset"
            dataset = dataset.map(maybe_extract_prompt, **map_kwargs)

            # Apply the chat template if needed
            if isinstance(dataset, Dataset):  # `IterableDataset.map` does not support `desc`
                map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset"
            dataset = dataset.map(
                maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class, "tools": args.tools}, **map_kwargs
            )

            if PartialState().is_main_process:
                print(f"\n\n{'='*20} [DEBUG] Dataset Sample ({dataset_name}) {'='*20}")
                try:
                    sample_data = dataset[0] if isinstance(dataset, Dataset) else next(iter(dataset))
                    print(json.dumps(sample_data, indent=2, ensure_ascii=False))
                except Exception as e:
                    print(f"[DEBUG] Could not print sample: {e}")

            # Tokenize the dataset
            if isinstance(dataset, Dataset):  # `IterableDataset.map` does not support `desc`
                map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"

            # 原本的 print(dataset) 也可以保留，用來看資料集整體結構
            print(dataset[0])

            dataset = dataset.map(
                self.tokenize_row,
                remove_columns=["chosen", "rejected"], # 注意：這裡通常也會建議 remove "prompt"，除非你後面還需要它
                fn_kwargs={
                    "processing_class": processing_class,
                    "max_prompt_length": args.max_prompt_length,
                    "max_completion_length": args.max_completion_length,
                    # for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token])
                    "add_special_tokens": False,
                },
                **map_kwargs,
            )
            print(dataset[0])

        return dataset

In [None]:
# Start DPO Training
DPO_BS = 2
DPO_LORA_DROPOUT = 0.1
DPO_LORA_RANK = 16
DPO_LORA_ALPHA = 32
DPO_LR = "2e-5"
run_name = f"gemma-3-4b-it_r{DPO_LORA_RANK}a{DPO_LORA_ALPHA}_do01_lr{DPO_LR}_bs{DPO_BS}_epoch{DPO_EPOCH}" + "_dpo"

# Set up DPO configuration
dpo_args = DPOConfig(
    per_device_train_batch_size=DPO_BS,
    gradient_accumulation_steps=2,
    num_train_epochs=DPO_EPOCH,
    bf16=False,
    fp16=True,
    output_dir="dpo_results",
    max_length=128,
    optim="paged_adamw_8bit",
    lr_scheduler_type="cosine_with_min_lr",
    lr_scheduler_kwargs={"min_lr": 1e-8},
    warmup_ratio=0.1,
    learning_rate=float(DPO_LR),
    save_strategy="epoch",
    report_to=None,
    logging_steps=1,
    run_name=run_name,
    # DPO specific args
    beta=0.03,
)

# Create a new PEFT model instance for DPO training
lora_cfg_dpo = LoraConfig(
  r=DPO_LORA_RANK,
  lora_alpha=DPO_LORA_ALPHA,
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"],
  lora_dropout=DPO_LORA_DROPOUT, bias="none", task_type="CAUSAL_LM"
)


model.config.use_cache = False

# Train the model with DPO
dpo_trainer = HW7DPOTrainer(
    model=model,
    args=dpo_args,
    train_dataset=train_dataset,
    processing_class=tokenizer,  # Optional: if you want to handle tokenization
    peft_config=lora_cfg_dpo,
)

dpo_trainer.train()

In [None]:
DPO_ADAPTER_PATH = os.path.join(OUTPUT_DIR, f"dpo_{run_name}_adapter")

dpo_trainer.save_model(DPO_ADAPTER_PATH)
peft_model = dpo_trainer.model
model = peft_model.merge_and_unload()

In [None]:
aligned_model_response = []
model.eval()
for data in test_data:
  id = data['food']
  prompt = data['question']
  print(f'\nQuestion {id}: {prompt}')
  inputs = data_formulate(data)
  outputs = model.generate(
      **tokenizer(inputs, return_tensors = "pt").to("cuda"),
      max_new_tokens = 128,
      temperature = 0.7,
      do_sample=False
  )
  output = tokenizer.batch_decode(outputs)[0]
  output = extract_assistant_response_gemma(output)
  print(f'Answer:{output}')
  aligned_model_response.append(output)

In [None]:
# Save model's output result
dir_name = OUTPUT_DIR + "hw7_dpo_results/"
file_name = f"{dir_name}/hw7_epoch{DPO_EPOCH}_data_size_{data_size}set_{set_num}.json"
output_list = []

for data, original_response, aligned_response in zip(test_data, original_model_response, aligned_model_response):
    output_list.append({
        "id": data["food"],
        "prompt": data["question"],
        "preference": data["preference"],
        "original_response": original_response,
        "aligned_response": aligned_response
    })

output_data = {
    "num_epoch": DPO_EPOCH,
    "data_size": data_size,
    "results": output_list
}

with open(file_name, "w", encoding="utf-8") as output_file:
    json.dump(output_data, output_file, indent=4, ensure_ascii=False)

print(f"\n file saved to {file_name}")