https://ai.google.dev/gemma/docs/core/huggingface_text_finetune_qlora

In [1]:
import json

from datasets import load_dataset

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, BitsAndBytesConfig

In [2]:
system_prompt = """Some system prompt"""
n_example = 4


def create_conversation(sample):
    return {
        "messages": [
            # Gemma 3 doesn't support system prompts per se
            # {"role": "system", "content": system_prompt},
            {
                "role": "user",
                "content": sample['parent_text'],
            },
            {"role": "assistant", "content": sample['comment_body']},
        ]
    }


dataset = load_dataset("csv", name="csv-for-gemma3-hf", split="train", data_files="conversations.csv")
dataset = dataset.shuffle(seed=42)
column_names_orig = dataset.column_names
print(column_names_orig)
print(json.dumps(dataset[n_example], indent=2))

['timestamp', 'comment_id', 'comment_body', 'parent_text']
{
  "timestamp": "2021-08-25 16:32:18 UTC",
  "comment_id": "hab6nka",
  "comment_body": "Pfft, just turn the volume way up, it'll blow dry it.",
  "parent_text": "Can confirm, shoulder length hair is not conducive to wearing headphones for the 4 hours after I shower."
}


In [3]:
dataset = dataset.map(create_conversation, remove_columns=column_names_orig, batched=False)
print(json.dumps(dataset[n_example], indent=2))

Map:   0%|          | 0/43898 [00:00<?, ? examples/s]

{
  "messages": [
    {
      "content": "Can confirm, shoulder length hair is not conducive to wearing headphones for the 4 hours after I shower.",
      "role": "user"
    },
    {
      "content": "Pfft, just turn the volume way up, it'll blow dry it.",
      "role": "assistant"
    }
  ]
}


In [4]:
torch.cuda.get_device_capability()[0]

8

In [5]:
# Hugging Face model id
model_id = "google/gemma-3-1b-it-qat-q4_0-unquantized"
# model_id = "google/gemma-3-4b-it-qat-q4_0-unquantized"
# model_id = "google/gemma-3-27b-it-qat-q4_0-unquantized"

# Select model class based on id
if model_id == "google/gemma-3-1b-it-qat-q4_0-unquantized":
    model_class = AutoModelForCausalLM
else:
    model_class = AutoModelForImageTextToText

# Check if GPU benefits from bfloat16
if torch.cuda.get_device_capability()[0] >= 8:
    torch_dtype = torch.bfloat16
else:
    torch_dtype = torch.float16

# Define model init arguments
model_kwargs = dict(
    attn_implementation="eager", # Use "flash_attention_2" when running on Ampere or newer GPU
    torch_dtype=torch_dtype, # What torch dtype to use, defaults to auto
    device_map="auto", # Let torch decide how to load the model
)

# BitsAndBytesConfig: Enables 4-bit quantization to reduce model size/memory usage
model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=model_kwargs['torch_dtype'],
    bnb_4bit_quant_storage=model_kwargs['torch_dtype'],
)

# Load model and tokenizer
model = model_class.from_pretrained(model_id, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_id) # Load the Instruction Tokenizer to use the official Gemma template
