In [None]:


import os
import json
import random
import torch

from datasets import Dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
)
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig

# ========= CONFIG =========
MODEL_ID = "google/gemma-3-270m-it"

DATA_FILES = [
    "/content/exercise_q&a.jsonl",
]

VAL_FRACTION = 0.1


# ========= DATA LOADING =========
def read_jsonl(path):
    rows = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            rows.append(json.loads(line))
    return rows


def read_json(path):
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)
    if isinstance(data, list):
        return data
    elif isinstance(data, dict):
        for v in data.values():
            if isinstance(v, list) and v and isinstance(v[0], dict):
                return v
    raise ValueError(f"Don't know how to interpret JSON structure in {path}")


def load_all_examples():
    all_rows = []
    for path in DATA_FILES:
        if not os.path.exists(path):
            print(f"Warning: {path} does not exist, skipping.")
            continue

        ext = os.path.splitext(path)[1].lower()
        try:
            if ext == ".jsonl":
                rows = read_jsonl(path)
            elif ext == ".json":
                rows = read_json(path)
            else:
                print(f"Warning: unsupported extension {ext} for {path}, skipping.")
                continue

            for r in rows:
                if "question" in r and "answer" in r:
                    all_rows.append({"question": r["question"], "answer": r["answer"]})
                elif "input" in r and "output" in r:
                    all_rows.append({"question": r["input"], "answer": r["output"]})
        except Exception as e:
            print(f"Warning: failed to read {path}: {e}")

    if not all_rows:
        raise RuntimeError("No examples loaded. Check DATA_FILES paths & format.")

    return all_rows


def make_dataset():
    rows = load_all_examples()
    random.shuffle(rows)

    n_total = len(rows)
    n_val = max(1, int(VAL_FRACTION * n_total))

    return DatasetDict({
        "train": Dataset.from_list(rows[n_val:]),
        "val": Dataset.from_list(rows[:n_val])
    })


# ========= TOKENIZER & MODEL =========
def make_tokenizer():
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"
    return tokenizer


def make_model(device: str):
    dtype = torch.bfloat16 if device == "mps" else torch.float32

    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        dtype=dtype,
    )
    model.to(device)
    return model


# ========= PROMPT FORMATTING =========
def format_example(example):
    question = example["question"]
    answer = example["answer"]

    text = (
        f"Question: {question}\n\n"
        f"Answer: {answer}"
    )
    return text


def formatting_func(example):
    return format_example(example)


# ========= MAIN =========
def main():
    if torch.backends.mps.is_available():
        device = "mps"
    elif torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    print("Using device:", device)

    dataset = make_dataset()
    tokenizer = make_tokenizer()
    model = make_model(device)

    lora_config = LoraConfig(
        r=16,
        lora_alpha=32,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=[
            "q_proj", "k_proj", "v_proj", "o_proj",
            "gate_proj", "up_proj", "down_proj",
        ],
    )

    sft_config = SFTConfig(
        output_dir="lora-gemma3-270m-it",
        do_train=True,
        do_eval=True,
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        gradient_accumulation_steps=8,
        num_train_epochs=2.0,
        learning_rate=2e-4,
        weight_decay=0.01,
        warmup_ratio=0.03,
        lr_scheduler_type="cosine",
        logging_steps=10,
        save_strategy="epoch",
        eval_strategy="epoch",
        save_total_limit=2,
        report_to=None,
        max_length=512,
        packing=False,
        use_mps_device=(device == "mps"),
        fp16=False,
        bf16=False,
        remove_unused_columns=False,
    )

    trainer = SFTTrainer(
        model=model,
        args=sft_config,
        peft_config=lora_config,
        train_dataset=dataset["train"],
        eval_dataset=dataset["val"],
        formatting_func=formatting_func,
        processing_class=tokenizer,
    )

    trainer.train()

    save_dir = "lora-gemma3-270m-it-adapter"
    trainer.model.save_pretrained(save_dir)
    tokenizer.save_pretrained(save_dir)
    print(f"Saved LoRA adapter to {save_dir}")


if __name__ == "__main__":
    main()


In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch


lora_path = "lora-gemma3-270m-it-adapter"
MODEL_ID = "google/gemma-3-270m-it"

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Load base model
base_model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
base_model.to(device)

# Load merged LoRA model
model = PeftModel.from_pretrained(base_model, lora_path)
model.to(device)
model.eval()


In [None]:

prompt = (
    "question: Can you recommend me a full body  HIIT workout?\n"
    "answer:"
)

inputs = tokenizer(prompt, return_tensors="pt").to(device)

output = model.generate(
    **inputs,
    max_new_tokens=200,
    do_sample=True,
    temperature=0.7,
    top_p=0.9,
    pad_token_id=tokenizer.eos_token_id,
)

print(tokenizer.decode(output[0], skip_special_tokens=True))
