In [1]:
import os
from dotenv import load_dotenv
load_dotenv()

os.environ["KAGGLE_USERNAME"] = os.getenv("KAGGLE_USERNAME")
os.environ["KAGGLE_KEY"] = os.getenv("KAGGLE_KEY")

In [2]:
# select a backend
os.environ["KERAS_BACKEND"] = "jax" # or torch or tensorflow
# avoid memory fragmentation on JAX backend
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00"

In [6]:
# import packages
import keras
import keras_nlp

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
!wget -O databricks-dolly-15k.jsonl "https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl"

zsh:1: command not found: wget


In [None]:
import json
data = []
with open("databricks-dolly-15k.jsonl") as file:
    for line in file:
        features = json.loads(line)
        # filter out examples with context, to keep it simple.
        if features["context"]:
            continue
        # format the entire example a single string.
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        data.append(template.format(**features))
# only use 1000 training examples, to keep it fast.
data = data[:1000]

In [9]:
# load model
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
gemma_lm.summary()

Downloading from https://www.kaggle.com/api/v1/models/keras/gemma/keras/gemma_2b_en/2/download/config.json...


100%|██████████| 555/555 [00:00<00:00, 418kB/s]


In [None]:
prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response=""
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))

In [None]:
prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

In [None]:
# LoRA Fine-tuning
# enable LoRA for the model and set the rank to 4
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

In [None]:
# limit the input sequence length to 512 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 512
# use adamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01
)
# exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics = [keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)

In [None]:
# inference after fine tuning
prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))

In [None]:
prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))