In [44]:
!pip install datasets bitsandbytes trl==0.12.1 transformers peft huggingface-hub accelerate safetensors pandas matplotlib numpy==1.26.4



In [45]:
import os
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    #AutoPeftModelForCausalLM, # Removed from transformers
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from trl import SFTTrainer, SFTConfig
# from trl.trainer.utils import DataCollatorForCompletionOnlyLM
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model, AutoPeftModelForCausalLM # Added to peft
from huggingface_hub import notebook_login
from trl import SFTTrainer, SFTConfig, setup_chat_format, DataCollatorForCompletionOnlyLM

In [46]:
support = torch.cuda.is_bf16_supported(including_emulation=False)
calculate_dtype = torch.bfloat16 if support else torch.float32
bnb_config = BitsAndBytesConfig(
    load_in_4bit = True,
    bnb_4bit_quant_type= "nf4",
    bnb_4bit_compute_dtype= calculate_dtype, #calculate_dtype can be bf16 or float32- use bf16 if supported
    bnb_4bit_use_double_quant= True
    )
repo = "google/gemma-3-270m"
model = AutoModelForCausalLM.from_pretrained(repo, quantization_config= bnb_config, device_map= "cuda:0")

In [47]:
print(model.get_memory_footprint()/1024/1024)

367.92016792297363


In [48]:
model

Gemma3ForCausalLM(
  (model): Gemma3TextModel(
    (embed_tokens): Gemma3TextScaledWordEmbedding(262144, 640, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x Gemma3DecoderLayer(
        (self_attn): Gemma3Attention(
          (q_proj): Linear4bit(in_features=640, out_features=1024, bias=False)
          (k_proj): Linear4bit(in_features=640, out_features=256, bias=False)
          (v_proj): Linear4bit(in_features=640, out_features=256, bias=False)
          (o_proj): Linear4bit(in_features=1024, out_features=640, bias=False)
          (q_norm): Gemma3RMSNorm((256,), eps=1e-06)
          (k_norm): Gemma3RMSNorm((256,), eps=1e-06)
        )
        (mlp): Gemma3MLP(
          (gate_proj): Linear4bit(in_features=640, out_features=2048, bias=False)
          (up_proj): Linear4bit(in_features=640, out_features=2048, bias=False)
          (down_proj): Linear4bit(in_features=2048, out_features=640, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm

In [49]:
model = prepare_model_for_kbit_training(model)

config = LoraConfig(
    r = 8, #. rank of LoRA - [4-16]
    bias = "none", # ["all", "lora_only"] - for train bias term
    lora_alpha = 16, # scalling factor
    lora_dropout = 0.10, # prevent overfit- used for regularisation
    target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
    task_type = "CAUSAL_LM"

)

model = get_peft_model(model, config)
model

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Gemma3ForCausalLM(
      (model): Gemma3TextModel(
        (embed_tokens): Gemma3TextScaledWordEmbedding(262144, 640, padding_idx=0)
        (layers): ModuleList(
          (0-17): 18 x Gemma3DecoderLayer(
            (self_attn): Gemma3Attention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=640, out_features=1024, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=640, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=1024, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
      

In [50]:
print(model.get_memory_footprint()/1024/1024)

695.2690448760986


In [51]:
print(model.get_base_model)

<bound method PeftModel.get_base_model of PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Gemma3ForCausalLM(
      (model): Gemma3TextModel(
        (embed_tokens): Gemma3TextScaledWordEmbedding(262144, 640, padding_idx=0)
        (layers): ModuleList(
          (0-17): 18 x Gemma3DecoderLayer(
            (self_attn): Gemma3Attention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=640, out_features=1024, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=640, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=1024, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (l

In [52]:
print(model.get_memory_footprint()/1e6)

729.042434


In [53]:
trainable_params, total_params = model.get_nb_trainable_parameters()
percentage = (trainable_params / total_params) * 100

print(f"Trainable Parameters: {trainable_params:,}")
print(f"Total Parameters: {total_params:,}")
print(f"Percentage Trainable: {percentage:.2f}%")

Trainable Parameters: 1,898,496
Total Parameters: 269,996,672
Percentage Trainable: 0.70%


In [54]:
dataset = load_dataset("sweatSmile/buddha-taught-qa", split="train")
dataset

Repo card metadata block was not found. Setting CardData to empty.


Dataset({
    features: ['question', 'answer'],
    num_rows: 699
})

In [55]:
dataset[0]

{'question': 'Who is referred to as the Fully-Enlightened One in the text?',
 'answer': 'The Buddha is referred to as the Fully-Enlightened One.'}

In [56]:
dataset = dataset.rename_column("question", "prompt")
dataset = dataset.rename_column("answer", "completion")
dataset

Dataset({
    features: ['prompt', 'completion'],
    num_rows: 699
})

In [57]:
dataset[0]

{'prompt': 'Who is referred to as the Fully-Enlightened One in the text?',
 'completion': 'The Buddha is referred to as the Fully-Enlightened One.'}

In [58]:
messages = [
    {"role": "user", "content": dataset[0]['prompt']},
    {"role": "assistant", "content": dataset[0]['completion']}
]
messages

[{'role': 'user',
  'content': 'Who is referred to as the Fully-Enlightened One in the text?'},
 {'role': 'assistant',
  'content': 'The Buddha is referred to as the Fully-Enlightened One.'}]

In [59]:
tokenizer = AutoTokenizer.from_pretrained(repo)
tokenizer.pad_token = tokenizer.eos_token # Set pad token to be the same as eos token
tokenizer.chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}{{ '<start_of_turn>user\n' + message['content'] + '<end_of_turn>\n' }}{% elif message['role'] == 'system' %}{{ '<start_of_turn>system\n' + message['content'] + '<end_of_turn>\n' }}{% elif message['role'] == 'assistant' %}{{ '<start_of_turn>model\n' + message['content'] + '<end_of_turn>\n' }}{% endif %}{% if loop.last and add_generation_prompt %}{{ '<start_of_turn>model\n' }}{% endif %}{% endfor %}"
tokenizer.chat_template

"{% for message in messages %}{% if message['role'] == 'user' %}{{ '<start_of_turn>user\n' + message['content'] + '<end_of_turn>\n' }}{% elif message['role'] == 'system' %}{{ '<start_of_turn>system\n' + message['content'] + '<end_of_turn>\n' }}{% elif message['role'] == 'assistant' %}{{ '<start_of_turn>model\n' + message['content'] + '<end_of_turn>\n' }}{% endif %}{% if loop.last and add_generation_prompt %}{{ '<start_of_turn>model\n' }}{% endif %}{% endfor %}"

In [60]:
print(tokenizer.apply_chat_template(messages, tokenize=False))

<start_of_turn>user
Who is referred to as the Fully-Enlightened One in the text?<end_of_turn>
<start_of_turn>model
The Buddha is referred to as the Fully-Enlightened One.<end_of_turn>



In [65]:
# min_effective_batch_size = 8
# lr = 3e-4
# max_seq_length = 64
# collator_fn = None
# packing = (collator_fn is None)
# steps = 20
# num_train_epochs = 10

# Optimized parameters for better training stability and performance
min_effective_batch_size = 6  # Slightly reduced from 8 for better gradient updates
lr = 2e-4  # Reduced from 3e-4 for more stable training
max_seq_length = 64  # Keep at 64
collator_fn = None
packing = (collator_fn is None)
steps = 15  # Slightly more frequent than original 20
num_train_epochs = 8  # Reduced from 10 to prevent overfitting
warmup_ratio = 0.05  # Small warmup for learning rate stability

sft_config = SFTConfig(
    output_dir = '/content/drive/MyDrive/google/gemma-3-270m-ada',
    packing = packing,
    max_seq_length = max_seq_length,
    gradient_checkpointing = True,
    gradient_checkpointing_kwargs = {'use_reentrant': False},
    gradient_accumulation_steps = 3,  # Slightly increased
    per_device_train_batch_size = min_effective_batch_size,
    auto_find_batch_size = True,
    num_train_epochs = num_train_epochs,
    learning_rate = lr,
    lr_scheduler_type = "cosine",  # Better than linear decay
    warmup_ratio = warmup_ratio,  # Gradual learning rate warmup
    weight_decay = 0.01,  # Light regularization
    max_grad_norm = 1.0,  # Gradient clipping for stability
    report_to = 'wandb',
    logging_dir = '/content/drive/MyDrive/google/gemma-3-270m/logs',
    logging_strategy = 'steps',
    logging_steps = steps,
    save_strategy = 'steps',
    save_steps = steps,
    save_total_limit = 2,  # Keep only 2 best checkpoints
    fp16 = True,  # Mixed precision for memory efficiency
)

trainer = SFTTrainer(
    model = model,
    train_dataset = dataset,
    processing_class = tokenizer,
    data_collator = collator_fn,
    args = sft_config,
)

trainer.train()

Generating train split: 0 examples [00:00, ? examples/s]



Step,Training Loss
15,2.3855
30,2.296
45,2.12
60,2.121
75,2.0286
90,1.9574
105,1.8023
120,1.8316
135,1.7146
150,1.6971


TrainOutput(global_step=256, training_loss=1.8118571266531944, metrics={'train_runtime': 373.4104, 'train_samples_per_second': 12.019, 'train_steps_per_second': 0.686, 'total_flos': 176172906184704.0, 'train_loss': 1.8118571266531944, 'epoch': 8.0})

In [66]:
trainer.save_model('Gemma-3-270m-4bit')

In [67]:
reloaded_model = AutoPeftModelForCausalLM.from_pretrained('Gemma-3-270m-4bit')
reloaded_model

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Gemma3ForCausalLM(
      (model): Gemma3TextModel(
        (embed_tokens): Gemma3TextScaledWordEmbedding(262145, 640, padding_idx=0)
        (layers): ModuleList(
          (0-17): 18 x Gemma3DecoderLayer(
            (self_attn): Gemma3Attention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=640, out_features=1024, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=640, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=1024, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              

In [68]:
merged_model = reloaded_model.merge_and_unload()
merged_model

Gemma3ForCausalLM(
  (model): Gemma3TextModel(
    (embed_tokens): Gemma3TextScaledWordEmbedding(262145, 640, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x Gemma3DecoderLayer(
        (self_attn): Gemma3Attention(
          (q_proj): Linear(in_features=640, out_features=1024, bias=False)
          (k_proj): Linear(in_features=640, out_features=256, bias=False)
          (v_proj): Linear(in_features=640, out_features=256, bias=False)
          (o_proj): Linear(in_features=1024, out_features=640, bias=False)
          (q_norm): Gemma3RMSNorm((256,), eps=1e-06)
          (k_norm): Gemma3RMSNorm((256,), eps=1e-06)
        )
        (mlp): Gemma3MLP(
          (gate_proj): Linear(in_features=640, out_features=2048, bias=False)
          (up_proj): Linear(in_features=640, out_features=2048, bias=False)
          (down_proj): Linear(in_features=2048, out_features=640, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma3RMSNorm((640,), eps

In [69]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [75]:
from huggingface_hub import HfApi
import os

# Set your token directly
HF_TOKEN = "Token_here"

# Define your folder path (update if different)
folder_path = "/content/Gemma-3-270m-4bit"  # Change this if needed

# Define your model repo name
repo_id = "sweatSmile/Gemma-3-270m-Buddha-QA"

api.upload_folder(
    folder_path="/content/Gemma-3-270m-4bit",
    repo_id="sweatSmile/Gemma-3-270m-Buddha-QA",
    repo_type="model",
    commit_message="Upload Gemma-3-270m fine-tuned on Buddha QA dataset"
)

Processing Files (0 / 0)                : |          |  0.00B /  0.00B            

New Data Upload                         : |          |  0.00B /  0.00B            

  ...Gemma-3-270m-4bit/training_args.bin: 100%|##########| 5.69kB / 5.69kB            

  ...t/Gemma-3-270m-4bit/tokenizer.model: 100%|##########| 4.69MB / 4.69MB            

  ...270m-4bit/adapter_model.safetensors: 100%|##########| 7.63MB / 7.63MB            

  ...nt/Gemma-3-270m-4bit/tokenizer.json: 100%|##########| 33.4MB / 33.4MB            

CommitInfo(commit_url='https://huggingface.co/sweatSmile/Gemma-3-270m-Buddha-QA/commit/9f16a67f2aad6f49d51614a5918b24499f7aaddc', commit_message='Upload Gemma-3-270m fine-tuned on Buddha QA dataset', commit_description='', oid='9f16a67f2aad6f49d51614a5918b24499f7aaddc', pr_url=None, repo_url=RepoUrl('https://huggingface.co/sweatSmile/Gemma-3-270m-Buddha-QA', endpoint='https://huggingface.co', repo_type='model', repo_id='sweatSmile/Gemma-3-270m-Buddha-QA'), pr_revision=None, pr_num=None)