https://deepmind.google/models/gemma/gemma-3/

https://huggingface.co/collections/google/gemma-3-release

In [None]:
import json
from datasets import load_dataset
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    pipeline,
    BitsAndBytesConfig,
    AutoProcessor,
)
from peft import LoraConfig, PeftModel

from random import randint

from trl import SFTConfig, SFTTrainer

from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

In [None]:
dataset = load_dataset(
    "csv",
    name="social-media-conversations",
    split="train",
    data_files="conversations.csv",
)
dataset = dataset.shuffle(seed=42)
print(dataset.column_names)

In [None]:
# iterate over the dataset, collect the lenghts of each comment_body and parent_text
comment_lengths = []
parent_lengths = []
for sample in tqdm(dataset):
    comment_lengths.append(len(sample["comment_body"]))
    parent_lengths.append(len(sample["parent_text"]))

In [None]:
plt.figure(figsize=(16, 9))
bins = np.logspace(
    np.log10(min(min(comment_lengths), min(parent_lengths))),
    np.log10(max(max(comment_lengths), max(parent_lengths))),
    100,
)
plt.grid(True)
ax = plt.gca()
ax.set_axisbelow(True)
plt.hist(
    comment_lengths, bins=bins, edgecolor='black', label='Comment Lengths', alpha=0.5, color='C0'
)
plt.hist(
    parent_lengths, bins=bins, edgecolor='black', label='Parent Lengths', alpha=0.5, color='C1'
)
plt.xscale('log')
plt.title('Histogram of Comment and Parent Lengths', fontsize=20)
plt.xlabel('Length (characters)', fontsize=16)
plt.ylabel('Frequency', fontsize=16)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.legend(fontsize=16)
plt.show()

In [None]:
def format_conversation(sample):
    return {
        "messages": [
            {"role": "user", "content": sample["parent_text"]},
            {"role": "assistant", "content": sample["comment_body"]},
        ]
    }


# Convert dataset to conversational format
dataset = dataset.map(format_conversation, remove_columns=dataset.features, batched=False)
dataset = dataset.train_test_split(test_size=0.2, shuffle=False)

In [None]:
n_example_train = 2
# Print formatted user prompt
print(json.dumps(dataset["train"][n_example_train]["messages"], indent=2))

In [None]:
n_example_test = 3004
print(json.dumps(dataset["test"][n_example_test]["messages"], indent=2))

In [None]:
print(f"len(dataset['train']): {len(dataset['train'])}")
print(f"len(dataset['test']): {len(dataset['test'])}")

In [None]:
# base_model = "google/gemma-3-4b-it"
base_model = "google/gemma-3-12b-it"
# base_model = "google/gemma-3-27b-it"
checkpoint_dir = "checkpoints"

# best for full finetuning: 1e-5
# best for LoRA: 2e-5
learning_rate = 1e-5

use_lora = False

print()
print(f"Training model: {base_model}")
print(f"Learning rate: {learning_rate}")
if use_lora:
    print("Using LoRA")
else:
    print("Using full finetuning")
print()

In [None]:
torch_dtype = torch.bfloat16

# Define model init arguments
model_kwargs = dict(
    attn_implementation="flash_attention_2",
    torch_dtype=torch_dtype,
    device_map="auto",
)

if use_lora:
    model_kwargs["quantization_config"] = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type='nf4',
        bnb_4bit_compute_dtype=model_kwargs['torch_dtype'],
        bnb_4bit_quant_storage=model_kwargs['torch_dtype'],
    )

    peft_config = LoraConfig(
        lora_alpha=16,
        lora_dropout=0.05,
        r=16,
        bias="none",
        target_modules="all-linear",
        task_type="CAUSAL_LM",
        modules_to_save=[
            "lm_head",
            "embed_tokens",
        ],
    )

print("Loading base model...")
model = AutoModelForCausalLM.from_pretrained(base_model, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained(base_model)

In [None]:
print(f"Device: {model.device}")
print(f"DType: {model.dtype}")
print(f"model.config._attn_implementation: {model.config._attn_implementation}")

In [None]:
print("Create pipeline and run one prediction.")
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

test_sample = dataset["test"][n_example_test]

# Convert a test example into a prompt with the Gemma template
prompt = pipe.tokenizer.apply_chat_template(
    test_sample["messages"][:1], tokenize=False, add_generation_prompt=True
)
outputs = pipe(prompt, max_new_tokens=256, disable_compile=True)

print(f"Question:\n{test_sample['messages'][0]['content']}\n")
print(f"Original Answer:\n{test_sample['messages'][1]['content']}\n")
print(f"Generated Answer (base model):\n{outputs[0]['generated_text'][len(prompt):].strip()}")

In [None]:
torch_dtype = model.dtype
evaluations = 10
eval_steps = len(dataset['train']) // evaluations

if use_lora:
    args = SFTConfig(
        output_dir=checkpoint_dir,
        max_length=512,
        packing=True,
        num_train_epochs=1,
        # max_steps=20,
        per_device_train_batch_size=1,
        gradient_checkpointing=True,
        optim="adamw_torch_fused",
        logging_steps=1,
        save_strategy="no",
        eval_strategy="steps",
        eval_steps=eval_steps,
        learning_rate=learning_rate,
        fp16=True if torch_dtype == torch.float16 else False,
        bf16=True if torch_dtype == torch.bfloat16 else False,
        max_grad_norm=0.3,
        warmup_ratio=0.03,
        lr_scheduler_type="constant",
        push_to_hub=False,
        report_to="tensorboard",
        dataset_kwargs={
            "add_special_tokens": False,
            "append_concat_token": True,
        },
    )
else:
    args = SFTConfig(
        output_dir=checkpoint_dir,
        max_length=512,
        packing=True,
        num_train_epochs=1,
        # max_steps=20,
        per_device_train_batch_size=1,
        gradient_checkpointing=True,
        optim="adamw_torch_fused",
        logging_steps=1,
        save_strategy="no",
        eval_strategy="steps",
        eval_steps=eval_steps,
        learning_rate=learning_rate,
        fp16=True if torch_dtype == torch.float16 else False,
        bf16=True if torch_dtype == torch.bfloat16 else False,
        lr_scheduler_type="constant",
        push_to_hub=False,
        report_to="tensorboard",
        dataset_kwargs={
            "add_special_tokens": False,
            "append_concat_token": True,
        },
    )

In [None]:
if use_lora:
    trainer = SFTTrainer(
        model=model,
        args=args,
        train_dataset=dataset["train"],
        eval_dataset=dataset['test'],
        peft_config=peft_config,
        processing_class=tokenizer,
    )
else:
    trainer = SFTTrainer(
        model=model,
        args=args,
        train_dataset=dataset['train'],
        eval_dataset=dataset['test'],
        processing_class=tokenizer,
    )

In [None]:
# if packing is enabled, the number of training steps
# is not the same as the number of steps in the dataset
# we need to set the eval_steps accordingly
training_steps = trainer.args.num_train_epochs * len(trainer.get_train_dataloader())
eval_steps = training_steps // evaluations
trainer.args.eval_steps = eval_steps
print(f"eval_steps: {trainer.args.eval_steps}")

In [None]:
# main training loop
trainer.train()

In [None]:
print()
print("Saving model...")
trainer.save_model()

In [None]:
try:
    # this might be needed by some performance evaluation tools
    processor = AutoProcessor.from_pretrained(base_model)
    processor.save_pretrained(checkpoint_dir)
except Exception as e:
    print(f"Base model does not have a processor: {e}")
    print("Skipping processor saving.")

In [None]:
# Access the log history
log_history = trainer.state.log_history

# Extract training / validation loss
train_losses = [log["loss"] for log in log_history if "loss" in log]
epoch_train = [log["epoch"] for log in log_history if "loss" in log]
eval_losses = [log["eval_loss"] for log in log_history if "eval_loss" in log]
epoch_eval = [log["epoch"] for log in log_history if "eval_loss" in log]

# Plot the training loss
plt.plot(epoch_train, train_losses, label="Training Loss")
plt.plot(epoch_eval, eval_losses, label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training and Validation Loss per Epoch")
plt.legend()
plt.grid(True)
plt.show()

In [None]:
# free some memory
del model, tokenizer, pipe, trainer
torch.cuda.empty_cache()

In [None]:
if use_lora:
    print("Loading base model...")
    model = AutoModelForCausalLM.from_pretrained(
        base_model, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
    )

    print("Loading LoRA weights...")
    peft_model = PeftModel.from_pretrained(model, checkpoint_dir)
    print("Merging base and LoRA...")
    merged_model = peft_model.merge_and_unload()
    print(f"Merged model DType: {merged_model.dtype}")

    print("Saving merged model...")
    merged_model_dir = "merged_model"
    merged_model.save_pretrained(merged_model_dir, safe_serialization=True)
    tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir)
    tokenizer.save_pretrained(merged_model_dir)

    try:
        # this might be needed by some performance evaluation tools
        processor = AutoProcessor.from_pretrained(base_model)
        processor.save_pretrained(merged_model_dir)
    except Exception as e:
        print(f"Base model does not have a processor: {e}")
        print("Skipping processor saving.")

    del model, peft_model, merged_model, tokenizer, processor
    torch.cuda.empty_cache()

In [None]:
print("Loading saved model for inference...")
if use_lora:
    # flash attention does not seem to work with LoRA for inference
    attn_implementation = "eager"
    model = AutoModelForCausalLM.from_pretrained(
        merged_model_dir,
        torch_dtype=torch_dtype,
        device_map="auto",
        attn_implementation=attn_implementation,
    )
    tokenizer = AutoTokenizer.from_pretrained(merged_model_dir)
else:
    attn_implementation = "flash_attention_2"
    model = AutoModelForCausalLM.from_pretrained(
        checkpoint_dir,
        torch_dtype=torch_dtype,
        device_map="auto",
        attn_implementation=attn_implementation,
    )
    tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir)

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

In [None]:
print(f"Device: {model.device}")
print(f"DType: {model.dtype}")
print(f"model.config._attn_implementation: {model.config._attn_implementation}")

In [None]:
def test(test_sample):
    print(f"Question:\n{test_sample['messages'][0]['content']}")
    print(f"Original Answer:\n{test_sample['messages'][1]['content']}")

    # Convert a test example into a prompt with the Gemma template
    prompt = pipe.tokenizer.apply_chat_template(
        test_sample["messages"][:1], tokenize=False, add_generation_prompt=True
    )
    if use_lora:
        outputs = pipe(
            prompt,
            max_new_tokens=256,
            do_sample=False,
            temperature=0.1,
            top_k=50,
            top_p=0.1,
            disable_compile=True,
        )
    else:
        outputs = pipe(prompt, max_new_tokens=256, disable_compile=True)
    print(f"Generated Answer:\n{outputs[0]['generated_text'][len(prompt):].strip()}")
    print("-" * 80)


print("Generate inference from the test set.")
print("-" * 80)
for i in range(10):
    n_test = randint(0, len(dataset['test']) - 1)
    test(dataset['test'][n_test])