https://ai.google.dev/gemma/docs/core/huggingface_text_full_finetune

https://huggingface.co/collections/google/gemma-3-release

https://huggingface.co/datasets/bebechien/MobileGameNPC

In [1]:
import json
from datasets import load_dataset
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

from random import randint
import re

from trl import SFTConfig, SFTTrainer

import matplotlib.pyplot as plt

In [2]:
dataset = load_dataset(
    "csv", name="social-media-conversations", split="train", data_files="conversations.csv"
)
dataset = dataset.shuffle(seed=42)
print(dataset.column_names)

['timestamp', 'comment_id', 'comment_body', 'parent_text']


In [3]:
def format_conversation(sample):
    return {
        "messages": [
            {"role": "user", "content": sample["parent_text"]},
            {"role": "assistant", "content": sample["comment_body"]},
        ]
    }


# Convert dataset to conversational format
dataset = dataset.map(format_conversation, remove_columns=dataset.features, batched=False)

# Split dataset into 80% training samples and 20% test samples
dataset = dataset.train_test_split(test_size=0.2, shuffle=False)

In [4]:
n_example_train = 2
# Print formatted user prompt
print(json.dumps(dataset["train"][n_example_train]["messages"], indent=2))

[
  {
    "content": "Yeah, at 65 a fatal heart attack, stroke, pulmonary embolism, etc. aren't out of the realm of possibility.",
    "role": "user"
  },
  {
    "content": "They aren't out of the realm of possibility at any age. It's just different levels of probability.",
    "role": "assistant"
  }
]


In [5]:
n_example_test = 3004
print(json.dumps(dataset["test"][n_example_test]["messages"], indent=2))

[
  {
    "content": "600 mph is 965.61 km/h",
    "role": "user"
  },
  {
    "content": "good bot",
    "role": "assistant"
  }
]


In [6]:
print(f"len(dataset['train']): {len(dataset['train'])}")
print(f"len(dataset['test']): {len(dataset['test'])}")

len(dataset['train']): 29572
len(dataset['test']): 7393


In [7]:
# base_model = "google/gemma-3-4b-it"
base_model = "google/gemma-3-12b-it"
# base_model = "google/gemma-3-27b-it"
checkpoint_dir = "checkpoints"
learning_rate = 5e-6

In [8]:
torch_dtype = torch.bfloat16

# Define model init arguments
model_kwargs = dict(
    attn_implementation="flash_attention_2",
    torch_dtype=torch_dtype,  # What torch dtype to use, defaults to auto
    device_map="auto",  # Let torch decide how to load the model
)

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(base_model, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained(base_model)

print(f"Device: {model.device}")
print(f"DType: {model.dtype}")

`torch_dtype` is deprecated! Use `dtype` instead!
    Found GPU0 NVIDIA GB10 which is of cuda capability 12.1.
    Minimum and Maximum cuda capability supported by this version of PyTorch is
    (8.0) - (12.0)
    


Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

Device: cuda:0
DType: torch.bfloat16


In [None]:
print(f"model.config._attn_implementation: {model.config._attn_implementation}")

'flash_attention_2'

In [10]:
# Load the model and tokenizer into the pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

# Load a sample from the test dataset
test_sample = dataset["test"][n_example_test]

# Convert a test example into a prompt with the Gemma template
prompt = pipe.tokenizer.apply_chat_template(
    test_sample["messages"][:1], tokenize=False, add_generation_prompt=True
)
outputs = pipe(prompt, max_new_tokens=256, disable_compile=True)

# Extract the user query and original answer
print(f"Question:\n{test_sample['messages'][0]['content']}\n")
print(f"Original Answer:\n{test_sample['messages'][1]['content']}\n")
print(f"Generated Answer (base model):\n{outputs[0]['generated_text'][len(prompt):].strip()}")

Device set to use cuda:0


Question:
600 mph is 965.61 km/h

Original Answer:
good bot

Generated Answer (base model):
That's correct! 600 miles per hour (mph) is indeed equal to approximately 965.61 kilometers per hour (km/h).

Here's the calculation:

* 1 mile = 1.60934 kilometers
* 600 miles = 600 * 1.60934 kilometers = 965.604 kilometers
* 600 miles per hour = 965.604 kilometers per hour

Rounding that gives you 965.61 km/h.


In [None]:
torch_dtype = model.dtype
evaluations = 10
eval_steps = len(dataset['train']) // evaluations

args = SFTConfig(
    output_dir=checkpoint_dir,  # directory to save and repository id
    max_length=512,  # max sequence length for model and packing of the dataset
    packing=True,  # Groups multiple samples in the dataset into a single sequence
    num_train_epochs=1,  # number of training epochs
    per_device_train_batch_size=1,  # batch size per device during training
    gradient_checkpointing=True,
    optim="adamw_torch_fused",  # use fused adamw optimizer
    logging_steps=1,  # log every step
    save_strategy="no",  # save checkpoint never
    eval_strategy="steps",
    eval_steps=eval_steps,
    learning_rate=learning_rate,  # learning rate
    fp16=True if torch_dtype == torch.float16 else False,  # use float16 precision
    bf16=True if torch_dtype == torch.bfloat16 else False,  # use bfloat16 precision
    lr_scheduler_type="constant",  # use constant learning rate scheduler
    push_to_hub=False,  # do not push model to hub
    report_to="tensorboard",  # report metrics to tensorboard
    dataset_kwargs={
        "add_special_tokens": False,  # Template with special tokens
        "append_concat_token": True,  # Add EOS token as separator token between examples
    },
)

In [12]:
# Create Trainer object
trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['test'],
    processing_class=tokenizer,
)

Tokenizing train dataset:   0%|          | 0/29572 [00:00<?, ? examples/s]

Packing train dataset:   0%|          | 0/29572 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/7393 [00:00<?, ? examples/s]

Packing eval dataset:   0%|          | 0/7393 [00:00<?, ? examples/s]

The model is already on multiple devices. Skipping the move to device specified in `args`.


In [None]:
# if packing is enabled, the number of training steps is different
# we need to set the eval_steps accordingly
training_steps = trainer.args.num_train_epochs * len(trainer.get_train_dataloader())
eval_steps = training_steps // evaluations
trainer.args.eval_steps = eval_steps
print(f"eval_steps: {trainer.args.eval_steps}")

eval_steps: 754


In [None]:
# Start training, the model will be automatically saved to the Hub and the output directory
trainer.train()

In [None]:
# Save the final model again to the Hugging Face Hub
trainer.save_model()

In [None]:
# Access the log history
log_history = trainer.state.log_history

# Extract training / validation loss
train_losses = [log["loss"] for log in log_history if "loss" in log]
epoch_train = [log["epoch"] for log in log_history if "loss" in log]
eval_losses = [log["eval_loss"] for log in log_history if "eval_loss" in log]
epoch_eval = [log["epoch"] for log in log_history if "eval_loss" in log]

# Plot the training loss
plt.plot(epoch_train, train_losses, label="Training Loss")
plt.plot(epoch_eval, eval_losses, label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training and Validation Loss per Epoch")
plt.legend()
plt.grid(True)
plt.show()

In [None]:
del model, tokenizer, pipe, trainer
torch.cuda.empty_cache()

In [None]:
model_id = checkpoint_dir

# Load Model
model = AutoModelForCausalLM.from_pretrained(
    model_id, torch_dtype="auto", device_map="auto", attn_implementation="eager"
)
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [None]:
# Load the model and tokenizer into the pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)


def test(test_sample):
    # Convert as test example into a prompt with the Gemma template
    prompt = pipe.tokenizer.apply_chat_template(
        test_sample["messages"][:1], tokenize=False, add_generation_prompt=True
    )
    outputs = pipe(prompt, max_new_tokens=256, disable_compile=True)

    # Extract the user query and original answer
    print(f"Question:\n{test_sample['messages'][0]['content']}")
    print(f"Original Answer:\n{test_sample['messages'][1]['content']}")
    print(f"Generated Answer:\n{outputs[0]['generated_text'][len(prompt):].strip()}")
    print("-" * 80)


# Test with an unseen dataset
test(dataset['test'][n_example_test])