# Instruct Fine-Tuning Gemma for Japanese Language
This notebook demonstrates the fine-tuning of the Gemma model on Japanese datasets. We will explore the workflow from data loading and preprocessing to model fine-tuning and evaluation.

**Key Steps:**
1. Setup environment variables for Kaggle and Weights & Biases (wandb).
2. Load and preprocess the Spanish dataset.
3. Set up model parallelism for TPU utilization.
4. Fine-tune the Gemma model using LoRA (Low-Rank Adaptation).
5. Evaluate model performance before and after fine-tuning.


##### you can look into the fine-tuning process logs in here: [link](https://wandb.ai/this-is-the-way-2005-independent/fine-tuning-gemma2_instruct_9b_ja)

#### Device:
we used the TPU VM v3-8 from kaggle.
#### Base model:
we used the fine-tuned version of the gemma2_9b_en which is fine-tuned on japanese data, The model [link](https://www.kaggle.com/models/mahdiseddigh/gemma2/keras/gemma2_9b_ja)

### My Gemma2 cookbook:
I made this repo and I'm uploading all notebooks related to working with gemma models, check it out:
https://github.com/Mhdaw/Gemma2

### Step 0: Installing the Required Libraries and Frameworks
To ensure that all necessary libraries and frameworks are installed, run the following commands:

In [None]:
!pip install -q -U keras-nlp keras datasets kagglehub keras_hub 
!pip install -q -U tensorflow-text
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
!pip install -q -U tensorflow-cpu
!pip install -q -U wandb

In [None]:
import jax
jax.devices()

## Step 1: Setup Environment Variables
We will configure the environment variables required for:
- Kaggle API access
- Weights & Biases for tracking experiments
- TensorFlow backend optimization.


In [None]:
import os
# Set the environment variables for Kaggle and Weights & Biases.
# from kaggle_secrets import UserSecretsClient if you use kaggle
# from google.colab import userdata if you use google colab
#import getpass if you use jupyter notebook
os.environ["KAGGLE_USERNAME"] = "your-username"# or UserSecretsClient().get_secret(KAGGLE_USERNAME) or userdata.get(KAGGLE_USERNAME) or getpass.getpass("Enter your KAGGLE_USERNAME: ")
os.environ["KAGGLE_KEY"] = "kaggle-api-key" # or UserSecretsClient().get_secret(KAGGLE_KEY) or userdata.get(KAGGLE_KEY) or getpass.getpass("Enter your  KAGGLE_KEY: ")
os.environ["WANDB_API_KEY"] = "wand-api-key" # or UserSecretsClient().get_secret(WANDB_API_KEY) or userdata.get(WANDB_API_KEY) or getpass.getpass("Enter your WANDB_API_KEY: ")
os.environ["KERAS_BACKEND"] = "jax"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00"

In [None]:
import tensorflow as tf
import keras
import keras_nlp
from datasets import load_dataset
import itertools
import wandb
from wandb.integration.keras import WandbMetricsLogger

## Step 2: Load and Explore Spanish Dataset
We are using the `CohereForAI/aya_dataset` dataset with japanese insturct data. 

**Subtasks:**
- Load training and validation datasets.
- Extract sample data for exploration.
- Limit dataset size for efficient experimentation.


Since we want to instruct fine-tune the Gemma 2 9b model for adapting to the Japanese language, we need a good amount of high-quality Japanese instruct and responses. For that, we use the 'Japanese' dataset, which is a multilingual instruct dataset.

You can look into it on Hugging Face: [Link](https://huggingface.co/datasets/CohereForAI/aya_dataset)  

**Dataset Summary (from the original dataset page):**  
The Aya Dataset is a multilingual instruction fine-tuning dataset curated by an open-science community via Aya Annotation Platform from Cohere For AI. The dataset contains a total of 204k human-annotated prompt-completion pairs along with the demographics data of the annotators.
This dataset can be used to train, finetune, and evaluate multilingual LLMs.

Curated by: Contributors of Aya Open Science Intiative.

Language(s): 65 languages (71 including dialects & scripts).

License: Apache 2.0

In [None]:
def load_specific_data(target_language="English"):
    aya_dataset = load_dataset("CohereForAI/aya_dataset")
    selected_dataset = aya_dataset.filter(lambda x: x['language'] == target_language)
    return selected_dataset

def generate_text_data(selected_dataset):
    Data = []
    for example in selected_dataset["train"]:
        instruction = example["inputs"]
        response = example["targets"]
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        Data.append(template.format(**{"instruction": instruction, "response": response}))

    test_data = []
    for example in selected_dataset["test"]:
        instruction = example["inputs"]
        response = example["targets"]
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        test_data.append(template.format(**{"instruction": instruction, "response": response}))
    return Data, test_data

In [None]:
subset = load_specific_data("Japanese")
train_text_data, test_text_data= generate_text_data(subset)

In [None]:
#Since there is no example for our target language in the test subset, we take a small fraction of train data:
train_text_data, test_text_data = train_text_data[:-200], train_text_data[-20:]

In [None]:
# Check the first example to ensure loading is correct
print("First training example:", train_text_data[0],"\n")
print("First validation example:", test_text_data[0])
print(f'\ntraining length:{len(train_text_data)}')

## Step 3: Data Preprocessing
The text data will be converted into TensorFlow datasets for training and validation. Key preprocessing steps include:
- Creating TensorFlow datasets from plain-text lists.
- Shuffling and batching training data for optimized input.
- Optional text cleaning (if needed).


In [None]:
batch_size = 4

# Convert the lists of text data to TensorFlow datasets
train_data = tf.data.Dataset.from_tensor_slices(train_text_data)
val_data = tf.data.Dataset.from_tensor_slices(test_text_data)

# Preprocess each text sample
def preprocess_text(text):
    return tf.convert_to_tensor(text, dtype=tf.string)

# Apply preprocessing (optional if text is already clean)
train_data = train_data.map(preprocess_text)
val_data = val_data.map(preprocess_text)

# Shuffle and batch the training data
train_data = train_data.shuffle(buffer_size=1000).batch(batch_size)
val_data = val_data.batch(batch_size)

## Step 4: Model Parallelism for Efficient Training and Loading the model
We configure model parallelism using TPUs to handle the large-scale Gemma model. Key components:
- **Device Mesh:** A mapping of TPU devices.
- **Layout Map:** Specifies the sharding strategy for different layers.
- Then we load the model in parallel devices.


In [None]:
# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices(),
)

model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in attention layers
layout_map["decoder_block.*attention.*(query|key|value)/kernel"] = (model_dim, None, None)
layout_map["decoder_block.*attention_output/kernel"] = (model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*/kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear/kernel"] = (model_dim, None)

model_parallel = keras.distribution.ModelParallel(
    layout_map=layout_map,
    batch_dim_name="batch",
)

keras.distribution.set_distribution(model_parallel)
model_id = "/kaggle/input/gemma2/keras/gemma2_9b_ja/1" # Or /kaggle/input/m/keras/gemma2/keras/gemma2_instruct_2b_en/2
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(model_id)
gemma_lm.summary()

## Step 5: Model Overview
We initialize the Gemma model for fine-tuning and explore its architecture.

### Key Model Parameters:
- **Model ID:** Pretrained Gemma version for transfer learning.
- **LoRA:** Enable Low-Rank Adaptation for fine-tuning.
- **Sequence Length:** Adjusted for task requirements.


## Step 6: Evaluate Model Performance Before Fine-Tuning
Before training, test the model on a set of prompts to benchmark its initial performance. This helps us compare improvements after fine-tuning.


In [None]:
template = "Instruction:\n{instruction}\n\nResponse:\n{response}"

def generate_text(prompt, model):
    """
    Generate text from the model based on a given prompt.
    """
    sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
    model.compile(sampler=sampler)
    output = model.generate(prompt, max_length=512)
    return output

In [None]:
# Sample prompt to check performance before and after fine-tuning
test_prompts = [
    "こんにちは！今日は कैसाですか？最近学んだ面白いことを教えてください。", # Greeting and request for recent information
    "イタリアのルネサンスの歴史について何を知っていますか？芸術と科学への影響を説明してもらえますか？", # Request for historical knowledge and cultural impact
    "秋の風景についての短い詩を日本語で書いてください。", # Request for poetic creativity
    "人工知能がどのように機能するのか、そして日本で最も一般的な用途は何なのかを簡単な言葉で説明してください。", # Request for technical explanation and geographical context
    "もし誰かが「手に余る」と言ったら、それは何を意味するでしょうか？どのような状況でこの表現を使うことができるでしょうか？", # Request for interpretation of an idiomatic expression
]

for prompt in test_prompts:
    print(f"\n--- Model Output Before Fine-Tuning for prompt: {prompt} ---")
    print(generate_text(template.format(instruction=prompt, response=""), gemma_lm))
    print("\n")

## Step 7: Fine-Tuning the Gemma Model with LoRA
We apply LoRA to enable efficient parameter updates during fine-tuning. Key configurations include:
- Optimizer: AdamW with weight decay for transformer models.
- Metrics: Sparse Categorical Accuracy.
- LoRA Rank: Defines the dimensionality of updates.

We use Weights & Biases to monitor training progress and metrics.


In [None]:
LoRA_rank = 8 # you can modify this 
# Enable LoRA for the model and set the LoRA rank to 2,4,...
gemma_lm.backbone.enable_lora(rank=LoRA_rank)
gemma_lm.summary()

In [None]:
gemma_lm.preprocessor.sequence_length = 512
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.02,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

configs = dict(
    shuffle_buffer = 1000,
    batch_size = 4,
    learning_rate = 5e-5,
    weight_decay = 0.02,
    sequence_length = 512,
    epochs = 20
)

wandb.init(project = "fine-tuning-gemma2_instruct_9b_ja",
    config=configs
)

### Step 8: Training the gemma model:
we train the gemma language model on our ```train_data``` and evaluate it on our ```val_data```, to save time and computation lets use small epochs like 20, If you have more time and computation available, go ahead and increase this!

In [None]:
# Fit the model
history = gemma_lm.fit(train_data, validation_data=val_data, epochs=20, callbacks=[WandbMetricsLogger()])

## Step 9: Evaluate Model Performance After Fine-Tuning
Finally, evaluate the fine-tuned model using the same prompts as earlier. Compare the responses to assess improvements in quality and relevance.


In [None]:
test_prompts = [
    "こんにちは！今日は कैसाですか？最近学んだ面白いことを教えてください。", # Greeting and request for recent information
    "イタリアのルネサンスの歴史について何を知っていますか？芸術と科学への影響を説明してもらえますか？", # Request for historical knowledge and cultural impact
    "秋の風景についての短い詩を日本語で書いてください。", # Request for poetic creativity
    "人工知能がどのように機能するのか、そして日本で最も一般的な用途は何なのかを簡単な言葉で説明してください。", # Request for technical explanation and geographical context
    "もし誰かが「手に余る」と言ったら、それは何を意味するでしょうか？どのような状況でこの表現を使うことができるでしょうか？", # Request for interpretation of an idiomatic expression
]

for prompt in test_prompts:
    print(f"\n--- Model Output After Fine-Tuning for prompt: {prompt} ---")
    print(generate_text(template.format(instruction=prompt, response=""), gemma_lm))
    print("\n")

#### If you look into our examples and compare it, you can see the models generation has improved for our target language.
Note: since this is a fine-tuned model of a base gemma(fine-tuned for Japanese) model and used instruct and response text in target language, we can expect some randomness and other things from its answers, as it has been fine-tuned on a small instruct datasets and for saving computation we limit the LoRA rank and epochs.

### Step 11: Uploading the fine-tuned model to kaggle:
Here we upload the final fine-tuned model to kaggle models so every one can use it!.
we use /kaggle/tmp to save the model, as the model size is larger than kaggle notebooks output directory size.

In [None]:
tmp_model_dir = "/kaggle/tmp/gemma2_instruct_9b_ja"  # Use /kaggle/tmp
preset_dir = "gemma2_instruct_9b_ja"
os.makedirs(tmp_model_dir, exist_ok=True)
gemma_lm.save_to_preset(tmp_model_dir)

print(f"Model saved to: {tmp_model_dir}")

In [None]:
import kagglehub
import keras_hub
if "KAGGLE_USERNAME" not in os.environ or "KAGGLE_KEY" not in os.environ:
    kagglehub.login()

model_version = 1
kaggle_username = kagglehub.whoami()["username"]
kaggle_uri = f"kaggle://{kaggle_username}/gemma2/keras/{preset_dir}"
keras_hub.upload_preset(kaggle_uri, tmp_model_dir)
print("Done!")

# Inference
Here we talk about how we can load the fine-tuned model from kaggle and use it:

**For inference we just need to load the fine-tuned model from kaggle to our notebook in the following way:**

for more info check out [here](https://keras.io/api/keras_nlp/models/gemma/gemma_causal_lm/)

specificly:

A preset is a directory of configs, weights and other file assets used to save and load a pre-trained model. The preset can be passed as one of:
* 1. 
a built-in preset identifier like 'bert_base_e
* 2. '
a Kaggle Models handle like 'kaggle://user/bert/keras/bert_base_
* 3. n'
a Hugging Face handle like 'hf://user/bert_base
* 4. en'
a path to a local preset directory like './bert_base_en'

**Infrence step by step:**
* 1. Load the fine-tuned model from kaggle models
* 2. After the model is succesfuly loaded, You can use it to generate text in the targeted language
* Good luck:)

In [None]:
final_model_id = "kaggle://mahdiseddigh/gemma2/keras/gemma2_instruct_9b_ja"
finetuned_gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(final_model_id)
finetuned_gemma_lm.summary()

In [None]:
test_prompt = #define your prompt...
print("\n--- Fine-tuned Models Output ---")
print(generate_text(template.format(instruction=test_prompt, response=""), finetuned_gemma_lm))

# Conclusion
This notebook showcased the complete workflow for fine-tuning the Gemma model for Japanese Instruct dataset. We highlighted:
- Dataset preparation
- Model architecture and parallelism
- Fine-tuning with LoRA
- Performance evaluation pre- and post-training