# Finetuning Gemma2 for multilingual translation

The goal of this notebook submission for the [Google - Unlock Global Communication with Gemma](https://www.kaggle.com/competitions/gemma-language-tuning) competition is to demonstrate how to finetune Gemma2 for the task of translating between multiple languages (currently 8), using a curated bible dataset.



# Bible Dataset

Here we load in a multilingual dataset of full text bibles, which we'll use for finetuning on a translation tasks in various languages. 

I'm using the Multilingual Full Text Bible dataset which can be found [here](https://www.kaggle.com/datasets/jordanyoung993/multi-lingual-full-text-bible). 

This dataset was created by myself using [pybible](https://github.com/Jordan-M-Young/pybible) a custom built python library that I also authored myself. 

This library wraps the API exposed by https://scripture.api.bible/ a bible repository website. Bible texts are compiled by this library and then merged by verse. 

The following table shows which languagues are currently supported by this dataset.

- English
- Polish
- Czech
- Thai
- Persian
- Italian
- Gujarati
- Swahili



# Why the Bible?

Aside from any artistic and historic merits, the bible is a lexically diverse text containing ~ 14,000 unique words (in English) . The bible is also a very structured text, it is organized into books, chapters, and at its most granular, verses. Verses are the atoms that compose a bible and each specific verse, regardless of bible version, contains essentially the same nugget of information. This makes the bible a useful tool as bible verses provide discrete chunks of text that convey more or less the same semantic information regardless of language. Thus we use the bible here to fine-tune gemma2 on translation tasks.


# Evaluation

Aside from qualitative assesments, this notebook will also quantitatively evaluate the responses of the base model and the finetuned respectively. I will use the following metrics to do this: 
- Jaccard similarity
- [BLEU](https://en.wikipedia.org/wiki/BLEU)
- [METEOR](https://en.wikipedia.org/wiki/METEOR)
- [NIST](https://en.wikipedia.org/wiki/NIST_(metric))

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
#By Gabriel Preda https://www.kaggle.com/code/gpreda/fine-tuning-gemma-2-model-using-lora-and-keras/notebook
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install -q -U keras-nlp
!pip install -q -U keras>=3
!pip install -q -U kagglehub --upgrade
!pip install tqdm
!pip install -U nltk

In [None]:
# get bible text data
import json
import tqdm

with open("/kaggle/input/multi-lingual-full-text-bible/merged_bibles.json",'r',encoding='utf-8') as jfile:
    data = json.loads(jfile.read())

In [None]:
# check data loaded...

data['GEN.1.1']


In [None]:
#generate sentence pair rows
def generate_sentence_pairs(data: dict):
    forward_pairs = []
    reverse_pairs = []
    for verse_id, verse_texts in data.items():
        for idx, verse in enumerate(verse_texts):
            for jdx, verse2 in enumerate(verse_texts):
                if idx >= jdx:
                    continue
                
                lang_a = list(verse.keys())[0]
                text_a = list(verse.values())[0]
                
                lang_b = list(verse2.keys())[0]
                text_b = list(verse2.values())[0]
                
                row_a = [verse_id, lang_a,text_a, lang_b, text_b]
                row_b =  [verse_id, lang_b, text_b, lang_a, text_a]
                
                forward_pairs.append(row_a)
                reverse_pairs.append(row_b)
                
    return forward_pairs, reverse_pairs



forward_pairs, reverse_pairs = generate_sentence_pairs(data)

print("Forward Pair:",forward_pairs[0])
print("Reverse Pair:",reverse_pairs[0])



In [None]:
import gc

train_queries = []
for pair in forward_pairs[:300]:
    
    language_a = pair[1].split()[0].replace(",","")
    language_b = pair[3].split()[0].replace(",","")
    text_a = pair[2]
    text_b = pair[4]
    query = f"Translate: {text_a}\nFrom {language_a} to {language_b}\nResponse:\n{text_b}"
    train_queries.append(query)
    
# del(data)
gc.collect()

In [None]:
import os

os.environ["KERAS_BACKEND"] = "jax" # you can also use tensorflow or torch
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00" # avoid memory fragmentation on JAX backend.
os.environ["JAX_PLATFORMS"] = ""
import keras
import keras_nlp


import kagglehub


#Make yours and Add copy to clipboard
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("KAGGLE_KEY")
secret_value_1 = user_secrets.get_secret("username")



keras.utils.set_random_seed(777)


# Compiling the model

In [None]:

gemma_causal_lm = keras_nlp.models.GemmaCausalLM.from_preset('gemma2_2b_en')
gemma_causal_lm.backbone.enable_lora(rank=4)


gemma_causal_lm.preprocessor.sequence_length = 100

gemma_causal_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.Adam(learning_rate=0.001),
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# Building a small test set.

In [None]:
#generate a small test set
count = 0
test_queries = []
test_answers = []
l = len(forward_pairs)
for pair in forward_pairs[500:l-1]:
    language_a = pair[1].split()[0].replace(",","")
    if language_a != "English":

        continue
    language_b = pair[3].split()[0].replace(",","")
    text_a = pair[2]
    text_b = pair[4]
    query = f"Translate: {text_a}\nFrom {language_a} to {language_b}\nResponse:"
    target_answer = text_b
    test_queries.append(query)
    test_answers.append(target_answer)
    count += 1
    
    if count == 7:
        break



In [None]:
#generate translation completions with base model

responses = gemma_causal_lm.generate(
            test_queries,
            max_length=100)

base_predict_answers = []
for response in responses:
    
    answer = response.split("Response:")[-1].replace("\n","")
    base_predict_answers.append(answer)

# Collecting evaluation metrics for base model

In [None]:
from nltk.translate import nist_score, bleu_score

def jaccard_similarity(set_a: set, set_b: set) -> float:
    return len(set_a.intersection(set_b)) / len(set_a.union(set_b))


base_evaluation_data = []
chencherry = bleu_score.SmoothingFunction()
for idx, target_answer in enumerate(test_answers):
    set_a = set(base_predict_answers[idx].split())
    set_b = set(target_answer.split())
    try:
        x = nist_score.sentence_nist([base_predict_answers[idx].split()],target_answer.split())
        y = bleu_score.sentence_bleu([base_predict_answers[idx].split()],target_answer.split(),smoothing_function=chencherry.method1)
        z = jaccard_similarity(set_a, set_b)
        base_evaluation_data.append({"query":test_queries[idx],"bleu":y,"nist":x,"jacc":z})
        
        print("Jaccard Similarity",z, "NIST",x, "BLEU",y)
    except:
        z = jaccard_similarity(set_a, set_b)
        print("Jaccard Similarity",z)
        base_evaluation_data.append({"query":test_queries[idx],"bleu":0,"nist":0,"jacc":z})


In [None]:
# qualitative look at base model performance
for idx, target_answer in enumerate(test_answers):
    print("TARGET_RESPONSE:",target_answer)
    print("PREDICT_RESPONSE:",base_predict_answers[idx],"\n\n")

# Finetuning the model

In [None]:
#train model
hist = gemma_causal_lm.fit(train_queries, epochs=2, batch_size=5)

In [None]:
if not os.path.isdir("../tmp"):
    os.mkdir("../tmp")
gemma_causal_lm.save('../tmp/gemma2_multilingual_translation_task_v1',zipped=False)

MODEL_SLUG = "gemma2"
VARIATION_SLUG = 'multilingual_translation'
kagglehub.model_upload(
  handle = f"jordanyoung993/{MODEL_SLUG}/keras/{VARIATION_SLUG}",
  local_model_dir = "../tmp/gemma2_multilingual_translation_task_v1",
  version_notes = 'Update 2024-10-15')


In [None]:
import matplotlib.pyplot as plt

plt.plot(list(range(len(hist.history['loss']))),hist.history['loss'])
plt.xlabel("Batch")
plt.ylabel("Loss")

In [None]:
plt.plot(list(range(len(hist.history['loss']))),hist.history['sparse_categorical_accuracy'])
plt.xlabel("Batch")
plt.ylabel("Sparse Categorical Accuracy")

In [None]:
    
responses = gemma_causal_lm.generate(
            test_queries,
            max_length=100)

finetuned_predict_answers = []
for response in responses:
    answer = response.split("Response:")[-1].replace("\n","")
    finetuned_predict_answers.append(answer)


# Collecting evaluation metrics for finetuned model



In [None]:
from nltk.translate import nist_score, bleu_score

finetuned_evaluation_data = []

# evaulate trained model on translation task.
for idx, target_answer in enumerate(test_answers):
    set_a = set(finetuned_predict_answers[idx].split())
    set_b = set(target_answer.split())
    try:
        x = nist_score.sentence_nist([finetuned_predict_answers[idx].split()],target_answer.split())
        y = bleu_score.sentence_bleu([finetuned_predict_answers[idx].split()],target_answer.split(),smoothing_function=chencherry.method1)
        z = jaccard_similarity(set_a, set_b)
        print("Jaccard Similarity",z, "NIST",x, "BLEU",y)
        finetuned_evaluation_data.append({"query":test_queries[idx],"bleu":y,"nist":x,"jacc":z})
    except:
        z = jaccard_similarity(set_a, set_b)
        print("Jaccard Similarity",z)
        finetuned_evaluation_data.append({"query":test_queries[idx],"bleu":0,"nist":0,"jacc":z})

# print(finetuned_evaluation_data)

In [None]:
# qualitative analysis of trained model performance
for idx, target_answer in enumerate(test_answers):
    print("TARGET_RESPONSE:",target_answer)
    print("PREDICT_RESPONSE:",finetuned_predict_answers[idx],"\n\n")

# Results

In [None]:
import matplotlib.pyplot as plt
import numpy as np

languages = []
for query in base_evaluation_data:
    query = query['query']
    language = query.split("\n")[1]
    language = language.replace("From","").replace(" to ","-")
    languages.append(language.replace("English-",""))
    
    

jaccard_scores = {
    "base":[],
    "fine":[]
}

bleu_scores = {
    "base":[],
    "fine":[]
}

nist_scores = {
    "base":[],
    "fine":[]
}

    
for idx, query in enumerate(base_evaluation_data):
    base_jacc = query['jacc']
    base_bleu = query['bleu']
    base_nist = query['nist']
    
    fine_jacc = finetuned_evaluation_data[idx]['jacc']
    fine_bleu = finetuned_evaluation_data[idx]['bleu']
    fine_nist = finetuned_evaluation_data[idx]['nist']
    
    jaccard_scores['base'].append(round(base_jacc,2))
    bleu_scores['base'].append(round(base_bleu,2))
    nist_scores['base'].append(round(base_nist,2))

    jaccard_scores['fine'].append(round(fine_jacc,2))
    bleu_scores['fine'].append(round(fine_bleu,2))
    nist_scores['fine'].append(round(fine_nist,2))

    


In [None]:
def plot_barchart(x_labels: list, data: dict, metric: str, y_lim: float = 1):

    x = np.arange(len(x_labels))  # the label locations
    width = 0.25  # the width of the bars
    multiplier = 0

    fig, ax = plt.subplots(layout='constrained')

    for attribute, measurement in data.items():
        offset = width * multiplier
        rects = ax.bar(x + offset, measurement, width, label=attribute)
        ax.bar_label(rects, padding=3)
        multiplier += 1

    # Add some text for labels, title and custom x-axis tick labels, etc.
    ax.set_ylabel(f"{metric}")
    ax.set_title(f"{metric} By Translation")
    ax.set_xticks(x + width, x_labels)
    ax.legend(loc='upper left', ncols=3)
    ax.set_ylim(0, y_lim)

    plt.show()
    
plot_barchart(languages, jaccard_scores, "Jaccard Similarity")

In [None]:
plot_barchart(languages, bleu_scores, "BLEU",y_lim=0.5)

In [None]:
plot_barchart(languages, nist_scores, "NIST", y_lim=3.0)

In [None]:
print(nist_scores)

# Discussion

Based on data compiled during the evaluation of each model, it seems that finetuning gemma2 on a translation task  improves the model's ability to translate english for some languages. Large improvements are observed in Italian, Polish, and Persian. Middling improvements were seen in Swahili. Little change and perhaps some regression was seen for Czech. The data compiled above is displayed below. Qualitatively, the model seems to have improved at all languages; The base model had trouble generating responses in most of the target languages and was only able to return a poor translation for the Italian language. The translations of the finetuned while not amazing, were at least in the target languages and contained some of the correct words, which suggests the model is in fact learning something about the languages.

# Jaccard Similarity

| Model | English-Italian | English-Swahili  | English-Thai| English-Gujarati| English-Czech| English-Polish| English-Persian|
| --- | --- | --- | --- | --- | --- | --- | --- |
|Base|0.07|0.0|0.0|0.0|0.04|0.0|0.0|
|Finetuned|0.46|0.42|0.0|0.27|0.04|0.31|0.45|


# BLEU

| Model | English-Italian | English-Swahili  | English-Thai| English-Gujarati| English-Czech| English-Polish| English-Persian|
| --- | --- | --- | --- | --- | --- | --- | --- |
|Base|0.02|0.0|0.0|0.0|0.01|0.0|0.0|
|Finetuned|0.26|0.03|0.0|0.03|0.01|0.11|0.06|


# NIST

| Model | English-Italian | English-Swahili  | English-Thai| English-Gujarati| English-Czech| English-Polish| English-Persian|
| --- | --- | --- | --- | --- | --- | --- | --- |
|Base|0.19|0.0|0.0|0.0|0.16|0.0|0.0|
|Finetuned|1.86|0.22|0.0|0.19|0.14|1.11|1.65|


It should be noted that the data displayed here has several limitations (1) only a small subset of the full text bible dataset was used for training (~300 samples) (2) only a small subset of the dataset was used for testing; each category only has one sample so the current results are hardly statistically relevant. (3) Due to the rules of the language it is difficult to measure the efficacy of the finetuned model on the Thai language.

In anycase the data show some minor improvements to the model's ability to translate post finetune; given a larger training set these improvements would likely be much more pronounced. I also used a small (4) LoRA value when compiling the model. Increasing the number of parameters available would also increase translation performance. Stay tuned for updates to this notebook.

--Jordan
