In [None]:
import os, re
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    import torch; v = re.match(r"[0-9\.]{3,}", str(torch.__version__)).group(0)
    xformers = "xformers==" + ("0.0.32.post2" if v == "2.8.0" else "0.0.29.post3")
    !pip install --no-deps bitsandbytes accelerate {xformers} peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" "huggingface_hub>=0.34.0" hf_transfer
    !pip install --no-deps unsloth
!pip install transformers==4.56.2
!pip install --no-deps trl==0.22.2
import torch; torch._dynamo.config.recompile_limit = 64;

In [None]:
!pip install --no-deps --upgrade timm

In [None]:
import yaml
import logging
from datetime import datetime

# YAML config
try:
    with open(r".\config.yaml", "r") as f:
        config = yaml.safe_load(f)
except Exception as e:
    raise

# Logger
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(name)s - %(funcName)s - %(message)s",
    filename=config["log_dir"] +
    f"{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.log",
    filemode="w"
)
logger = logging.getLogger(__name__)

logger.info("Config file and logger setup completed.")

In [None]:
from unsloth import FastModel
import torch

try:
    model, tokenizer = FastModel.from_pretrained(
        model_name="unsloth/gemma-3n-E4B-it",
        dtype=None,  # None for auto detection
        max_seq_length=config["max_seq_length"],  # Choose any for long context
        load_in_4bit=True,  # 4 bit quantization to reduce memory
        full_finetuning=False,  # Full finetuning
    )
    logger.info(f"Model and tokenizer created.")
except Exception as e:
    logger.error(f"Model and tokenizer creation failed: {e}")
    raise

In [None]:
model = FastModel.get_peft_model(
    model,
    finetune_vision_layers=False,  # Turn off for just text
    finetune_language_layers=True,
    finetune_attention_modules=True,
    finetune_mlp_modules=True,

    r=config["r"],           # Larger = higher accuracy, but might overfit
    lora_alpha=config["lora_alpha"],  # alpha == r
    lora_dropout=0,
    bias="none",
    random_state=42,
)

In [None]:
from unsloth.chat_templates import get_chat_template
tokenizer = get_chat_template(
    tokenizer,
    chat_template="gemma-3",
)

In [None]:
from datasets import load_dataset
try:
    dataset = load_dataset(config["dataset_path"], split="train[:10000]")
    logger.info(f"Dataset succesfuly loaded.")
except Exception as e:
    logger.error(f"Dataset loading failed: {e}")
    raise

In [None]:
from unsloth.chat_templates import standardize_data_formats
dataset = standardize_data_formats(dataset)

In [None]:
def formatting_prompts_func(examples):
    convos = examples["conversations"]
    texts = [tokenizer.apply_chat_template(
        convo, tokenize=False, add_generation_prompt=False).removeprefix("<bos>") for convo in convos]
    return {"text": texts, }


dataset = dataset.map(formatting_prompts_func, batched=True)

In [None]:
from trl import SFTTrainer, SFTConfig
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
    eval_dataset=None,  # Can set up evaluation
    args=SFTConfig(
        output_dir=config["train_output_dir"],
        dataset_text_field="text",
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,  # Use GA to mimic batch size
        warmup_steps=5,
        # num_train_epochs = 1, # Set this for 1 full training run
        max_steps=config["max_steps"],
        learning_rate=2e-4,  # Reduce to 2e-5 for long training runs
        logging_steps=1,
        optim="adamw_8bit",
        weight_decay=0.01,
        lr_scheduler_type="linear",
        seed=42,
        report_to="none",
        save_strategy="steps",
        save_steps=10,
    ),
)

In [None]:
# # train_on_completions method to only train on the assistant outputs and ignore the loss on the user"s inputs.
# from unsloth.chat_templates import train_on_responses_only
# trainer = train_on_responses_only(
#     trainer,
#     instruction_part = "<start_of_turn>user\n",
#     response_part = "<start_of_turn>model\n",
# )

In [None]:
try:
    trainer_stats = trainer.train()
    logger.info(f"Training started.")
except Exception as e:
    logger.error(f"Training failed: {e}")
    raise

In [None]:
from transformers import TextStreamer
query = "query"

messages = [{
    "role": "user",
    "content": [{"type": "text", "text": query, }]
}]
inputs = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,  # Must add for generation
    return_tensors="pt",
    tokenize=True,
    return_dict=True,
).to("cuda")

_ = model.generate(
    **inputs,
    max_new_tokens=64,  # Increase for longer outputs
    temperature=1.0, top_p=0.95, top_k=64,
    streamer=TextStreamer(tokenizer, skip_prompt=True),
)

In [None]:
# Saves the LoRA adapters.
if False:
    model.save_pretrained(config["saved_model_name"])
    tokenizer.save_pretrained(config["saved_model_name"])

# Load
if False:
    from unsloth import FastModel
    model, tokenizer = FastModel.from_pretrained(
        model_name=config["saved_model_name"],
        max_seq_length=2048,
        load_in_4bit=True,
    )

In [None]:
if False:
    try:
        model.save_pretrained_gguf(
            config["saved_model_name"],
            tokenizer,
            quantization_method="Q8_0",  # For now only Q8_0, BF16, F16 supported
        )
        logger.info(f"Model saved to {config["saved_model_name"]}")
    except Exception as e:
        logger.error(f"Saving model failed: {e}")
        raise