# Text generation with the Gemma model

In [1]:
import os

# Free up more GPU memory on the Jax and TensorFlow backends.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00"

In [2]:
!pip install keras-nightly keras-hub-nightly --upgrade -q

In [3]:
import kagglehub

kagglehub.login()

VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

Kaggle credentials set.
Kaggle credentials successfully validated.


In [4]:
import keras_hub
import keras

In [16]:
import tensorflow as tf


In [5]:
gemma_lm = keras_hub.models.CausalLM.from_preset(
    "gemma3_1b",
    dtype="float32",
)

Downloading from https://www.kaggle.com/api/v1/models/keras/gemma3/keras/gemma3_1b/3/download/config.json...


100%|██████████| 966/966 [00:00<00:00, 1.94MB/s]


Downloading from https://www.kaggle.com/api/v1/models/keras/gemma3/keras/gemma3_1b/3/download/task.json...


100%|██████████| 3.23k/3.23k [00:00<00:00, 7.41MB/s]


Downloading from https://www.kaggle.com/api/v1/models/keras/gemma3/keras/gemma3_1b/3/download/assets/tokenizer/vocabulary.spm...


100%|██████████| 4.47M/4.47M [00:00<00:00, 17.8MB/s]


Downloading from https://www.kaggle.com/api/v1/models/keras/gemma3/keras/gemma3_1b/3/download/model.weights.h5...


100%|██████████| 1.86G/1.86G [00:26<00:00, 76.1MB/s]


In [6]:
gemma_lm.summary(line_length=80)

In [7]:
gemma_lm.compile(sampler="greedy")

In [8]:
gemma_lm.generate("A piece of advice", max_length=40)

'A piece of advice from a former student of mine:\n\n<blockquote>“I’m not sure if you’ve heard of it, but I’ve been told that the best way to learn'

In [9]:
gemma_lm.generate("How can I make brownies?", max_length=40)

"How can I make brownies?\n\n[User 0001]\n\nI'm trying to make brownies for my son's birthday party. I've never made brownies before."

In [10]:
gemma_lm.generate(
    "The following brownie recipe is easy to make in just a few "
    "steps.\n\nYou can start by",
    max_length=40,
)

'The following brownie recipe is easy to make in just a few steps.\n\nYou can start by melting the butter and sugar in a saucepan over medium heat.\n\nThen add the eggs and vanilla extract'

In [11]:
gemma_lm.generate(
    "Tell me about the 542nd president of the United States.",
    max_length=40,
)

'Tell me about the 542nd president of the United States.\n\nThe 542nd president of the United States was James A. Garfield.\n\nThe 542'

In [12]:
import json

PROMPT_TEMPLATE = """"[instruction]\n{}[end]\n[response]\n"""
RESPONSE_TEMPLATE = """{}[end]"""

dataset_path = keras.utils.get_file(
    origin=(
        "https://hf.co/datasets/databricks/databricks-dolly-15k/"
        "resolve/main/databricks-dolly-15k.jsonl"
    ),
)
data = {"prompts": [], "responses": []}
with open(dataset_path) as file:
    for line in file:
        features = json.loads(line)
        if features["context"]:
            continue
        data["prompts"].append(PROMPT_TEMPLATE.format(features["instruction"]))
        data["responses"].append(RESPONSE_TEMPLATE.format(features["response"]))

Downloading data from https://hf.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
[1m13085339/13085339[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


In [13]:
data["prompts"][0]

'"[instruction]\nWhich is a species of fish? Tope or Rope[end]\n[response]\n'

In [14]:
data["responses"][0]

'Tope[end]'

In [17]:
ds = tf.data.Dataset.from_tensor_slices(data).shuffle(2000).batch(2)
val_ds = ds.take(100)
train_ds = ds.skip(100)

In [18]:
preprocessor = gemma_lm.preprocessor

In [19]:
preprocessor.sequence_length = 512

In [21]:
batch = next(iter(train_ds))

In [22]:
x, y, sample_weight = preprocessor(batch)

In [23]:
x["token_ids"].shape

TensorShape([2, 512])

In [24]:
x["padding_mask"].shape

TensorShape([2, 512])

In [25]:
y.shape

TensorShape([2, 512])

In [26]:
sample_weight.shape

TensorShape([2, 512])

In [27]:
x["token_ids"][0, :5], y[0, :5]

(<tf.Tensor: shape=(5,), dtype=int32, numpy=array([     2,  77074,  22768, 236842,    107], dtype=int32)>,
 <tf.Tensor: shape=(5,), dtype=int32, numpy=array([ 77074,  22768, 236842,    107,  24249], dtype=int32)>)

# LORA

In [30]:
from keras import ops

class Linear(keras.Layer):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.kernel = self.add_weight(shape=(input_dim, output_dim))

    def call(self, inputs):
        return ops.matmul(inputs, self.kernel)

In [31]:
class LoraLinear(keras.Layer):
    def __init__(self, input_dim, output_dim, rank):
        super().__init__()
        self.kernel = self.add_weight(
            shape=(input_dim, output_dim), trainable=False
        )
        self.alpha = self.add_weight(shape=(input_dim, rank))
        self.beta = self.add_weight(shape=(rank, output_dim))

    def call(self, inputs):
        frozen = ops.matmul(inputs, self.kernel)
        update = ops.matmul(ops.matmul(inputs, self.alpha), self.beta)
        return frozen + update

In [32]:
gemma_lm.backbone.enable_lora(rank=8)

In [35]:
gemma_lm.backbone.trainable = False
for i in range(gemma_lm.backbone.num_layers):
    layer = gemma_lm.backbone.get_layer(f"decoder_block_{i}")
    layer.attention.key_dense.trainable = True
    layer.attention.key_dense.enable_lora(rank=8)
    layer.attention.query_dense.trainable = True
    layer.attention.query_dense.enable_lora(rank=8)

In [37]:
gemma_lm.summary(line_length=80)

In [38]:
gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.Adam(5e-5),
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(train_ds, validation_data=val_ds, epochs=1)

[1m5172/5172[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m478s[0m 75ms/step - loss: 0.3251 - sparse_categorical_accuracy: 0.5437 - val_loss: 0.3095 - val_sparse_categorical_accuracy: 0.5585


<keras.src.callbacks.history.History at 0x7ceecc741d50>