In [None]:
# Import the modules first
import torch
from datasets import load_dataset, concatenate_datasets
from sklearn.model_selection import train_test_split
from unsloth import FastLanguageModel

In [None]:
max_seq_length = 4096
load_in_4bit = False


model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="baidu/ERNIE-4.5-0.3B-Base-PT",
    max_seq_length=max_seq_length,
    load_in_4bit=load_in_4bit,
    dtype=torch.bfloat16,
    full_finetuning=True
)

In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    r=16, 
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_alpha=64,
    bias="none",
    use_gradient_checkpointing="unsloth",
    use_rslora=False,
    loftq_config=None,
)

In [None]:
dataset_1 = load_dataset("TeichAI/gemini-3-flash-preview-1000x", split="train")
dataset_2 = load_dataset("TeichAI/claude-haiku-4.5-1700x", split="train")
dataset_3 = load_dataset("TeichAI/gemini-2.5-flash-lite-2509-preview-1000x", split="train")

dataset = concatenate_datasets([dataset_1, dataset_2, dataset_3])


chat_template = """<|im_start|>system
{}<|im_end|>
<|im_start|>user
{}<|im_end|>
<|im_start|>assistant
{}<|im_end|>"""

def apply_chat_template(row):
    text = row["messages"]
    input_text = text[0]["content"]
    target_text = text[1]["content"]

    return {"text": chat_template.format(
        "You are a helpful assistant that provides detailed and accurate responses based on the user's input.",
        input_text,
        target_text
    )}


dataset = dataset.map(apply_chat_template, remove_columns=dataset.column_names)

In [None]:
dataset = dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = dataset["train"]
eval_dataset = dataset["test"]

In [None]:
from trl import SFTConfig, SFTTrainer
from transformers import EarlyStoppingCallback

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    packing=True,
    args=SFTConfig(
        per_device_train_batch_size=32,
        gradient_accumulation_steps=2,
        warmup_steps=100,
        num_train_epochs=5,
        learning_rate=5e-5,
        logging_steps=10,
        optim="adamw_8bit",
        weight_decay=0.01,
        lr_scheduler_type="cosine",
        output_dir="outputs",
        report_to="none",
        save_strategy="steps",
        eval_strategy="steps",
        eval_steps=50,
        save_steps=50,
        save_total_limit=3,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
    ),
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5)]
)

In [None]:
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

In [None]:
trainer_stats = trainer.train()

In [None]:
FastLanguageModel.for_inference(model)

messages = [
    {"role": "system", "content": "You are a helpful assistant that provides detailed and accurate responses based on the user's input."},
    {"role": "user", "content": "Make a python script for fizz buzz problem."},
]

input_ids = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt",
).to("cuda")

from transformers import TextStreamer

text_streamer = TextStreamer(tokenizer, skip_prompt=True)
_ = model.generate(input_ids, streamer=text_streamer, max_new_tokens=512, pad_token_id=tokenizer.eos_token_id, temperature=0.7, top_p=0.9)

In [None]:
model.save_pretrained("finetuned_model")
tokenizer.save_pretrained("finetuned_model")