<link rel="stylesheet" href="/site-assets/css/gemma.css">
<link rel="stylesheet" href="https://fonts.googleapis.com/css2?family=Google+Symbols:opsz,wght,FILL,GRAD@20..48,100..700,0..1,-50..200" />

# Fine-tuning Gemma to understand Dutch numerals 

## Introduction

Here I demonstrate fine-tuning the Gemma Large Language Model using Low Rank Adaptation (LoRA) along the lines decribed in this article:
<a target="_blank" href="https://ai.google.dev/gemma/docs/lora_tuning"><img src="https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png" height="15" width="15" />Fine-tune Gemma models in Keras using LoRA</a>.

I use an example in Dutch language understanding: understand Dutch numerals written in words. This is a fairly complex problem on which Gemma performs poorly, as not only are the words in Dutch, but Dutch numerals are written in a different order than English ones and several words can be strung together in various ways.

For example the number 34 is written as vierendertig (vier = 4 , en = and, dertig = 30). 483 is written as vierhonderddrieentachtig (vier = 4 , honderd = 100 , drie = 3 , en = and, tachtig = 80 - but no 'en' between 'vierhonderd' and 'drieentachtig').


## Background

Gemma is a family of lightweight, state-of-the art open models built from the same research and technology used to create the Gemini models.

Large Language Models (LLMs) like Gemma have been shown to be effective at a variety of NLP tasks. An LLM is first pre-trained on a large corpus of text in a self-supervised fashion. Pre-training helps LLMs learn general-purpose knowledge, such as statistical relationships between words. An LLM can then be fine-tuned with domain-specific data to perform downstream tasks (such as sentiment analysis).

LLMs are extremely large in size (parameters in the order of billions). Full fine-tuning (which updates all the parameters in the model) is not required for most applications because typical fine-tuning datasets are relatively much smaller than the pre-training datasets.

[Low Rank Adaptation (LoRA)](https://arxiv.org/abs/2106.09685) is a fine-tuning technique which greatly reduces the number of trainable parameters for downstream tasks by freezing the weights of the model and inserting a smaller number of new weights into the model. This makes training with LoRA much faster and more memory-efficient, and produces smaller model weights (a few hundred MBs), all while maintaining the quality of the model outputs.



## The Code

### Install dependencies

Install Keras, KerasNLP, and other dependencies.

In [3]:
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
%pip install -q -U keras_nlp

%pip install jax jaxlib

%pip install -q -U "keras>=3"

Note: you may need to restart the kernel to use updated packages.


DEPRECATION: Loading egg at c:\bin\python313\lib\site-packages\tabularocr-0.1.0-py3.13.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330


Collecting jax
  Downloading jax-0.4.37-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib
  Downloading jaxlib-0.4.36-cp313-cp313-win_amd64.whl.metadata (1.1 kB)
Collecting opt_einsum (from jax)
  Using cached opt_einsum-3.4.0-py3-none-any.whl.metadata (6.3 kB)
Downloading jax-0.4.37-py3-none-any.whl (2.2 MB)
   ---------------------------------------- 0.0/2.2 MB ? eta -:--:--
   ---------------------------------------- 2.2/2.2 MB 23.1 MB/s eta 0:00:00
Downloading jaxlib-0.4.36-cp313-cp313-win_amd64.whl (63.3 MB)
   ---------------------------------------- 0.0/63.3 MB ? eta -:--:--
    --------------------------------------- 1.0/63.3 MB 129.1 MB/s eta 0:00:01
   - -------------------------------------- 2.4/63.3 MB 4.6 MB/s eta 0:00:14
   -- ------------------------------------- 3.9/63.3 MB 5.3 MB/s eta 0:00:12
   --- ------------------------------------ 5.2/63.3 MB 5.5 MB/s eta 0:00:11
   --- ------------------------------------ 5.8/63.3 MB 5.0 MB/s eta 0:00:12
   ---- --------------

DEPRECATION: Loading egg at c:\bin\python313\lib\site-packages\tabularocr-0.1.0-py3.13.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330


Note: you may need to restart the kernel to use updated packages.


DEPRECATION: Loading egg at c:\bin\python313\lib\site-packages\tabularocr-0.1.0-py3.13.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330


### Select a backend

Keras is a high-level, multi-framework deep learning API designed for simplicity and ease of use. Using Keras 3, you can run workflows on one of three backends: TensorFlow, JAX, or PyTorch.

For this demo, configure the backend for JAX.

In [1]:
import os

os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

### Import packages

Import Keras and KerasNLP.

In [2]:
import keras
import keras_nlp

RuntimeError: This version of jaxlib was built using AVX instructions, which your CPU and/or operating system do not support. You may be able work around this issue by building jaxlib from source.

### Load data

In [4]:
import json
data = []
with open("./kaggle/input/meergetallen/meergetallendata.jsonl") as file:
    for line in file:
        features = json.loads(line)
        # Format the example as a single string
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        data.append(template.format(**features))

print('cases', len(data))
# Use 100 training examples
data = data[:100]

print('cases', len(data))
print(data[10])

cases 1000
cases 100
Instruction:
vierhonderddrieentachtig

Response:
483


### Load model

We create a model using the Keras implementation of `GemmaCausalLM`, an end-to-end Gemma model for causal language modeling. A causal language model predicts the next token based on previous tokens.

Create the model using the `from_preset` method:

In [5]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")
gemma_lm.summary()

NameError: name 'keras_nlp' is not defined

The `from_preset` method instantiates the model from a preset architecture and weights. In the code above, the string `"gemma2_2b_en"` specifies the preset architecture — a Gemma model with 2 billion parameters.



## Inference before fine tuning

In this section, you will query the model with various prompts to see how it responds.

### Prompt with numeral for 35

Query the model for the number 35 as Dutch numeral

In [10]:
prompt = template.format(
    instruction="vijfendertig?",
    response="",
)

print('ik' , 'ben')

print('*' , prompt , '*')

for d in range(0 , 37):
    print('*' , data[d] , '*')


sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))

ik ben
* Instruction:
vijfendertig?

Response:
 *
* Instruction:
zeshonderdzeven

Response:
607 *
* Instruction:
negenhonderdtwintig

Response:
920 *
* Instruction:
zeventien

Response:
17 *
* Instruction:
vierhonderdzesenzeventig

Response:
476 *
* Instruction:
zeshonderdzesentwintig

Response:
626 *
* Instruction:
driehonderddrie

Response:
303 *
* Instruction:
vierhonderdtweeenveertig

Response:
442 *
* Instruction:
achthonderd

Response:
800 *
* Instruction:
eenhonderdachtenveertig

Response:
148 *
* Instruction:
negenendertig

Response:
39 *
* Instruction:
vierhonderddrieentachtig

Response:
483 *
* Instruction:
zevenhonderdnegenenzeventig

Response:
779 *
* Instruction:
tweehonderdvijftig

Response:
250 *
* Instruction:
eenhonderdtachtig

Response:
180 *
* Instruction:
eenhonderdzeventien

Response:
117 *
* Instruction:
negenhonderdachtennegentig

Response:
998 *
* Instruction:
achthonderdvijfenzestig

Response:
865 *
* Instruction:
zevenenvijftig

Response:
57 *
* Instruction:
a

NameError: name 'keras_nlp' is not defined

The model responds with a lot of things except the right number.

## LoRA fine-tuning

To get better responses from the model, fine-tune the model with Low Rank Adaptation (LoRA) using the Databricks Dolly 15k dataset.

The LoRA rank determines the dimensionality of the trainable matrices that are added to the original weights of the LLM. It controls the expressiveness and precision of the fine-tuning adjustments.

A higher rank means more detailed changes are possible, but also means more trainable parameters. A lower rank means less computational overhead, but potentially less precise adaptation.

This tutorial uses a LoRA rank of 4. In practice, begin with a relatively small rank (such as 4, 8, 16). This is computationally efficient for experimentation. Train your model with this rank and evaluate the performance improvement on your task. Gradually increase the rank in subsequent trials and see if that further boosts performance.

In [None]:
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

Note that enabling LoRA reduces the number of trainable parameters significantly (from 2.6 billion to 2.9 million).

In [None]:
# Limit the input sequence length to 256 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 256
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=20, batch_size=1)

### Note on mixed precision fine-tuning on NVIDIA GPUs

Full precision is recommended for fine-tuning. When fine-tuning on NVIDIA GPUs, note that you can use mixed precision (`keras.mixed_precision.set_global_policy('mixed_bfloat16')`) to speed up training with minimal effect on training quality. Mixed precision fine-tuning does consume more memory so is useful only on larger GPUs.


For inference, half-precision (`keras.config.set_floatx("bfloat16")`) will work and save memory while mixed precision is not applicable.

In [None]:
# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

## Inference after fine-tuning
After fine-tuning, responses follow the instruction provided in the prompt.

### Prompt with numeral for 737

In [None]:
prompt = template.format(
    instruction="zevenhonderdzevenendertig?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))

The model now recognizes the number correctly.

### Prompt with question about numeral for 73

In [None]:
prompt = template.format(
    instruction="Wat is drieenzeventig?",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

The model does not recognize the number in this question.

## Conclusion

This tutorial covered LoRA fine-tuning on a Gemma model using KerasNLP. Check out the following docs next:

* Learn how to [generate text with a Gemma model](https://ai.google.dev/gemma/docs/get_started).
* Learn how to perform [distributed fine-tuning and inference on a Gemma model](https://ai.google.dev/gemma/docs/distributed_tuning).
* Learn how to [use Gemma open models with Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/open-models/use-gemma).
* Learn how to [fine-tune Gemma using KerasNLP and deploy to Vertex AI](https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_gemma_kerasnlp_to_vertexai.ipynb).

# License
Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.