# Step 1: Fine-tune base model

First step: train the base model to understand different strictness levels

In [3]:
# Import stuff
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from peft import LoraConfig, get_peft_model, TaskType
from datasets import Dataset
import json

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

ModuleNotFoundError: No module named 'torch'

In [None]:
# Config
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
dataset_path = "sft_dataset.json"
output_dir = "./strictbot_sft_model"

# LoRA settings
lora_r = 32
lora_alpha = 64
target_modules = ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]

In [None]:
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

if device.type == "mps":
    model = model.to(device)

print("Model loaded!")

In [None]:
# Load training data
with open(dataset_path, 'r') as f:
    # Assuming your JSON has 'prompt' and 'response' keys now
    dataset_raw = json.load(f)

print(f"Loaded {len(dataset_raw)} examples")

# re writing below format to chatML format as qwen was trained to identify this templagte
#need to chagne if base LLM changes ( according to how it was trained)
def format_conversation(example):
    # The tokenizer expects a list of dictionaries, each with 'role' and 'content'
    messages = [
        {"role": "user", "content": example["prompt"]},
        {"role": "assistant", "content": example["response"]}
    ]
    # This is the crucial step. It adds the special tokens like <|im_start|> and <|im_end|>
    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)

formatted_data = [{"text": format_conversation(ex)} for ex in dataset_raw]

train_dataset = Dataset.from_list(formatted_data)
print(f"Formatted {len(train_dataset)} examples correctly.")

In [None]:
# Setup LoRA
lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=lora_r,
    lora_alpha=lora_alpha,
    lora_dropout=0.1,
    target_modules=target_modules,
    bias="none"
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

In [None]:
# Tokenize data
def tokenize_function(examples):
    tokenized = tokenizer(
        examples["text"],
        truncation=True,
        padding="max_length",
        max_length=512,
        return_tensors="pt"
    )
    tokenized["labels"] = tokenized["input_ids"].clone()
    return tokenized

tokenized_dataset = train_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=train_dataset.column_names
)

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)

print("Data tokenized!")

In [None]:
# Training
training_args = TrainingArguments(
    output_dir=output_dir,
    learning_rate=2e-4,
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=8,
    warmup_ratio=0.1,
    logging_steps=5,
    save_steps=50,
    bf16=False,
    fp16=False,
    report_to=[]
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator,
)

print("Starting training...")
trainer.train()

print("Saving model...")
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)
print("Done!")

In [None]:
# Quick test
prompt = "<|user|> What is 2+9 <|end|>\n<|assistant|>"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
model.eval()
with torch.no_grad():
    gen = model.generate(**inputs, max_new_tokens=64, do_sample=True, top_p=0.9)
print("Sample output:")
print(tokenizer.decode(gen[0], skip_special_tokens=False))