# LoRA (Finetuning)

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/gemma/blob/main/colabs/lora_finetuning.ipynb)

This is an example on fine-tuning Gemma with LoRA. It's best to first read the [finetuning](https://gemma-llm.readthedocs.io/en/latest/finetuning.html) colab to understand this one.

See the [LoRA sampling](https://gemma-llm.readthedocs.io/en/latest/lora_sampling.html) tutorial if you just want to do inference with LoRA.


In [1]:
!pip install -q gemma

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.7/5.7 MB[0m [31m41.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m122.3/122.3 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m469.1/469.1 kB[0m [31m28.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m55.2/55.2 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m396.1/396.1 kB[0m [31m25.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m101.8/101.8 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m111.0/111.0 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.5/76.5 kB[0m [31m7.7 MB/s[0m eta 

In [2]:
# Common imports
import os
import optax
import treescope

# Gemma imports
from kauldron import kd
from gemma import gm

By default, Jax do not utilize the full GPU memory, but this can be overwritten. See [GPU memory allocation](https://docs.jax.dev/en/latest/gpu_memory_allocation.html):

In [3]:
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"


## Config updates

If you're familiar with the [finetuning](https://gemma-llm.readthedocs.io/en/latest/finetuning.html) tutorial, switching to LoRA only require 3 changes to the trainer.

For an end-to-end example, see
[lora.py](https://github.com/google-deepmind/gemma/tree/main/examples/lora.py) config.

### 1. Model

Wrap the model in the `gm.nn.LoRA`. This will apply model surgery to replace all the linear and compatible layers with LoRA layers.

In [4]:
model = gm.nn.LoRA(
    rank=4,
    model=gm.nn.Gemma3_4B(tokens="batch.input", text_only=True),
)

Internally, this uses the [`gemma.peft`](https://github.com/google-deepmind/gemma/blob/main/gemma/peft) mini-library to perform model surgery.

### 2. Checkpoint

Wrap the init transform in a `gm.ckpts.SkipLoRA`. The wrapper is required because the param structure with and without LoRA is different.

Only the initial pretrained weights are loaded, but the LoRA weights are kept to their random initialization.

In [5]:
init_transform = gm.ckpts.SkipLoRA(
    wrapped=gm.ckpts.LoadCheckpoint(
        path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
    ),
)

Note: If you're loading the weights directly with `gm.ckpts.load_params`, you can use the `peft.split_params` and `peft.merge_params` instead. See [LoRA sampling](https://gemma-llm.readthedocs.io/en/latest/lora_sampling.html) for an example.

### 3. Optimizer

Finally, we add a mask to the optimizer (with `kd.optim.partial_updates`), so only the LoRA weights are trained.

In [6]:
optimizer = kd.optim.partial_updates(
    optax.adafactor(learning_rate=0.005),
    # We only optimize the LoRA weights. The rest of the model is frozen.
    mask=kd.optim.select("lora"),
)

## Training

### Data pipeline

Like for the [finetuning](https://gemma-llm.readthedocs.io/en/latest/finetuning.html) example, we recreate the tokenizer:

In [7]:
tokenizer = gm.text.Gemma3Tokenizer()

tokenizer.encode('This is an example sentence', add_bos=True)

[<_Gemma3SpecialTokens.BOS: 2>, 2094, 563, 614, 2591, 13315]

And the data pipeline:

In [8]:
ds = kd.data.py.Json(
    name='mtnt/en-fr',
    split='train',
    shuffle=True,
    batch_size=8,
    transforms=[
        # Create the model inputs/targets/loss_mask.
        gm.data.Seq2SeqTask(
            # Select which field from the dataset to use.
            # https://www.tensorflow.org/datasets/catalog/mtnt
            in_prompt='src',
            in_response='dst',
            # Output batch is {'input': ..., 'target': ..., 'loss_mask': ...}
            out_input='input',
            out_target='target',
            out_target_mask='loss_mask',
            tokenizer=tokenizer,
            # Padding parameters
            max_length=200,
            truncate=True,
        ),
    ],
)

ex = ds[0]

treescope.show(ex)



Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/mtnt/en-fr/1.0.0...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]

Generating train examples...: 0 examples [00:00, ? examples/s]

Shuffling /root/tensorflow_datasets/mtnt/en-fr/incomplete.LPNYUO_1.0.0/mtnt-train.array_record*...:   0%|     …

Generating test examples...: 0 examples [00:00, ? examples/s]

Shuffling /root/tensorflow_datasets/mtnt/en-fr/incomplete.LPNYUO_1.0.0/mtnt-test.array_record*...:   0%|      …

Generating valid examples...: 0 examples [00:00, ? examples/s]

Shuffling /root/tensorflow_datasets/mtnt/en-fr/incomplete.LPNYUO_1.0.0/mtnt-valid.array_record*...:   0%|     …

Dataset mtnt downloaded and prepared to /root/tensorflow_datasets/mtnt/en-fr/1.0.0. Subsequent calls will reuse this data.


We can decode an example from the batch to inspect the model input and check it is properly formatted:

In [9]:
text = tokenizer.decode(ex['input'][0])

print(text)

<start_of_turn>user
*From Russia with love light snowing yesterday and -12C in the morning today*<end_of_turn>
<start_of_turn>model
* Bons baisers de Russie avec une neige légère hier et une température de -12°C ce matin*


### Trainer

We then create the trainer, reusing the `model`, `init_transform` and `optimizer` created above:

In [10]:
trainer = kd.train.Trainer(
    seed=42,  # The seed of enlightenment
    workdir='/tmp/ckpts',  # TODO(epot): Make the workdir optional by default
    # Dataset
    train_ds=ds,
    # Model
    model=model,
    init_transform=init_transform,
    # Training parameters
    num_train_steps=500,
    train_losses={
        "loss": kd.losses.SoftmaxCrossEntropyWithIntLabels(
            logits="preds.logits",
            labels="batch.target",
            mask="batch.loss_mask",
        ),
    },
    optimizer=optimizer,
)

Trainning can be launched with the `.train()` method.

Note that the trainer like the model are immutables, so it does not store the state nor params. Instead the state containing the trained parameters is returned.

In [11]:
state, aux = trainer.train()

Disabling pygrain multi-processing (unsupported in colab).




Starting training loop at step 0


train:   0%|          | 0/501 [00:00<?, ?it/s]

XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 1694498816 bytes.

## Checkpointing

In [None]:
# TODO(epot): Doc on:
# * saving only LoRA params
# * Fuse params

## Evaluation

Here, we only perform a qualitative evaluation by sampling a prompt.

For more info on evals:

* See the [sampling](https://gemma-llm.readthedocs.io/en/latest/sampling.html) tutorial for more info on running inference.
* To add evals during training, see the Kauldron [evaluator](https://kauldron.readthedocs.io/en/latest/eval.html) documentation.


In [None]:
sampler = gm.text.ChatSampler(
    model=model,
    params=state.params,
    tokenizer=tokenizer,
)

We test a sentence, using the same formatting used during fine-tuning:

In [None]:
sampler.chat('I\'m feeling happy today!')

"Je me sens heureux aujourd'hui !"

The model correctly translated our prompt to French!