# Fine-Tuning

Fine-tuning is the process of expanding an LLM's knowledge base by training an existing model with additional samples at a reduced learning rate. The goal is to make the LLM "smarter" without impacting the knowledge instilled by the original training. Many foundation models from vendors such as OpenAI, Google, and Meta support fine-tuning. OpenAI models can be fine-tuned using the [fine-tuning API](https://platform.openai.com/docs/guides/fine-tuning), or they can be fine-tuned using the [fine-tuning dashboard](https://platform.openai.com/finetune). Let's demonstrate by fine-tuning `GPT-4o-mini` to generate SQL queries for a database. We'll use Microsoft's Northwind database, but we'll call it the "Wintellect" database so the LLM can't use anything it learned about Northwind during its training.

Begin by defining a function that uses `GPT-4o` to generate SQL queries for a series of questions. Later, we'll use these queries and the questions they were generated from to fine-tune `gpt-4o-mini`. The prompt here includes the database schema definition:

In [None]:
import re
from openai import OpenAI

OPENAI_API_KEY = 'OPENAI_API_KEY'

def text2sql(text):
    prompt = f'''
        Generate a well-formed SQLite query from the prompt below. Return
        the SQL only. Do not use markdown formatting, and do not use SELECT *.

        PROMPT: {text}
    
        The database targeted by the query is named Wintellect and it contains
        the following tables:

        CREATE TABLE [Categories]
        (
            [CategoryID] INTEGER PRIMARY KEY AUTOINCREMENT,
            [CategoryName] TEXT,
            [Description] TEXT
        )

        CREATE TABLE [Customers]
        (
            [CustomerID] TEXT,
            [CompanyName] TEXT,
            [ContactName] TEXT,
            [ContactTitle] TEXT,
            [Address] TEXT,
            [City] TEXT,
            [Region] TEXT,
            [PostalCode] TEXT,
            [Country] TEXT,
            [Phone] TEXT,
            [Fax] TEXT,
            PRIMARY KEY (`CustomerID`)
        )

        CREATE TABLE [Employees]
        (
            [EmployeeID] INTEGER PRIMARY KEY AUTOINCREMENT,
            [LastName] TEXT,
            [FirstName] TEXT,
            [Title] TEXT,
            [TitleOfCourtesy] TEXT,
            [BirthDate] DATE,
            [HireDate] DATE,
            [Address] TEXT,
            [City] TEXT,
            [Region] TEXT,
            [PostalCode] TEXT,
            [Country] TEXT,
            [HomePhone] TEXT,
            [Extension] TEXT,
            [Notes] TEXT,
            [ReportsTo] INTEGER,
            FOREIGN KEY ([ReportsTo]) REFERENCES [Employees] ([EmployeeID]) 
        )

        CREATE TABLE [Shippers]
        (
            [ShipperID] INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
            [CompanyName] TEXT NOT NULL,
            [Phone] TEXT
        )

        CREATE TABLE [Suppliers]
        (
            [SupplierID] INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
            [CompanyName] TEXT NOT NULL,
            [ContactName] TEXT,
            [ContactTitle] TEXT,
            [Address] TEXT,
            [City] TEXT,
            [Region] TEXT,
            [PostalCode] TEXT,
            [Country] TEXT,
            [Phone] TEXT,
            [Fax] TEXT,
            [HomePage] TEXT
        )

        CREATE TABLE [Products]
        (
            [ProductID] INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
            [ProductName] TEXT NOT NULL,
            [SupplierID] INTEGER,
            [CategoryID] INTEGER,
            [QuantityPerUnit] TEXT,
            [UnitPrice] NUMERIC DEFAULT 0,
            [UnitsInStock] INTEGER DEFAULT 0,
            [UnitsOnOrder] INTEGER DEFAULT 0,
            [ReorderLevel] INTEGER DEFAULT 0,
            [Discontinued] TEXT NOT NULL DEFAULT '0',
            FOREIGN KEY ([CategoryID]) REFERENCES [Categories] ([CategoryID]),
            FOREIGN KEY ([SupplierID]) REFERENCES [Suppliers] ([SupplierID])
        )

        CREATE TABLE [Orders]
        (
            [OrderID] INTEGER PRIMARY KEY AUTOINCREMENT,
            [CustomerID] INTEGER,
            [EmployeeID] INTEGER,
            [OrderDate] DATETIME,
            [ShipperID] INTEGER,
            FOREIGN KEY (EmployeeID) REFERENCES Employees (EmployeeID),
            FOREIGN KEY (CustomerID) REFERENCES Customers (CustomerID),
            FOREIGN KEY (ShipperID) REFERENCES Shippers (ShipperID)
        );

        CREATE TABLE [Order Details]
        (
            [OrderID] INTEGER NOT NULL,
            [ProductID] INTEGER NOT NULL,
            [UnitPrice] NUMERIC NOT NULL DEFAULT 0,
            [Quantity] INTEGER NOT NULL DEFAULT 1,
            [Discount] REAL NOT NULL DEFAULT 0,
            PRIMARY KEY ("OrderID", "ProductID"),
            FOREIGN KEY ([OrderID]) REFERENCES [Orders] ([OrderID]),
            FOREIGN KEY ([ProductID]) REFERENCES [Products] ([ProductID]) 
        )
        '''

    messages = [
        {
            'role': 'system',
            'content': 'You are a database expert who can convert questions into SQL queries'
        },
        {
            'role': 'user',
            'content': prompt
        }
    ]

    client = OpenAI(api_key=OPENAI_API_KEY)
    
    response = client.chat.completions.create(
        model='gpt-4o',
        messages=messages,
        temperature=0
    )

    sql = response.choices[0].message.content

    # Strip markdown characters if present
    pattern = r'^```[\w]*\n|\n```$'
    return re.sub(pattern, '', sql, flags=re.MULTILINE)

Use the function to generate queries from 20 questions:

In [None]:
questions = [
    "How many employees does Wintellect have?",
    "What are the employees' names?",
    "How long has Nancy Davolio worked at Wintellect?",
    "List products have been discontinued but are currently in stock.",
    "Which products have been discontinued but are currently in stock, how many of each is in stock, and what is the combined value of those products?",
    "List all products that are currently out of stock and the suppliers for those products.",
    "What are Wintellect's most popular products, and how many of each have been sold?",
    "How many orders were placed in 1997?",
    "Which shipper or shippers delivered ikura in 1997?",
    "How many suppliers does Wintellect have, and what are their names?",
    "What countries does Wintellect have customers in, and which country has received the most shipments?",
    "Who were our top 5 customers in 1997 by volume?",
    "What is Wintellect's most expensive product?",
    "List products for which the number on order is greater than the number in stock.",
    "Which customer has spent more money than any other, where are they located, and how much have they spent?",
    "Which employee generated the most revenue in 1997?",
    "List products that are in stock but for which there are no orders.",
    "What country ordered the most aniseed syrup?",
    "Who were our top 5 customers in 1997 by revenue?",
    "Which employee had the most sales in the first half of 1997?"
]

# Generate the queries
queries = [text2sql(question) for question in questions]

# Show the results
for i, question in enumerate(questions):
    print(f'\x1b[31m{question}\x1b[0m')
    print(queries[i])
    print('-' * 40)

[31mHow many employees does Wintellect have?[0m
SELECT COUNT(EmployeeID) AS NumberOfEmployees FROM Employees;
----------------------------------------
[31mWhat are the employees' names?[0m
SELECT FirstName, LastName FROM Employees;
----------------------------------------
[31mHow long has Nancy Davolio worked at Wintellect?[0m
SELECT julianday('now') - julianday(HireDate) AS DaysWorked
FROM Employees
WHERE FirstName = 'Nancy' AND LastName = 'Davolio';
----------------------------------------
[31mList products have been discontinued but are currently in stock.[0m
SELECT ProductID, ProductName, SupplierID, CategoryID, QuantityPerUnit, UnitPrice, UnitsInStock, UnitsOnOrder, ReorderLevel
FROM Products
WHERE Discontinued = '1' AND UnitsInStock > 0;
----------------------------------------
[31mWhich products have been discontinued but are currently in stock, how many of each is in stock, and what is the combined value of those the products?[0m
SELECT 
    ProductName, 
    UnitsInS

The results are accurate because the prompt included details about the database schema. Of course, transmitting the schema definition in every request increases cost and latency. Can we do something about that using fine tuning?

## Test `GPT-4o-mini`'s ability to generate queries unaided

Before we do any fine-tuning, let's test `GPT-4o-mini`'s ability to generate queries without knowledge of the database schema. Define a function that takes a question as input and returns a SQL query as output:

In [148]:
def text2sqltest(text):
    prompt = f'''
        Generate a well-formed SQLite query targeting the Wintellect database
        from the prompt below. Return the SQL only. Do not use markdown formatting,
        and do not use SELECT *.

        PROMPT: {text}
        '''

    messages = [
        {
            'role': 'system',
            'content': 'You are a database expert who can convert questions into SQL queries'
        },
        {
            'role': 'user',
            'content': prompt
        }
    ]

    client = OpenAI(api_key=OPENAI_API_KEY)
    
    response = client.chat.completions.create(
        model='gpt-4o-mini',
        messages=messages,
        temperature=0
    )

    sql = response.choices[0].message.content

    # Strip markdown characters if present
    pattern = r'^```[\w]*\n|\n```$'
    return re.sub(pattern, '', sql, flags=re.MULTILINE)

Now run five test questions through the function:

In [149]:
test_questions = [
    "Show all products that are out of stock and how many of each are currently on order.",
    "Create a report that shows the customers from each city that has employees in it.",
    "List products for which the number on order exceeds the number currently in stock.",
    "Which employee generated the least revenue in 1997?",
    "Which customer ordered the most tofu?",
]

# Generate the queries
test_queries = [text2sqltest(question) for question in test_questions]

# Show the results
for i, question in enumerate(test_questions):
    print(f'\x1b[31m{question}\x1b[0m')
    print(test_queries[i])
    print('-' * 40)

[31mShow all products that are out of stock and how many of each are currently on order.[0m
SELECT p.product_id, p.product_name, o.quantity_on_order 
FROM products p 
JOIN orders o ON p.product_id = o.product_id 
WHERE p.stock_quantity = 0;
----------------------------------------
[31mCreate a report that shows the customers from each city that has employees in it.[0m
SELECT DISTINCT c.CustomerID, c.CustomerName, c.City
FROM Customers c
JOIN Employees e ON c.City = e.City;
----------------------------------------
[31mList products for which the number on order exceeds the number currently in stock.[0m
SELECT ProductID, ProductName, NumberOnOrder, NumberInStock 
FROM Products 
WHERE NumberOnOrder > NumberInStock;
----------------------------------------
[31mWhich employee generated the least revenue in 1997?[0m
SELECT employee_id, SUM(revenue) AS total_revenue 
FROM sales 
WHERE strftime('%Y', sale_date) = '1997' 
GROUP BY employee_id 
ORDER BY total_revenue ASC 
LIMIT 1;
------

The results are poor because `GPT-4o-mini` had to guess the database schema. Let's see if we can fix that with fine-tuning.

## Fine-tune `GPT-4o-mini`

The first step in fine-tuning `GPT-4o-mini` is to save the 20 questions and the SQL queries generated from them by `GPT-4o` in a JSONL file:

In [150]:
import json

prompt = '''
Generate a well-formed SQLite query targeting the Wintellect database
from the prompt below. Return the SQL only. Do not use markdown formatting,
and do not use SELECT *.

PROMPT: {text}
'''

lines = []

for i, question in enumerate(questions):
    messages = []
    messages.append({ "role": "system", "content": "You are a database expert who can convert questions into SQL queries" })
    messages.append({ "role": "user", "content": f"{prompt.format(text=question)}" })
    messages.append({ "role": "assistant", "content": f"{queries[i]}" })
    lines.append({ "messages": messages })

with open('Data/training_data.jsonl', 'w', encoding='utf-8') as file:
    for line in lines:
        json_line = json.dumps(line, ensure_ascii=False)
        file.write(json_line + '\n')

Show the contents of the file:

In [151]:
with open('Data/training_data.jsonl', 'r') as file:
    for line in file:
        print(line)

{"messages": [{"role": "system", "content": "You are a database expert who can convert questions into SQL queries"}, {"role": "user", "content": "\nGenerate a well-formed SQLite query targeting the Wintellect database\nfrom the prompt below. Return the SQL only. Do not use markdown formatting,\nand do not use SELECT *.\n\nPROMPT: How many employees does Wintellect have?\n"}, {"role": "assistant", "content": "SELECT COUNT(EmployeeID) AS NumberOfEmployees FROM Employees;"}]}

{"messages": [{"role": "system", "content": "You are a database expert who can convert questions into SQL queries"}, {"role": "user", "content": "\nGenerate a well-formed SQLite query targeting the Wintellect database\nfrom the prompt below. Return the SQL only. Do not use markdown formatting,\nand do not use SELECT *.\n\nPROMPT: What are the employees' names?\n"}, {"role": "assistant", "content": "SELECT FirstName, LastName FROM Employees;"}]}

{"messages": [{"role": "system", "content": "You are a database expert 

Upload the file:

In [152]:
client = OpenAI(api_key=OPENAI_API_KEY)

train_file = client.files.create(
    file=open('Data/training_data.jsonl', 'rb'),
    purpose='fine-tune'
)

Start the fine-tuning job. It typically takes a few minutes to complete:

In [153]:
job = client.fine_tuning.jobs.create(
    training_file=train_file.id, 
    model='gpt-4o-mini-2024-07-18', 
    hyperparameters={
        'n_epochs': 5,
        'batch_size': 5,
        'learning_rate_multiplier': 0.2
    }
)

print(job)

FineTuningJob(id='ftjob-yRXkJ0il0gzHj5WxpojAVrF7', created_at=1730660011, error=Error(code=None, message=None, param=None), fine_tuned_model=None, finished_at=None, hyperparameters=Hyperparameters(n_epochs=5, batch_size=5, learning_rate_multiplier=0.2), model='gpt-4o-mini-2024-07-18', object='fine_tuning.job', organization_id='org-PmH6Y7Qm7c5qpbbHU5RIsmQ3', result_files=[], seed=939455719, status='validating_files', trained_tokens=None, training_file='file-yZCUu4G3y8uHnFI30uhFSwCR', validation_file=None, estimated_finish=None, integrations=[], user_provided_suffix=None)


Either monitor the fine-tuning job in the portal, or wait until the job status changes to "succeeded" or "failed." Again, it will typically take a few minutes for the job to complete.

In [154]:
client.fine_tuning.jobs.retrieve(job.id).status

'succeeded'

Once fine-tuning is complete, get the model name:

In [155]:
model = client.fine_tuning.jobs.retrieve(job.id).fine_tuned_model
print(model)

ft:gpt-4o-mini-2024-07-18:personal::APa5THFy


## Test the fine-tuned model's ability to generate queries

Can the fine-tuned model generate syntactially queries when the database schema isn't inlcuded in the prompt? Rewrite the `text2sqltest` function to use the fine-tuned model rather than `GPT-4o-mini`:

In [156]:
def text2sqltest(text):
    prompt = f'''
        Generate a well-formed SQLite query targeting the Wintellect database
        from the prompt below. Return the SQL only. Do not use markdown formatting,
        and do not use SELECT *.

        PROMPT: {text}
        '''

    messages = [
        {
            'role': 'system',
            'content': 'You are a database expert who can convert questions into SQL queries'
        },
        {
            'role': 'user',
            'content': prompt
        }
    ]

    client = OpenAI(api_key=OPENAI_API_KEY)
    
    response = client.chat.completions.create(
        model=model, # Fine-tuned model
        messages=messages,
        temperature=0
    )

    sql = response.choices[0].message.content

    # Strip markdown characters if present
    pattern = r'^```[\w]*\n|\n```$'
    return re.sub(pattern, '', sql, flags=re.MULTILINE)

Now use the fine-tuned model to generate queries from the test questions:

In [157]:
# Generate the queries
test_queries = [text2sqltest(question) for question in test_questions]

# Show the results
for i, question in enumerate(test_questions):
    print(f'\x1b[31m{question}\x1b[0m')
    print(test_queries[i])
    print('-' * 40)

[31mShow all products that are out of stock and how many of each are currently on order.[0m
SELECT p.ProductID, p.ProductName, p.UnitsOnOrder
FROM Products p
WHERE p.UnitsInStock = 0;
----------------------------------------
[31mCreate a report that shows the customers from each city that has employees in it.[0m
SELECT DISTINCT c.CustomerID, c.CompanyName, c.ContactName, c.City
FROM Customers c
JOIN Employees e ON c.City = e.City;
----------------------------------------
[31mList products for which the number on order exceeds the number currently in stock.[0m
SELECT ProductName, UnitsOnOrder, UnitsInStock 
FROM Products 
WHERE UnitsOnOrder > UnitsInStock;
----------------------------------------
[31mWhich employee generated the least revenue in 1997?[0m
SELECT e.EmployeeID, e.FirstName, e.LastName, SUM(od.UnitPrice * od.Quantity * (1 - od.Discount)) AS TotalRevenue
FROM Employees e
JOIN Orders o ON e.EmployeeID = o.EmployeeID
JOIN [Order Details] od ON o.OrderID = od.OrderID
WH

The results are much better this time because knowledge of the database schema is "baked in" to the fine-tuned model.