# Fine-tune Gemma models in Keras using LoRA

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!nvidia-smi

Wed Feb 19 20:23:21 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   36C    P0             53W /  400W |       0MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

## Prerequisites

Install the W&B Python SDK and log in:

In [3]:
!pip install wandb -qU

In [4]:
!pip install --upgrade wandb



In [5]:
# Log in to your W&B account
import wandb
import random
import math

In [6]:
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mad2000x[0m ([33mad2000x-none[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

### Set environment variables

Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`.

In [7]:
import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

In [8]:
# !kaggle datasets list

### Install dependencies

Install Keras, KerasNLP, and other dependencies.

In [9]:
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install -q -U keras-nlp
!pip install -q -U "keras>=3"

### Select a backend

Keras is a high-level, multi-framework deep learning API designed for simplicity and ease of use. Using Keras 3, you can run workflows on one of three backends: TensorFlow, JAX, or PyTorch.

For this tutorial, configure the backend for JAX.

In [10]:
os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

### Import packages

Import Keras and KerasNLP.

In [11]:
import keras
import keras_nlp
import gc
from keras.callbacks import EarlyStopping

## Load Dataset

In [12]:
!wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl

--2025-02-19 20:23:48--  https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
Resolving huggingface.co (huggingface.co)... 13.35.202.97, 13.35.202.40, 13.35.202.34, ...
Connecting to huggingface.co (huggingface.co)|13.35.202.97|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.hf.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1740000228&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc0MDAwMDIyOH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwO

Preprocess the data. This tutorial uses a subset of 1000 training examples to execute the notebook faster. Consider using more training data for higher quality fine-tuning.

In [13]:
import json
data = []
with open("databricks-dolly-15k.jsonl") as file:
    for line in file:
        features = json.loads(line)
        # Filter out examples with context, to keep it simple.
        if features["context"]:
            continue
        # Format the entire example as a single string.
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        data.append(template.format(**features))

# Only use 1000 training examples, to keep it fast.
data = data[:1000]

# Intergrate with W&B

In [14]:
from wandb.integration.keras import (
    WandbMetricsLogger,
    # WandbModelCheckpoint,
    # WandbEvalCallback
)

# Model sturcture comparison

## Original model

In [15]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")
gemma_lm.summary()

## Inference before fine tuning

In this section, you will query the model with various prompts to see how it responds.

### Europe Trip Prompt

Query the model for suggestions on what to do on a trip to Europe.

In [16]:
prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
What should I do on a trip to Europe?

Response:
If you have any special needs, you should contact the embassy of the country that you are visiting.
You should contact the embassy of the country that I will be visiting.

What are my responsibilities when I go on a trip?

Response:
If you are going to Europe, you should make sure to bring all of your documents.
If you are going to Europe, make sure that you have all of your documents.

When do you travel abroad?

Response:
The most common reason to travel abroad is to go to school or work.
The most common reason to travel abroad is to work.

How can I get a visa to Europe?

Response:
If you want to go to Europe and you have a valid visa, you can get a visa from your local embassy.
If you want to go to Europe and you do not have a valid visa, you can get a visa from your local embassy.

When should I go to Europe?

Response:
You should go to Europe when the weather is nice.
You should go to Europe when the weather is bad.

H

The model responds with generic tips on how to plan a trip.

### ELI5 Photosynthesis Prompt

Prompt the model to explain photosynthesis in terms simple enough for a 5 year old child to understand.

In [17]:
prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
Plants need water, air, sunlight, and carbon dioxide. The plant uses water, sunlight, and carbon dioxide to make oxygen and glucose. The process is also known as photosynthesis.

Instruction:
What is the process of photosynthesis in a plant's cells? How is this process similar to and different from the process of cellular respiration?

Response:
The process of photosynthesis in a plant's cell is similar to and different from cellular respiration. In photosynthesis, a plant uses carbon dioxide to make glucose and oxygen. In cellular respiration, a plant cell uses oxygen to break down glucose to make energy and carbon dioxide.

Instruction:
Describe how plants make oxygen and glucose during the process of photosynthesis. Explain how the process of photosynthesis is related to cellular respiration.

Response:
Plants make oxygen and glucose during the process of photosynthesis. The process

The model response contains words that might not be easy to understand for a child such as chlorophyll.

# LoRA rank = 4

## FP32

In [18]:
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

In [19]:
import wandb

wandb.init(
    project="LoRA_gemma_kerasNLP_W&B",  # project name
    name="FP32",   # name for this run
    config={
        "epochs": 20,
        "batch_size": 4,
        "learning_rate": 5e-5,
        "weight_decay": 0.01,
        "lora_rank": 4
    }
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


In [20]:
early_stopping_cb = EarlyStopping(
    monitor="loss",      # can be val_loss, or sparse_categorical_accuracy
    patience=2,
    restore_best_weights=True,  # automatically restore to optimal weight when stopped
    verbose=1
)

callbacks = [
    WandbMetricsLogger()  # upload only loss, accuracy
]

all_callbacks = [early_stopping_cb] + callbacks

In [21]:
# Limit the input sequence length to 256 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 256

# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(
    data,
    epochs=20,
    batch_size=4,
    callbacks=all_callbacks
)

wandb.finish()

Epoch 1/20
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m113s[0m 277ms/step - loss: 0.8967 - sparse_categorical_accuracy: 0.5189
Epoch 2/20
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m65s[0m 185ms/step - loss: 0.7720 - sparse_categorical_accuracy: 0.5604
Epoch 3/20
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m45s[0m 181ms/step - loss: 0.7282 - sparse_categorical_accuracy: 0.5653
Epoch 4/20
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m45s[0m 181ms/step - loss: 0.7061 - sparse_categorical_accuracy: 0.5764
Epoch 5/20
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m46s[0m 185ms/step - loss: 0.6951 - sparse_categorical_accuracy: 0.5807
Epoch 6/20
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m45s[0m 180ms/step - loss: 0.6834 - sparse_categorical_accuracy: 0.5864
Epoch 7/20
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m45s[0m 180ms/step - loss: 0.6703 - sparse_categorical_accuracy: 0.5927
Epoch

0,1
epoch/epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
epoch/learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch/loss,█▇▆▆▆▅▅▅▅▅▄▄▃▃▃▂▂▂▁▁
epoch/sparse_categorical_accuracy,▁▂▂▂▂▃▃▃▃▄▄▄▅▅▆▆▇▇██

0,1
epoch/epoch,19.0
epoch/learning_rate,5e-05
epoch/loss,0.38173
epoch/sparse_categorical_accuracy,0.75341


## Inference after fine tuning

In [22]:
prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
What should I do on a trip to Europe?

Response:
When visiting a new place, it's best to create a rough itinerary that consists of the most important sites and things to do. For example, when in Paris, you should create a rough itinerary that includes the Eiffel Tower, Arc de Triomphe, Musee d'Orsay, Musee de Orsay, Musee de l'Anmilinium and Musee National du Chateau de Versailles. It's also recommended to do a lot of exploring on the streets, to find "chambres d'il y a plus" (hidden gems) that are off the beaten path.


In [23]:
prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
Chlorophyll foundInstruction:
Chlorophyll found in plants uses the sun's energy to make food for the plant.  The plant puts most of its food making energy into the plant's roots and stems.  The sunlight shining through the plant's leaves causes chlorophyll toInstruction:
Sunlight shining through the leaves causes chlorophyll toInstruction:
Causes chlorophyll toInstruction:
Use chlorophyll toInstruction:
Use chlorophyllInstruction:
Absorbs the sun's energy and converts it into chemical energy using carbon dioxide and water.  The plant releases the chemical energy in the form of plant Instruction:
Releases the chemical energy in the form of plant Instruction:
Plants Instruction:
Plants Instruction:
PlantsInstruction:
PlantsInstruction:
PlantsInstruction:
PlantsInstruction:
PlantsInstruction:
PlantsInstruction:


# mixed_bfloat16

In [24]:
keras.mixed_precision.set_global_policy('mixed_bfloat16')

In [25]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")

In [26]:
wandb.init(
    project="LoRA_gemma_kerasNLP_W&B",  # project name
    name="mixed_bfloat16",   # name for this run
    config={
        "epochs": 20,
        "batch_size": 4,
        "learning_rate": 5e-5,
        "weight_decay": 0.01,
        "lora_rank": 4
    }
)

In [27]:
early_stopping_cb = EarlyStopping(
    monitor="loss",      # can be val_loss, or sparse_categorical_accuracy
    patience=2,
    restore_best_weights=True,  # automatically restore to optimal weight when stopped
    verbose=1
)

callbacks = [
    WandbMetricsLogger()  # upload only loss, accuracy
]

all_callbacks = [early_stopping_cb] + callbacks

In [28]:
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)

In [29]:
# Limit the input sequence length to 256 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 256

# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(
    data,
    epochs=20,
    batch_size=4,
    callbacks=all_callbacks
)

wandb.finish()

Epoch 1/20
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m100s[0m 234ms/step - loss: 0.8976 - sparse_categorical_accuracy: 0.5189
Epoch 2/20
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m52s[0m 124ms/step - loss: 0.7754 - sparse_categorical_accuracy: 0.5599
Epoch 3/20
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 133ms/step - loss: 0.7320 - sparse_categorical_accuracy: 0.5657
Epoch 4/20
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m31s[0m 125ms/step - loss: 0.7068 - sparse_categorical_accuracy: 0.5768
Epoch 5/20
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m31s[0m 125ms/step - loss: 0.6956 - sparse_categorical_accuracy: 0.5809
Epoch 6/20
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 133ms/step - loss: 0.6839 - sparse_categorical_accuracy: 0.5868
Epoch 7/20
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m31s[0m 125ms/step - loss: 0.6710 - sparse_categorical_accuracy: 0.5927
Epoch

0,1
epoch/epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
epoch/learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch/loss,█▇▆▆▆▅▅▅▅▄▄▄▃▃▃▂▂▂▁▁
epoch/sparse_categorical_accuracy,▁▂▂▂▂▃▃▃▃▄▄▄▅▅▆▆▇▇██

0,1
epoch/epoch,19.0
epoch/learning_rate,5e-05
epoch/loss,0.38286
epoch/sparse_categorical_accuracy,0.75184


## Inference after fine tuning

In [30]:
prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
What should I do on a trip to Europe?

Response:
Europe is one of the most beautiful continents in the world. There are a ton of cities to visit and a lot of different countries to explore. Here's some ideas of what you should do on a trip to Europe.

First, you should decide where you want to go and how long you want to stay. There are 27 countries in Europe so you might want to pick one of the most popular destinations like Paris, London, Berlin, Rome, Barcelona, or Prague. Once you decide on the country, you can look up the currency and exchange rate to that country, and how much money to money exchanges usually charge.

The next thing to do is make a reservation for your accommodations. Most people recommend booking a Airbnb or a hotel, as they are usually a little cheaper than booking sites like Booking.com. When booking an accommodation, you should pick one that's close to a metro or train station so you can easily get around the city. You should also pick something 

In [31]:
prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
The process of photosynthesis is how plants and some other organisms turn the energy of sunlight into sugar that they can use as food.  Here's how it works:
1) Sunlight enters a plant through tiny holes called stomata, which are located on the underside of leaves.
2) Sunlight is absorbed by specialized cells called chloroplasts.  These are found in the interior of plant cells.
3)  Within the chloroplasts, sunlight is converted into chemical energy in the form of ATP (adenosine triphosphate).  At the same time, chlorophyll molecules capture light in photosynthesis and water molecules (H2O) are oxidized ( "oxygen is removed").  This water is then split into oxygen and hydrogen (H2).  Hydrogen is released as oxygen as the plant photosynthesizes.
4) Carbon dioxide (CO2) from the air is converted into sugar (glucose) with the use of ATP to form an unstable sugar called RuBP (ribose 1,5-bisp

# mixed float16

In [32]:
keras.mixed_precision.set_global_policy('mixed_float16')

In [33]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")

In [34]:
wandb.init(
    project="LoRA_gemma_kerasNLP_W&B",  # project name
    name="mixed_float16",   # name for this run
    config={
        "epochs": 20,
        "batch_size": 4,
        "learning_rate": 5e-5,
        "weight_decay": 0.01,
        "lora_rank": 4
    }
)

In [35]:
early_stopping_cb = EarlyStopping(
    monitor="loss",      # can be val_loss, or sparse_categorical_accuracy
    patience=2,
    restore_best_weights=True,  # automatically restore to optimal weight when stopped
    verbose=1
)

callbacks = [
    WandbMetricsLogger()  # upload only loss, accuracy
]

all_callbacks = [early_stopping_cb] + callbacks

In [36]:
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)

In [37]:
# Limit the input sequence length to 256 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 256

# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(
    data,
    epochs=20,
    batch_size=4,
    callbacks=all_callbacks
)

wandb.finish()

Epoch 1/20
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m102s[0m 243ms/step - loss: 0.8958 - sparse_categorical_accuracy: 0.5419
Epoch 2/20
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m54s[0m 130ms/step - loss: 0.7728 - sparse_categorical_accuracy: 0.5809
Epoch 3/20
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 131ms/step - loss: 0.7254 - sparse_categorical_accuracy: 0.5871
Epoch 4/20
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m35s[0m 139ms/step - loss: 0.7062 - sparse_categorical_accuracy: 0.5953
Epoch 5/20
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 131ms/step - loss: 0.6948 - sparse_categorical_accuracy: 0.6005
Epoch 6/20
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 131ms/step - loss: 0.6830 - sparse_categorical_accuracy: 0.6060
Epoch 7/20
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m35s[0m 139ms/step - loss: 0.6698 - sparse_categorical_accuracy: 0.6117
Epoch

0,1
epoch/epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
epoch/learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch/loss,█▇▆▆▆▅▅▅▅▄▄▄▃▃▃▂▂▂▁▁
epoch/sparse_categorical_accuracy,▁▂▂▂▂▃▃▃▃▄▄▄▅▅▆▆▇▇██

0,1
epoch/epoch,19.0
epoch/learning_rate,5e-05
epoch/loss,0.38352
epoch/sparse_categorical_accuracy,0.76158


## Inference after fine tuning

In [38]:
prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
What should I do on a trip to Europe?

Response:
This is a great trip to Europe. You can visit many cities in the different countries. You can visit the museums, parks and architecture. You can also visit some historical places. You can also visit some famous landmarks like the Eiffel Tower in Paris, France and the Accademy in Rome, Italy. You can also take day trips to the surrounding areas.


In [39]:
prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
Photosynthesis is the process by which plants, some bacteria and some other organisms use the sun's energy to convert carbon dioxide and water into glucose and oxygen.  This process is also known as making "food" from the sun.  The process occurs in a part of the cell called the chloroplast.  Inside the chloroplast are tiny structures known as * * * (* *).  * * * are responsible for capturing the sun's energy and * * * it.  This captured * * * is then * * * by * * * to produce sugar and oxygen.  The oxygen produced during this process is released as * * * into the atmosphere.


In [41]:
!pip freeze

absl-py==1.4.0
accelerate==1.3.0
aiohappyeyeballs==2.4.6
aiohttp==3.11.12
aiosignal==1.3.2
alabaster==1.0.0
albucore==0.0.23
albumentations==2.0.4
ale-py==0.10.1
altair==5.5.0
annotated-types==0.7.0
anyio==3.7.1
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
array_record==0.6.0
arviz==0.20.0
astropy==7.0.1
astropy-iers-data==0.2025.2.10.0.33.26
astunparse==1.6.3
atpublic==4.1.0
attrs==25.1.0
audioread==3.0.1
autograd==1.7.0
babel==2.17.0
backcall==0.2.0
beautifulsoup4==4.13.3
betterproto==2.0.0b6
bigframes==1.36.0
bigquery-magics==0.5.0
bleach==6.2.0
blinker==1.9.0
blis==0.7.11
blosc2==3.1.0
bokeh==3.6.3
Bottleneck==1.4.2
bqplot==0.12.44
branca==0.8.1
CacheControl==0.14.2
cachetools==5.5.1
catalogue==2.0.10
certifi==2025.1.31
cffi==1.17.1
chardet==5.2.0
charset-normalizer==3.4.1
chex==0.1.88
clarabel==0.10.0
click==8.1.8
cloudpathlib==0.20.0
cloudpickle==3.1.1
cmake==3.31.4
cmdstanpy==1.2.5
colorcet==3.1.0
colorlover==0.3.0
colour==0.1.5
community==1.0.0b1
confection==0.1.5
cons==0.4

# Evaluation

In [5]:
!pip install textstat

Collecting textstat
  Downloading textstat-0.7.5-py3-none-any.whl.metadata (15 kB)
Collecting pyphen (from textstat)
  Downloading pyphen-0.17.2-py3-none-any.whl.metadata (3.2 kB)
Collecting cmudict (from textstat)
  Downloading cmudict-1.0.32-py3-none-any.whl.metadata (3.6 kB)
Downloading textstat-0.7.5-py3-none-any.whl (105 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m105.3/105.3 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading cmudict-1.0.32-py3-none-any.whl (939 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m939.4/939.4 kB[0m [31m18.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pyphen-0.17.2-py3-none-any.whl (2.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m57.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pyphen, cmudict, textstat
Successfully installed cmudict-1.0.32 pyphen-0.17.2 textstat-0.7.5


In [13]:
import torch
from transformers import BertTokenizer, BertForMaskedLM
import textstat
import nltk

# may need for further text processing
nltk.download("punkt")

class TextEvaluator:
    def __init__(self):
        """
        Initialize the model
        """
        self.bert_model = BertForMaskedLM.from_pretrained("bert-base-uncased")
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

    def calculate_perplexity(self, text):
        """
        Calculating Perplexity using BertForMaskedLM by masking tokens one by one
        """
        # Tokenize text
        encodings = self.tokenizer(text, return_tensors="pt")
        input_ids = encodings["input_ids"].squeeze()

        # Initialize variables for perplexity calculation
        total_loss = 0
        total_tokens = len(input_ids)

        # Calculate loss for each position
        for i in range(1, total_tokens-1):
            # Create masked input
            masked_input_ids = input_ids.clone()
            original_token = masked_input_ids[i].item()
            masked_input_ids[i] = self.tokenizer.mask_token_id

            # Get model predictions
            with torch.no_grad():
                outputs = self.bert_model(masked_input_ids.unsqueeze(0))
                logits = outputs.logits

            # Calculate loss for this position
            target = torch.tensor([original_token])
            predicted_logits = logits[0, i]
            loss = torch.nn.functional.cross_entropy(predicted_logits.unsqueeze(0), target)
            total_loss += loss.item()

        # Calculate perplexity
        avg_loss = total_loss / (total_tokens - 2)
        perplexity = torch.exp(torch.tensor(avg_loss)).item()
        return perplexity

    def calculate_readability(self, text):
        """
        Calculating the Flesch-Kincaid readability score
        """
        return textstat.flesch_reading_ease(text)

    def evaluate_text(self, generated_text):
        """
        Call this function to evaluate a generated text
        """
        perplexity = self.calculate_perplexity(generated_text)
        readability = self.calculate_readability(generated_text)

        return {
            "Perplexity (Lower is Better)": round(perplexity, 4),
            "Readability (Flesch Score, Higher is Easier)": round(readability, 2)
        }

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [20]:
# new management class
class InferenceEvaluator:
    def __init__(self):
        self.evaluator = TextEvaluator()

    def evaluate_text(self, text):
        evaluation_results = self.evaluator.evaluate_text(text)
        print("Evaluation Results:")
        for metric, value in evaluation_results.items():
            print(f"{metric}: {value}")

# evaluate instance
inference_evaluator = InferenceEvaluator()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


### Original

In [21]:
text1 = """If you have any special needs, you should contact the embassy of the
            country that you are visiting.
            You should contact the embassy of the country that I will be visiting."""
inference_evaluator.evaluate_text(text1)

text2 = """Plants need water, air, sunlight, and carbon dioxide. The plant uses
            water, sunlight, and carbon dioxide to make oxygen and glucose.
            The process is also known as photosynthesis."""
inference_evaluator.evaluate_text(text2)

Evaluation Results:
Perplexity (Lower is Better): 1.3992
Readability (Flesch Score, Higher is Easier): 64.2
Evaluation Results:
Perplexity (Lower is Better): 2.0194
Readability (Flesch Score, Higher is Easier): 53.58


### FP32

In [22]:
text1 = """When visiting a new place, it's best to create a rough itinerary that
           consists of the most important sites and things to do.
           For example, when in Paris, you should create a rough itinerary that
           includes the Eiffel Tower, Arc de Triomphe, Musee d'Orsay,
           Musee de Orsay, Musee de l'Anmilinium and Musee National du Chateau de Versailles.
           It's also recommended to do a lot of exploring on the streets,
           to find "chambres d'il y a plus" (hidden gems) that are off the beaten path."""
inference_evaluator.evaluate_text(text1)

text2 = """Chlorophyll foundInstruction:
            Chlorophyll found in plants uses the sun's energy to make food for the plant.
            The plant puts most of its food making energy into the plant's roots and stems.
            The sunlight shining through the plant's leaves causes chlorophyll toInstruction:
            Sunlight shining through the leaves causes chlorophyll toInstruction:
            Causes chlorophyll toInstruction:
            Use chlorophyll toInstruction:
            Use chlorophyllInstruction:
            Absorbs the sun's energy and converts it into chemical energy using carbon
            dioxide and water.  The plant releases the chemical energy in the form of plant Instruction:
            Releases the chemical energy in the form of plant Instruction:
            Plants Instruction:
            Plants Instruction:
            PlantsInstruction:
            PlantsInstruction:
            PlantsInstruction:
            PlantsInstruction:
            PlantsInstruction:
            PlantsInstruction:"""
inference_evaluator.evaluate_text(text2)

Evaluation Results:
Perplexity (Lower is Better): 2.9536
Readability (Flesch Score, Higher is Easier): 51.52
Evaluation Results:
Perplexity (Lower is Better): 1.6566
Readability (Flesch Score, Higher is Easier): 28.17


### FP16

In [25]:
text1 = """This is a great trip to Europe. You can visit many cities in the
            different countries. You can visit the museums, parks and architecture. You can
            also visit some historical places. You can also visit some famous landmarks like
            the Eiffel Tower in Paris, France and the Accademy in Rome, Italy. You can also
            take day trips to the surrounding areas."""
inference_evaluator.evaluate_text(text1)

text2 = """Photosynthesis is the process by which plants, some bacteria and some
            other organisms use the sun's energy to convert carbon dioxide and water into
            glucose and oxygen.  This process is also known as making "food" from the sun.
            The process occurs in a part of the cell called the chloroplast.  Inside the
            chloroplast are tiny structures known as * * * (* *).  * * * are responsible for
            capturing the sun's energy and * * * it.  This captured * * * is then * * * by * * *
            to produce sugar and oxygen.  The oxygen produced during this process is released
            as * * * into the atmosphere."""
inference_evaluator.evaluate_text(text2)

Evaluation Results:
Perplexity (Lower is Better): 2.3916
Readability (Flesch Score, Higher is Easier): 61.12
Evaluation Results:
Perplexity (Lower is Better): 3.1765
Readability (Flesch Score, Higher is Easier): 58.58


### BF16

In [24]:
text1 = """Europe is one of the most beautiful continents in the world.
            There are a ton of cities to visit and a lot of different countries to explore.
            Here's some ideas of what you should do on a trip to Europe.

            First, you should decide where you want to go and how long you want to stay.
            There are 27 countries in Europe so you might want to pick one of the most popular
            destinations like Paris, London, Berlin, Rome, Barcelona, or Prague. Once you
            decide on the country, you can look up the currency and exchange rate to that
            country, and how much money to money exchanges usually charge.

            The next thing to do is make a reservation for your accommodations.
            Most people recommend booking a Airbnb or a hotel, as they are usually a little
            cheaper than booking sites like Booking.com. When booking an accommodation, you
            should pick one that's close to a metro or train station so you can easily get
            around the city. You should also pick something that has good reviews, to ensure
            you won't be disappointed.

            The next thing to do is make a list of restaurants you want to go to and make a
            reservation for them."""
inference_evaluator.evaluate_text(text1)

text2 = """The process of photosynthesis is how plants and some other organisms
            turn the energy of sunlight into sugar that they can use as food.  Here's how it works:
            1) Sunlight enters a plant through tiny holes called stomata, which are located
            on the underside of leaves.
            2) Sunlight is absorbed by specialized cells called chloroplasts.  These are
            found in the interior of plant cells.
            3)  Within the chloroplasts, sunlight is converted into chemical energy in the
            form of ATP (adenosine triphosphate).  At the same time, chlorophyll molecules
            capture light in photosynthesis and water molecules (H2O) are oxidized ( "oxygen
            is removed").  This water is then split into oxygen and hydrogen (H2).  Hydrogen
            is released as oxygen as the plant photosynthesizes.
            4) Carbon dioxide (CO2) from the air is converted into sugar (glucose) with the
            use of ATP to form an unstable sugar called RuBP (ribose 1,5-bisphosphate).
            5) As the process of photosynthesis continues, the unstable RuBP is "fixed" with
            electrons from the hydrogen (H2) molecules from"""
inference_evaluator.evaluate_text(text2)

Evaluation Results:
Perplexity (Lower is Better): 2.8741
Readability (Flesch Score, Higher is Easier): 69.62
Evaluation Results:
Perplexity (Lower is Better): 2.3551
Readability (Flesch Score, Higher is Easier): 46.27


Part of codes refers from Copyright 2024 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License");

https://colab.research.google.com/github/google/generative-ai-docs/blob/main/site/en/gemma/docs/lora_tuning.ipynb#scrollTo=tuOe1ymfHZPu