# Finetuning

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/gemma/blob/main/colabs/finetuning.ipynb)

This is an example on fine-tuning Gemma. For an example on how to run a pre-trained Gemma model, see the [sampling](https://gemma-llm.readthedocs.io/en/latest/sampling.html) tutorial.

To fine-tune Gemma, we use the [kauldron](https://kauldron.readthedocs.io/en/latest/) library which abstract most of the boilerplate (checkpoint management, training loop, evaluation, metric reporting, sharding,...).


In [4]:
!pip install -q gemma

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/122.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m122.3/122.3 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[?25h

In [5]:
!pip install PyPDF2  # For PDF reading

Collecting PyPDF2
  Downloading pypdf2-3.0.1-py3-none-any.whl.metadata (6.8 kB)
Downloading pypdf2-3.0.1-py3-none-any.whl (232 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/232.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m232.6/232.6 kB[0m [31m17.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: PyPDF2
Successfully installed PyPDF2-3.0.1


In [2]:
!pip install kauldron

Collecting kauldron
  Downloading kauldron-1.2.1-py3-none-any.whl.metadata (3.3 kB)
Collecting clu (from kauldron)
  Downloading clu-0.0.12-py3-none-any.whl.metadata (1.9 kB)
Collecting grain (from kauldron)
  Downloading grain-0.2.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (15 kB)
Collecting jaxtyping (from kauldron)
  Downloading jaxtyping-0.3.0-py3-none-any.whl.metadata (7.0 kB)
Collecting lark (from kauldron)
  Downloading lark-1.2.2-py3-none-any.whl.metadata (1.8 kB)
Collecting mediapy (from kauldron)
  Downloading mediapy-1.2.2-py3-none-any.whl.metadata (4.8 kB)
Collecting ml_collections (from kauldron)
  Downloading ml_collections-1.0.0-py3-none-any.whl.metadata (22 kB)
Collecting tfds-nightly (from kauldron)
  Downloading tfds_nightly-4.9.8.dev202503240044-py3-none-any.whl.metadata (11 kB)
Collecting xmanager (from kauldron)
  Downloading xmanager-0.6.0-py3-none-any.whl.metadata (12 kB)
Collecting wadler-lindig>=0.1.3 (from jaxtyping->kauldron)
  Down

In [6]:
# Common imports
import os
import optax
import treescope

# Gemma imports
from kauldron import kd
from gemma import gm

import tensorflow_datasets as tfds

import PyPDF2

By default, Jax do not utilize the full GPU memory, but this can be overwritten. See [GPU memory allocation](https://docs.jax.dev/en/latest/gpu_memory_allocation.html):

In [7]:
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

## Data pipeline

First create the tokenizer, as it's required in the data pipeline.

In [8]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [9]:
tokenizer = gm.text.Gemma3Tokenizer()

tokenizer.encode('This is an example sentence', add_bos=True)

[<_Gemma3SpecialTokens.BOS: 2>, 2094, 563, 614, 2591, 13315]

In [10]:
def extract_text_from_pdf(pdf_path):
    with open(pdf_path, 'rb') as pdf_file:
        pdf_reader = PyPDF2.PdfReader(pdf_file)
        text = ""
        for page_num in range(len(pdf_reader.pages)):
            page = pdf_reader.pages[page_num]
            text += page.extract_text()
        return text

In [11]:
pdf_dir = 'drive/MyDrive/pdfs'
data = []
for filename in os.listdir(pdf_dir):
    if filename.endswith(".pdf"):
        pdf_path = os.path.join(pdf_dir, filename)
        text = extract_text_from_pdf(pdf_path)
        data.append({"text": text})

In [12]:
def pdf_data_generator():
  for item in data:  # 'data' is your list of PDF dictionaries
      yield item

In [16]:
ds = kd.data.py.DataSource(
       data_source=pdf_data_generator(),
       shuffle=True,
       batch_size=8,  # Adjust as needed
       length=len(data),  # Set to the number of PDF files
       transforms=[
           kd.data.Elements(keep=["text"]),
           # Add other relevant transformations for text data
       ],
   )

First we need a data pipeline. Multiple pipelines are supported including:

* [HuggingFace](https://kauldron.readthedocs.io/en/latest/api/kd/data/py/HuggingFace.html)
* [TFDS](https://kauldron.readthedocs.io/en/latest/api/kd/data/py/Tfds.html)
* [Json](https://kauldron.readthedocs.io/en/latest/api/kd/data/py/Json.html)
* ...

It's quite simple to add your own data, or to create mixtures from multiple sources. See the [pipeline documentation](https://kauldron.readthedocs.io/en/latest/data_py.html).

We use `transforms` to customize the data pipeline, this includes:

* Tokenizing the inputs (with `gm.data.Tokenize`)
* Creating the model inputs (with `gm.data.Tokenize`))
* Adding padding (with `gm.data.Pad`) (required to batch inputs with different lengths)

Note that in practice, you can combine multiple transforms into a higher level transform. See the `gm.data.ContrastiveTask()` transform in the [DPO example](https://github.com/google-deepmind/gemma/tree/main/examples/dpo.py) for an example.

Here, we try [mtnt](https://www.tensorflow.org/datasets/catalog/mtnt), a small translation dataset. The dataset structure is `{'src': ..., 'dst': ...}`.

In [None]:
# ds = kd.data.py.Tfds(
#     name='mtnt/en-fr',
#     split='train',
#     shuffle=True,
#     batch_size=8,
#     transforms=[
#         # Create the model inputs/targets/loss_mask.
#         gm.data.Seq2SeqTask(
#             # Select which field from the dataset to use.
#             # https://www.tensorflow.org/datasets/catalog/mtnt
#             in_prompt='src',
#             in_response='dst',
#             # Output batch is {'input': ..., 'target': ..., 'loss_mask': ...}
#             out_input='input',
#             out_target='target',
#             out_target_mask='loss_mask',
#             tokenizer=tokenizer,
#             # Padding parameters
#             max_length=200,
#             truncate=True,
#         ),
#     ],
# )

# ex = ds[0]

# treescope.show(ex)

Disabling pygrain multi-processing (unsupported in colab).
{
    'input': i64[8 200],
    'loss_mask': bool_[8 200 1],
    'target': i64[8 200 1],
}


We can decode an example from the batch to inspect the model input. We see that the `<start_of_turn>` / `<end_of_turn>` where correctly added to follow Gemma dialog format.

In [None]:
text = tokenizer.decode(ex['input'][0])

print(text)

<start_of_turn>user
Would love any other tips from anyone, but specially from someone who’s been where I’m at.<end_of_turn>
<start_of_turn>model
J'apprécierais vraiment d'autres astuces, mais particulièrement par quelqu'un qui était était déjà là où je me trouve.


## Trainer

The [kauldron](https://kauldron.readthedocs.io/en/latest/) trainer allow to train Gemma simply by providing a dataset, a model, a loss and an optimizer.

Dataset, model and losses are connected together through a `key` strings system. For more information, see the [key documentation](https://kauldron.readthedocs.io/en/latest/intro.html#keys-and-context).

Each key starts by a registered prefix. Common prefixes includes:

* `batch`: The output of the dataset (after all transformations). Here our batch is `{'input': ..., 'loss_mask': ..., 'target': ...}`
* `preds`: The output of the model. For Gemma models, this is `gm.nn.Output(logits=..., cache=...)`
* `params`: Model parameters (can be used to add a weight decay loss, or monitor the params norm in metrics)






In [17]:
model = gm.nn.Gemma3_4B(
    tokens="batch.input",
)

In [18]:
loss = kd.losses.SoftmaxCrossEntropyWithIntLabels(
    logits="preds.logits",
    labels="batch.target",
    mask="batch.loss_mask",
)

We then create the trainer:

In [19]:
trainer = kd.train.Trainer(
    seed=42,  # The seed of enlightenment
    workdir='/tmp/ckpts',  # TODO(epot): Make the workdir optional by default
    # Dataset
    train_ds=ds,
    # Model
    model=model,
    init_transform=gm.ckpts.LoadCheckpoint(  # Load the weights from the pretrained checkpoint
        path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
    ),
    # Training parameters
    num_train_steps=300,
    train_losses={"loss": loss},
    optimizer=optax.adafactor(learning_rate=1e-3),
)

Trainning can be launched with the `.train()` method.

Note that the trainer like the model are immutables, so it does not store the state nor params. Instead the state containing the trained parameters is returned.

In [20]:
state, aux = trainer.train()

TypeError: object of type 'generator' has no len()

## Checkpointing

To save the model params, you can either:

* Activate checkpointing in the trainer by adding:

  ```python
  trainer = kd.train.Trainer(
      workdir='/tmp/my_experiment/',
      checkpointer=kd.ckpts.Checkpointer(
          save_interval_steps=500,
      ),
      ...
  )
  ```

  This will also save the optimizer, step, dataset state,...


* Manually save the trained params:

  ```python
  gm.ckpts.save_params(state.params, '/tmp/my_ckpt/')
  ```

## Evaluation

Here, we only perform a qualitative evaluation by sampling a prompt.

For more info on evals:

* See the [sampling](https://gemma-llm.readthedocs.io/en/latest/sampling.html) tutorial for more info on running inference.
* To add evals during training, see the Kauldron [evaluator](https://kauldron.readthedocs.io/en/latest/eval.html) documentation.


In [None]:
sampler = gm.text.ChatSampler(
    model=model,
    params=state.params,
    tokenizer=tokenizer,
)

We test a sentence, using the same formatting used during fine-tuning:

In [None]:
sampler.chat('Hello! My next holidays are in Paris.')

'Salut ! Mes vacances suivantes seront à Paris.'

The model correctly translated our prompt to French!

## Next steps

To fine-tune outside of Colab, you can look at our [examples/](https://github.com/google-deepmind/gemma/tree/main/examples/) folder for more complexes trainer configs, including LoRA, DPO and sharding.