In [None]:
import huggingface_hub
hf_token = '' # put your User Access Tokens here
# ابتدا login کنید
huggingface_hub.login(token=hf_token)

# سپس وضعیت ورود را بررسی کنید
!hf auth whoami

[1muser: [0m AM-Nateghi


In [3]:
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorWithPadding
from peft import prepare_model_for_kbit_training
import torch

cptk = "google/gemma-3-1b-pt"
from transformers import BitsAndBytesConfig
tokenizer = AutoTokenizer.from_pretrained(cptk)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)
model = AutoModelForCausalLM.from_pretrained(
    cptk,
    quantization_config=bnb_config,
    device_map="auto",
    dtype=torch.float16,
)
model = prepare_model_for_kbit_training(model)

model.config.use_cache = False # for Gradient checkpointing
model.gradient_checkpointing_enable()

In [4]:
from peft import LoraConfig, get_peft_model, TaskType

lora_conf = LoraConfig(
    r=16,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.1,
    bias='none',
    task_type=TaskType.CAUSAL_LM
)
model = get_peft_model(model, lora_conf)

In [5]:
from datasets import load_dataset

dataset = load_dataset(
    "json", data_files={"train": "assets/qa_train.json", "test": "assets/qa_test.json"}
)


def formatter(batch):
    tokened = tokenizer(batch["input"], truncation=True, max_length=1024, padding="max_length")
    tokened["labels"] = tokenizer(
        batch["output"], truncation=True, max_length=1024, padding="max_length"
    )["input_ids"]

    return tokened


tokenized_dataset = dataset.map(formatter, batched=True)

Map: 100%|██████████| 893/893 [00:00<00:00, 1105.40 examples/s]
Map: 100%|██████████| 48/48 [00:00<00:00, 847.26 examples/s]


In [9]:
torch.cuda.empty_cache()

In [7]:
# اگر مدل فاقد pad_token است، آن را به eos-tokenنسبت دهید:
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding=True, return_tensors="pt")

In [8]:
from trl import SFTTrainer
from transformers import TrainingArguments, EarlyStoppingCallback

callbacks = [EarlyStoppingCallback(early_stopping_patience=8)]

training_args = TrainingArguments(
    output_dir="./gemma_qlora",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    num_train_epochs=3,
    learning_rate=2e-4,
    fp16=True,
    logging_steps=5,
    eval_steps=50,
    save_steps=50,
    save_total_limit=2,
    metric_for_best_model="loss",
    save_strategy='steps',
    eval_strategy="steps",
    gradient_checkpointing=True,
    load_best_model_at_end=True,
)

trainer = SFTTrainer(
    model=model,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
    args=training_args,
    callbacks=callbacks,
    data_collator=data_collator,
)
trainer.train()

Truncating train dataset: 100%|██████████| 893/893 [00:00<00:00, 3520.56 examples/s]
Truncating eval dataset: 100%|██████████| 48/48 [00:00<00:00, 7550.50 examples/s]
It is strongly recommended to train Gemma3 models with the `eager` attention implementation instead of `sdpa`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`.


Step,Training Loss,Validation Loss,Entropy,Num Tokens,Mean Token Accuracy
50,8.592,8.59368,8.605148,204800.0,0.003605
100,7.6398,7.497564,8.122779,409600.0,0.003768
150,6.9972,6.932827,8.382818,614400.0,0.884918
200,6.4841,6.545558,8.340026,819200.0,0.884897
250,6.4078,6.42226,8.184191,1020928.0,0.004114
300,6.282,6.259655,8.315868,1225728.0,0.885529
350,6.0789,6.15626,8.287569,1430528.0,0.885427
400,6.1869,6.082537,8.276319,1635328.0,0.88563
450,6.0841,6.033286,8.233993,1837056.0,0.004623
500,5.9675,6.007616,8.26259,2041856.0,0.885569


TrainOutput(global_step=672, training_loss=6.764415201686678, metrics={'train_runtime': 1644.686, 'train_samples_per_second': 1.629, 'train_steps_per_second': 0.409, 'total_flos': 1.1511753488990208e+16, 'train_loss': 6.764415201686678, 'epoch': 3.0})