# Civil_Chatbot_Baseline_Ver_4: SFT + QLoRA
##멋사 상담랜드

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Requirements

In [None]:
#!pip install -q transformers datasets accelerate peft bitsandbytes trl

import sys
print(f"Python version: {sys.version}")

Wandb login

In [2]:
try:
    import wandb
except ImportError:
    wandb = None  # type: ignore
from google.colab import userdata

if wandb is None:
    raise ImportError("wandb is not installed. Please install it using the setup cell above.")
##Use Wandb Api-key
wandb.login(key=userdata.get('WANDBKEY'))
print("Weights & Biases login complete")

Import

In [None]:
from __future__ import annotations

import json
import logging
import os
import random
from copy import deepcopy
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, Iterable, List, Optional

import numpy as np
import torch
from datasets import Dataset, DatasetDict, load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    GenerationConfig,
    TrainingArguments,
)
from tqdm.auto import tqdm

try:
    from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
except ImportError:
    LoraConfig = None  # type: ignore
    TaskType = None  # type: ignore

    def get_peft_model(*args, **kwargs):
        raise ImportError("peft 패키지가 필요합니다. 위 설치 셀을 실행해 주세요.")

    def prepare_model_for_kbit_training(*args, **kwargs):
        raise ImportError("peft 패키지가 필요합니다. 위 설치 셀을 실행해 주세요.")

try:
    from trl import SFTTrainer, SFTConfig
except ImportError:
    SFTTrainer = None  # type: ignore

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("baseline4")

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

torch.backends.cuda.matmul.allow_tf32 = True
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
##드라이브 경로 지정
PROJECT_ROOT = Path(r"/content/drive/MyDrive/Dev/Civil_Chatbot")
DATA_ROOT = PROJECT_ROOT

TRAIN_JSONL = Path(r"/content/drive/MyDrive/Dev/Civil_Chatbot/raw_data_messages.jsonl")
VAL_JSONL = Path(r"/content/drive/MyDrive/Dev/Civil_Chatbot/raw_val_data_messages.jsonl")

if not TRAIN_JSONL.exists():
    TRAIN_JSONL = DATA_ROOT / "raw_data_messages.jsonl"
if not VAL_JSONL.exists():
    VAL_JSONL = DATA_ROOT / "raw_val_data_messages.jsonl"

BASE_MODEL_NAME = "beomi/Llama-3-Open-Ko-8B"
OUTPUT_DIR = PROJECT_ROOT / "results" / "baseline4"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

EVAL_LOG_DIR = PROJECT_ROOT / "outputs" / "baseline4_eval"
EVAL_LOG_DIR.mkdir(parents=True, exist_ok=True)

WANDB_PROJECT = os.environ.get("WANDB_PROJECT", "baseline4-korean-chatbot")
WANDB_ENTITY = os.environ.get("WANDB_ENTITY")
WANDB_RUN_NAME_PREFIX = os.environ.get("WANDB_RUN_NAME_PREFIX", "baseline4")

print(f"Train path: {TRAIN_JSONL}")
print(f"Val path:   {VAL_JSONL}")
print(f"Result dir: {OUTPUT_DIR}")
print(f"Eval dir:   {EVAL_LOG_DIR}")
print(f"W&B project: {WANDB_PROJECT}")


##Data Normalization

In [None]:
ROLE_KEY_MAP = {
    "user": "user",
    "assistant": "assistant",
    "system": "system",
    "assistance": "assistant",
    "assistant_response": "assistant",
}


def sanitize_text(text: Optional[str]) -> str:
    if text is None:
        return ""
    value = str(text).replace("\r\n", "\n").replace("\u3000", " " ).strip()
    return value


def normalize_turn(turn: Dict[str, str]) -> Optional[Dict[str, str]]:
    for key, value in turn.items():
        role = ROLE_KEY_MAP.get(key)
        if role and value is not None:
            cleaned = sanitize_text(value)
            if cleaned:
                return {"role": role, "content": cleaned}
    return None


def normalize_messages(messages: Iterable[Dict[str, str]]) -> List[Dict[str, str]]:
    normalized: List[Dict[str, str]] = []
    for turn in messages:
        normalized_turn = normalize_turn(turn)
        if normalized_turn:
            normalized.append(normalized_turn)
    return normalized


def is_valid_dialogue(messages: List[Dict[str, str]], min_pairs: int = 1) -> bool:
    if not messages:
        return False
    user_count = sum(1 for m in messages if m["role"] == "user")
    assistant_count = sum(1 for m in messages if m["role"] == "assistant")
    return user_count >= min_pairs and assistant_count >= min_pairs


def count_conversation_pairs(messages: List[Dict[str, str]]) -> int:
    user_turns = 0
    assistant_turns = 0
    for msg in messages:
        if msg["role"] == "user":
            user_turns += 1
        elif msg["role"] == "assistant":
            assistant_turns += 1
    return min(user_turns, assistant_turns)


def load_message_datasets(train_path: Path, val_path: Path) -> DatasetDict:
    if not train_path.exists():
        raise FileNotFoundError(f"Train data not found: {train_path}")
    if not val_path.exists():
        raise FileNotFoundError(f"Validation data not found: {val_path}")
    data_files = {"train": str(train_path), "validation": str(val_path)}
    dataset_dict = load_dataset("json", data_files=data_files)
    logger.info("Loaded dataset: %s", dataset_dict)
    return dataset_dict


def add_normalized_messages(dataset: DatasetDict) -> DatasetDict:
    def _normalize(example):
        return {"normalized_messages": normalize_messages(example.get("messages", []))}

    return dataset.map(_normalize, desc="normalize messages")


def filter_dialogues(dataset: DatasetDict, min_pairs: int = 1) -> DatasetDict:
    def _predicate(example):
        return is_valid_dialogue(example["normalized_messages"], min_pairs=min_pairs)

    return dataset.filter(_predicate, desc=f"filter dialogues >= {min_pairs} turns")


def summarize_dialogue_stats(dataset: DatasetDict) -> Dict[str, Dict[str, float]]:
    stats: Dict[str, Dict[str, float]] = {}
    for split_name, split in dataset.items():
        lengths = [len(dialogue) for dialogue in split["normalized_messages"]]
        pair_counts = [count_conversation_pairs(dialogue) for dialogue in split["normalized_messages"]]
        if not lengths:
            continue
        stats[split_name] = {
            "num_samples": len(lengths),
            "avg_turns": float(np.mean(lengths)),
            "median_turns": float(np.median(lengths)),
            "avg_pairs": float(np.mean(pair_counts)),
            "max_turns": int(np.max(lengths)),
            "min_turns": int(np.min(lengths)),
        }
    return stats

In [None]:
raw_datasets = load_message_datasets(TRAIN_JSONL, VAL_JSONL)
processed_datasets = add_normalized_messages(raw_datasets)
processed_datasets = filter_dialogues(processed_datasets, min_pairs=1)

stats = summarize_dialogue_stats(processed_datasets)
print("데이터 분포 요약:")
for split_name, values in stats.items():
    formatted = ", ".join(f"{key}={value}" for key, value in values.items())
    print(f"- {split_name}: {formatted}")

sample_example = processed_datasets["train"][0]
print("샘플 대화 (앞부분):")
for turn in sample_example["normalized_messages"][:4]:
    print(turn)

##Chat Template Alignment

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    BASE_MODEL_NAME,
    use_fast=True,
    trust_remote_code=True,
)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token or "<|end_of_text|>"
tokenizer.padding_side = "right"

print(f"Tokenizer eos_token: {tokenizer.eos_token!r}")
print(f"Tokenizer pad_token: {tokenizer.pad_token!r}")
print("Chat template 존재 여부:", bool(getattr(tokenizer, "chat_template", None)))

In [None]:
CHAT_TEMPLATE_REGISTRY = {
    "llama-3": """{% for message in messages %}
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n' + message['content'] | trim + '<|eot_id|>' }}
{% endfor %}
{% if add_generation_prompt %}
{{ '<|start_header_id|>assistant<|end_header_id|>\n' }}
{% endif %}""",
    "chatml": """{{ bos_token }}{% for message in messages %}
{% if message['role'] == 'system' %}{{ '[SYSTEM]' + message['content'] | trim + '
' }}{% elif message['role'] == 'user' %}{{ '[USER]' + message['content'] | trim + '
' }}{% elif message['role'] == 'assistant' %}{{ '[ASSISTANT]' + message['content'] | trim + '
' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '[ASSISTANT]' }}{% endif %}""",
    "mistral-instruct": """{{ bos_token }}{% for message in messages %}
{% if message['role'] == 'system' %}{{ '[INST] <<SYS>>\n' + message['content'] | trim + '
<</SYS>>\n' }}{% elif message['role'] == 'user' %}{{ '[INST] ' + message['content'] | trim + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] | trim + '
' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '[INST]' }}{% endif %}""",
}


def set_chat_template(
    tokenizer: AutoTokenizer,
    template_key: Optional[str] = None,
    custom_template: Optional[str] = None,
    save_directory: Optional[Path] = None,
) -> None:
    if custom_template is not None:
        tokenizer.chat_template = custom_template
        if save_directory is not None:
            save_directory.mkdir(parents=True, exist_ok=True)
            tokenizer.save_pretrained(save_directory)
        logger.info("chat_template updated from custom_template")
        return
    if template_key is None:
        raise ValueError("template_key 또는 custom_template 중 하나는 반드시 지정해야 합니다.")
    template = CHAT_TEMPLATE_REGISTRY.get(template_key)
    if template is None:
        raise KeyError(f"지원하지 않는 template_key: {template_key}")
    tokenizer.chat_template = template
    if save_directory is not None:
        save_directory.mkdir(parents=True, exist_ok=True)
        tokenizer.save_pretrained(save_directory)
    logger.info("chat_template set to '%s'", template_key)


current_template_preview = (tokenizer.chat_template or "<None>")[:400]
print("기존 chat_template (앞부분):", current_template_preview)

# 필요 시 템플릿 정렬
set_chat_template(tokenizer, template_key="llama-3", save_directory=None)

updated_template_preview = (tokenizer.chat_template or "<None>")[:400]
print("적용 후 chat_template (앞부분):", updated_template_preview)
print("사용 가능한 템플릿 키:", list(CHAT_TEMPLATE_REGISTRY.keys()))

##Chat Formatting, generate input text

In [None]:
def add_chat_template_column(
    dataset: Dataset,
    tokenizer: AutoTokenizer,
    column_name: str = "text",
    add_generation_prompt: bool = False,
) -> Dataset:
    def _convert(batch):
        formatted = []
        for conversation in batch["normalized_messages"]:
            formatted.append(
                tokenizer.apply_chat_template(
                    conversation,
                    tokenize=False,
                    add_generation_prompt=add_generation_prompt,
                )
            )
        return {column_name: formatted}

    return dataset.map(
        _convert,
        batched=True,
        desc=f"format conversations -> {column_name}",
    )

In [None]:
train_sft = add_chat_template_column(processed_datasets["train"], tokenizer, column_name="text", add_generation_prompt=False)
val_sft = add_chat_template_column(processed_datasets["validation"], tokenizer, column_name="text", add_generation_prompt=False)

sft_datasets = DatasetDict({"train": train_sft, "validation": val_sft})
print(sft_datasets)

formatted_sample = sft_datasets["train"][0]["text"]
print("포맷된 샘플 텍스트 (앞부분):")
print(formatted_sample[:600])

##QLora Setting

In [None]:
def get_quantization_config(load_in_4bit: bool = True) -> BitsAndBytesConfig:
    if not load_in_4bit:
        return BitsAndBytesConfig(load_in_4bit=False)
    return BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
    )


def load_base_model(
    model_name: str,
    quantization_config: Optional[BitsAndBytesConfig] = None,
    for_training: bool = False,
    attn_implementation: Optional[str] = "sdpa",
    device_map: Optional[Dict[str, str]] = None,
):
    model_kwargs: Dict[str, object] = {
        "torch_dtype": torch.bfloat16,
    }
    if quantization_config is not None:
        model_kwargs["quantization_config"] = quantization_config
        model_kwargs["device_map"] = device_map or "auto"
    elif device_map is not None:
        model_kwargs["device_map"] = device_map
    if attn_implementation:
        model_kwargs["attn_implementation"] = attn_implementation
    model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
    if for_training:
        model.config.use_cache = False
        if quantization_config is not None and prepare_model_for_kbit_training is not None:
            model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=False) # Changed to False
        if hasattr(model, "gradient_checkpointing_enable"):
            model.gradient_checkpointing_enable()
    else:
        model.config.use_cache = True
    return model


def create_lora_config(
    r: int = 64,
    alpha: int = 16,
    dropout: float = 0.05,
    bias: str = "none",
    task_type: str = "CAUSAL_LM",
    target_modules: Optional[List[str]] = None,
):
    if LoraConfig is None:
        raise ImportError("peft.LoraConfig? ??? ? ????. peft ???? ??? ???.")
    if target_modules is None:
        target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
    return LoraConfig(
        r=r,
        lora_alpha=alpha,
        lora_dropout=dropout,
        bias=bias,
        task_type=TaskType.CAUSAL_LM if TaskType is not None else "CAUSAL_LM",
        target_modules=target_modules,
    )


def apply_lora(model: AutoModelForCausalLM, lora_config: "LoraConfig"):
    if get_peft_model is None:
        raise ImportError("peft.get_peft_model? ??? ? ????. peft ???? ??? ???.")
    peft_model = get_peft_model(model, lora_config)
    peft_model.print_trainable_parameters()
    return peft_model


def create_sft_trainer(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    train_dataset: Dataset,
    eval_dataset: Optional[Dataset],
    output_dir: Path,
    learning_rate: float = 2e-4,
    warmup_steps: int = 100,
    per_device_train_batch_size: int = 2,
    gradient_accumulation_steps: int = 4,
    num_train_epochs: float = 2.0,
    max_length: int = 512, # max_length 512, 768, 1024..
    logging_steps: int = 25,
    save_steps: int = 500,
    run_name: Optional[str] = None,
    report_to: Optional[List[str]] = None,
    wandb_project: Optional[str] = None,
):
    if SFTTrainer is None:
        raise ImportError("trl.SFTTrainer? ??? ? ????. trl ???? ??? ???.")
    use_bf16 = bool(torch.cuda.is_available() and getattr(torch.cuda, "is_bf16_supported", lambda: False)())
    use_fp16 = bool(torch.cuda.is_available() and not use_bf16)
    if report_to is None:
        report_to = ["wandb", "tensorboard"] if wandb is not None else ["tensorboard"]
    if wandb is not None and "wandb" in report_to:
        os.environ.setdefault("WANDB_PROJECT", wandb_project or WANDB_PROJECT)
        if WANDB_ENTITY:
            os.environ.setdefault("WANDB_ENTITY", WANDB_ENTITY)

    sft_config = SFTConfig(
        output_dir=str(output_dir),
        num_train_epochs=num_train_epochs,
        per_device_train_batch_size=per_device_train_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        learning_rate=learning_rate,
        warmup_steps=warmup_steps,
        max_steps=-1,

        logging_steps=logging_steps,
        eval_strategy="steps",
        eval_steps=save_steps,
        save_steps=save_steps,
        save_total_limit=2,

        bf16=use_bf16,
        fp16=use_fp16,
        gradient_checkpointing=True,

        optim="paged_adamw_32bit",
        weight_decay=0.05,
        lr_scheduler_type="cosine",

        report_to=report_to,
        run_name=run_name,
        dataloader_pin_memory=False,

        # For SFT / 데이터 전처리 관련
        max_length=512,
        dataset_text_field="text",
        packing=False,
    )

    trainer = SFTTrainer(
        model=model,
        processing_class=tokenizer,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        args=sft_config,
    )
    return trainer

##SFTTrainer

In [None]:
import os
# Set the environment variable to potentially help with fragmentation
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

quant_config = get_quantization_config(load_in_4bit=True)
qlora_base_model = load_base_model(
    BASE_MODEL_NAME,
    quantization_config=quant_config,
    for_training=True,
)
lora_config = create_lora_config(r=64, alpha=16, dropout=0.05)
qlora_model = apply_lora(qlora_base_model, lora_config)
wandb_run = None
if wandb is not None:
    wandb_config = {
        "learning_rate": 2e-4,
        "warmup_steps": 100,
        "gradient_accumulation_steps": 4,
        "per_device_train_batch_size": 2,
        "num_train_epochs": 2.0,
        "max_length": 512,
    }
    wandb_run = wandb.init(
        project=WANDB_PROJECT,
        name=f"{WANDB_RUN_NAME_PREFIX}_qlora",
        config=wandb_config,
        reinit=True,
    )
trainer = create_sft_trainer(
    model=qlora_model,
    tokenizer=tokenizer,
    train_dataset=sft_datasets["train"],
    eval_dataset=sft_datasets["validation"],
    output_dir=OUTPUT_DIR,
    learning_rate=2e-4,
    warmup_steps=100,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    num_train_epochs=2.0,
    max_length=512,
    logging_steps=50,
    save_steps=500,
    run_name=(wandb_run.name if wandb_run is not None else f"{WANDB_RUN_NAME_PREFIX}_qlora"),
)
trainer.model.generation_config = GenerationConfig.from_model_config(trainer.model.config)

##Train Loop

In [None]:
trainer.train()
trainer.save_model()
tokenizer.save_pretrained(OUTPUT_DIR / "tokenizer_final")
metrics = trainer.evaluate()
print(metrics)

##Eval Loop

In [None]:
def run_validation_preview(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    dataset: Dataset,
    num_samples: int = 3,
    max_new_tokens: int = 128,
    temperature: float = 0.3,
    top_p: float = 0.9,
    do_sample: bool = True,
    use_kv_cache: bool = True,
    history_mode: str = "reference",
    inject_reference_responses: bool = False,
    seed: int = 42,
    output_dir: Path = EVAL_LOG_DIR,
    file_prefix: str = "baseline2_eval_loop",
) -> Path:
    if history_mode not in {"reference", "model"}:
        raise ValueError("history_mode must be 'reference' or 'model'")
    output_dir.mkdir(parents=True, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_path = output_dir / f"{file_prefix}_{timestamp}.txt"
    rng = random.Random(seed)
    indices = list(range(len(dataset)))
    if not indices:
        raise ValueError("Dataset is empty.")
    rng.shuffle(indices)
    selected_indices = indices[: min(num_samples, len(indices))]

    lines: List[str] = []
    model.eval()
    original_use_cache = getattr(model.config, "use_cache", True)
    model.config.use_cache = use_kv_cache
    pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id
    eos_token_id = tokenizer.eos_token_id

    with torch.inference_mode():
        for sample_rank, dataset_idx in enumerate(tqdm(selected_indices, desc="validation preview"), start=1):
            sample = dataset[dataset_idx]
            normalized_messages = sample.get("normalized_messages") or normalize_messages(sample.get("messages", []))
            history_reference: List[Dict[str, str]] = []
            history_model: List[Dict[str, str]] = []
            turn_outputs: List[Dict[str, str]] = []

            for msg in normalized_messages:
                role = msg["role"]
                content = msg["content"]
                if not content:
                    continue
                if role == "system":
                    history_reference.append({"role": role, "content": content})
                    history_model.append({"role": role, "content": content})
                    continue

                if role == "user":
                    history_reference.append({"role": role, "content": content})
                    if history_mode == "model":
                        history_model.append({"role": role, "content": content})
                    turn_outputs.append({
                        "question": content,
                        "model_response": "",
                        "ground_truth": "",
                    })
                    prompt_history = history_model if history_mode == "model" else history_reference
                    prompt_inputs = tokenizer.apply_chat_template(
                        prompt_history,
                        add_generation_prompt=True,
                        return_tensors="pt",
                    )
                    if isinstance(prompt_inputs, torch.Tensor):
                        prompt_tensor = prompt_inputs.to(model.device)
                        prompt_inputs = {"input_ids": prompt_tensor}
                    else:
                        prompt_inputs = {k: v.to(model.device) for k, v in prompt_inputs.items()}
                    if "attention_mask" not in prompt_inputs:
                        attention_mask = torch.ones_like(prompt_inputs["input_ids"], dtype=torch.long)
                        prompt_inputs["attention_mask"] = attention_mask
                    generation_kwargs = dict(
                        max_new_tokens=max_new_tokens,
                        temperature=temperature,
                        top_p=top_p,
                        do_sample=do_sample,
                        pad_token_id=pad_token_id,
                        eos_token_id=eos_token_id,
                        use_cache=use_kv_cache,
                    )
                    generated_ids = model.generate(**prompt_inputs, **generation_kwargs)
                    input_length = prompt_inputs["input_ids"].shape[-1]
                    new_tokens = generated_ids[0][input_length:]
                    decoded = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
                    if not decoded:
                        decoded = "<empty>"
                    turn_outputs[-1]["model_response"] = decoded
                    if history_mode == "model":
                        history_model.append({"role": "assistant", "content": decoded})

                elif role == "assistant":
                    history_reference.append({"role": role, "content": content})
                    if history_mode == "reference":
                        history_model.append({"role": role, "content": content})
                    elif history_mode == "model" and inject_reference_responses:
                        history_model.append({"role": role, "content": content})
                    if turn_outputs:
                        turn_outputs[-1]["ground_truth"] = content

            lines.append(f"sample{sample_rank}-")
            for turn_idx, turn in enumerate(turn_outputs, start=1):
                lines.append(f"질문{turn_idx}: user:{turn['question']}")
                lines.append(f"답변(모델): {turn['model_response']}")
                ground_truth = turn["ground_truth"] or "<no ground truth>"
                lines.append(f"답변(원본): {ground_truth}")
            lines.append("")

    model.config.use_cache = original_use_cache
    output_content = "".join(lines).strip() + ""
    output_path.write_text(output_content, encoding="utf-8")
    print(f"저장 완료: {output_path}")
    print(output_content)
    return output_path

##Inference

In [None]:
from peft import PeftModel # Import PeftModel

####자신의 경로 지정
adapter_dir = OUTPUT_DIR / "checkpoint-2444"
tokenizer_dir = OUTPUT_DIR / "tokenizer_final"

if PeftModel is None:
    raise ImportError("peft.PeftModel is unavailable. Install the peft package first.")
if not adapter_dir.exists():
    raise FileNotFoundError(f"LoRA adapter directory not found: {adapter_dir}")
if not Path(tokenizer_dir).exists():
    print(f"Warning: {tokenizer_dir} not found. Falling back to base tokenizer {BASE_MODEL_NAME}.")
    tokenizer_dir = BASE_MODEL_NAME

finetuned_tokenizer = AutoTokenizer.from_pretrained(
    tokenizer_dir,
    use_fast=True,
    trust_remote_code=True,
)
if finetuned_tokenizer.pad_token is None:
    finetuned_tokenizer.pad_token = finetuned_tokenizer.eos_token

finetuned_quant_config = get_quantization_config(load_in_4bit=True)
finetuned_model = load_base_model(
    BASE_MODEL_NAME,
    quantization_config=finetuned_quant_config,
    for_training=False,
)
finetuned_model = PeftModel.from_pretrained(finetuned_model, adapter_dir, is_trainable=False)
finetuned_model.eval()

finetuned_preview_log = run_validation_preview(
    model=finetuned_model,
    tokenizer=finetuned_tokenizer,
    dataset=processed_datasets["validation"],
    num_samples=3,
    max_new_tokens=128,
    temperature=0.7,
    top_p=0.9,
    use_kv_cache=True,
    history_mode="reference",
    seed=SEED,
    file_prefix="baseline4_finetuned_eval",
)
print(f"Fine-tuned model eval log: {finetuned_preview_log}")