<a href="https://colab.research.google.com/github/amit-chaubey/ft-LoRA-Gemma-4bit/blob/main/Finetune_LoRA_gemma_3n_E4B_it.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install datasets bitsandbytes trl huggingface-hub accelerate safetensors pandas matplotlib

In [2]:
import os
import torch
from datasets import load_dataset
from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from trl import SFTTrainer, SFTConfig

In [4]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit = True,
    bnb_4bit_quant_type= "nf4",
    bnb_4bit_compute_dtype= torch.float32,
    bnb_4bit_use_double_quant= True
    )
repo = "google/gemma-3n-E4B-it"
model = AutoModelForCausalLM.from_pretrained(repo, quantization_config= bnb_config, device_map= "cuda:0")


config.json:   0%|          | 0.00/4.15k [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/171k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/3.08G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/2.66G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/215 [00:00<?, ?B/s]

In [5]:
print(model.get_memory_footprint()/1024/1024)

8299.145645141602


In [6]:
model

Gemma3nForConditionalGeneration(
  (model): Gemma3nModel(
    (vision_tower): TimmWrapperModel(
      (timm_model): MobileNetV5Encoder(
        (conv_stem): ConvNormAct(
          (conv): Conv2dSame(3, 64, kernel_size=(3, 3), stride=(2, 2), bias=False)
          (bn): RmsNormAct2d(
            (drop): Identity()
            (act): GELU(approximate='none')
          )
        )
        (blocks): Sequential(
          (0): Sequential(
            (0): EdgeResidual(
              (conv_exp): Conv2dSame(64, 256, kernel_size=(3, 3), stride=(2, 2), bias=False)
              (bn1): RmsNormAct2d(
                (drop): Identity()
                (act): GELU(approximate='none')
              )
              (aa): Identity()
              (se): Identity()
              (conv_pwl): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (bn2): RmsNormAct2d(
                (drop): Identity()
                (act): Identity()
              )
              (drop_path): Identi

In [7]:
model = prepare_model_for_kbit_training(model)

config = LoraConfig(
    r = 16, #. rank of LoRA - [4-16]
    bias = "none", # ["all", "lora_only"] - for train bias term
    lora_alpha = 32, # scalling factor
    lora_dropout = 0.07, # prevent overfit- used for regularisation
    target_modules = ["query_key_value", "o_proj", "qkv_proj", "gate_up_proj", "down_proj"],
    task_type = "CAUSAL_LM"

)

model = get_peft_model(model, config)
model

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Gemma3nForConditionalGeneration(
      (model): Gemma3nModel(
        (vision_tower): TimmWrapperModel(
          (timm_model): MobileNetV5Encoder(
            (conv_stem): ConvNormAct(
              (conv): Conv2dSame(3, 64, kernel_size=(3, 3), stride=(2, 2), bias=False)
              (bn): RmsNormAct2d(
                (drop): Identity()
                (act): GELU(approximate='none')
              )
            )
            (blocks): Sequential(
              (0): Sequential(
                (0): EdgeResidual(
                  (conv_exp): Conv2dSame(64, 256, kernel_size=(3, 3), stride=(2, 2), bias=False)
                  (bn1): RmsNormAct2d(
                    (drop): Identity()
                    (act): GELU(approximate='none')
                  )
                  (aa): Identity()
                  (se): Identity()
                  (conv_pwl): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
       

In [8]:
print(model.get_memory_footprint()/1024/1024)

14421.863479614258


In [9]:
print(model.get_base_model)

<bound method PeftModel.get_base_model of PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Gemma3nForConditionalGeneration(
      (model): Gemma3nModel(
        (vision_tower): TimmWrapperModel(
          (timm_model): MobileNetV5Encoder(
            (conv_stem): ConvNormAct(
              (conv): Conv2dSame(3, 64, kernel_size=(3, 3), stride=(2, 2), bias=False)
              (bn): RmsNormAct2d(
                (drop): Identity()
                (act): GELU(approximate='none')
              )
            )
            (blocks): Sequential(
              (0): Sequential(
                (0): EdgeResidual(
                  (conv_exp): Conv2dSame(64, 256, kernel_size=(3, 3), stride=(2, 2), bias=False)
                  (bn1): RmsNormAct2d(
                    (drop): Identity()
                    (act): GELU(approximate='none')
                  )
                  (aa): Identity()
                  (se): Identity()
                  (conv_pwl): Conv2d(256, 128, kernel_size=

In [10]:
print(model.get_memory_footprint()/1e6)

15122.41992


In [11]:
trainable_params, total_params = model.get_nb_trainable_parameters()
percentage = (trainable_params / total_params) * 100

print(f"Trainable Parameters: {trainable_params:,}")
print(f"Total Parameters: {total_params:,}")
print(f"Percentage Trainable: {percentage:.2f}%")

Trainable Parameters: 12,615,680
Total Parameters: 7,862,593,808
Percentage Trainable: 0.16%
