<link rel="stylesheet" href="/site-assets/css/gemma.css">
<link rel="stylesheet" href="https://fonts.googleapis.com/css2?family=Google+Symbols:opsz,wght,FILL,GRAD@20..48,100..700,0..1,-50..200" />

# Fine-tuning Gemma to understand Dutch numerals 

## Introduction

Here I demonstrate fine-tuning the Gemma Large Language Model using Low Rank Adaptation (LoRA) along the lines decribed in this article:
<a target="_blank" href="https://ai.google.dev/gemma/docs/lora_tuning"><img src="https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png" height="15" width="15" />Fine-tune Gemma models in Keras using LoRA</a>.

I use an example in Dutch language understanding: understand Dutch numerals written in words. This is a fairly complex problem on which Gemma performs poorly, as not only are the words in Dutch, but Dutch numerals are written in a different order than English ones and several words can be strung together in various ways.

For example the number 34 is written as vierendertig (vier = 4 , en = and, dertig = 30). 483 is written as vierhonderddrieentachtig (vier = 4 , honderd = 100 , drie = 3 , en = and, tachtig = 80 - but no 'en' between 'vierhonderd' and 'drieentachtig').


## The Code

### Install dependencies

Install Keras and KerasNLP. This has to be done once, so is commented out here.

In [1]:
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
#!pip install -q -U keras-nlp
#!pip install -q -U "keras>=3"

### Select a backend

Keras is a high-level, multi-framework deep learning API designed for simplicity and ease of use. Using Keras 3, you can run workflows on one of three backends: TensorFlow, JAX, or PyTorch.

For this demo, we configure the backend for JAX.

In [2]:
import os

os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

### Import packages

Import Keras and KerasNLP.

In [3]:
import keras
import keras_nlp

### Load data

In [4]:
import json

data = []

testdata = []

with open("/kaggle/input/meergetallen/meergetallendata.jsonl") as file:
    for line in file:
        features = json.loads(line)
        # Format the example as a single string
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        data.append(template.format(**features))

        testtemplate = "Instruction:\n{instruction}\n\nResponse:\n"
        testdata.append(testtemplate.format(**features))


print('all cases', len(data))
# Use 100 training examples
data = data[:100]

testdata = testdata[:23]

print('first 100 cases', len(data))
print(data[10])

all cases 1000
first 100 cases 100
Instruction:
vierhonderddrieentachtig

Response:
483


### Load model

We create a model using the Keras implementation of `GemmaCausalLM`, an end-to-end Gemma model for causal language modeling. A causal language model predicts the next token based on previous tokens.

Create the model using the `from_preset` method:

In [5]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")
gemma_lm.summary()

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


The `from_preset` method instantiates the model from a preset architecture and weights. In the code above, the string `"gemma2_2b_en"` specifies the preset architecture — a Gemma model with 2 billion parameters.



## Inference before fine-tuning

In this section, we query the model with some prompts from the test data to see how it responds.

### Prompts with first few test data

Query the model for the numbers as Dutch numerals in the first few test data.

In [6]:
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)

for d in range(0 , 3):
    print('*' , testdata[d] , '*')
    
    print('infer' , gemma_lm.generate(testdata[d], max_length=100))

# sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
# gemma_lm.compile(sampler=sampler)
# print(gemma_lm.generate(prompt, max_length=256))

* Instruction:
zeshonderdzeven

Response:
 *
infer Instruction:
zeshonderdzeven

Response:
210421

1.
I'm sorry to have kept you waiting.

2.
I'm very sorry to hear you had a bad day.

3.
I'm very sorry to hear that you were in such a bad mood.

4.
I'm sorry I didn't get to see you.

5.
Sorry for not answering the phone.
* Instruction:
negenhonderdtwintig

Response:
 *
infer Instruction:
negenhonderdtwintig

Response:
ninetyninety

I am not sure what you are referring to. I think you are asking about the number 19.

If you are asking about the number nine, it would be written in numerals as nine.

If you are asking about the number nineteen, it would be written as nineteen.

I am not sure what you are asking about, but I think you are referring to the number nineteen
* Instruction:
zeventien

Response:
 *
infer Instruction:
zeventien

Response:
zeventien

Instruction:
zeventientien

Response:
zeventientientien

Instruction:
zestien

Response:
zestien

Instruction:
zeventientientientien

The model responds with a lot of things except the right number.

## LoRA fine-tuning

To generate better responses, we fine-tune the model with Low Rank Adaptation (LoRA) using 1000 examples of correct data in the form 
`{"instruction": "negenhonderdvierentwintig" , "response": "924" }`.

The LoRA rank determines the dimensionality of the trainable matrices that are added to the original weights of the LLM. It controls the expressiveness and precision of the fine-tuning adjustments.

A higher rank means more detailed changes are possible, but also means more trainable parameters. A lower rank means less computational overhead, but potentially less precise adaptation. We use a LoRA rank of 4. 



In [7]:
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

Note that enabling LoRA reduces the number of trainable parameters significantly (from 2.6 billion to 2.9 million).

In [8]:
training = True

if training:

    # Limit the input sequence length to 27 
    gemma_lm.preprocessor.sequence_length = 27
    # Use AdamW (a common optimizer for transformer models).
    optimizer = keras.optimizers.AdamW(
        learning_rate=5e-5,
        weight_decay=0.01,
    )
    # Exclude layernorm and bias terms from decay.
    optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])
    
    gemma_lm.compile(
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        optimizer=optimizer,
        weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
    )
    gemma_lm.fit(data, epochs=20, batch_size=1)

Epoch 1/20
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m65s[0m 391ms/step - loss: 3.7904 - sparse_categorical_accuracy: 0.2582
Epoch 2/20
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m40s[0m 214ms/step - loss: 1.8345 - sparse_categorical_accuracy: 0.6163
Epoch 3/20
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m23s[0m 229ms/step - loss: 1.2319 - sparse_categorical_accuracy: 0.7359
Epoch 4/20
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 246ms/step - loss: 1.0427 - sparse_categorical_accuracy: 0.7571
Epoch 5/20
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 269ms/step - loss: 0.8995 - sparse_categorical_accuracy: 0.7584
Epoch 6/20
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 261ms/step - loss: 0.7620 - sparse_categorical_accuracy: 0.7699
Epoch 7/20
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 253ms/step - loss: 0.6022 - sparse_categorical_accuracy: 0.7811
Epoch 

In [9]:
# Save the model: already done
# gemma_lm.save('gemma_dutchnumerals.keras')

# To download the saved model 
from IPython.display import FileLink 
# Provide a download link 
FileLink('./gemma_dutchnumerals.keras')


## Inference after fine-tuning
After fine-tuning, responses follow the instruction provided in the prompt.

### Prompt with numeral for 737

In [10]:
import gc
from tensorflow.keras.models import load_model

infer = True

if infer:

    gemma_lmft = gemma_lm
  # replace by following to reuse previously saved model
  #  gemma_lmft = load_model('/kaggle/input/d/drj19461/gemma-dutch-numerals-model/gemma_dutchnumerals.keras')
    
    gemma_lmft.summary()
    
    sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
    gemma_lmft.compile(sampler=sampler)
    
    prompt = template.format(
        instruction="zevenhonderdzevenendertig?",
        response="",
    )
    
    print(gemma_lmft.generate(prompt, max_length=100))

Instruction:
zevenhonderdzevenendertig?

Response:
737


The model now recognizes the number correctly.

### Prompts with test data

In [11]:
for d in range(0 , len(testdata)):
    print('*' , testdata[d] , '*')
    
    print('infer' , gemma_lmft.generate(testdata[d], max_length=100))




* Instruction:
zeshonderdzeven

Response:
 *
infer Instruction:
zeshonderdzeven

Response:
607
* Instruction:
negenhonderdtwintig

Response:
 *
infer Instruction:
negenhonderdtwintig

Response:
920
* Instruction:
zeventien

Response:
 *
infer Instruction:
zeventien

Response:
17
* Instruction:
vierhonderdzesenzeventig

Response:
 *
infer Instruction:
vierhonderdzesenzeventig

Response:
476
* Instruction:
zeshonderdzesentwintig

Response:
 *
infer Instruction:
zeshonderdzesentwintig

Response:
626
* Instruction:
driehonderddrie

Response:
 *
infer Instruction:
driehonderddrie

Response:
303
* Instruction:
vierhonderdtweeenveertig

Response:
 *
infer Instruction:
vierhonderdtweeenveertig

Response:
442
* Instruction:
achthonderd

Response:
 *
infer Instruction:
achthonderd

Response:
800
* Instruction:
eenhonderdachtenveertig

Response:
 *
infer Instruction:
eenhonderdachtenveertig

Response:
148
* Instruction:
negenendertig

Response:
 *
infer Instruction:
negenendertig

Response:
39
* 

The model now recognizes all Dutch numerals correctly as the expected numbers in digits.

## Conclusion

This example uses LoRA fine-tuning on a Gemma model using KerasNLP. 

After fine-tuning with 100 training data the model has learned to recognize Dutch numerals under 1000.



# License
Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.