## Setup

### Configure your API key

In [None]:
import os
from dotenv import load_dotenv

load_dotenv()

os.environ["KAGGLE_USERNAME"] = os.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = os.get('KAGGLE_KEY')

### Install dependencies

In [None]:
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install -q -U keras-nlp
!pip install -q -U "keras>=3"
!pip install jax jaxlib
!pip install --upgrade tensorflow keras ml-dtypes

### Select a backend

Keras is a high-level, multi-framework deep learning API designed for simplicity and ease of use. Using Keras 3, you can run workflows on one of three backends: TensorFlow, JAX, or PyTorch.

Let's configure the backend for JAX.

In [None]:
os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

### Import packages

Import Keras and KerasNLP.

In [None]:
import keras
import keras_nlp

## Load Dataset

In [None]:
from datasets import load_dataset

# Load the Sinhala-QA dataset
ds = load_dataset("Suchinthana/Sinhala-QA-Translate")

# Prepare the data for fine-tuning
data = []
for example in ds["train"]:
    template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
    formatted_example = template.format(
        instruction=example['Question'], 
        response=example['TranslatedAnswer']
    )
    data.append(formatted_example)

## Load Model

In [None]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")
gemma_lm.summary()

## Inference before fine tuning

In [None]:
prompt = template.format(
    instruction="මෝනාලීසා චිත්‍රය අදින ලද්දේ",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))

## LoRA Fine-tuning

In [None]:
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

In [None]:
# Limit the input sequence length to 256 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 256
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=1e-4,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=20, batch_size=1)

## Inference after fine-tuning

In [None]:
prompt = template.format(
    instruction="මෝනාලීසා චිත්‍රය අදින ලද්දේ",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))