# Fine-tune Gemma models in Keras using LoRA

### Install dependencies

Install Keras, KerasNLP, and other dependencies.

In [2]:
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install -q -U keras-nlp
!pip install -q -U keras>=3

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow-decision-forests 1.8.1 requires wurlitzer, which is not installed.[0m[31m
[0m[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow-decision-forests 1.8.1 requires wurlitzer, which is not installed.
tensorflow 2.15.0 requires keras<2.16,>=2.15.0, but you have keras 3.3.3 which is incompatible.[0m[31m
[0m

### Select a backend

Keras is a high-level, multi-framework deep learning API designed for simplicity and ease of use. Using Keras 3, you can run workflows on one of three backends: TensorFlow, JAX, or PyTorch.

For this tutorial, configure the backend for JAX.

In [3]:
import os

os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

### Import packages

Import Keras and KerasNLP.

In [4]:
import keras
import keras_nlp

2024-05-07 02:04:10.783366: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-07 02:04:10.783505: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-07 02:04:10.917646: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


## Load Dataset

Preprocess the data. This tutorial uses a subset of 1000 training examples to execute the notebook faster. Consider using more training data for higher quality fine-tuning.

In [5]:
import json
data = []
with open('/kaggle/input/databricks-dolly-15k/databricks-dolly-15k.jsonl') as file:
    for line in file:
        features = json.loads(line)
        # Filter out examples with context, to keep it simple.
        if features["context"]:
            continue
        # Format the entire example as a single string.
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        data.append(template.format(**features))

# Only use 1000 training examples, to keep it fast.
data = data[:1000]

## Load Model

KerasNLP provides implementations of many popular [model architectures](https://keras.io/api/keras_nlp/models/){:.external}. In this tutorial, you'll create a model using `GemmaCausalLM`, an end-to-end Gemma model for causal language modeling. A causal language model predicts the next token based on previous tokens.

Create the model using the `from_preset` method:

In [6]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
gemma_lm.summary()

Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'task.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'metadata.json' f

The `from_preset` method instantiates the model from a preset architecture and weights. In the code above, the string "gemma_2b_en" specifies the preset architecture — a Gemma model with 2 billion parameters.

NOTE: A Gemma model with 7
billion parameters is also available. To run the larger model in Colab, you need access to the premium GPUs available in paid plans. Alternatively, you can perform [distributed tuning on a Gemma 7B model](https://ai.google.dev/gemma/docs/distributed_tuning) on Kaggle or Google Cloud.

## Inference before fine tuning

In this section, you will query the model with various prompts to see how it responds.


### Europe Trip Prompt

Query the model for suggestions on what to do on a trip to Europe.

In [7]:
prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
What should I do on a trip to Europe?

Response:
1. Take a trip to Europe.
2. Take a trip to Europe.
3. Take a trip to Europe.
4. Take a trip to Europe.
5. Take a trip to Europe.
6. Take a trip to Europe.
7. Take a trip to Europe.
8. Take a trip to Europe.
9. Take a trip to Europe.
10. Take a trip to Europe.
11. Take a trip to Europe.
12. Take a trip to Europe.
13. Take a trip to Europe.
14. Take a trip to Europe.
15. Take a trip to Europe.
16. Take a trip to Europe.
17. Take a trip to Europe.
18. Take a trip to Europe.
19. Take a trip to Europe.
20. Take a trip to Europe.
21. Take a trip to Europe.
22. Take a trip to Europe.
23. Take a trip to Europe.
24. Take a trip to Europe.
25. Take a trip to


The model just responds with a recommendation to take a trip to Europe.

### ELI5 Photosynthesis Prompt

Prompt the model to explain photosynthesis in terms simple enough for a 5 year old child to understand.


In [8]:
prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
Photosynthesis is the process by which plants use the energy from the sun to convert water and carbon dioxide into oxygen and glucose. The process begins with the absorption of light energy by chlorophyll molecules in the leaves of plants. The energy from the light is used to split water molecules into hydrogen and oxygen. The oxygen is released into the atmosphere, while the hydrogen is used to make glucose. The glucose is then used by the plant to make energy and grow.

Explanation:
Photosynthesis is the process by which plants use the energy from the sun to convert water and carbon dioxide into oxygen and glucose. The process begins with the absorption of light energy by chlorophyll molecules in the leaves of plants. The energy from the light is used to split water molecules into hydrogen and oxygen. The oxygen is released into the atmosphere, while the hydrogen is used to make gluc

The responses contains words that might not be easy to understand for a child such as chlorophyll, glucose, etc.

## LoRA Fine-tuning

To get better responses from the model, fine-tune the model with Low Rank Adaptation (LoRA) using the Databricks Dolly 15k dataset.

The LoRA rank determines the dimensionality of the trainable matrices that are added to the original weights of the LLM. It controls the expressiveness and precision of the fine-tuning adjustments.

A higher rank means more detailed changes are possible, but also means more trainable parameters. A lower rank means less computational overhead, but potentially less precise adaptation.

This tutorial uses a LoRA rank of 4. In practice, begin with a relatively small rank (such as 4, 8, 16). This is computationally efficient for experimentation. Train your model with this rank and evaluate the performance improvement on your task. Gradually increase the rank in subsequent trials and see if that further boosts performance.

In [9]:
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

Note that enabling LoRA reduces the number of trainable parameters significantly (from 2.5 billion to 1.3 million).

In [10]:
# Limit the input sequence length to 512 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 512
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)

[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m803s[0m 781ms/step - loss: 0.4591 - sparse_categorical_accuracy: 0.5237


<keras.src.callbacks.history.History at 0x78bc406f6830>

## Inference after fine-tuning
After fine-tuning, responses follow the instruction provided in the prompt.

### Europe Trip Prompt


In [11]:
prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
What should I do on a trip to Europe?

Response:
There are so many things to do in Europe!  Here are a few ideas:

* Visit the Eiffel Tower in Paris
* Take a river cruise on the Rhine River
* Visit the Vatican in Rome
* Take a train ride through the Swiss Alps
* Visit the Colosseum in Rome
* Visit the Louvre in Paris
* Visit the Palace of Versailles in Paris
* Visit the Tower of London in London
* Visit the Tower of Pisa in Italy
* Visit the Acropolis in Athens
* Visit the Acropolis in Athens
* Visit the Acropolis in Athens
* Visit the Acropolis in Athens
* Visit the Acropolis in Athens
* Visit the Acropolis in Athens
* Visit the Acropolis in Athens
* Visit the Acropolis in Athens
* Visit the Acropolis in Athens
* Visit the Acropolis in Athens
* Visit the Acropolis in Athens
* Visit the Acropolis in Athens
* Visit the Acropolis in Athens
* Visit the Acropolis in Athens
* Visit the Acropolis in Athens
* Visit the Acropolis in Athens
* Visit the Acropolis in Athens
* Visit t

The model now recommends places to visit in Europe.

### ELI5 Photosynthesis Prompt


In [12]:
prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
Photosynthesis is the process by which plants use sunlight to convert carbon dioxide and water into glucose and oxygen. The glucose is used by the plant to grow and the oxygen is released into the atmosphere.
