# Finetuning LLaMa + Text-to-SQL 

This walkthrough shows you how to fine-tune LLaMa-7B on a Text-to-SQL dataset, and then use it for inference against
any database of structured data using LlamaIndex.

**NOTE**: This code is taken and adapted from Modal's `doppel-bot` repo: https://github.com/modal-labs/doppel-bot.
**NOTE**: A lot of the code is contained in the underlying Python scripts in the `src` directory. We definitely encourage you to go and take a look!

### Setup

NOTE: you will need to setup a Modal account + token in order to use this notebook.

In [None]:
!pip install -r requirements.txt

### Load Training Data for Finetuning LLaMa

We load data from `b-mc2/sql-create-context` on Hugging Face: https://huggingface.co/datasets/b-mc2/sql-create-context.

This dataset consists of tuples of natural language queries, create table statements, and ground-truth SQL queries. This is the dataset that we use to finetune our SQL model.

In [8]:
data_dir = "data_sql"

!modal run src.load_data_sql --data-dir {data_dir}

[2K[32m✓[0m Initialized. [38;5;249mView app at [0m[4;38;5;249mhttps://modal.com/apps/ap-cD5MNkDR86bbjkHJ86BAdo[0m
[2K[34m⠋[0m Initializing...
[2K[34m⠹[0m Creating objects...objects...
[38;5;244m├── [0m[34m⠋[0m Creating download_models...
[38;5;244m└── [0m[34m⠋[0m Creating mount /Users/jerryliu/Programming/modal_finetune_sql/src: 
[2K[1A[2K[1A[2K[1A[2K[34m⠴[0m Creating objects...
[38;5;244m├── [0m[34m⠸[0m Creating download_models...
[2K[1A[2K[1A[2K[34m⠏[0m Creating objects...s/jerryliu/Programming/modal_finetune_sql/src
[38;5;244m├── [0m[32m🔨[0m Created download_models.
[2K[1A[2K[1A[2K[34m⠹[0m Creating objects...s/jerryliu/Programming/modal_finetune_sql/src
[38;5;244m├── [0m[32m🔨[0m Created download_models.
[2K[1A[2K[1A[2K[34m⠴[0m Creating objects...s/jerryliu/Programming/modal_finetune_sql/src
[38;5;244m├── [0m[32m🔨[0m Created download_models.
[38;5;244m├── [0m[32m🔨[0m Created mount /Users/jerryliu/Programming/modal

### Run Finetuning Script

We run our finetuning script on the loaded dataset.
The finetuning script contains the following components:
- We split the dataset into training and validation splits.
- We format each split into input/output tuples of token id's. This means that the labels are the same as inputs (loss signal is measured on full input, not just on the generated portion). 
- We use `LoraConfig` from `peft` for efficient fine-tuning.
- We use `transformers.Trainer` to actually run the training process.
- If a valid `WANDB_PROJECT` is specified, along with the relevant secret in Modal, then we will log results to wandb.

We use Modal to spin up an A100 to run our finetuning code. 

In [None]:
data_dir = "data_sql"

!modal run src.finetune_sql --data-dir {data_dir}

### Evaluation

We provide a basic evaluation script over sample data from `sql-create-context` so that you can see for yourself how well the finetuned model performs vs. the baseline model.

In [17]:
!modal run src.eval_sql::main

[2K[32m✓[0m Initialized. [38;5;249mView app at [0m[4;38;5;249mhttps://modal.com/apps/ap-JS10MTVAhat3jK4K8QwKyv[0m
[2K[34m⠋[0m Initializing...
[2K[34m⠼[0m Creating objects...objects...
[38;5;244m├── [0m[34m⠋[0m Creating download_models...
[38;5;244m└── [0m[34m⠋[0m Creating mount /Users/jerryliu/Programming/modal_finetune_sql/src: 
[2K[1A[2K[1A[2K[1A[2K[34m⠧[0m Creating objects...
[38;5;244m├── [0m[34m⠸[0m Creating download_models...
[38;5;244m└── [0m[34m⠸[0m Creating mount /Users/jerryliu/Programming/modal_finetune_sql/src: 
[2K[1A[2K[1A[2K[1A[2K[34m⠋[0m Creating objects...
[38;5;244m├── [0m[34m⠦[0m Creating download_models...
[38;5;244m└── [0m[34m⠦[0m Creating mount /Users/jerryliu/Programming/modal_finetune_sql/src: 
[2K[1A[2K[1A[2K[1A[2K[34m⠸[0m Creating objects...
[38;5;244m├── [0m[34m⠏[0m Creating download_models...
[38;5;244m└── [0m[34m⠏[0m Creating mount /Users/jerryliu/Programming/modal_finetune_sql/src: 
[

### Integrate Model with LlamaIndex

Now that the model is finetuned, the checkpoints and model binary are stored in a model directory (by default it is in `/vol/data_sql`).

We can now use this model in LlamaIndex for text-to-SQL applications.

Specifically, we provide an interface allowing users to define any `sqlite` data file, and then they can run queries over this data file. We first create and dump a sample `cities.db` file containing (city, population, country) tuples. We then run inference over this file.

#### Create sample db

In [9]:
# create sample 
from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    select,
    column,
)

In [10]:
db_file = "cities.db"
engine = create_engine(f"sqlite:///{db_file}")
metadata_obj = MetaData()

In [11]:
# create city SQL table
table_name = "city_stats"
city_stats_table = Table(
    table_name,
    metadata_obj,
    Column("city_name", String(16), primary_key=True),
    Column("population", Integer),
    Column("country", String(16), nullable=False),
)
metadata_obj.create_all(engine)

In [12]:
# insert sample rows
from sqlalchemy import insert

rows = [
    {"city_name": "Toronto", "population": 2930000, "country": "Canada"},
    {"city_name": "Tokyo", "population": 13960000, "country": "Japan"},
    {"city_name": "Chicago", "population": 2679000, "country": "United States"},
    {"city_name": "Seoul", "population": 9776000, "country": "South Korea"},
]
for row in rows:
    stmt = insert(city_stats_table).values(**row)
    with engine.connect() as connection:
        cursor = connection.execute(stmt)
        connection.commit()

#### Run Inference

In [15]:
query = "Which city has the highest population?"

!modal run src.inference_sql_llamaindex::main --query '{query}' --sqlite-file-path {db_file} --model-dir "data_sql" --use-finetuned-model True

[2K[32m✓[0m Initialized. [38;5;249mView app at [0m[4;38;5;249mhttps://modal.com/apps/ap-BjYd5um2Efz70MftYxlyzE[0m
[2K[34m⠋[0m Initializing...
[2K[34m⠹[0m Creating objects...objects...
[38;5;244m├── [0m[34m⠋[0m Creating download_models...
[38;5;244m└── [0m[34m⠋[0m Creating mount /Users/jerryliu/Programming/modal_finetune_sql/src: 
[2K[1A[2K[1A[2K[1A[2K[34m⠴[0m Creating objects...
[38;5;244m├── [0m[34m⠸[0m Creating download_models...
[38;5;244m└── [0m[34m⠸[0m Creating mount /Users/jerryliu/Programming/modal_finetune_sql/src: 
[2K[1A[2K[1A[2K[1A[2K[34m⠏[0m Creating objects...
[38;5;244m├── [0m[34m⠦[0m Creating download_models...
[2K[1A[2K[1A[2K[34m⠹[0m Creating objects...s/jerryliu/Programming/modal_finetune_sql/src
[38;5;244m├── [0m[32m🔨[0m Created download_models.
[2K[1A[2K[1A[2K[34m⠴[0m Creating objects...s/jerryliu/Programming/modal_finetune_sql/src
[38;5;244m├── [0m[32m🔨[0m Created download_models.
[2K[1A[2K

In [16]:
# you can also choose to run the original (nonfinetuned model) to compare results
# note: it throws an error 

# use non-finetuned model
!modal run src.inference_sql_llamaindex::main --query '{query}' --sqlite-file-path {db_file} --model-dir "data_sql" --use-finetuned-model False

[2K[32m✓[0m Initialized. [38;5;249mView app at [0m[4;38;5;249mhttps://modal.com/apps/ap-1n1QmIubsUDYdMnVNCTGzh[0m
[2K[34m⠋[0m Initializing...
[2K[34m⠹[0m Creating objects...objects...
[38;5;244m├── [0m[34m⠋[0m Creating download_models...
[38;5;244m└── [0m[34m⠋[0m Creating mount /Users/jerryliu/Programming/modal_finetune_sql/src: 
[2K[1A[2K[1A[2K[1A[2K[34m⠴[0m Creating objects...
[38;5;244m├── [0m[34m⠸[0m Creating download_models...
[38;5;244m└── [0m[34m⠸[0m Creating mount /Users/jerryliu/Programming/modal_finetune_sql/src: 
[2K[1A[2K[1A[2K[1A[2K[34m⠏[0m Creating objects...
[38;5;244m├── [0m[32m🔨[0m Created download_models.
[2K[1A[2K[1A[2K[34m⠹[0m Creating objects...s/jerryliu/Programming/modal_finetune_sql/src
[38;5;244m├── [0m[32m🔨[0m Created download_models.
[2K[1A[2K[1A[2K[34m⠴[0m Creating objects...s/jerryliu/Programming/modal_finetune_sql/src
[38;5;244m├── [0m[32m🔨[0m Created download_models.
[2K[1A[2K[1A

### (Optional) Download Model

If you want to download the model weights for your own use, just run the following script.

In [None]:
from src.download_weights import main

main("out_model", model_dir="data_sql")