# Instruct Fine-Tuning Gemma for 14 Languages
This notebook demonstrates the fine-tuning of the Gemma model on 14 datasets. We will explore the workflow from data loading and preprocessing to model fine-tuning and evaluation.

**Key Steps:**
1. Setup environment variables for Kaggle and Weights & Biases (wandb).
2. Load and preprocess the 14 Instruct dataset.
3. Set up model parallelism for TPU utilization.
4. Fine-tune the Gemma model using LoRA (Low-Rank Adaptation).
5. Evaluate model performance before and after fine-tuning.

the used languages:

`Spanish`,`Iranian Persian`,`Japanese`,`Korean`,`Russian`,`German`,`Swedish`,
                     `Simplified Chinese`, `Danish`, `English(american and british)`, `Finnish`,`Italian`, `Dutch`,
                     `Turkish`

##### you can look into the fine-tuning process logs in here: [link](https://wandb.ai/this-is-the-way-2005-independent/fine-tuning-gemma2_2b_instruct_Polyglot)

#### Device:
we used the TPU VM v3-8 from kaggle.
#### Base model:
we used the fine-tuned version of the gemma2_2b_en which is fine-tuned on 54 datasets(multilingual).                                                
The model [link](https://www.kaggle.com/models/mahdiseddigh/gemma2/keras/gemma2_2b_polyglot)

The fine-tuning notebook: [link](https://www.kaggle.com/code/mahdiseddigh/fine-tuning-gemma2-2b-polyglot)

### My Gemma2 cookbook:
I made this repo and I'm uploading all notebooks related to working with gemma models, check it out:
https://github.com/Mhdaw/Gemma2

### Step 0: Installing the Required Libraries and Frameworks
To ensure that all necessary libraries and frameworks are installed, run the following commands:

In [None]:
!pip install -q -U keras-nlp keras datasets kagglehub keras_hub 
!pip install -q -U tensorflow-text
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
!pip install -q -U tensorflow-cpu
!pip install -q -U wandb

In [None]:
import jax
jax.devices()

## Step 1: Setup Environment Variables
We will configure the environment variables required for:
- Kaggle API access
- Weights & Biases for tracking experiments
- TensorFlow backend optimization.


In [None]:
import os

os.environ["KAGGLE_USERNAME"] = "your-kaggle-username"
os.environ["KAGGLE_KEY"] = "your-kaggle-key"
os.environ["WANDB_API_KEY"] = "your-wandb-key"
os.environ["KERAS_BACKEND"] = "jax"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00"

In [None]:
#print(f"num cpus:{os.cpu_count()}") 

In [None]:
import tensorflow as tf
import keras
import keras_nlp
from datasets import load_dataset
import itertools
import wandb
from wandb.integration.keras import WandbMetricsLogger
import matplotlib.pyplot as plt

## Step 2: Load and Explore Korean Dataset
We are using the `CohereForAI/aya_dataset` dataset. 

**Subtasks:**
- Load training and validation datasets for each language and then concatenante them into a general dataset.
- Extract sample data for exploration.
- Limit dataset size for efficient experimentation.


Since we want to instruct fine-tune the Gemma 2 2b model for adapting to the 14 languages, we need a good amount of high-quality multilingual instruct and responses. For that, we use the 'aya_dataset' dataset, which is a multilingual instruct dataset.

You can look into it on Hugging Face: [Link](https://huggingface.co/datasets/CohereForAI/aya_dataset)  

**Dataset Summary (from the original dataset page):**  
The Aya Dataset is a multilingual instruction fine-tuning dataset curated by an open-science community via Aya Annotation Platform from Cohere For AI. The dataset contains a total of 204k human-annotated prompt-completion pairs along with the demographics data of the annotators.
This dataset can be used to train, finetune, and evaluate multilingual LLMs.

Curated by: Contributors of Aya Open Science Intiative.

Language(s): 65 languages (71 including dialects & scripts).

License: Apache 2.0

In [None]:
def load_and_process_aya(languages, max_examples=None):
    """Loads and processes the AYA dataset for multiple languages.

    Args:
        languages: A list of target languages.
        max_examples: Maximum number of examples to load per language. If None, loads all.

    Returns:
        A dictionary where keys are languages and values are dictionaries
        containing 'train' and 'test' lists of processed text data.
        Returns an empty dictionary if there are errors.
    """

    try:
        aya_dataset = load_dataset("CohereForAI/aya_dataset")
    except Exception as e:
        print(f"Error loading AYA dataset: {e}")
        return {}

    all_lang_data = {}

    for lang in languages:
        try:
            print(f"Processing {lang} data...")
            selected_dataset = aya_dataset.filter(lambda x: x['language'] == lang)

            Data = []
            for example in selected_dataset["train"]:
                instruction = example["inputs"]
                response = example["targets"]
                template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
                Data.append(template.format(**{"instruction": instruction, "response": response}))

            if max_examples is not None:
                Data = Data[:max_examples]

            test_data = Data[-min(75, len(Data)):]  # Handle cases with fewer than 75 examples
            train_data = Data[:-min(75, len(Data))]
            if len(train_data)>0:
                all_lang_data[lang] = {"train": train_data, "test": test_data}
                print(f"Loaded {len(train_data)} training examples for {lang}")
            else:
                print(f"Not enough data to create train/test split for {lang}")

        except Exception as e:
            print(f"Error processing {lang} data: {e}")
            continue # Skip to the next language

    return all_lang_data

In [None]:
languages_to_load = ["Spanish","Iranian Persian","Japanese","Korean","Russian","German","Swedish",
                     "Simplified Chinese", "Danish", "English", "Finnish","Italian", "Dutch",
                     "Turkish"]
# we use 14 languages and 1000 train example for each, Note some languages have less than 1000 exmaples.
loaded_data = load_and_process_aya(languages_to_load, max_examples=1000)

In [None]:
full_data_train = []
full_data_val = []
for lang, data in loaded_data.items():
    full_data_train.extend(data["train"])
    full_data_val.extend(data["test"])
  
print(f"Total train examples: {len(full_data_train)}")
print(f"Total test examples: {len(full_data_val)}")

## Step 3: Data Preprocessing
The text data will be converted into TensorFlow datasets for training and validation. Key preprocessing steps include:
- Creating TensorFlow datasets from plain-text lists.
- Shuffling and batching training data for optimized input.
- Optional text cleaning (if needed).


In [None]:
batch_size = 4

# Convert the lists of text data to TensorFlow datasets
train_data = tf.data.Dataset.from_tensor_slices(full_data_train)
val_data = tf.data.Dataset.from_tensor_slices(full_data_val)

# Preprocess each text sample
def preprocess_text(text):
    return tf.convert_to_tensor(text, dtype=tf.string)

# Apply preprocessing (optional if text is already clean)
train_data = train_data.map(preprocess_text)
val_data = val_data.map(preprocess_text)

# Shuffle and batch the training data
train_data = train_data.shuffle(buffer_size=1000).batch(batch_size)
val_data = val_data.batch(batch_size)

## Step 4: Model Parallelism for Efficient Training and Loading the model
We configure model parallelism using TPUs to handle the large-scale Gemma model. Key components:
- **Device Mesh:** A mapping of TPU devices.
- **Layout Map:** Specifies the sharding strategy for different layers.
- Then we load the model in parallel devices.


## Step 5: Model Overview
We initialize the Gemma model for fine-tuning and explore its architecture.

### Key Model Parameters:
- **Model ID:** Pretrained Gemma version for transfer learning.
- **LoRA:** Enable Low-Rank Adaptation for fine-tuning.
- **Sequence Length:** Adjusted for task requirements.


**Note: the device mesh for 9b and 2b model is different! use accordingly.**

In [None]:
# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices(),
)

model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Assuming the second dimension (2304) of the attention kernels is divisible by 8,
# we shard along that dimension:
layout_map["token_embedding/embeddings"] = (model_dim, None)  # Shard embeddings along model dimension
layout_map["decoder_block.*attention.*(query|key|value)/kernel"] = (None, model_dim, None) # Shard attention kernels along second dimension
layout_map["decoder_block.*attention_output/kernel"] = (None, model_dim, None)  # Shard attention output kernels along second dimension
layout_map["decoder_block.*ffw_gating.*/kernel"] = (None, model_dim)  # Shard FFN gating kernels along second dimension
layout_map["decoder_block.*ffw_linear/kernel"] = (model_dim, None) # Shard FFN linear kernels along first dimension

model_parallel = keras.distribution.ModelParallel(
    layout_map=layout_map,
    batch_dim_name="batch",
)

keras.distribution.set_distribution(model_parallel)
model_id = "/kaggle/input/gemma2/keras/gemma2_2b_polyglot/1" # change this if you want
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(model_id)
gemma_lm.summary()

In [None]:
template = "Instruction:\n{instruction}\n\nResponse:\n{response}"

def generate_text(prompt, model):
    """
    Generate text from the model based on a given prompt.
    """
    sampler = keras_nlp.samplers.TopKSampler(k=15, seed=2)
    model.compile(sampler=sampler)
    output = model.generate(prompt, max_length=512)
    return output

## Step 6: Evaluate Model Performance Before Fine-Tuning
Before training, test the model on a set of prompts to benchmark its initial performance. This helps us compare improvements after fine-tuning.


In [None]:
test_prompts_multilingual = [
    # English (American)
    "What's a fun fact about the Grand Canyon?",  # US English
    "Write a short story about a talking dog.", # US English
    # Arabic (Modern Standard)
    "ما هي أقدم مدينة في العالم؟", # What is the oldest city in the world?
    "اكتب جملة عن أهمية القراءة.", # Write a sentence about the importance of reading.
    # Chinese (Simplified)
    "长城有多长？", # How long is the Great Wall?
    "写一个关于猫的短篇故事。", # Write a short story about a cat.
    # Dutch
    "Wat is de hoofdstad van Nederland?", # What is the capital of the Netherlands?
    "Schrijf een korte beschrijving van een molen.", # Write a short description of a windmill.
    # French (European)
    "Quelle est la capitale de la France ?", # What is the capital of France?
    "Écris une courte description de la Tour Eiffel.", # Write a short description of the Eiffel Tower.
    # German
    "Was ist die Hauptstadt von Deutschland?", # What is the capital of Germany?
    "Schreibe eine kurze Beschreibung des Brandenburger Tors.", # Write a short description of the Brandenburg Gate.
    # Italian
    "Qual è la capitale d'Italia?", # What is the capital of Italy?
    "Scrivi una breve descrizione del Colosseo.", # Write a short description of the Colosseum.
    # Japanese
    "日本の首都はどこですか？", # What is the capital of Japan?
    "桜について短い文章を書いてください。", # Please write a short sentence about cherry blossoms.
    # Korean
    "한국의 수도는 어디입니까?", # What is the capital of Korea?
    "한국 음식에 대해 간단히 설명해 주세요.", # Please briefly explain about Korean food.
    # Polish
    "Jaka jest stolica Polski?", # What is the capital of Poland?
    "Napisz krótkie opowiadanie o smoku.", # Write a short story about a dragon.
    # Portuguese (Brazilian)
    "Qual é a capital do Brasil?", # What is the capital of Brazil?
    "Escreva uma breve descrição do Cristo Redentor.", # Write a short description of Christ the Redeemer.
    # Russian
    "Какая столица России?", # What is the capital of Russia?
    "Напишите короткий рассказ о медведе.", # Write a short story about a bear.
    # Spanish (European)
    "¿Cuál es la capital de España?", # What is the capital of Spain?
    "Escribe una breve descripción de la Sagrada Familia.", # Write a short description of the Sagrada Familia.
    # Thai
    "ประเทศไทยมีเมืองหลวงชื่ออะไร", # What is the capital of Thailand?
    "เขียนประโยคสั้นๆ เกี่ยวกับวัดไทย", # Write a short sentence about Thai temples.
    # Turkish
    "Türkiye'nin başkenti neresidir?", # What is the capital of Turkey?
    "Kapadokya hakkında kısa bir açıklama yazın.", # Write a short description about Cappadocia.
    # Ukrainian
    "Яка столиця України?", # What is the capital of Ukraine?
    "Напишіть коротку розповідь про кота.", # Write a short story about a cat.
    #Vietnamese
    "Thủ đô của Việt Nam là gì?", # What is the capital of Vietnam?
    "Hãy viết một câu ngắn về Vịnh Hạ Long.", # Please write a short sentence about Ha Long Bay.
    #Persian
    "پایتخت ایران کجاست؟", #Where is the capital of Iran?
    "یک جمله کوتاه درباره حافظ بنویسید", #Write a short sentence about Hafez
]

for prompt in test_prompts_multilingual:
    print(f"\n--- Model Output Before Fine-tuning for prompt: {prompt} ---") 
    print(generate_text(template.format(instruction=prompt, response=""), gemma_lm))
    print("\n")

## Step 7: Fine-Tuning the Gemma Model with LoRA
We apply LoRA to enable efficient parameter updates during fine-tuning. Key configurations include:
- Optimizer: AdamW with weight decay for transformer models.
- Metrics: Sparse Categorical Accuracy.
- LoRA Rank: Defines the dimensionality of updates.

We use Weights & Biases to monitor training progress and metrics.


In [None]:
LoRA_rank = 8 # you can modify this 
# Enable LoRA for the model and set the LoRA rank to 2,4,...
gemma_lm.backbone.enable_lora(rank=LoRA_rank)
gemma_lm.summary()

In [None]:
gemma_lm.preprocessor.sequence_length = 512
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.02,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

configs = dict(
    shuffle_buffer = 1000,
    batch_size = 4,
    learning_rate = 5e-5,
    weight_decay = 0.02,
    sequence_length = 512,
    epochs = 16
)

wandb.init(project = "fine-tuning-gemma2_2b_instruct_Polyglot",
    config=configs
)

### Step 8: Training the gemma model:
we train the gemma language model on our ```train_data``` and evaluate it on our ```val_data```, to save time and computation lets use small epochs like 20, If you have more time and computation available, go ahead and increase this!

In [None]:
# Fit the model
history = gemma_lm.fit(train_data, validation_data=val_data, epochs=16, verbose=0, callbacks=[WandbMetricsLogger()])
print("Training finished....")

## Step 9: Evaluate Model Performance After Fine-Tuning
Finally, evaluate the fine-tuned model using the same prompts as earlier. Compare the responses to assess improvements in quality and relevance.


In [None]:
test_prompts_multilingual = [
    # English (American)
    "What's a fun fact about the Grand Canyon?",  # US English
    "Write a short story about a talking dog.", # US English
    # Arabic (Modern Standard)
    "ما هي أقدم مدينة في العالم؟", # What is the oldest city in the world?
    "اكتب جملة عن أهمية القراءة.", # Write a sentence about the importance of reading.
    # Chinese (Simplified)
    "长城有多长？", # How long is the Great Wall?
    "写一个关于猫的短篇故事。", # Write a short story about a cat.
    # Dutch
    "Wat is de hoofdstad van Nederland?", # What is the capital of the Netherlands?
    "Schrijf een korte beschrijving van een molen.", # Write a short description of a windmill.
    # French (European)
    "Quelle est la capitale de la France ?", # What is the capital of France?
    "Écris une courte description de la Tour Eiffel.", # Write a short description of the Eiffel Tower.
    # German
    "Was ist die Hauptstadt von Deutschland?", # What is the capital of Germany?
    "Schreibe eine kurze Beschreibung des Brandenburger Tors.", # Write a short description of the Brandenburg Gate.
    # Italian
    "Qual è la capitale d'Italia?", # What is the capital of Italy?
    "Scrivi una breve descrizione del Colosseo.", # Write a short description of the Colosseum.
    # Japanese
    "日本の首都はどこですか？", # What is the capital of Japan?
    "桜について短い文章を書いてください。", # Please write a short sentence about cherry blossoms.
    # Korean
    "한국의 수도는 어디입니까?", # What is the capital of Korea?
    "한국 음식에 대해 간단히 설명해 주세요.", # Please briefly explain about Korean food.
    # Polish
    "Jaka jest stolica Polski?", # What is the capital of Poland?
    "Napisz krótkie opowiadanie o smoku.", # Write a short story about a dragon.
    # Portuguese (Brazilian)
    "Qual é a capital do Brasil?", # What is the capital of Brazil?
    "Escreva uma breve descrição do Cristo Redentor.", # Write a short description of Christ the Redeemer.
    # Russian
    "Какая столица России?", # What is the capital of Russia?
    "Напишите короткий рассказ о медведе.", # Write a short story about a bear.
    # Spanish (European)
    "¿Cuál es la capital de España?", # What is the capital of Spain?
    "Escribe una breve descripción de la Sagrada Familia.", # Write a short description of the Sagrada Familia.
    # Thai
    "ประเทศไทยมีเมืองหลวงชื่ออะไร", # What is the capital of Thailand?
    "เขียนประโยคสั้นๆ เกี่ยวกับวัดไทย", # Write a short sentence about Thai temples.
    # Turkish
    "Türkiye'nin başkenti neresidir?", # What is the capital of Turkey?
    "Kapadokya hakkında kısa bir açıklama yazın.", # Write a short description about Cappadocia.
    # Ukrainian
    "Яка столиця України?", # What is the capital of Ukraine?
    "Напишіть коротку розповідь про кота.", # Write a short story about a cat.
    #Vietnamese
    "Thủ đô của Việt Nam là gì?", # What is the capital of Vietnam?
    "Hãy viết một câu ngắn về Vịnh Hạ Long.", # Please write a short sentence about Ha Long Bay.
    #Persian
    "پایتخت ایران کجاست؟", #Where is the capital of Iran?
    "یک جمله کوتاه درباره حافظ بنویسید", #Write a short sentence about Hafez the iranian poet 
]

for prompt in test_prompts_multilingual:
    print(f"\n--- Model Output After Fine-tuning for prompt: {prompt} ---") 
    print(generate_text(template.format(instruction=prompt, response=""), gemma_lm))
    print("\n")

#### If you look into our examples and compare it, you can see the models generation has improved for our languages, Some more than others.
Note: since this is a fine-tuned model of a base gemma(fine-tuned for 54 languages) model and used instruct and response text in target language, we can expect some randomness and other things from its answers, as it has been fine-tuned on a small instruct datasets(less than 10000 examples) and for saving computation we limit the LoRA rank and epochs.


### Step 10: Uploading the fine-tuned model to kaggle:
Here we upload the final fine-tuned model to kaggle models so every one can use it!.
we use /kaggle/tmp to save the model, as the model size is larger than kaggle notebooks output directory size.

In [None]:
tmp_model_dir = "/kaggle/tmp/gemma2_2b_instruct_Polyglot"  # Use /kaggle/tmp
preset_dir = "gemma2_2b_instruct_Polyglot"
os.makedirs(tmp_model_dir, exist_ok=True)
gemma_lm.save_to_preset(tmp_model_dir)

print(f"Model saved to: {tmp_model_dir}")

In [None]:
import kagglehub
import keras_hub
if "KAGGLE_USERNAME" not in os.environ or "KAGGLE_KEY" not in os.environ:
    kagglehub.login()

model_version = 1
kaggle_username = kagglehub.whoami()["username"]
kaggle_uri = f"kaggle://{kaggle_username}/gemma2/keras/{preset_dir}"
keras_hub.upload_preset(kaggle_uri, tmp_model_dir)
print("Done!")

# Inference
Here we talk about how we can load the fine-tuned model from kaggle and use it:

**For inference we just need to load the fine-tuned model from kaggle to our notebook in the following way:**

for more info check out [here](https://keras.io/api/keras_nlp/models/gemma/gemma_causal_lm/)

specificly:

A preset is a directory of configs, weights and other file assets used to save and load a pre-trained model. The preset can be passed as one of:
* 1. 
a built-in preset identifier like 'bert_base_e
* 2. '
a Kaggle Models handle like 'kaggle://user/bert/keras/bert_base_
* 3. n'
a Hugging Face handle like 'hf://user/bert_base
* 4. en'
a path to a local preset directory like './bert_base_en'

**Infrence step by step:**
* 1. Load the fine-tuned model from kaggle models
* 2. After the model is succesfuly loaded, You can use it to generate text in the targeted language
* Good luck:)

In [None]:
final_model_id = "kaggle://mahdiseddigh/gemma2/keras/gemma2_2b_instruct_Polyglot"
finetuned_gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(final_model_id)
finetuned_gemma_lm.summary()

In [None]:
test_prompt = #define your prompt...
print("\n--- Fine-tuned Models Output ---")
print(generate_text(template.format(instruction=test_prompt, response=""), finetuned_gemma_lm))

# Conclusion
This notebook showcased the complete workflow for fine-tuning the (fine-tuned)Gemma model for 14 Instruct datasets. We highlighted:
- Dataset preparation
- Model architecture and parallelism
- Fine-tuning with LoRA
- Performance evaluation pre- and post-training