## Prerequisites

Before delving into the fine-tuning process, ensure that you have the following prerequisites in place:

1. **GPU**: [gemma-2b](https://huggingface.co/google/gemma-2b) - can be finetuned on T4(free google colab) while [gemma-7b](https://huggingface.co/google/gemma-7b) requires an A100 GPU.
2. **Python Packages**: Ensure that you have the necessary Python packages installed. You can use the following commands to install them:

Let's begin by checking if your GPU is correctly detected:

In [None]:
!nvidia-smi

Thu Mar  7 11:34:29 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   43C    P0              27W /  70W |    141MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

## Step 2 - Model loading
We'll load the model using QLoRA quantization to reduce the usage of memory


In [None]:
!pip3 install -q -U bitsandbytes==0.42.0
!pip3 install -q -U peft==0.8.2
!pip3 install -q -U trl==0.7.10
!pip3 install -q -U accelerate==0.27.1
!pip3 install -q -U datasets==2.17.0
!pip3 install -q -U transformers==4.38.0

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m105.0/105.0 MB[0m [31m17.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m183.4/183.4 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m280.0/280.0 kB[0m [31m12.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m150.9/150.9 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m510.5/510.5 kB[0m [31m15.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.8/79.8 kB[0m [31m10.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

Now we specify the model ID and then we load it with our previously defined quantization configuration.Now we specify the model ID and then we load it with our previously defined quantization configuration.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# if you are using google colab

# import os
# from google.colab import userdata
# os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN')

In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
# model_id = "google/gemma-7b-it"
# model_id = "google/gemma-7b"
model_id = "google/gemma-2b-it"
# model_id = "google/gemma-2b"

model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0})
tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token=True)

In [None]:
def get_completion(query: str, model, tokenizer) -> str:
  device = "cuda:0"

  prompt_template = """
  <start_of_turn>user
  Below is an instruction that describes a task. Write a response that appropriately completes the request.
  {query}
  <end_of_turn>\n<start_of_turn>model


  """
  prompt = prompt_template.format(query=query)

  encodeds = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)

  model_inputs = encodeds.to(device)


  generated_ids = model.generate(**model_inputs, max_new_tokens=1000, do_sample=True, pad_token_id=tokenizer.eos_token_id)
  # decoded = tokenizer.batch_decode(generated_ids)
  decoded = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
  return (decoded)

In [None]:
result = get_completion(query="code the fibonacci series in python using reccursion", model=model, tokenizer=tokenizer)
print(result)

A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.



  user
  Below is an instruction that describes a task. Write a response that appropriately completes the request.
  code the fibonacci series in python using reccursion
  
model


  ```python
def fibonacci(n):
    if n == 0:
        return 0
    elif n == 1:
        return 1
    else:
        return fibonacci(n-1) + fibonacci(n-2)


print(fibonacci(10))
```


## Step 3 - Load dataset for finetuning

### Lets Load the Dataset

For this tutorial, we will fine-tune Mistral 7B Instruct for code generation.

We will be using this [dataset](https://huggingface.co/datasets/TokenBender/code_instructions_122k_alpaca_style) which is curated by [TokenBender (e/xperiments)](https://twitter.com/4evaBehindSOTA) and is an excellent data source for fine-tuning models for code generation. It follows the alpaca style of instructions, which is an excellent starting point for this task. The dataset structure should resemble the following:

```json
{
  "instruction": "Create a function to calculate the sum of a sequence of integers.",
  "input": "[1, 2, 3, 4, 5]",
  "output": "# Python code def sum_sequence(sequence): sum = 0 for num in sequence: sum += num return sum"
}
```

In [None]:
from datasets import load_dataset

dataset = load_dataset("Clinton/Text-to-sql-v1", split="train")
dataset

Downloading readme:   0%|          | 0.00/118 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/635M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset({
    features: ['instruction', 'input', 'response', 'source', 'text'],
    num_rows: 262208
})

In [None]:
df = dataset.to_pandas()
df.head(10)

Unnamed: 0,instruction,input,response,source,text
0,Name the home team for carlton away team,CREATE TABLE table_name_77 (\n home_team VA...,SELECT home_team FROM table_name_77 WHERE away...,sql_create_context,Below are sql tables schemas paired with instr...
1,what will the population of Asia be when Latin...,"CREATE TABLE table_22767 (\n ""Year"" real,\n...","SELECT ""Asia"" FROM table_22767 WHERE ""Latin Am...",wikisql,Below are sql tables schemas paired with instr...
2,How many faculty members do we have for each g...,"CREATE TABLE Student (\n StuID INTEGER,\n ...","SELECT Sex, COUNT(*) FROM Faculty GROUP BY Sex...",nvbench,Below are sql tables schemas paired with instr...
3,List the record of 0-1 from the table?,CREATE TABLE table_14656147_2 (\n week VARC...,SELECT week FROM table_14656147_2 WHERE record...,sql_create_context,Below are sql tables schemas paired with instr...
4,"Which silver has a Gold smaller than 12, a Ran...",CREATE TABLE table_name_24 (\n silver VARCH...,SELECT silver FROM table_name_24 WHERE gold < ...,sql_create_context,Below are sql tables schemas paired with instr...
5,When did Samsung Electronics Co LTD make the G...,"CREATE TABLE table_47482 (\n ""Company name""...","SELECT ""Date"" FROM table_47482 WHERE ""Company ...",wikisql,Below are sql tables schemas paired with instr...
6,what are the early morning flights from BOSTON...,"CREATE TABLE time_interval (\n period text,...",SELECT DISTINCT flight.flight_id FROM airport_...,atis,Below are sql tables schemas paired with instr...
7,Name the most 3 credits,CREATE TABLE table_148535_2 (\n Id VARCHAR\n),SELECT MIN(3 AS _credits) FROM table_148535_2,sql_create_context,Below are sql tables schemas paired with instr...
8,What is every yellow jersey entry for the dist...,"CREATE TABLE table_3791 (\n ""Year"" text,\n ...","SELECT ""Yellow jersey"" FROM table_3791 WHERE ""...",wikisql,Below are sql tables schemas paired with instr...
9,"In what years was there a rank lower than 9, u...",CREATE TABLE table_name_63 (\n years VARCHA...,SELECT years FROM table_name_63 WHERE matches ...,sql_create_context,Below are sql tables schemas paired with instr...


In [None]:
df['text'][0]

'Below are sql tables schemas paired with instruction that describes a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables. ### Instruction: Name the home team for carlton away team ### Input: CREATE TABLE table_name_77 (\n    home_team VARCHAR,\n    away_team VARCHAR\n) ### Response: SELECT home_team FROM table_name_77 WHERE away_team = "carlton"'

Instruction Fintuning - Prepare the dataset under the format of "prompt" so the model can better understand :
1. the function generate_prompt : take the instruction and output and generate a prompt
2. shuffle the dataset
3. tokenizer the dataset

### Formatting the Dataset

Now, let's format the dataset in the required [gemma instruction formate](https://huggingface.co/google/gemma-7b-it).

> Many tutorials and blogs skip over this part, but I feel this is a really important step.

```
<start_of_turn>user What is your favorite condiment? <end_of_turn>
<start_of_turn>model Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavor to whatever I'm cooking up in the kitchen!<end_of_turn>
```

You can use the following code to process your dataset and create a JSONL file in the correct format:

In [None]:
def generate_prompt(data_point):
    """Gen. input text based on a prompt, task instruction, (context info.), and answer

    :param data_point: dict: Data point
    :return: dict: tokenzed prompt
    """
    prefix_text = 'Below are sql tables schemas paired with instruction that describes a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables\n\n'
    # Samples with additional context into.
    if data_point['input']:
        text = f"""<start_of_turn>user {prefix_text} {data_point["instruction"]} here are the inputs {data_point["input"]} <end_of_turn>\n<start_of_turn>model{data_point["response"]} <end_of_turn>"""
    # Without
    else:
        text = f"""<start_of_turn>user {prefix_text} {data_point["instruction"]} <end_of_turn>\n<start_of_turn>model{data_point["response"]} <end_of_turn>"""
    return text

# add the "prompt" column in the dataset
text_column = [generate_prompt(data_point) for data_point in dataset]
dataset = dataset.add_column("prompt", text_column)

We'll need to tokenize our data so the model can understand.


In [None]:
dataset = dataset.shuffle(seed=1234)  # Shuffle dataset here
dataset = dataset.map(lambda samples: tokenizer(samples["prompt"]), batched=True)

Map:   0%|          | 0/262208 [00:00<?, ? examples/s]

Split dataset into 90% for training and 10% for testing

In [None]:
dataset = dataset.train_test_split(test_size=0.2)
train_data = dataset["train"]
test_data = dataset["test"]

### After Formatting, We should get something like this

```json
{
"text":"<start_of_turn>user Create a function to calculate the sum of a sequence of integers. here are the inputs [1, 2, 3, 4, 5] <end_of_turn>
<start_of_turn>model # Python code def sum_sequence(sequence): sum = 0 for num in sequence: sum += num return sum <end_of_turn>",
"instruction":"Create a function to calculate the sum of a sequence of integers",
"input":"[1, 2, 3, 4, 5]",
"output":"# Python code def sum_sequence(sequence): sum = 0 for num in,
 sequence: sum += num return sum",
"prompt":"<start_of_turn>user Create a function to calculate the sum of a sequence of integers. here are the inputs [1, 2, 3, 4, 5] <end_of_turn>
<start_of_turn>model # Python code def sum_sequence(sequence): sum = 0 for num in sequence: sum += num return sum <end_of_turn>"

}
```

While using SFT (**[Supervised Fine-tuning Trainer](https://huggingface.co/docs/trl/main/en/sft_trainer)**) for fine-tuning, we will be only passing in the “text” column of the dataset for fine-tuning.

In [None]:
print(test_data)

Dataset({
    features: ['instruction', 'input', 'response', 'source', 'text', 'prompt', 'input_ids', 'attention_mask'],
    num_rows: 52442
})


## Step 4 - Apply Lora  
Here comes the magic with peft! Let's load a PeftModel and specify that we are going to use low-rank adapters (LoRA) using get_peft_model utility function and  the prepare_model_for_kbit_training method from PEFT.

In [None]:
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

In [None]:
print(model)

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaAttention(
          (q_proj): Linear4bit(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear4bit(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear4bit(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear4bit(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear4bit(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear4bit(in_features=2048, out_features=16384, bias=False)
          (down_proj): Linear4bit(in_features=16384, out_features=2048, bias=False)
          (act_fn): GELUActivation()
        )
        (input_layernorm): GemmaRMSNorm()
        (post_attention_layernorm): GemmaRMSNorm()
      )
 

In [None]:
import bitsandbytes as bnb
def find_all_linear_names(model):
  cls = bnb.nn.Linear4bit #if args.bits == 4 else (bnb.nn.Linear8bitLt if args.bits == 8 else torch.nn.Linear)
  lora_module_names = set()
  for name, module in model.named_modules():
    if isinstance(module, cls):
      names = name.split('.')
      lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    if 'lm_head' in lora_module_names: # needed for 16-bit
      lora_module_names.remove('lm_head')
  return list(lora_module_names)

In [None]:
modules = find_all_linear_names(model)
print(modules)

['q_proj', 'gate_proj', 'k_proj', 'down_proj', 'v_proj', 'o_proj', 'up_proj']


In [None]:
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=64,
    lora_alpha=32,
    target_modules=modules,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)

In [None]:
trainable, total = model.get_nb_trainable_parameters()
print(f"Trainable: {trainable} | total: {total} | Percentage: {trainable/total*100:.4f}%")

Trainable: 78446592 | total: 2584619008 | Percentage: 3.0351%


## Step 5 - Run the training!

Setting the training arguments:
* for the reason of demo, we just ran it for few steps (100) just to showcase how to use this integration with existing tools on the HF ecosystem.

In [None]:
# import transformers

# tokenizer.pad_token = tokenizer.eos_token


# trainer = transformers.Trainer(
#     model=model,
#     train_dataset=train_data,
#     eval_dataset=test_data,
#     args=transformers.TrainingArguments(
#         per_device_train_batch_size=1,
#         gradient_accumulation_steps=4,
#         warmup_steps=0.03,
#         max_steps=100,
#         learning_rate=2e-4,
#         fp16=True,
#         logging_steps=1,
#         output_dir="outputs_gemma_text2sql_finetuned_test",
#         optim="paged_adamw_8bit",
#         save_strategy="epoch",
#     ),
#     data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
# )


### Fine-Tuning with qLora and Supervised Fine-Tuning

We're ready to fine-tune our model using qLora. For this tutorial, we'll use the `SFTTrainer` from the `trl` library for supervised fine-tuning. Ensure that you've installed the `trl` library as mentioned in the prerequisites.

In [None]:
#new code using SFTTrainer
import transformers

from trl import SFTTrainer

tokenizer.pad_token = tokenizer.eos_token
torch.cuda.empty_cache()

trainer = SFTTrainer(
    model=model,
    train_dataset=train_data,
    eval_dataset=test_data,
    dataset_text_field="prompt",
    peft_config=lora_config,
    args=transformers.TrainingArguments(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        warmup_steps=0.03,
        max_steps=100,
        learning_rate=2e-4,
        logging_steps=1,
        output_dir="outputs",
        optim="paged_adamw_8bit",
        save_strategy="epoch",
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)



Map:   0%|          | 0/209766 [00:00<?, ? examples/s]

Map:   0%|          | 0/52442 [00:00<?, ? examples/s]



## Lets start training

In [None]:
model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
trainer.train()



Step,Training Loss
1,5.0645
2,4.5541
3,4.4269
4,3.1988
5,2.4931
6,1.7602
7,1.8814
8,1.6711
9,1.6665
10,1.5509


TrainOutput(global_step=100, training_loss=1.0868981036543846, metrics={'train_runtime': 1041.2109, 'train_samples_per_second': 0.384, 'train_steps_per_second': 0.096, 'total_flos': 1519980356517888.0, 'train_loss': 1.0868981036543846, 'epoch': 0.0})

 Share adapters on the 🤗 Hub

In [None]:
new_model = "/content/drive/MyDrive/gemma-text2sql-Instruct-Finetune-test_2" #Name of the model you will be pushing to huggingface model hub

In [None]:
trainer.model.save_pretrained(new_model)

In [None]:
base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.float16,
    device_map={"": 0},
)
merged_model= PeftModel.from_pretrained(base_model, new_model)
merged_model= merged_model.merge_and_unload()

# Save the merged model
merged_model.save_pretrained("/content/drive/MyDrive/merged_model_2",safe_serialization=True)
tokenizer.save_pretrained("/content/drive/MyDrive/merged_model_2")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
merged_model.push_to_hub("gemma-text2sql-Instruct-Finetune-test")
tokenizer.push_to_hub("gemma-text2sql-Instruct-Finetune-test")

## Test out Finetuned Model

In [None]:
!pip install chromadb



In [None]:
json_data =  {
        "text": "List all movie titles.",
        "orgId": "30092022-173259674-cc766ff9-031b-49c3-b58e-b615372fe654",
        "userId": "067-89c3-b66e-b3452ef41",
        "oldQuery": "",
        "tableMetadata": [
            {
                "DataTableId": "movie_table_id",
                "DataTableName": "movie",
                "TableDescription": "Table for movie information",
                "PrimaryKey": [
                    "movie_id"
                ],
                "ForeignKeys": [],
                "Columns": [
                    {
                        "Name": "movie_id",
                        "ColumnDescription": "Unique identifier for the movie.",
                        "ElementType": "SERIAL PRIMARY KEY"
                    },
                    {
                        "Name": "title",
                        "ColumnDescription": "Title of the movie.",
                        "ElementType": "VARCHAR(200) NOT NULL"
                    },
                    {
                        "Name": "release_date",
                        "ColumnDescription": "Release date of the movie.",
                        "ElementType": "DATE"
                    },
                    {
                        "Name": "genre",
                        "ColumnDescription": "Genre of the movie.",
                        "ElementType": "VARCHAR(100)"
                    },
                    {
                        "Name": "director",
                        "ColumnDescription": "Director of the movie.",
                        "ElementType": "VARCHAR(100)"
                    },
                    {
                        "Name": "rating",
                        "ColumnDescription": "Rating of the movie.",
                        "ElementType": "DECIMAL(3, 1)"
                    }
                ]
            },
            {
                "DataTableId": "actor_table_id",
                "DataTableName": "actor",
                "TableDescription": "Table for actor information",
                "PrimaryKey": [
                    "actor_id"
                ],
                "ForeignKeys": [],
                "Columns": [
                    {
                        "Name": "actor_id",
                        "ColumnDescription": "Unique identifier for the actor.",
                        "ElementType": "SERIAL PRIMARY KEY"
                    },
                    {
                        "Name": "first_name",
                        "ColumnDescription": "First name of the actor.",
                        "ElementType": "VARCHAR(50) NOT NULL"
                    },
                    {
                        "Name": "last_name",
                        "ColumnDescription": "Last name of the actor.",
                        "ElementType": "VARCHAR(50) NOT NULL"
                    },
                    {
                        "Name": "gender",
                        "ColumnDescription": "Gender of the actor.",
                        "ElementType": "VARCHAR(10)"
                    },
                    {
                        "Name": "date_of_birth",
                        "ColumnDescription": "Date of birth of the actor.",
                        "ElementType": "DATE"
                    }
                ]
            },
            {
                "DataTableId": "movie_actor_table_id",
                "DataTableName": "movie_actor",
                "TableDescription": "Table for movie-actor relationships",
                "PrimaryKey": [
                    "movie_id",
                    "actor_id"
                ],
                "ForeignKeys": [
                    "movie_id REFERENCES movie(movie_id)",
                    "actor_id REFERENCES actor(actor_id)"
                ],
                "Columns": [
                    {
                        "Name": "movie_id",
                        "ColumnDescription": "ID of the movie.",
                        "ElementType": "INT"
                    },
                    {
                        "Name": "actor_id",
                        "ColumnDescription": "ID of the actor.",
                        "ElementType": "INT"
                    }
                ]
            },
            {
                "DataTableId": "movie_review_table_id",
                "DataTableName": "movie_review",
                "TableDescription": "Table for movie reviews",
                "PrimaryKey": [
                    "review_id"
                ],
                "ForeignKeys": [
                    "movie_id REFERENCES movie(movie_id)"
                ],
                "Columns": [
                    {
                        "Name": "review_id",
                        "ColumnDescription": "Unique identifier for the review.",
                        "ElementType": "SERIAL PRIMARY KEY"
                    },
                    {
                        "Name": "movie_id",
                        "ColumnDescription": "ID of the movie.",
                        "ElementType": "INT"
                    },
                    {
                        "Name": "reviewer_name",
                        "ColumnDescription": "Name of the reviewer.",
                        "ElementType": "VARCHAR(100) NOT NULL"
                    },
                    {
                        "Name": "rating",
                        "ColumnDescription": "Rating given in the review.",
                        "ElementType": "DECIMAL(3, 1)"
                    },
                    {
                        "Name": "review_text",
                        "ColumnDescription": "Text of the review.",
                        "ElementType": "TEXT"
                    },
                    {
                        "Name": "review_date",
                        "ColumnDescription": "Date of the review.",
                        "ElementType": "DATE"
                    }
                ]
            },
            {
                "DataTableId": "movie_award_table_id",
                "DataTableName": "movie_award",
                "TableDescription": "Table for movie awards",
                "PrimaryKey": [
                    "award_id"
                ],
                "ForeignKeys": [
                    "movie_id REFERENCES movie(movie_id)"
                ],
                "Columns": [
                    {
                        "Name": "award_id",
                        "ColumnDescription": "Unique identifier for the award.",
                        "ElementType": "SERIAL PRIMARY KEY"
                    },
                    {
                        "Name": "movie_id",
                        "ColumnDescription": "ID of the movie.",
                        "ElementType": "INT"
                    },
                    {
                        "Name": "award_name",
                        "ColumnDescription": "Name of the award.",
                        "ElementType": "VARCHAR(200) NOT NULL"
                    },
                    {
                        "Name": "award_category",
                        "ColumnDescription": "Category of the award.",
                        "ElementType": "VARCHAR(100)"
                    },
                    {
                        "Name": "award_year",
                        "ColumnDescription": "Year the award was received.",
                        "ElementType": "INT"
                    }
                ]
            },
            {
                "DataTableId": "production_company_table_id",
                "DataTableName": "production_company",
                "TableDescription": "Table for movie production companies",
                "PrimaryKey": [
                    "company_id"
                ],
                "ForeignKeys": [],
                "Columns": [
                    {
                        "Name": "company_id",
                        "ColumnDescription": "Unique identifier for the production company.",
                        "ElementType": "SERIAL PRIMARY KEY"
                    },
                    {
                        "Name": "company_name",
                        "ColumnDescription": "Name of the production company.",
                        "ElementType": "VARCHAR(200) NOT NULL"
                    },
                    {
                        "Name": "headquarters_location",
                        "ColumnDescription": "Location of the production company's headquarters.",
                        "ElementType": "VARCHAR(200)"
                    },
                    {
                        "Name": "founding_date",
                        "ColumnDescription": "Date the production company was founded.",
                        "ElementType": "DATE"
                    }
                ]
            },
            {
                "DataTableId": "movie_production_company_table_id",
                "DataTableName": "movie_production_company",
                "TableDescription": "Table for movie-production company relationships",
                "PrimaryKey": [
                    "movie_id",
                    "company_id"
                ],
                "ForeignKeys": [
                    "movie_id REFERENCES movie(movie_id)",
                    "company_id REFERENCES production_company(company_id)"
                ],
                "Columns": [
                    {
                        "Name": "movie_id",
                        "ColumnDescription": "ID of the movie.",
                        "ElementType": "INT"
                    },
                    {
                        "Name": "company_id",
                        "ColumnDescription": "ID of the production company.",
                        "ElementType": "INT"
                    }
                ]
            },
            {
                "DataTableId": "movie_location_table_id",
                "DataTableName": "movie_location",
                "TableDescription": "Table for movie locations",
                "PrimaryKey": [
                    "location_id"
                ],
                "ForeignKeys": [
                    "movie_id REFERENCES movie(movie_id)"
                ],
                "Columns": [
                    {
                        "Name": "location_id",
                        "ColumnDescription": "Unique identifier for the location.",
                        "ElementType": "SERIAL PRIMARY KEY"
                    },
                    {
                        "Name": "movie_id",
                        "ColumnDescription": "ID of the movie.",
                        "ElementType": "INT"
                    },
                    {
                        "Name": "location_name",
                        "ColumnDescription": "Name of the location.",
                        "ElementType": "VARCHAR(200) NOT NULL"
                    },
                    {
                        "Name": "filming_date",
                        "ColumnDescription": "Date when filming took place at the location.",
                        "ElementType": "DATE"
                    }
                ]
            },
            {
                "DataTableId": "movie_genre_table_id",
                "DataTableName": "movie_genre",
                "TableDescription": "Table for movie genres",
                "PrimaryKey": [
                    "genre_id"
                ],
                "ForeignKeys": [
                    "movie_id REFERENCES movie(movie_id)"
                ],
                "Columns": [
                    {
                        "Name": "genre_id",
                        "ColumnDescription": "Unique identifier for the genre.",
                        "ElementType": "SERIAL PRIMARY KEY"
                    },
                    {
                        "Name": "movie_id",
                        "ColumnDescription": "ID of the movie.",
                        "ElementType": "INT"
                    },
                    {
                        "Name": "genre_name",
                        "ColumnDescription": "Name of the genre.",
                        "ElementType": "VARCHAR(100) NOT NULL"
                    }
                ]
            },
            {
                "DataTableId": "movie_language_table_id",
                "DataTableName": "movie_language",
                "TableDescription": "Table for movie languages",
                "PrimaryKey": [
                    "language_id"
                ],
                "ForeignKeys": [
                    "movie_id REFERENCES movie(movie_id)"
                ],
                "Columns": [
                    {
                        "Name": "language_id",
                        "ColumnDescription": "Unique identifier for the language.",
                        "ElementType": "SERIAL PRIMARY KEY"
                    },
                    {
                        "Name": "movie_id",
                        "ColumnDescription": "ID of the movie.",
                        "ElementType": "INT"
                    },
                    {
                        "Name": "language_name",
                        "ColumnDescription": "Name of the language.",
                        "ElementType": "VARCHAR(50) NOT NULL"
                    }
                ]
            },
            {
                "DataTableId": "movie_character_table_id",
                "DataTableName": "movie_character",
                "TableDescription": "Table for movie characters",
                "PrimaryKey": [
                    "character_id"
                ],
                "ForeignKeys": [
                    "movie_id REFERENCES movie(movie_id)",
                    "actor_id REFERENCES actor(actor_id)"
                ],
                "Columns": [
                    {
                        "Name": "character_id",
                        "ColumnDescription": "Unique identifier for the character.",
                        "ElementType": "SERIAL PRIMARY KEY"
                    },
                    {
                        "Name": "movie_id",
                        "ColumnDescription": "ID of the movie.",
                        "ElementType": "INT"
                    },
                    {
                        "Name": "character_name",
                        "ColumnDescription": "Name of the character.",
                        "ElementType": "VARCHAR(100) NOT NULL"
                    },
                    {
                        "Name": "actor_id",
                        "ColumnDescription": "ID of the actor who portrayed the character.",
                        "ElementType": "INT"
                    }
                ]
            },
            {
                "DataTableId": "user_table_id",
                "DataTableName": "user",
                "TableDescription": "Table for user information",
                "PrimaryKey": [
                    "user_id"
                ],
                "ForeignKeys": [],
                "Columns": [
                    {
                        "Name": "user_id",
                        "ColumnDescription": "Unique identifier for the user.",
                        "ElementType": "SERIAL PRIMARY KEY"
                    },
                    {
                        "Name": "username",
                        "ColumnDescription": "Username of the user.",
                        "ElementType": "VARCHAR(50) UNIQUE NOT NULL"
                    },
                    {
                        "Name": "email",
                        "ColumnDescription": "Email of the user.",
                        "ElementType": "VARCHAR(100) UNIQUE NOT NULL"
                    },
                    {
                        "Name": "password",
                        "ColumnDescription": "Password of the user.",
                        "ElementType": "VARCHAR(100) NOT NULL"
                    }
                ]
            },
            {
                "DataTableId": "user_review_table_id",
                "DataTableName": "user_review",
                "TableDescription": "Table for user reviews",
                "PrimaryKey": [
                    "user_review_id"
                ],
                "ForeignKeys": [
                    "user_id REFERENCES user(user_id)",
                    "movie_id REFERENCES movie(movie_id)"
                ],
                "Columns": [
                    {
                        "Name": "user_review_id",
                        "ColumnDescription": "Unique identifier for the user review.",
                        "ElementType": "SERIAL PRIMARY KEY"
                    },
                    {
                        "Name": "user_id",
                        "ColumnDescription": "ID of the user.",
                        "ElementType": "INT"
                    },
                    {
                        "Name": "movie_id",
                        "ColumnDescription": "ID of the movie.",
                        "ElementType": "INT"
                    },
                    {
                        "Name": "rating",
                        "ColumnDescription": "Rating given in the review.",
                        "ElementType": "DECIMAL(3, 1)"
                    },
                    {
                        "Name": "review_text",
                        "ColumnDescription": "Text of the review.",
                        "ElementType": "TEXT"
                    },
                    {
                        "Name": "review_date",
                        "ColumnDescription": "Date of the review.",
                        "ElementType": "DATE"
                    }
                ]
            },
            {
                "DataTableId": "user_rating_table_id",
                "DataTableName": "user_rating",
                "TableDescription": "Table for user ratings",
                "PrimaryKey": [
                    "user_rating_id"
                ],
                "ForeignKeys": [
                    "user_id REFERENCES user(user_id)",
                    "movie_id REFERENCES movie(movie_id)"
                ],
                "Columns": [
                    {
                        "Name": "user_rating_id",
                        "ColumnDescription": "Unique identifier for the user rating.",
                        "ElementType": "SERIAL PRIMARY KEY"
                    },
                    {
                        "Name": "user_id",
                        "ColumnDescription": "ID of the user.",
                        "ElementType": "INT"
                    },
                    {
                        "Name": "movie_id",
                        "ColumnDescription": "ID of the movie.",
                        "ElementType": "INT"
                    },
                    {
                        "Name": "rating",
                        "ColumnDescription": "Rating given by the user.",
                        "ElementType": "DECIMAL(3, 1)"
                    }
                ]
            },
            {
                "DataTableId": "user_subscription_table_id",
                "DataTableName": "user_subscription",
                "TableDescription": "Table for user subscriptions",
                "PrimaryKey": [
                    "subscription_id"
                ],
                "ForeignKeys": [
                    "user_id REFERENCES user(user_id)"
                ],
                "Columns": [
                    {
                        "Name": "subscription_id",
                        "ColumnDescription": "Unique identifier for the subscription.",
                        "ElementType": "SERIAL PRIMARY KEY"
                    },
                    {
                        "Name": "user_id",
                        "ColumnDescription": "ID of the user.",
                        "ElementType": "INT"
                    },
                    {
                        "Name": "start_date",
                        "ColumnDescription": "Start date of the subscription.",
                        "ElementType": "DATE"
                    },
                    {
                        "Name": "end_date",
                        "ColumnDescription": "End date of the subscription.",
                        "ElementType": "DATE"
                    }
                ]
            },
            {
                "DataTableId": "user_preference_table_id",
                "DataTableName": "user_preference",
                "TableDescription": "Table for user preferences",
                "PrimaryKey": [
                    "preference_id"
                ],
                "ForeignKeys": [
                    "user_id REFERENCES user(user_id)",
                    "genre_id REFERENCES movie_genre(genre_id)"
                ],
                "Columns": [
                    {
                        "Name": "preference_id",
                        "ColumnDescription": "Unique identifier for the user preference.",
                        "ElementType": "SERIAL PRIMARY KEY"
                    },
                    {
                        "Name": "user_id",
                        "ColumnDescription": "ID of the user.",
                        "ElementType": "INT"
                    },
                    {
                        "Name": "genre_id",
                        "ColumnDescription": "ID of the preferred genre.",
                        "ElementType": "INT"
                    }
                ]
            },
            {
                "DataTableId": "user_watchlist_table_id",
                "DataTableName": "user_watchlist",
                "TableDescription": "Table for user watchlist",
                "PrimaryKey": [
                    "watchlist_id"
                ],
                "ForeignKeys": [
                    "user_id REFERENCES user(user_id)",
                    "movie_id REFERENCES movie(movie_id)"
                ],
                "Columns": [
                    {
                        "Name": "watchlist_id",
                        "ColumnDescription": "Unique identifier for the watchlist item.",
                        "ElementType": "SERIAL PRIMARY KEY"
                    },
                    {
                        "Name": "user_id",
                        "ColumnDescription": "ID of the user.",
                        "ElementType": "INT"
                    },
                    {
                        "Name": "movie_id",
                        "ColumnDescription": "ID of the movie.",
                        "ElementType": "INT"
                    }
                ]
            },
            {
                "DataTableId": "user_history_table_id",
                "DataTableName": "user_history",
                "TableDescription": "Table for user history",
                "PrimaryKey": [
                    "history_id"
                ],
                "ForeignKeys": [
                    "user_id REFERENCES user(user_id)",
                    "movie_id REFERENCES movie(movie_id)"
                ],
                "Columns": [
                    {
                        "Name": "history_id",
                        "ColumnDescription": "Unique identifier for the user history entry.",
                        "ElementType": "SERIAL PRIMARY KEY"
                    },
                    {
                        "Name": "user_id",
                        "ColumnDescription": "ID of the user.",
                        "ElementType": "INT"
                    },
                    {
                        "Name": "movie_id",
                        "ColumnDescription": "ID of the movie.",
                        "ElementType": "INT"
                    },
                    {
                        "Name": "watch_date",
                        "ColumnDescription": "Date when the movie was watched.",
                        "ElementType": "DATE"
                    }
                ]
            },
            {
                "DataTableId": "award_category_table_id",
                "DataTableName": "award_category",
                "TableDescription": "Table for movie awards categories",
                "PrimaryKey": [
                    "category_id"
                ],
                "ForeignKeys": [],
                "Columns": [
                    {
                        "Name": "category_id",
                        "ColumnDescription": "Unique identifier for the award category.",
                        "ElementType": "SERIAL PRIMARY KEY"
                    },
                    {
                        "Name": "category_name",
                        "ColumnDescription": "Name of the award category.",
                        "ElementType": "VARCHAR(200) NOT NULL"
                    }
                ]
            },
            {
                "DataTableId": "award_nomination_table_id",
                "DataTableName": "award_nomination",
                "TableDescription": "Table for award nominations",
                "PrimaryKey": [
                    "nomination_id"
                ],
                "ForeignKeys": [
                    "award_id REFERENCES movie_award(award_id)",
                    "category_id REFERENCES award_category(category_id)",
                    "movie_id REFERENCES movie(movie_id)"
                ],
                "Columns": [
                    {
                        "Name": "nomination_id",
                        "ColumnDescription": "Unique identifier for the award nomination.",
                        "ElementType": "SERIAL PRIMARY KEY"
                    },
                    {
                        "Name": "award_id",
                        "ColumnDescription": "ID of the award.",
                        "ElementType": "INT"
                    },
                    {
                        "Name": "category_id",
                        "ColumnDescription": "ID of the award category.",
                        "ElementType": "INT"
                    },
                    {
                        "Name": "movie_id",
                        "ColumnDescription": "ID of the movie.",
                        "ElementType": "INT"
                    },
                    {
                        "Name": "nominee",
                        "ColumnDescription": "Name of the nominee.",
                        "ElementType": "VARCHAR(200) NOT NULL"
                    }
                ]
            }
        ]
}


In [None]:
import chromadb
def tables(json_data):
  if json_data["tableMetadata"] != []:
    create_table_list = []
    table_info_dict = {}
    if json_data["tableMetadata"] != []:
        for i in range(len(json_data["tableMetadata"])):
            table_json = json_data["tableMetadata"][i]
            DataTableName = table_json['DataTableName']
            columns = table_json["Columns"]
            TableDescription = table_json["TableDescription"]
            PRIMARY_KEY = table_json["PrimaryKey"]
            FOREIGN_KEY = table_json["ForeignKeys"]
            table_info_dict[DataTableName] = FOREIGN_KEY

        # Generate SQL statement to create the table
            create_table_sql = f"CREATE TABLE {DataTableName} COMMENT {TableDescription} IMPORTANT PRIMARY_KEY : {PRIMARY_KEY} , FOREIGN_KEY : {FOREIGN_KEY})("
            for column in columns:
                column_name = column["Name"]
                column_description = column["ColumnDescription"]
                element_type = column["ElementType"]
                create_table_sql += f"{column_name} COMMENT {column_description} {element_type}, "
            create_table_sql = create_table_sql.rstrip(", ") + ");"
            # print(create_table_sql)
            create_table_list.append(create_table_sql)

  return create_table_list,table_info_dict

def vector_db(create_table_list,user_question):
        client = chromadb.Client()

        documents = []
        metadatas = []
        ids = []

        for i, table in enumerate(create_table_list):
                # print(table)
                ids.append(str(i + 1))
                metadatas.append({'source': f"{table}"})
                documents.append(table)

        collection = client.get_or_create_collection(f"embeddings")

        collection.peek()
        if collection.peek() is None:
            collection.add(
                    documents= documents,
                metadatas=metadatas,
                    ids=ids
        )

        elif collection.peek() is not None:
            collection.upsert(
                    documents= documents,
                    metadatas=metadatas,
                    ids=ids
                )

        # else:
        #     # create_table_list.append(collection.get({[ "documents" ]}))
        #     collection = client.get_collection(f"embeddings_{userid}")

        similar_tables = collection.query(
            query_texts=[f"{user_question}"],
            n_results= 3
        )
        results = similar_tables['documents'][0]
        print(results)

        return results

def foreign_key_tables(table_schema,create_table_list,foreign_key_dict):

  foreign_key_tables = set()

  # Iterate through each table schema entry
  for schema_entry in table_schema:
      # Check if the entry contains any foreign key constraints
      foreign_key_tables.add(schema_entry)
      for table_name, foreign_keys in foreign_key_dict.items():
          # print(table_name, foreign_keys)
          if table_name in schema_entry:
              # Iterate through the foreign keys
              for fk in foreign_keys:
                  # print(fk)
                  table_name = fk.split("REFERENCES")[1].strip().split("(")[0].strip()
                  table_id = f"CREATE TABLE {table_name}"
                  # Find corresponding foreign key table in table list\
                  for table_entry in create_table_list:
                      if table_id in table_entry:
                          # print(fk,table_entry)
                          # Add the foreign key table to the foreign key tables set
                          foreign_key_tables.add(table_entry)

  return list(foreign_key_tables)

In [None]:
import os
print(os.listdir('/content/drive/MyDrive/merged_model'))


['config.json', 'generation_config.json', 'model-00002-of-00002.safetensors', 'model.safetensors.index.json', 'tokenizer_config.json', 'special_tokens_map.json', 'tokenizer.model', 'tokenizer.json']


In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

In [None]:
model_id = "gemma-text2sql-Instruct-Finetune-test"
# model_id = "google/gemma-2b"

model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0})
tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token=True)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/13.5k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/67.1M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/132 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/2.16k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/522 [00:00<?, ?B/s]

In [None]:
def get_completion(user_question,table_schema, model, tokenizer) -> str:
  device = "cuda:0"

  prompt_template = f"""
          <start_of_turn>user
            -- As a PostgreSQL SQL Database expert, you are required to provide answers to the questions based on the provided table schema ({table_schema}).
            -- Adhere strictly to the following constraints while generating SQL queries:

            -- Strict Constraints:
            -- STRICTLY only generate one query.
            -- 1. Thoroughly examine table schemas, including relationships, attributes, primary keys, and foreign keys.
            -- 3. Carefully observe what kind of join needed according to user question and table_schema provided.
            -- 4. Correctly join tables with appropriate column names.
            -- 5. Determine the type and number of JOIN queries needed based on user questions.
            -- 6. Correct any errors in old queries according to the user question and table schema.
            -- 7. End each SQL query with a semicolon.
            -- 8. STRICLTY Avoid using '*' in SELECT statements -- Very Important
            -- 9. Use only table and column names as provided in the table schema.
            -- 10. Ensure all queries are valid PostgreSQL SQL queries.
            -- 11. Provide responses only in the form of SQL queries.
            -- 12. STRICTLY Check the table names properly before giving the query please provide correct table name.
            -- 13. STRICTLY use JOIN only if required,
            -- 14. Don't perform unessasary JOINS

            -- Given user question: {user_question}

            -- Format of Answer:
            -- Postgres SQL
          <end_of_turn>\n<start_of_turn>model
          """
  # prompt = prompt_template.format(query=query)

  encodeds = tokenizer(prompt_template, return_tensors="pt", add_special_tokens=True)

  model_inputs = encodeds.to(device)


  generated_ids = model.generate(**model_inputs, max_new_tokens=1000, do_sample=True, pad_token_id=tokenizer.eos_token_id)
  # decoded = tokenizer.batch_decode(generated_ids)
  decoded = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
  return (decoded)

In [None]:
create_table_list,foreign_key_dict = tables(json_data)

In [None]:
user_question = "List all movie titles."
table_schema = vector_db(create_table_list,user_question)
foreign_key_tables_result = foreign_key_tables(table_schema,create_table_list,foreign_key_dict)
print(foreign_key_tables_result)

/root/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onnx.tar.gz: 100%|██████████| 79.3M/79.3M [00:06<00:00, 13.7MiB/s]


["CREATE TABLE movie COMMENT Table for movie information IMPORTANT PRIMARY_KEY : ['movie_id'] , FOREIGN_KEY : [])(movie_id COMMENT Unique identifier for the movie. SERIAL PRIMARY KEY, title COMMENT Title of the movie. VARCHAR(200) NOT NULL, release_date COMMENT Release date of the movie. DATE, genre COMMENT Genre of the movie. VARCHAR(100), director COMMENT Director of the movie. VARCHAR(100), rating COMMENT Rating of the movie. DECIMAL(3, 1));", "CREATE TABLE movie_genre COMMENT Table for movie genres IMPORTANT PRIMARY_KEY : ['genre_id'] , FOREIGN_KEY : ['movie_id REFERENCES movie(movie_id)'])(genre_id COMMENT Unique identifier for the genre. SERIAL PRIMARY KEY, movie_id COMMENT ID of the movie. INT, genre_name COMMENT Name of the genre. VARCHAR(100) NOT NULL);", "CREATE TABLE movie_character COMMENT Table for movie characters IMPORTANT PRIMARY_KEY : ['character_id'] , FOREIGN_KEY : ['movie_id REFERENCES movie(movie_id)', 'actor_id REFERENCES actor(actor_id)'])(character_id COMMENT Un

In [None]:
result = get_completion(user_question,table_schema, model, tokenizer)
print(result)

A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.



          user
            -- As a PostgreSQL SQL Database expert, you are required to provide answers to the questions based on the provided table schema (["CREATE TABLE movie COMMENT Table for movie information IMPORTANT PRIMARY_KEY : ['movie_id'] , FOREIGN_KEY : [])(movie_id COMMENT Unique identifier for the movie. SERIAL PRIMARY KEY, title COMMENT Title of the movie. VARCHAR(200) NOT NULL, release_date COMMENT Release date of the movie. DATE, genre COMMENT Genre of the movie. VARCHAR(100), director COMMENT Director of the movie. VARCHAR(100), rating COMMENT Rating of the movie. DECIMAL(3, 1));", "CREATE TABLE movie_genre COMMENT Table for movie genres IMPORTANT PRIMARY_KEY : ['genre_id'] , FOREIGN_KEY : ['movie_id REFERENCES movie(movie_id)'])(genre_id COMMENT Unique identifier for the genre. SERIAL PRIMARY KEY, movie_id COMMENT ID of the movie. INT, genre_name COMMENT Name of the genre. VARCHAR(100) NOT NULL);", "CREATE TABLE movie_character COMMENT Table for movie characters IMP