<div align="center">
    <img alt="Institut Polytechnique de Paris Logo" width="auto" height="150px" src="https://www.ip-paris.fr/sites/default/files/presse/Charte%20Graphique/Logo%20IP%20Paris%206%20%C3%A9coles%20vertical%20png.png" />
</div>

<div style="text-align: center; font-family: Arial, sans-serif; margin: 20px 0;">
    <h1 style="font-size: 32px; margin-bottom: 10px;">Lab 4: Text to SQL</h1>
    <p style="font-size: 16px; margin: 0;">Authors: 
        <strong>Tim Luka Horstmann</strong> & <strong>William Liaw</strong>
    </p>
</div>

---
**Tim Luka Horstmann**

**Email:** [tim.horstmann@ip-paris.fr](mailto:tim.horstmann@ip-paris.fr)  
**Website:** [horstmann.tech](https://horstmann.tech)


**William Liaw**

**Email:** [william.liaw@telecom-paris.fr](mailto:william.liaw@ip-paris.fr)

In this lab, we'll apply what we've already learned in the previous labs to build a sqlite database using Ministral 8b. This lab is less guided than the previous ones and you'll need to refer to what you've done previously to complete each part. Moreover, this lab is more focused on prompt engineering and you have to find the best prompt and prompt strategy (system prompt? temperature value? dialog prompt style?).

For this lab, we need to use sqlite3 to execute the generated queries.
Check the doc: https://docs.python.org/3/library/sqlite3.html

<font color='red'>BE CAREFUL: you need to generate sql queries then automaticly exectute them with sqlite3 connector. DO NOT generate python code. DO NOT copy paste genereted query to the connector.</font>

<font color='green'>TIPS: sqlite3 create a file containing your db. Delete it if you need to reset the db.</font>



Lab overview:

0. Modules installation and model loading.
1. Create tables using llm.
2. Populate tables using llm.
3. Explore our tables using llm.
4. More than one table with llm.

IMPORTANT:
- You must work in pairs. You must submit **ONLY ONE NOTEBOOK** for each pair.
- Do not share your work with other pairs.
- You should not use Copilot, ChatGPT or similar tools. At the very least, remove the prompt ...

<div style="
    background-color: #2C3E50; 
    color: #ECF0F1; 
    font-size: 28px; 
    text-align: center; 
    padding: 20px; 
    border-radius: 15px; 
    box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.2); 
    width: 100%; 
    margin: auto; 
    font-family: Times New Roman, sans-serif;">
    0. Setup

</div>

In [1]:
# !pip install transformers datasets bitsandbytes accelerate

In [2]:
import os
import re
import sqlite3
from itertools import islice

import torch
from tqdm.notebook import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          BitsAndBytesConfig, GenerationConfig)

In [3]:
def read_hf_token(file_path):
    """Reads the Hugging Face token from a file."""
    try:
        with open(file_path, "r") as file:
            token = file.read().strip()
            if not token:
                raise ValueError("The token file is empty.")
            return token
    except FileNotFoundError:
        raise FileNotFoundError(f"Token file not found at: {file_path}")
    except Exception as e:
        raise RuntimeError(f"An error occurred while reading the token: {e}")

In [4]:
# File path to the token
hf_token_file = "token.txt"

# Read the token from the file
hf_token = read_hf_token(hf_token_file)
llm_name = "mistralai/Ministral-8B-Instruct-2410"

# We want to use 4bit quantization to save memory
quantization_config = BitsAndBytesConfig(load_in_8bit=False, load_in_4bit=True)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    llm_name, padding_side="left", token=hf_token, cache_dir="/Data/llm"
)

# Prevent some transformers specific issues.
tokenizer.use_default_system_prompt = False
tokenizer.pad_token_id = tokenizer.eos_token_id

# Load LLM.
llm = AutoModelForCausalLM.from_pretrained(
    llm_name,
    quantization_config=quantization_config,
    device_map={"": 0},  # load all the model layers on GPU 0
    torch_dtype=torch.bfloat16,  # float precision
    token=hf_token,
    cache_dir="/Data/llm",
)

# Set LLM on eval mode.
llm.eval()

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(131072, 4096)
    (layers): ModuleList(
      (0-35): 36 x MistralDecoderLayer(
        (self_attn): MistralSdpaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=12288, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=12288, bias=False)
          (down_proj): Linear4bit(in_features=12288, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): MistralRMSN

#### Definition of different generation configurations

In [6]:
generation_create_config = GenerationConfig(
    do_sample=False,
    max_new_tokens=256,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.pad_token_id,
)

generation_insert_config = GenerationConfig(
    do_sample=True,
    max_new_tokens=512,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.pad_token_id,
)

generation_long_insert_config = GenerationConfig(
    do_sample=True,
    max_new_tokens=1024,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.pad_token_id,
)

generation_insert_config_more_diverse = GenerationConfig(
    do_sample=True,
    max_new_tokens=512,
    temperature=0.7,
    top_p=0.9,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.pad_token_id,
)

In [7]:
def generate_batch(prompts, llm=llm, generation_config=generation_create_config):
    """
    Generates responses for a batch of prompts using the LLM.
    """
    # Create turns for all prompts
    turns = [[{"role": "user", "content": prompt}] for prompt in prompts]

    # Tokenize batch of turns
    input_ids = tokenizer.apply_chat_template(
        turns, return_tensors="pt", padding=True
    ).to("cuda")

    # Ensure we don't use gradients to save memory
    with torch.no_grad():
        outputs = llm.generate(input_ids, generation_config)

    # Recover and decode answers
    answer_tokens = [output[input_ids.shape[1] : -1] for output in outputs]
    return [tokenizer.decode(tokens).strip() for tokens in answer_tokens]


def generate(prompts, llm=llm, generation_config=generation_create_config):
    """
    Generates response(s) for one or more prompts.
    Automatically calls `generate_batch()` if multiple prompts are provided.
    """
    if isinstance(prompts, list):
        return generate_batch(prompts, llm=llm, generation_config=generation_config)
    return generate_batch([prompts], llm=llm, generation_config=generation_config)[0]


def extract_sql(query, mode):
    """
    Extracts SQL query from the generated text, ensuring that only a valid SQL statement is returned.
    This prevents issues caused by hallucinated explanations or Markdown artifacts.
    """
    sql_query = None

    # Remove Markdown code block formatting (e.g., ```sql ... ```)
    query = re.sub(r"```sql|```", "", query, flags=re.MULTILINE).strip()

    # Extract SQL statement based on mode
    if mode == "CREATE":
        sql_query = re.search(
            r"CREATE TABLE\s+\w+\s*\(.*?\);", query, re.DOTALL | re.IGNORECASE
        )
    elif mode == "INSERT":
        sql_query = re.search(
            r"INSERT INTO\s+\w+\s*\(.*?\)\s*VALUES\s*\(.*?\);",
            query,
            re.DOTALL | re.IGNORECASE,
        )
    elif mode == "SELECT":
        sql_query = re.search(
            r"SELECT\s+.*?\s+FROM\s+\w+.*?;", query, re.DOTALL | re.IGNORECASE
        )
    elif mode == "DELETE":
        sql_query = re.search(
            r"(WITH\s+.*? AS .*?DELETE FROM\s+\w+.*?;|DELETE FROM\s+\w+.*?;)",
            query,
            re.DOTALL,
        )

    # If a valid SQL query is found, return it
    if sql_query:
        return sql_query.group(0).strip()

    # If no valid SQL was extracted, return None
    return None


def is_correct_number_of_rows(sql_query, expected_rows):
    """
    Checks how many separate (...) groups appear in the VALUES clause.
    Returns True if it matches expected_rows, otherwise False.
    """
    if not sql_query:
        return False

    # Count occurrences of '(' that appear after the "VALUES" keyword
    match = re.search(r"VALUES\s*(\(.+\))\s*;", sql_query, re.DOTALL | re.IGNORECASE)
    if not match:
        return False

    # The substring with all (row) groups
    rows_str = match.group(1)
    # Count how many times we have an open parenthesis that signals a new row
    row_count = rows_str.count("(")
    return row_count == expected_rows


# General function to generate SQL queries for batch processing
def generate_sql(prompts, mode="CREATE", generation_config=generation_create_config):
    """
    Generates SQL queries using the LLM based on the given prompts.
    Supports batch processing when multiple prompts are provided.
    """
    queries = generate(prompts, generation_config=generation_config)
    if isinstance(prompts, list):
        return [extract_sql(query, mode) for query in queries]
    return extract_sql(queries, mode)


def generate_sql_with_retry(
    prompts, mode="CREATE", retries=3, generation_config=None, expected_rows=None
):
    results = []
    for prompt in prompts:
        for _ in range(retries):
            raw_output = generate(prompt, generation_config=generation_config)
            sql_query = extract_sql(raw_output, mode)
            if not sql_query:
                print(f"Retrying: extraction failed for prompt:\n{prompt}\n")
                print(f"Raw output:\n{raw_output}\n")
                continue

            # Check correct row count
            if expected_rows and not is_correct_number_of_rows(
                sql_query, expected_rows
            ):
                print(
                    f"Retrying: row count mismatch. Expected {expected_rows}:\n{sql_query}\n"
                )
                continue

            # If we got here, we have a valid query with correct row count
            results.append(sql_query)
            break
        else:
            # If we never broke out of the for-loop, we failed all retries
            results.append(None)
    return results

In [8]:
database_name = "database.db"

if os.path.exists(database_name):
    os.remove(database_name)

<div style="
    background-color: #2C3E50; 
    color: #ECF0F1; 
    font-size: 28px; 
    text-align: center; 
    padding: 20px; 
    border-radius: 15px; 
    box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.2); 
    width: 100%; 
    margin: auto; 
    font-family: Times New Roman, sans-serif;">
    1. Create tables using llms
</div>

You need to generate and execute SQL queries to create 3 tables:
- "characters": Id (primary key), Name (str), Age (int), Profession (str).
- "characters20": same as characters.
- "skills": Id (primary key), Name (str), Profession (str).

For example, by running this code ``cursor.execute("""PRAGMA table_info(characters);""").fetchall()``.

You should have this results:

```
[
    (0, 'id', 'INTEGER', 0, None, 1),
    (1, 'name', 'TEXT', 1, None, 0),
    (2, 'age', 'INTEGER', 1, None, 0),
    (3, 'profession', 'TEXT', 1, None, 0)
]
```

<font color='red'>BE CAREFUL: sqlite3 doesn't have the same possibility than SQL. You may need to specify it.</font>

In [9]:
def create_tables(
    database_name,
    prompts={
        "characters": (
            "Write an SQL query to create a table named 'characters' "
            "with columns: id (INTEGER PRIMARY KEY), name (TEXT), age (INTEGER), "
            "profession (TEXT)."
        ),
        "characters20": (
            "Write an SQL query to create a table named 'characters20' "
            "with columns: id (INTEGER PRIMARY KEY), name (TEXT), age (INTEGER), "
            "profession (TEXT)."
        ),
        "skills": (
            "Write an SQL query to create a table named 'skills' "
            "with columns: id (INTEGER PRIMARY KEY), name (TEXT), profession (TEXT)."
        ),
    },
    commit=True,
    generation_config=generation_create_config,
):
    # Establish database connection
    with sqlite3.connect(database_name) as conn:
        cursor = conn.cursor()

        # Generate and execute SQL queries for table creation
        for table_name, prompt in tqdm(prompts.items(), desc="Creating Tables"):
            sql_query = generate_sql(
                prompt, mode="CREATE", generation_config=generation_config
            )
            if sql_query:
                print(f"Executing query for '{table_name}':\n{sql_query}\n")
                cursor.execute(sql_query)
            else:
                print(f"Failed to generate SQL query for '{table_name}'.")

        # Commit changes
        if commit:
            conn.commit()

        # Verify structure of each created table
        for table_name in prompts.keys():
            schema_info = cursor.execute(f"PRAGMA table_info({table_name});").fetchall()
            print(f"Table Info for '{table_name}':\n{schema_info}\n")

    print("Tables created successfully!")

In [10]:
create_tables(database_name, generation_config=generation_create_config)

Creating Tables:   0%|          | 0/3 [00:00<?, ?it/s]

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Executing query for 'characters':
CREATE TABLE characters (
    id INTEGER PRIMARY KEY,
    name TEXT,
    age INTEGER,
    profession TEXT
);

Executing query for 'characters20':
CREATE TABLE characters20 (
    id INTEGER PRIMARY KEY,
    name TEXT,
    age INTEGER,
    profession TEXT
);

Executing query for 'skills':
CREATE TABLE skills (
    id INTEGER PRIMARY KEY,
    name TEXT,
    profession TEXT
);

Table Info for 'characters':
[(0, 'id', 'INTEGER', 0, None, 1), (1, 'name', 'TEXT', 0, None, 0), (2, 'age', 'INTEGER', 0, None, 0), (3, 'profession', 'TEXT', 0, None, 0)]

Table Info for 'characters20':
[(0, 'id', 'INTEGER', 0, None, 1), (1, 'name', 'TEXT', 0, None, 0), (2, 'age', 'INTEGER', 0, None, 0), (3, 'profession', 'TEXT', 0, None, 0)]

Table Info for 'skills':
[(0, 'id', 'INTEGER', 0, None, 1), (1, 'name', 'TEXT', 0, None, 0), (2, 'profession', 'TEXT', 0, None, 0)]

Tables created successfully!


We also experimented with creating the `characters20` table using the following prompt provided to the language model: `"Write an SQL query to create a table named 'characters20' with the same schema as 'characters'."`. This generated the following SQL code:

```sql
CREATE TABLE characters20 AS
SELECT * FROM characters;
```

While this successfully created the `characters20` table with the expected schema, there was a significant difference: the `id` column, unlike in the original `characters` table, was not defined as a primary key. This difference caused issues in subsequent tasks that relied on the primary key to enforce data integrity and uniqueness.

To resolve this, we decided to use the same prompt that had been used to create the `characters` table. This ensured the `characters20` table had a schema identical to the original, including the proper primary key constraint, thereby avoiding the issues encountered with the initial approach.

<div style="
    background-color: #2C3E50; 
    color: #ECF0F1; 
    font-size: 28px; 
    text-align: center; 
    padding: 20px; 
    border-radius: 15px; 
    box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.2); 
    width: 100%; 
    margin: auto; 
    font-family: Times New Roman, sans-serif;">
    2. Populate tables using llm
</div>

You need to generate and execute SQL queries to fill in “characters” and “characters20” :
- For both, the age must be constrained between 18 and 50 (we'll assess whether the constraint is met later).
- For “characters”, generate 10 rows using the prompt. Apply the prompt 10 times (you should end up with 100 lines).
- For “characters20”, generate 20 rows using the prompt. Apply the prompt 5 times (you should also get 100 lines at the end).


For example, executing this code ``cursor.execute("SELECT * FROM characters")``.

You should get this result (with 100 rows and perhaps different values ...) :

```
[
    (1, 'Alice', 25, 'Artist'),
    (2, 'Bob', 35, 'Engineer'),
    ...
    (99, 'Ian', 32, 'Architect'),
    (100, 'Jane', 18, 'Dancer')
]
```

<font color='red'> BE CAREFUL: If your generation configuration doesn't include sampling, you'll always have the same rows.</font>

<font color='green'> BONUS: In section 3, we'll compare the number of duplicated rows between the two methods. Do you have a better strategy for minimizing the number of duplicated rows? Give it a try! (create another table for this purpose) </font>

In [11]:
populated_tables = {
    "characters": {
        "n_rows": 10,
        "n_runs": 10,
        "generation_config": generation_insert_config,
    },
    "characters20": {
        "n_rows": 20,
        "n_runs": 5,
        "generation_config": generation_insert_config,
    },
    # Below tables are added for the bonus task: more diverse data to minimize the repetition of the data
    "characters_more_diverse": {
        "n_rows": 10,
        "n_runs": 10,
        "generation_config": generation_insert_config_more_diverse,
    },
    "characters20_more_diverse": {
        "n_rows": 20,
        "n_runs": 5,
        "generation_config": generation_insert_config_more_diverse,
    },
}

In [12]:
# For the bonus task: we create two additional tables to later hold more diverse data:
prompts = {
    "characters_more_diverse": (
        "Write an SQL query to create a table named 'characters_more_diverse' "
        "with columns: id (INTEGER PRIMARY KEY), name (TEXT), age (INTEGER), "
        "profession (TEXT)."
    ),
    "characters20_more_diverse": (
        "Write an SQL query to create a table named 'characters20_more_diverse' "
        "with columns: id (INTEGER PRIMARY KEY), name (TEXT), age (INTEGER), "
        "profession (TEXT)."
    ),
}

create_tables(
    database_name, prompts=prompts, generation_config=generation_create_config
)

Creating Tables:   0%|          | 0/2 [00:00<?, ?it/s]

Executing query for 'characters_more_diverse':
CREATE TABLE characters_more_diverse (
    id INTEGER PRIMARY KEY,
    name TEXT,
    age INTEGER,
    profession TEXT
);

Executing query for 'characters20_more_diverse':
CREATE TABLE characters20_more_diverse (
    id INTEGER PRIMARY KEY,
    name TEXT,
    age INTEGER,
    profession TEXT
);

Table Info for 'characters_more_diverse':
[(0, 'id', 'INTEGER', 0, None, 1), (1, 'name', 'TEXT', 0, None, 0), (2, 'age', 'INTEGER', 0, None, 0), (3, 'profession', 'TEXT', 0, None, 0)]

Table Info for 'characters20_more_diverse':
[(0, 'id', 'INTEGER', 0, None, 1), (1, 'name', 'TEXT', 0, None, 0), (2, 'age', 'INTEGER', 0, None, 0), (3, 'profession', 'TEXT', 0, None, 0)]

Tables created successfully!


#### Populate tables

_Note: In contrast to the code executed so far, we will now leverage the batch processing capabilities of the LLM that we implemented earlier to significantly speed up the inference process for populating the four different tables specified above. Although this would not necessarily be required, we decided to implement this optimization on top of the other tasks._

In [13]:
def populate_tables(database_name, populated_tables, commit=True):
    # Establish database connection
    conn = sqlite3.connect(database_name)
    cursor = conn.cursor()

    # Common prompt template
    base_populate_prompt = (
        "Generate an SQL INSERT INTO query (no explanations) to add {n_rows} unique rows "
        "to the '{table}' table (age between 18 and 50). "
        "The table schema is as follows: {info}. "
        "Use realistic and creative names, and professions. "
        "Make sure that EXACTLY {n_rows} rows are inserted, no more, no less. "
        "Strictly output 1 valid SQL statement. No additional text or formatting. "
    )

    # Build a schema cache for each table
    table_schema_info = {}
    for table_name in populated_tables:
        schema_info = cursor.execute(f"PRAGMA table_info({table_name});").fetchall()
        table_schema_info[table_name] = ", ".join(
            [f"{col[1]} ({col[2]})" for col in schema_info if col[1] != "id"]
        )

    # Iterate over each table and generate/execute INSERT queries in batches
    failures = 0
    total = sum(table["n_runs"] for table in populated_tables.values())
    pbar = tqdm(total=total)

    for table_name, table_config in populated_tables.items():
        pbar.desc = f"Populating {table_name}"
        pbar.refresh()

        # Prepare batch of prompts
        run_times = table_config.get("n_runs", 10)
        batch_prompts = [
            base_populate_prompt.format(
                n_rows=table_config.get("n_rows", 10),
                table=table_name,
                info=table_schema_info[table_name],
            )
            for _ in range(run_times)
        ]

        # Generate batch of SQL queries
        sql_queries = generate_sql_with_retry(
            batch_prompts,
            mode="INSERT",
            generation_config=table_config.get(
                "generation_config", generation_insert_config
            ),
            expected_rows=None,
        )

        # Execute generated queries
        for sql_query in sql_queries:
            if sql_query:
                try:
                    print(f"Executing query for {table_name}:\n{sql_query}\n")
                    cursor.execute(sql_query)
                except Exception as e:
                    failures += 1
                    print(
                        f"Failed to execute SQL query for {table_name}. Error: {e}.\n"
                    )
            else:
                failures += 1
                print(f"Failed to generate SQL query for {table_name}.\n")
            pbar.update(1)

        # Commit after each table's insertions
        if commit:
            conn.commit()

    conn.close()
    pbar.close()
    print(f"Generated successfully {total - failures}/{total} times!")

In [14]:
populate_tables(database_name, populated_tables)

  0%|          | 0/30 [00:00<?, ?it/s]

Retrying: extraction failed for prompt:
Generate an SQL INSERT INTO query (no explanations) to add 10 unique rows to the 'characters' table (age between 18 and 50). The table schema is as follows: name (TEXT), age (INTEGER), profession (TEXT). Use realistic and creative names, and professions. Make sure that EXACTLY 10 rows are inserted, no more, no less. Strictly output 1 valid SQL statement. No additional text or formatting. 

Raw output:
INSERT INTO characters (name, age, profession) VALUES ('Zephyr', 31, 'Musician'), ('Astra', 25, 'Artist'), ('Kael', 45, 'Designer'), ('Orion', 28, 'Writer'), ('Lira', 35, 'Photographer'), ('Cedric', 42, 'Editor'), ('Isolde', 23, 'Performer'), ('Ada', 38, 'Animator'), ('Eamon', 20, 'Graphic Designer'), ('Eirian', 43, 'Writer')

Executing query for characters:
INSERT INTO characters (name, age, profession) VALUES
('Alix Stone', 23, 'Software Engineer'),
('Bryce Miller', 47, 'Chef'),
('Cassandra Lee', 35, 'Actor'),
('Darin Walker', 18, 'Artist'),
('Ele

#### Verify correct execution/population

In [15]:
with sqlite3.connect(database_name) as conn:
    cursor = conn.cursor()
    # Verify the number of rows in each table
    for table_name in populated_tables:
        cursor.execute(f"SELECT * FROM {table_name};")
        output = cursor.fetchall()
        print(f"Table: {table_name}")
        print(f"Number of rows: {len(output)}")
        print(f"5 top rows: {output[:5]}")
        print()

Table: characters
Number of rows: 98
5 top rows: [(1, 'Alix Stone', 23, 'Software Engineer'), (2, 'Bryce Miller', 47, 'Chef'), (3, 'Cassandra Lee', 35, 'Actor'), (4, 'Darin Walker', 18, 'Artist'), (5, 'Eleanor Clark', 38, 'Artist')]

Table: characters20
Number of rows: 90
5 top rows: [(1, 'Ava Stone', 32, 'Engineer'), (2, 'Ethan Rose', 28, ' Artist'), (3, 'Olivia Davis', 19, 'Writer'), (4, 'Noah Carter', 45, 'Chef'), (5, 'Eleanor White', 25, 'Doctor')]

Table: characters_more_diverse
Number of rows: 100
5 top rows: [(1, 'Alice', 28, 'Software Engineer'), (2, 'Bob', 45, 'Musician'), (3, 'Charlie', 32, 'Architect'), (4, 'Diana', 22, 'Photographer'), (5, 'Eli', 50, 'Writer')]

Table: characters20_more_diverse
Number of rows: 89
5 top rows: [(1, 'Ada Lovelace', 30, 'Mathematician'), (2, 'Alvin Ailey', 45, 'Choreographer'), (3, 'Bessie Smith', 40, 'Musician'), (4, 'Carl Sagan', 60, 'Astronomer'), (5, 'Dorothy Day', 50, 'Activist')]



It appears that the LLM occasionally exhibits inconsistencies, failing to generate the exact number of rows requested, even after a certain number of retries. It is also noteworthy that the LLM generally tends to generate names comprising both a first name and a surname, even when this formatting is not explicitly specified. Interestingly, the LLM frequently includes the names of well-known figures, such as Albert Einstein and Nikola Tesla, in its outputs.

<div style="
    background-color: #2C3E50; 
    color: #ECF0F1; 
    font-size: 28px; 
    text-align: center; 
    padding: 20px; 
    border-radius: 15px; 
    box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.2); 
    width: 100%; 
    margin: auto; 
    font-family: Times New Roman, sans-serif;">
    3. Explore our tables using llm.

</div>
  
First, you need to generate and execute SQL queries that indicate the number of duplicate rows (without ids) in each character table. To make things easier, we only ask for the number of each duplicated rows.

Here is an examples of expected results:

```
[(2,), (7,), (5,), (2,), (2,), (3,), (2,), (2,), (2,), (2,), (2,), (2,), (2,)]
```

<font color='green'> BONUS: Generate a query that returns the total count of duplicated rows. You may need to do this in several steps.</font>

Secondly, you need to generate and execute SQL queries that remove duplicate rows. To make things easier, it's not necessary to keep original duplicated lines. For example, if you have a list like this : [a, b, a, c]. We ask you to remove all the a: [b, c].

<font color='green'> BONUS: Generate a query that delete duplicated but keep the original row. [a, b, a, c] -> [a, b, c] </font>

Finaly, you need to generate and execute SQL queries that check if the age constraint is respected.

<font color='red'> BE CAREFUL: Do each step for every characters tables you have.</font>


> First, you need to generate and execute SQL queries that indicate the number of duplicate rows (without ids) in each character table. To make things easier, we only ask for the number of each duplicated rows.
> 
> Here is an examples of expected results: [(2,), (7,), (5,), (2,), (2,), (3,), (2,), (2,), (2,), (2,), (2,), (2,), (2,)]
> 
> <font color='green'> BONUS: Generate a query that returns the total count of duplicated rows. You may need to do this in several steps.</font>

In [16]:
def detect_duplicate_rows_in_characters(
    database_name
    , character_tables=["characters", "characters20", "characters_more_diverse", "characters20_more_diverse"]
):
    # Prompts for listing the counts of each duplicate row (ignoring ID).
    base_duplicate_count_prompt = (
        "Write an SQL query to find all duplicate rows (ignoring the id column) in "
        "the '{table_name}' table. You should group by (name, age, profession) "
        "and only return the count for each group that appears more than once. "
        "Return only the count in each row (e.g. 2, 5, 7, etc.)."
    )
    duplicate_count_prompts = {
        table_name: base_duplicate_count_prompt.format(table_name=table_name)
        for table_name in character_tables
    }

    # Prompt for total duplicated rows (summed across all groups).
    # One could either sum up the counts from the previous query or
    # produce a separate query that calculates the total duplicates in one go
    base_total_duplicate_count_prompt = (
        "Write an SQL query to calculate the total number of duplicated rows in the "
        "'{table_name}' table, ignoring the 'id' column. Rows are considered duplicates if they "
        "have the same combination of values for the columns ('name', 'age', 'profession'). "
        "The query should:\n"
        "1. First, identify duplicate groups by grouping rows with the same ('name', 'age', 'profession') values as duplicate_count.\n"
        "2. Include only groups where the count of rows is greater than 1.\n"
        "3. Sum duplicate_count and return a single number as the result."
    )

    total_duplicate_count_prompts = {
        table_name: base_total_duplicate_count_prompt.format(table_name=table_name)
        for table_name in character_tables
    }

    # Connect to the database
    with sqlite3.connect(database_name) as conn:
        cursor = conn.cursor()

        total = 2 * len(character_tables)
        pbar = tqdm(total=total)
        for table_name in character_tables:
            # 1. Get the counts for each group of duplicates
            duplicates_prompt = duplicate_count_prompts[table_name]
            duplicates_query = generate_sql(duplicates_prompt, mode="SELECT")

            if duplicates_query:
                try:
                    print(f"--- Duplicates in '{table_name}' ---\n")
                    print(f"SQL Query:\n{duplicates_query}\n")
                    cursor.execute(duplicates_query)
                    row_counts = cursor.fetchall()
                    print(f"Result (counts of each group): {row_counts}\n")
                except Exception as e:
                    print(
                        f"Failed to generate duplicates query for '{table_name}'. Error: {e}.\n"
                    )
            else:
                print(f"Failed to generate duplicates query for '{table_name}'.\n")

            pbar.update(1)

            # 2. Bonus: total number of duplicate rows
            total_prompt = total_duplicate_count_prompts[table_name]
            total_query = generate_sql(total_prompt, mode="SELECT")

            if total_query:
                try:
                    print(f"--- Total Duplicates in '{table_name}' ---\n")
                    print(f"SQL Query:\n{total_query}\n")
                    cursor.execute(total_query)
                    total_count = cursor.fetchall()
                    print(f"Total duplicated rows: {total_count}\n")
                except Exception as e:
                    print(
                        f"Failed to generate total duplicates query for '{table_name}'. Error: {e}.\n"
                    )
            else:
                print(
                    f"Failed to generate total duplicates query for '{table_name}'.\n"
                )

            pbar.update(1)

    pbar.close()
    print("Duplicate detection completed.")

In [17]:
detect_duplicate_rows_in_characters(database_name)

  0%|          | 0/8 [00:00<?, ?it/s]

--- Duplicates in 'characters' ---

SQL Query:
SELECT COUNT(*) AS duplicate_count
FROM characters
GROUP BY name, age, profession
HAVING COUNT(*) > 1;

Result (counts of each group): []

--- Total Duplicates in 'characters' ---

SQL Query:
SELECT SUM(duplicate_count) AS total_duplicates
FROM (
    SELECT COUNT(*) AS duplicate_count
    FROM characters
    GROUP BY name, age, profession
    HAVING COUNT(*) > 1
) AS duplicate_groups;

Total duplicated rows: [(None,)]

--- Duplicates in 'characters20' ---

SQL Query:
SELECT COUNT(*) AS duplicate_count
FROM characters20
GROUP BY name, age, profession
HAVING COUNT(*) > 1;

Result (counts of each group): []

--- Total Duplicates in 'characters20' ---

SQL Query:
SELECT SUM(duplicate_count) AS total_duplicates
FROM (
    SELECT COUNT(*) AS duplicate_count
    FROM characters20
    GROUP BY name, age, profession
    HAVING COUNT(*) > 1
) AS duplicate_groups;

Total duplicated rows: [(None,)]

--- Duplicates in 'characters_more_diverse' ---

SQL

> Secondly, you need to generate and execute SQL queries that remove duplicate rows. To make things easier, it's not necessary to keep original duplicated lines. For example, if you have a list like this : [a, b, a, c]. We ask you to remove all the a: [b, c].

In [18]:
def remove_all_duplicates_in_characters(
    database_name,
    character_tables=[
        "characters",
        "characters20",
        "characters_more_diverse",
        "characters20_more_diverse",
    ],
    commit=False,
):
    base_remove_all_duplicates_prompt = (
        "Write an SQL query to remove all rows from the '{table_name}' table that have "
        "duplicate values of (name, age, profession). If two or more rows share the same "
        "(name, age, profession), delete all rows with that combination. "
        "For example, if the data is [a, b, a, c], after deletion it should be [b, c]."
    )

    remove_all_duplicates_prompts = {
        table_name: base_remove_all_duplicates_prompt.format(table_name=table_name)
        for table_name in character_tables
    }

    with sqlite3.connect(database_name) as conn:
        cursor = conn.cursor()
        cursor.execute("BEGIN")
        pbar = tqdm(total=len(character_tables))
        for table_name in character_tables:

            count_duplicates = (
                "SELECT SUM(duplicate_count) AS total_duplicates\n"
                "FROM (\n"
                "    SELECT COUNT(*) AS duplicate_count\n"
                "    FROM {table_name}\n"
                "    GROUP BY name, age, profession\n"
                "    HAVING COUNT(*) > 1\n"
                ") AS duplicate_groups;"
            ).format(table_name=table_name)

            prompt = remove_all_duplicates_prompts[table_name]
            delete_query = generate_sql(prompt, mode="DELETE")

            if delete_query:
                try:
                    cursor.execute(f"SELECT * FROM {table_name}")
                    initial_table = cursor.fetchall()

                    print(f"--- Removing all duplicates in '{table_name}' ---\n")
                    print(f"SQL Query:\n{delete_query}\n")
                    cursor.execute(delete_query)

                    cursor.execute(count_duplicates)
                    output = cursor.fetchall()
                    print(f"Remaining duplicates: {output[0][0]}.\n")

                    cursor.execute(f"SELECT * FROM {table_name}")
                    final_table = cursor.fetchall()
                    print(f"Removed rows: {len(initial_table) - len(final_table)}")
                    print(f"Number of rows after removal: {len(final_table)}\n")
                except Exception as e:
                    print(
                        f"Failed to execute delete query for '{table_name}'. Error: {e}\n"
                    )
            else:
                print(f"Failed to generate delete query for '{table_name}'.\n")

            pbar.update(1)

        if not commit:
            cursor.execute("ROLLBACK")

        pbar.close()

    print("All-duplicates removal completed.")

In [19]:
remove_all_duplicates_in_characters(database_name)

  0%|          | 0/4 [00:00<?, ?it/s]

--- Removing all duplicates in 'characters' ---

SQL Query:
DELETE FROM characters
WHERE (name, age, profession) IN (
    SELECT name, age, profession
    FROM characters
    GROUP BY name, age, profession
    HAVING COUNT(*) > 1
);

Remaining duplicates: None.

Removed rows: 0
Number of rows after removal: 98

--- Removing all duplicates in 'characters20' ---

SQL Query:
DELETE FROM characters20
WHERE (name, age, profession) IN (
    SELECT name, age, profession
    FROM characters20
    GROUP BY name, age, profession
    HAVING COUNT(*) > 1
);

Remaining duplicates: None.

Removed rows: 0
Number of rows after removal: 90

--- Removing all duplicates in 'characters_more_diverse' ---

SQL Query:
DELETE FROM characters_more_diverse
WHERE (name, age, profession) IN (
    SELECT name, age, profession
    FROM characters_more_diverse
    GROUP BY name, age, profession
    HAVING COUNT(*) > 1
);

Remaining duplicates: None.

Removed rows: 0
Number of rows after removal: 100

--- Removing al

As we can observe in the cell above, there are no more duplicated rows.

> <font color='green'> BONUS: Generate a query that delete duplicated but keep the original row. [a, b, a, c] -> [a, b, c] </font>

In [20]:
def remove_duplicates_but_keep_original(
    database_name,
    character_tables=[
        "characters",
        "characters20",
        "characters_more_diverse",
        "characters20_more_diverse",
    ],
    commit=False,
):
    base_remove_keep_one_prompt = (
        "Write an SQL query to remove duplicate rows from '{table_name}'. "
        "If two or more rows share the same (name, age, profession), keep exactly one row and delete the rest. "
        "For example, if the data is [a, b, a, c], the final result should be [a, b, c]. Use whichever method "
        "is typical in SQLite (e.g., rowid or min(id)) to keep the earliest row inserted."
    )

    remove_keep_one_prompts = {
        table_name: base_remove_keep_one_prompt.format(table_name=table_name)
        for table_name in character_tables
    }

    with sqlite3.connect(database_name) as conn:
        cursor = conn.cursor()
        cursor.execute("BEGIN")
        pbar = tqdm(total=len(character_tables))
        for table_name in character_tables:

            count_duplicates = (
                "SELECT SUM(duplicate_count) AS total_duplicates\n"
                "FROM (\n"
                "    SELECT COUNT(*) AS duplicate_count\n"
                "    FROM {table_name}\n"
                "    GROUP BY name, age, profession\n"
                "    HAVING COUNT(*) > 1\n"
                ") AS duplicate_groups;"
            ).format(table_name=table_name)

            prompt = remove_keep_one_prompts[table_name]
            delete_query = generate_sql(prompt, mode="DELETE")
            if delete_query:
                try:
                    cursor.execute(f"SELECT * FROM {table_name}")
                    initial_table = cursor.fetchall()

                    print(
                        f"--- Removing duplicates but keeping one in '{table_name}' ---\n"
                    )
                    print(f"SQL Query:\n{delete_query}\n")
                    cursor.execute(delete_query)

                    cursor.execute(count_duplicates)
                    output = cursor.fetchall()
                    print(f"Remaining duplicates: {output[0][0]}.\n")

                    cursor.execute(f"SELECT * FROM {table_name}")
                    final_table = cursor.fetchall()
                    print(f"Removed rows: {len(initial_table) - len(final_table)}")
                    print(f"Number of rows after removal: {len(final_table)}\n")
                except Exception as e:
                    print(
                        f"Failed to execute delete query for '{table_name}'. Error: {e}\n"
                    )
            else:
                print(f"Failed to generate delete query for '{table_name}'.\n")

            pbar.update(1)

        if not commit:
            cursor.execute("ROLLBACK")

        pbar.close()

    print("Deduplication completed (keeping one row per duplicate group).")

In [21]:
remove_duplicates_but_keep_original(database_name, commit=True)

  0%|          | 0/4 [00:00<?, ?it/s]

--- Removing duplicates but keeping one in 'characters' ---

SQL Query:
WITH CTE AS (
    SELECT
        name,
        age,
        profession,
        ROWID,
        ROW_NUMBER() OVER (PARTITION BY name, age, profession ORDER BY ROWID) as rn
    FROM
        characters
)
DELETE FROM characters
WHERE ROWID NOT IN (
    SELECT ROWID
    FROM CTE
    WHERE rn = 1
);

Remaining duplicates: None.

Removed rows: 0
Number of rows after removal: 98

--- Removing duplicates but keeping one in 'characters20' ---

SQL Query:
WITH cte AS (
    SELECT
        name,
        age,
        profession,
        ROWID,
        ROW_NUMBER() OVER (PARTITION BY name, age, profession ORDER BY ROWID) as rn
    FROM
        characters20
)
DELETE FROM characters20
WHERE ROWID NOT IN (
    SELECT ROWID
    FROM cte
    WHERE rn = 1
);

Remaining duplicates: None.

Removed rows: 0
Number of rows after removal: 90

--- Removing duplicates but keeping one in 'characters_more_diverse' ---

SQL Query:
WITH cte AS (
 

As we can observe in the cell above, there are no more duplicated rows.

> Finally, you need to generate and execute SQL queries that check if the age constraint is respected.

In [22]:
def check_age_constraints(
    database_name,
    character_tables=[
        "characters",
        "characters20",
        "characters_more_diverse",
        "characters20_more_diverse",
    ],
):
    # 1) Prompt to find all violation rows: if age < 18 or age > 50
    base_find_violations_prompt = (
        "Write an SQL query to find all rows in the '{table_name}' table "
        "where the 'age' is NOT between 18 and 50, inclusive. Return the full row "
        "(e.g. id, name, age, profession) for each violation."
    )

    find_violations_prompts = {
        table_name: base_find_violations_prompt.format(table_name=table_name)
        for table_name in character_tables
    }

    # 2) Prompt to count how many violations exist in each table
    base_count_violations_prompt = (
        "Write an SQL query to count how many rows in the '{table_name}' table "
        "have an 'age' outside the range [18, 50]. Return a single integer."
    )

    count_violations_prompts = {
        table_name: base_count_violations_prompt.format(table_name=table_name)
        for table_name in character_tables
    }

    with sqlite3.connect(database_name) as conn:
        cursor = conn.cursor()

        # Progress bar for convenience
        pbar = tqdm(total=2 * len(character_tables))

        for table_name in character_tables:
            # A) Find all violation rows
            violations_prompt = find_violations_prompts[table_name]
            find_query = generate_sql(violations_prompt, mode="SELECT")

            if find_query:
                try:
                    print(f"\n--- Checking age constraints in '{table_name}' ---\n")
                    print(f"Query to find violations:\n{find_query}\n")
                    cursor.execute(find_query)
                    violating_rows = cursor.fetchall()
                    if violating_rows:
                        print("Violations found:")
                        for row in violating_rows:
                            print(row)
                    else:
                        print("No violations found (all rows comply).")
                except Exception as e:
                    print(
                        f"Failed to execute violation query for '{table_name}'. Error: {e}"
                    )
            else:
                print(f"Failed to generate violation query for '{table_name}'.")

            pbar.update(1)

            # B) Count how many violation rows
            count_prompt = count_violations_prompts[table_name]
            count_query = generate_sql(count_prompt, mode="SELECT")

            if count_query:
                try:
                    print(f"\nQuery to count violations:\n{count_query}\n")
                    cursor.execute(count_query)
                    count_result = (
                        cursor.fetchone()
                    )  # Should be something like (3,) if 3 violations
                    print(f"Total violations in '{table_name}': {count_result[0]}\n")
                except Exception as e:
                    print(
                        f"Failed to execute count query for '{table_name}'. Error: {e}"
                    )
            else:
                print(f"Failed to generate count query for '{table_name}'.")

            pbar.update(1)

        pbar.close()

    print("Age constraint check completed.")

In [23]:
check_age_constraints(database_name)

  0%|          | 0/8 [00:00<?, ?it/s]


--- Checking age constraints in 'characters' ---

Query to find violations:
SELECT id, name, age, profession
FROM characters
WHERE age NOT BETWEEN 18 AND 50;

Violations found:
(15, 'Mahatma Gandhi', 70, 'Social Worker')

Query to count violations:
SELECT COUNT(*) AS count
FROM characters
WHERE age < 18 OR age > 50;

Total violations in 'characters': 1


--- Checking age constraints in 'characters20' ---

Query to find violations:
SELECT id, name, age, profession
FROM characters20
WHERE age NOT BETWEEN 18 AND 50;

Violations found:
(74, 'Ada Lovelace', 17, 'Mathematician')
(75, 'Albus Dumbledore', 115, 'Headmaster')
(76, 'Albert Einstein', 74, 'Physicist')
(78, 'Asimo', 3, 'Robot')
(81, 'Charles Darwin', 82, 'Naturalist')
(82, 'Cornelia Otis Skinner', 79, 'Actress')
(84, 'Douglas Adams', 60, 'Writer')
(87, 'Florence Nightingale', 90, 'Nurse')
(88, 'Gandhi', 90, 'Reformer')
(89, 'George Orwell', 68, 'Writer')
(90, 'Hank Scorpio', 999, 'Scientist')

Query to count violations:
SELECT COU

#### Check violations:

In [24]:
with sqlite3.connect(database_name) as conn:
    for table_name in ["characters", "characters20", "characters_more_diverse", "characters20_more_diverse"]:
        cursor = conn.cursor()
        cursor.execute(
            "SELECT id, name, age, profession\n"
            f"FROM {table_name}\n"
            "WHERE age NOT BETWEEN 18 AND 50;"
        )
        output = cursor.fetchall()
        print(f"Violations in '{table_name}': {len(output)}")

Violations in 'characters': 1
Violations in 'characters20': 11
Violations in 'characters_more_diverse': 1
Violations in 'characters20_more_diverse': 1


<div style="
    background-color: #2C3E50; 
    color: #ECF0F1; 
    font-size: 28px; 
    text-align: center; 
    padding: 20px; 
    border-radius: 15px; 
    box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.2); 
    width: 100%; 
    margin: auto; 
    font-family: Times New Roman, sans-serif;">
    4. More than one table with llm.

</div>

First, choose your best characters table (with the largest number of rows).

Second, generate and execute an SQL query that returns the set of unique professions in the table.

Third, generate and execute an SQL query that populates the skill tables from this set of unique professions.

Fourth, generate and execute an SQL query that verifies that the professions in the skill table exist in your characters table.

Finally, generate and execute an SQL query that returns the name of the skills associated with a character name (by profession).

In [25]:
# Function to chunk the professions list into smaller groups
def chunk_professions(professions, chunk_size = 10):
    it = iter(professions)
    while chunk := list(islice(it, chunk_size)):
        yield chunk

In [27]:
def sync_skills_with_characters(database_name, commit=False):
    character_tables = [
        "characters",
        "characters20",
        "characters_more_diverse",
        "characters20_more_diverse",
    ]

    with sqlite3.connect(database_name) as conn:
        cursor = conn.cursor()

        # --------------------------------------------------------------------------------
        # STEP 1: Choose your "best" (largest) characters table
        # --------------------------------------------------------------------------------
        largest_table = None
        largest_count = -1

        for table_name in character_tables:
            cursor.execute(f"SELECT COUNT(*) FROM {table_name}")
            row_count = cursor.fetchall()[0][0]
            if row_count > largest_count:
                largest_count = row_count
                largest_table = table_name

        print(f"Chosen table '{largest_table}' with {largest_count} rows.\n")

        # --------------------------------------------------------------------------------
        # STEP 2: Generate & execute a query to get DISTINCT professions
        # --------------------------------------------------------------------------------
        profession_prompt = (
            "Write a single SQL SELECT statement that returns the DISTINCT professions "
            f"from the '{largest_table}' table."
        )
        distinct_profession_query = generate_sql(
            profession_prompt,
            mode="SELECT",
            generation_config=generation_create_config,
        )

        if distinct_profession_query:
            print(f"--- Getting unique professions from '{largest_table}' ---\n")
            print(f"SQL Query:\n{distinct_profession_query}\n")
            cursor.execute(distinct_profession_query)
            professions = [row[0] for row in cursor.fetchall()]
            print("Unique professions:", professions, "\n")
        else:
            print("Failed to generate query for distinct professions.")
            return  # Exit early if the query cannot be generated

        # --------------------------------------------------------------------------------
        # STEP 3: Populate the skills table using the set of unique professions
        # --------------------------------------------------------------------------------
        skils_schema = ", ".join(
            [
                f"{col[1]} ({col[2]})"
                for col in cursor.execute("PRAGMA table_info(skills);").fetchall()
                if col[1] != "id"
            ]
        )

        # DISCLAIMER: The following implementation populates the `skills` table using a single
        # SQL query, as requested. However, this approach has proven to be prone to hallucinations
        # (e.g., syntax errors or incorrect interpretations by the LLM). For this reason, we have
        # opted to comment it out while still providing it for future reference. Instead, we now
        # populate the `skills` table in smaller chunks of unique professions to ensure reliability.

        # joined_professions = ", ".join(f"'{p}'" for p in professions)
        # populate_skills_prompt = (
        #     "Write a single SQL INSERT statement that inserts one row into the 'skills' table for each "
        #     f"of these professions: {joined_professions}. The table schema is {skils_schema}. "
        #     "Be creative for the skill name."
        # )

        # populate_skills_query = generate_sql(
        #     populate_skills_prompt,
        #     mode="INSERT",
        #     generation_config=generation_long_insert_config,
        # )

        # if populate_skills_query:
        #     print("--- Populating 'skills' table ---\n")
        #     print(f"SQL Query:\n{populate_skills_query}\n")
        #     try:
        #         cursor.execute(populate_skills_query)
        #         if commit:
        #             conn.commit()
        #         print("Successfully populated 'skills' table.\n")
        #     except Exception as e:
        #         print(f"Error executing skills population query: {e}\n")
        #         return
        # else:
        #     print("Failed to generate skills population query.\n")
        #     return

        for profession_chunk in chunk_professions(professions):
            joined_professions = ", ".join(f"'{p}'" for p in profession_chunk)

            populate_skills_prompt = (
                "Write a single SQL INSERT statement that inserts one row into the 'skills' table for each "
                f"of these professions: {joined_professions}. The table schema is {skils_schema}. "
                "Be creative for the skill name and ensure each skill aligns with the profession."
            )

            populate_skills_query = generate_sql(
                populate_skills_prompt,
                mode="INSERT",
                generation_config=generation_long_insert_config,
            )

            if populate_skills_query:
                print("--- Populating 'skills' table ---\n")
                print(f"SQL Query:\n{populate_skills_query}\n")
                try:
                    cursor.execute(populate_skills_query)
                    if commit:
                        conn.commit()
                    print(f"Successfully populated 'skills' table for professions: {profession_chunk}\n")
                except Exception as e:
                    print(f"Error executing skills population query for professions {profession_chunk}: {e}\n")
                    continue
            else:
                print(f"Failed to generate skills population query for professions {profession_chunk}.\n")
                continue

        # --------------------------------------------------------------------------------
        # STEP 4: Verify that professions in 'skills' exist in the chosen table
        # --------------------------------------------------------------------------------
        largest_table_schema = ", ".join(
            [
                f"{col[1]} ({col[2]})"
                for col in cursor.execute(
                    f"PRAGMA table_info({largest_table});"
                ).fetchall()
                if col[1] != "id"
            ]
        )
        base_verify_prompt = (
            "Write a SQL SELECT statement that returns any professions in 'skills' "
            f"that do NOT exist in '{largest_table}'. The 'skills' table schema is "
            f"{skils_schema}. The '{largest_table}' table schema is {largest_table_schema}. "
            f"Only return the 'profession' column from 'skills' that has no match in '{largest_table}'."
        )

        verify_query = generate_sql(
            base_verify_prompt,
            mode="SELECT",
            generation_config=generation_create_config,
        )

        if verify_query:
            print("--- Verifying 'skills' professions ---\n")
            print(f"SQL Query:\n{verify_query}\n")
            cursor.execute(verify_query)
            missing_professions = cursor.fetchall()
            if missing_professions:
                print(
                    "Found professions in 'skills' that are missing in the character table:"
                )
                for mp in missing_professions:
                    print("  Missing profession:", mp[0])
                print("\n")
            else:
                print("All 'skills' professions exist in the character table.\n")
        else:
            print("Failed to generate verification query.\n")

        # --------------------------------------------------------------------------------
        # STEP 5: Return the name of the skills associated with a character name
        #         (by matching profession)
        # --------------------------------------------------------------------------------
        base_skill_by_char_prompt = (
            "Write a single SQL SELECT statement that returns pairs of (c.name, s.name) where "
            f"c.name is the character's name in '{largest_table}' and s.name is the skill's name "
            "in 'skills'. Match them by profession, i.e. c.profession = s.profession. "
            "Return columns as (character_name, skill_name)."
        )

        skill_by_char_query = generate_sql(
            base_skill_by_char_prompt,
            mode="SELECT",
            generation_config=generation_create_config,
        )

        if skill_by_char_query:
            print("--- Getting skills by character name ---\n")
            print(f"SQL Query:\n{skill_by_char_query}\n")
            cursor.execute(skill_by_char_query)
            char_skill_pairs = cursor.fetchall()
            if char_skill_pairs:
                print("Character-Skill pairs found:")
                for pair in char_skill_pairs:
                    print(pair)
                print("\n")
            else:
                print("No Character-Skill pairs found.\n")
        else:
            print("Failed to generate skill-by-character query.\n")

    print("Finished syncing skills with characters.")

In the function above, we choose to populate the `skills` table with only a subset of professions at a time. This approach minimizes the risk of hallucinations by the LLM during SQL query generation, such as syntax errors or misinterpretation of the task.

In [28]:
sync_skills_with_characters(database_name)

Chosen table 'characters_more_diverse' with 100 rows.

--- Getting unique professions from 'characters_more_diverse' ---

SQL Query:
SELECT DISTINCT profession
FROM characters_more_diverse;

Unique professions: ['Software Engineer', 'Musician', 'Architect', 'Photographer', 'Writer', 'Teacher', 'Graphic Designer', 'Actor', 'Scientist', 'Artist', 'Chef', 'Actress', 'Doctor', 'Nurse', 'Computer Scientist', 'Physicist', 'Activist', 'Judge', 'Author', 'Politician', 'Designer', 'Journalist', 'Entrepreneur', 'Data Scientist', 'Dancer', 'Historian', 'Librarian', 'Mathematician', 'Fashion Designer', 'Astronaut', 'Athlete'] 

--- Populating 'skills' table ---

SQL Query:
INSERT INTO skills (name, profession)
VALUES
('Code Crafting', 'Software Engineer'),
('Melody Mastery', 'Musician'),
('Blueprint Excellence', 'Architect'),
('Cliched Shutter', 'Photographer'),
('Pen Prose', 'Writer'),
('Educational Empowerment', 'Teacher'),
('Visual Visionary', 'Graphic Designer'),
('Emotional Embodiment', 'Acto