https://ai.google.dev/gemma/docs/core/huggingface_text_full_finetune

https://huggingface.co/collections/google/gemma-3-release

https://huggingface.co/datasets/bebechien/MobileGameNPC

In [None]:
import json
from datasets import load_dataset
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

from random import randint
import re

from trl import SFTConfig, SFTTrainer

import matplotlib.pyplot as plt

In [None]:
n_example = 4

dataset = load_dataset(
    "csv", name="social-media-conversations", split="train", data_files="conversations.csv"
)
dataset = dataset.shuffle(seed=42)
print(dataset.column_names)
print(json.dumps(dataset[n_example], indent=2))

In [None]:
def format_conversation(sample):
    return {
        "messages": [
            {"role": "user", "content": sample["parent_text"]},
            {"role": "assistant", "content": sample["comment_body"]},
        ]
    }


# Convert dataset to conversational format
dataset = dataset.map(format_conversation, remove_columns=dataset.features, batched=False)

# Split dataset into 80% training samples and 20% test samples
dataset = dataset.train_test_split(test_size=0.2, shuffle=False)

# Print formatted user prompt
print(json.dumps(dataset["train"][n_example]["messages"], indent=2))

In [None]:
# base_model = "google/gemma-3-4b-it"
base_model = "google/gemma-3-12b-it"
# base_model = "google/gemma-3-27b-it"
checkpoint_dir = "checkpoints"
learning_rate = 5e-5

In [None]:
torch_dtype = torch.bfloat16

# Define model init arguments
model_kwargs = dict(
    attn_implementation="eager",  # Use "flash_attention_2" when running on Ampere or newer GPU
    torch_dtype=torch_dtype,  # What torch dtype to use, defaults to auto
    device_map="auto",  # Let torch decide how to load the model
)

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(base_model, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained(base_model)

print(f"Device: {model.device}")
print(f"DType: {model.dtype}")

In [None]:
# Load the model and tokenizer into the pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

# Load a sample from the test dataset
n_example_test = 4
test_sample = dataset["test"][n_example_test]

# Convert a test example into a prompt with the Gemma template
prompt = pipe.tokenizer.apply_chat_template(
    test_sample["messages"][:1], tokenize=False, add_generation_prompt=True
)
outputs = pipe(prompt, max_new_tokens=256, disable_compile=True)

# Extract the user query and original answer
print(f"Question:\n{test_sample['messages'][0]['content']}\n")
print(f"Original Answer:\n{test_sample['messages'][1]['content']}\n")
print(f"Generated Answer (base model):\n{outputs[0]['generated_text'][len(prompt):].strip()}")

In [None]:
torch_dtype = model.dtype

args = SFTConfig(
    output_dir=checkpoint_dir,  # directory to save and repository id
    max_length=512,  # max sequence length for model and packing of the dataset
    packing=False,  # Groups multiple samples in the dataset into a single sequence
    num_train_epochs=1,  # number of training epochs
    per_device_train_batch_size=1,  # batch size per device during training
    gradient_checkpointing=False,  # Caching is incompatible with gradient checkpointing
    optim="adamw_torch_fused",  # use fused adamw optimizer
    logging_steps=1,  # log every step
    save_strategy="no",  # save checkpoint never
    eval_strategy="epoch",  # evaluate checkpoint every epoch
    learning_rate=learning_rate,  # learning rate
    fp16=True if torch_dtype == torch.float16 else False,  # use float16 precision
    bf16=True if torch_dtype == torch.bfloat16 else False,  # use bfloat16 precision
    lr_scheduler_type="constant",  # use constant learning rate scheduler
    push_to_hub=False,  # do not push model to hub
    report_to="tensorboard",  # report metrics to tensorboard
    dataset_kwargs={
        "add_special_tokens": False,  # Template with special tokens
        "append_concat_token": True,  # Add EOS token as separator token between examples
    },
)

In [None]:
# Create Trainer object
trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['test'],
    processing_class=tokenizer,
)

In [None]:
# Start training, the model will be automatically saved to the Hub and the output directory
trainer.train()

In [None]:
# Save the final model again to the Hugging Face Hub
trainer.save_model()

In [None]:
# Access the log history
log_history = trainer.state.log_history

# Extract training / validation loss
train_losses = [log["loss"] for log in log_history if "loss" in log]
epoch_train = [log["epoch"] for log in log_history if "loss" in log]
eval_losses = [log["eval_loss"] for log in log_history if "eval_loss" in log]
epoch_eval = [log["epoch"] for log in log_history if "eval_loss" in log]

# Plot the training loss
plt.plot(epoch_train, train_losses, label="Training Loss")
plt.plot(epoch_eval, eval_losses, label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training and Validation Loss per Epoch")
plt.legend()
plt.grid(True)
plt.show()

In [None]:
del model, tokenizer, pipe, trainer
torch.cuda.empty_cache()

In [None]:
model_id = checkpoint_dir

# Load Model
model = AutoModelForCausalLM.from_pretrained(
    model_id, torch_dtype="auto", device_map="auto", attn_implementation="eager"
)
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [None]:
# Load the model and tokenizer into the pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)


def test(test_sample):
    # Convert as test example into a prompt with the Gemma template
    prompt = pipe.tokenizer.apply_chat_template(
        test_sample["messages"][:1], tokenize=False, add_generation_prompt=True
    )
    outputs = pipe(prompt, max_new_tokens=256, disable_compile=True)

    # Extract the user query and original answer
    print(f"Question:\n{test_sample['messages'][0]['content']}")
    print(f"Original Answer:\n{test_sample['messages'][1]['content']}")
    print(f"Generated Answer:\n{outputs[0]['generated_text'][len(prompt):].strip()}")
    print("-" * 80)


# Test with an unseen dataset
for item in dataset['test']:
    test(item)