# Fine-tune Gemma for Text to SQL in Keras using LoRA

This notebook demonstrates how to fine-tune Gemma-2B to answer with SQL query as output given an input consitent of a question in plain English and a schmea of a table — in other words, the model will convert a question we ask in natural language like the following:

```
“How many customers did buy Camembert in the month of August?”
```

And given a schema of the table that looks like this:
```sql
CREATE TABLE purchases (
    purchase_id INT PRIMARY KEY,
    purchase_date DATE,
    customer_id INT,
    product_name VARCHAR(128)
);
```

Into a SQL query that can be run on the `purchases` table to get the actual result:
```sql
SELECT COUNT(DISTINCT customer_id) AS num_customers
FROM purchases
WHERE product_name = 'camembert'
AND EXTRACT(MONTH FROM purchase_date) = 8;
```

We will be using this [Text to SQL Dataset](https://huggingface.co/datasets/knowrohit07/know_sql) in the tuning process.

> Note: Whilst this notebook shows primarily how to fine-tune for the Text to SQL task, the approach can be easily adapted the tuning Gemma-2B for other tasks.

In [1]:
import numpy as np 
import pandas as pd 

import os

In [2]:
pd.set_option('display.max_colwidth', None)

In [3]:
import warnings
warnings.filterwarnings("ignore")

In [4]:
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

/kaggle/input/gemma/keras/gemma_2b_en/2/config.json
/kaggle/input/gemma/keras/gemma_2b_en/2/tokenizer.json
/kaggle/input/gemma/keras/gemma_2b_en/2/metadata.json
/kaggle/input/gemma/keras/gemma_2b_en/2/model.weights.h5
/kaggle/input/gemma/keras/gemma_2b_en/2/assets/tokenizer/vocabulary.spm
/kaggle/input/inputs/know_sql_val3ign.json
/kaggle/input/data-assistants-with-gemma/submission_categories.txt
/kaggle/input/data-assistants-with-gemma/submission_instructions.txt


In [5]:
!pip install -q -U keras-nlp
!pip install -q -U keras>=3

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow 2.15.0 requires keras<2.16,>=2.15.0, but you have keras 3.0.5 which is incompatible.
tensorflowjs 4.16.0 requires packaging~=23.1, but you have packaging 21.3 which is incompatible.[0m[31m
[0m

In [6]:
import os

os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

In [7]:
import keras
import keras_nlp

2024-02-24 16:15:38.127399: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-24 16:15:38.127534: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-24 16:15:38.272198: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


> Add model from right hand side by add model button and add `gemma_2b_en` Model (9.34 GB)

In [8]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
gemma_lm.summary()

Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


In [9]:
!pip install datasets



In [10]:
import pandas as pd

# Read JSON file into a DataFrame
df = pd.read_json('/kaggle/input/inputs/know_sql_val3ign.json')

# Write the DataFrame to a Parquet file
df.to_parquet('output.parquet')


In [11]:
raw_df = pd.read_parquet("output.parquet")

In [12]:
raw_df.head()


Unnamed: 0,question,context,answer
0,how many district with incumbent being lindy boggs,"CREATE TABLE table_1341586_19 (district VARCHAR, incumbent VARCHAR)","SELECT COUNT(district) FROM table_1341586_19 WHERE incumbent = ""Lindy Boggs"""
1,what's the result with candidates being billy tauzin (d) unopposed,"CREATE TABLE table_1341586_19 (result VARCHAR, candidates VARCHAR)","SELECT result FROM table_1341586_19 WHERE candidates = ""Billy Tauzin (D) Unopposed"""
2,how many candidates with result being retired to run for u. s. senate republican hold,"CREATE TABLE table_1341586_19 (candidates VARCHAR, result VARCHAR)","SELECT COUNT(candidates) FROM table_1341586_19 WHERE result = ""Retired to run for U. S. Senate Republican hold"""
3,what's the result with district being louisiana 2,"CREATE TABLE table_1341586_19 (result VARCHAR, district VARCHAR)","SELECT result FROM table_1341586_19 WHERE district = ""Louisiana 2"""
4,who is the the candidates with first elected being 1977,"CREATE TABLE table_1341586_19 (candidates VARCHAR, first_elected VARCHAR)",SELECT candidates FROM table_1341586_19 WHERE first_elected = 1977


In [13]:
def clean(text):
    return text.replace(u'\xa0', u' ').strip()

def prompt_fn(row):
    template = "Question:\n{question}\nContext:\n{context}\n\nAnswer:\n{answer}"
    prompt = template.format(
      question=clean(row['question']), context=clean(row['context']), answer=clean(row['answer'])
    )
    return prompt

In [14]:
LIMIT = 2000
data = raw_df[:LIMIT].apply(prompt_fn,axis=1).values.tolist()

print(data[0])

Question:
how many district with incumbent being lindy boggs
Context:
CREATE TABLE table_1341586_19 (district VARCHAR, incumbent VARCHAR)

Answer:
SELECT COUNT(district) FROM table_1341586_19 WHERE incumbent = "Lindy Boggs"


In [15]:
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

In [16]:
# Limit the input sequence length to 512 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 512
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)

[1m2000/2000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2629s[0m 1s/step - loss: 0.2184 - sparse_categorical_accuracy: 0.6938


<keras.src.callbacks.history.History at 0x7e683429fe50>

In [17]:
print(prompt_fn(raw_df.loc[LIMIT]))

Question:
What is the SFC when the specific impulse is 453?
Context:
CREATE TABLE table_15944_5 (sfc_in_g__kn·s_ VARCHAR, specific_impulse__s_ VARCHAR)

Answer:
SELECT sfc_in_g__kn·s_ FROM table_15944_5 WHERE specific_impulse__s_ = 453


In [18]:
row = {
  "question": "What kind of competition was it at San Siro at 18:30 GMT?",
  "context": "CREATE TABLE table_name_60 (competition VARCHAR, ground VARCHAR, time VARCHAR)",
  "answer": ""
}

prompt = prompt_fn(row)
print(gemma_lm.generate(prompt, max_length=256))
print('Expected: SELECT competition FROM table_name_60 WHERE ground = "san siro" AND time = "18:30 gmt"')

Question:
What kind of competition was it at San Siro at 18:30 GMT?
Context:
CREATE TABLE table_name_60 (competition VARCHAR, ground VARCHAR, time VARCHAR)

Answer:
SELECT competition FROM table_name_60 WHERE ground = "San Siro" AND time = "18:30 GMT"
Expected: SELECT competition FROM table_name_60 WHERE ground = "san siro" AND time = "18:30 gmt"
