# Formal Semantic Parsing with Language Models

Some useful functions and some functions that you need to implement are presented in this notebook. You don't have to use them though.

In this assignment, we will evaluate the ability of language on semantic parsing task. In particular, SQL parsing. The assignment has three parts:


1.   Basic Prompt
2.   Fine-tuning
3.   Context-Free Grammar

In each parts, you will exaime the model output in terms of correctness and output well-formedness.

Part 2 and 3 can be replaced by other means such as RAG. Students are welcomed to propose their own ideas to solve this task.

References:
https://github.com/jkkummerfeld/text2sql-data/
https://github.com/epfl-dlab/transformers-CFG


In [1]:
context = '''You are an SQL database query expert, specialized in generating correct and efficient SQL query statements for complex databases, particularly for the SQLite database.
Your task is to generate SQL queries that you must follow the following requirements:
1. Follow the provided database schema; no assumptions about other tables.
2. Generate a single SQL query, strictly compatible with SQLite.
3. Use ANSI SQL syntax suitable for SQLite.
4. Output only the SQL query, prefixed with `SQL:` and ending with `;`.
5. Do not include any additional text or questions.
Below is the schema of the database, including the structure and column descriptions of 7 tables storing geographic information of the United States.
Table 1: state ( state_name, population, area, country_name, capital, density )
Table 2: city ( city_name, population, country_name, state_name )
Table 3: border_info ( state_name, border )
Table 4: highlow ( state_name, highest_elevation, lowest_point, highest_point, lowest_elevation )
Table 5: lake ( lake_name, area, country_name, state_name )
Table 6: mountain ( mountain_name, mountain_altitude, country_name, state_name )
Table 7: river ( river_name, length, country_name, traverse )
Here are some examples of natural language questions and their corresponding SQL queries:
Question: how big is texas ?
SQL: SELECT STATEalias0.AREA FROM STATE AS STATEalias0 WHERE STATEalias0.STATE_NAME = "texas" ;
Question: what is the area of california ?
SQL: SELECT STATEalias0.AREA FROM STATE AS STATEalias0 WHERE STATEalias0.STATE_NAME = "california" ;
Question: how many people live in washington ?
SQL: SELECT STATEalias0.POPULATION FROM STATE AS STATEalias0 WHERE STATEalias0.STATE_NAME = "washington" ;
Question: what state has the smallest population ?
SQL: SELECT STATEalias0.STATE_NAME FROM STATE AS STATEalias0 WHERE STATEalias0.POPULATION = ( SELECT MIN( STATEalias1.POPULATION ) FROM STATE AS STATEalias1 ) ;
Question: what state has the largest population ?
SQL: SELECT STATEalias0.STATE_NAME FROM STATE AS STATEalias0 WHERE STATEalias0.POPULATION = ( SELECT MAX( STATEalias1.POPULATION ) FROM STATE AS STATEalias1 ) ;
Question: which rivers run through the state with the largest city in the us ?
SQL: SELECT RIVERalias0.RIVER_NAME FROM RIVER AS RIVERalias0 WHERE RIVERalias0.TRAVERSE IN ( SELECT CITYalias0.STATE_NAME FROM CITY AS CITYalias0 WHERE CITYalias0.POPULATION = ( SELECT MAX( CITYalias1.POPULATION ) FROM CITY AS CITYalias1 ) ) ;
Question: what is the area of the state with the capital albany ?
SQL: SELECT STATEalias0.AREA FROM STATE AS STATEalias0 WHERE STATEalias0.CAPITAL = "albany" ;
Question: which states have points higher than the highest point in colorado ?
SQL: SELECT HIGHLOWalias0.STATE_NAME FROM HIGHLOW AS HIGHLOWalias0 WHERE HIGHLOWalias0.HIGHEST_ELEVATION > ( SELECT HIGHLOWalias1.HIGHEST_ELEVATION FROM HIGHLOW AS HIGHLOWalias1 WHERE HIGHLOWalias1.STATE_NAME = "colorado" ) ;
Question: what are the highest points of states surrounding mississippi ?
SQL: SELECT HIGHLOWalias0.HIGHEST_POINT FROM HIGHLOW AS HIGHLOWalias0 WHERE HIGHLOWalias0.STATE_NAME IN ( SELECT BORDER_INFOalias0.BORDER FROM BORDER_INFO AS BORDER_INFOalias0 WHERE BORDER_INFOalias0.STATE_NAME = "mississippi" ) ;
Question: which states do colorado river flow through ?
SQL: SELECT RIVERalias0.TRAVERSE FROM RIVER AS RIVERalias0 WHERE RIVERalias0.RIVER_NAME = "colorado" ;
Now, please generate the SQL query for the following question:'''

making_prompt = lambda x: context.replace('\n', ' ') + ' ' + x + ' ?'

In [2]:
import json


def extract_sentence_fields(sentence):
    text = sentence['text']
    variables = sentence['variables']
    split = sentence['question-split']
    return text, variables, split


def insert_variables(sql, sql_variables, sent, sent_variables):
    for info in sql_variables:
        name = info['name']
        value = info['example']
        if name in sent_variables and sent_variables[name] != '':
            value = sent_variables[name]
        sent = value.join(sent.split(name))
        qvalue = '{}'.format(value)
        sql = qvalue.join(sql.split(name))
    return sql, sent


def build_question_split(jsons, making_prompt=making_prompt, keep_variables=False):
    datasets = {}
    for json_dict in jsons:
        for query in [json_dict['sql'][0]]:
            sql_vars = json_dict['variables']
            for sentence in json_dict['sentences']:
                text, variables, split = extract_sentence_fields(sentence)
                if split == 'exclude':
                    continue
                if keep_variables:
                    sql = query
                    question = text
                else:
                    sql, question = insert_variables(query, sql_vars, text, variables)
                if not split in datasets:
                    datasets[split] = []
                example = {'text': making_prompt(question) + ' ' + sql, 'question': question, 'sql': sql}
                datasets[split].append(example)
    return datasets


with open('geography.json', 'r') as file:
    geography_data = json.load(file)
    geography_datasets = build_question_split(geography_data)

In [3]:
import sqlite3

# .sqlite file path
file_path = 'geography-db.added-in-2020.sqlite'


def load_sqlite_file(file_path):
    try:
        # Establish a connection to the database
        conn = sqlite3.connect(file_path)
        print(f'Loaded database from {file_path}')
        return conn
    except sqlite3.Error as e:
        print(f'Error loading database: {e}')
        return None


conn = load_sqlite_file(file_path)
cursor = conn.cursor()

# Verify sql statements and delete statements that cannot be executed
for split in geography_datasets:
    for example in geography_datasets[split][:]:
        try:
            cursor.execute(example['sql'])
        except sqlite3.Error as e:
            print(split)  # Shows which part has been removed
            geography_datasets[split].remove(example)

conn.close()

Loaded database from geography-db.added-in-2020.sqlite
dev
test
test
train
train


In [4]:
def get_all_results(dataset, cursor):
    gold_answers = cursor.execute(dataset['sql']).fetchall()
    dataset['answers'] = [list(row) for row in gold_answers]
    dataset['generated_sql'] = ''
    dataset['generated_answers'] = [['']]


conn = load_sqlite_file(file_path)
cursor = conn.cursor()
for split in geography_datasets:
    for dataset in geography_datasets[split]:
        get_all_results(dataset, cursor)

# Don't forget to close the connection when you're done
conn.close()

Loaded database from geography-db.added-in-2020.sqlite


In [5]:
geography_datasets['dev'][0]

{'text': 'You are an SQL database query expert, specialized in generating correct and efficient SQL query statements for complex databases, particularly for the SQLite database. Your task is to generate SQL queries that you must follow the following requirements: 1. Follow the provided database schema; no assumptions about other tables. 2. Generate a single SQL query, strictly compatible with SQLite. 3. Use ANSI SQL syntax suitable for SQLite. 4. Output only the SQL query, prefixed with `SQL:` and ending with `;`. 5. Do not include any additional text or questions. Below is the schema of the database, including the structure and column descriptions of 7 tables storing geographic information of the United States. Table 1: state ( state_name, population, area, country_name, capital, density ) Table 2: city ( city_name, population, country_name, state_name ) Table 3: border_info ( state_name, border ) Table 4: highlow ( state_name, highest_elevation, lowest_point, highest_point, lowest_el

In [6]:
def compare_results(dataset):
    tp, fp, fn, tn = 0, 0, 0, 0
    exact_match = 0

    for generated, actual in zip(dataset['generated_answers'], dataset['answers']):

        if generated == [['SQL Error']] or generated == [['']]:
            fn += len(actual)
            continue

        generated_set = set(tuple(row) for row in generated)
        actual_set = set(tuple(row) for row in actual)

        if generated_set == actual_set:
            exact_match += 1

        tp += len(generated_set & actual_set)

        fp += len(generated_set - actual_set)

        fn += len(actual_set - generated_set)

    return tp, tn, fp, fn, exact_match, exact_match / len(dataset)

In [7]:
def sql_syntax_correct_rate(dataset):
    wrong = 0

    for syntax in dataset['generated_sql']:

        if syntax == 'SQL syntax error.' or syntax == 'The SQL cannot be executed.':
            wrong += 1

    return (len(dataset) - wrong) / len(dataset)

In [8]:
def calculate_metrics(tp, fp, fn):

    precision = tp / (tp + fp) if (tp + fp) > 0 else 0

    recall = tp / (tp + fn) if (tp + fn) > 0 else 0

    f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

    return precision, recall, f1

In [9]:
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cuda


In [10]:
# need to fill, should return micro/marco precision, recall, f1 and exact_match, grammatical_ratio
from transformers import GenerationConfig
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor


def evaluate(dataset, model, tokenizer, making_prompt=making_prompt, cfg=True, cfg_enbf='sql_query.ebnf'):
    model.eval()

    generation_config = GenerationConfig(
        max_new_tokens=128,  # Maximum number of tokens to generate, excluding input tokens
    )

    temp_dataset = []

    with torch.no_grad():
        for index, example in enumerate(dataset):
            question = example['question']

            # Generate SQL from model
            prompt = making_prompt(question)
            inputs = tokenizer(prompt, return_tensors='pt').to(device)
            if cfg:
                # Load JSON grammar
                with open(cfg_enbf, 'r') as file:
                    grammar_str = file.read()
                grammar = IncrementalGrammarConstraint(grammar_str, 'query', tokenizer)
                grammar_processor = GrammarConstrainedLogitsProcessor(grammar)
                grammar_processors = [grammar_processor]
            else:
                grammar_processors = None

            output = model.generate(
                **inputs,  # Input tokens and other configurations for the model
                generation_config=generation_config,  # Use predefined generation configuration
                logits_processor=grammar_processors,  # Constrain generation with grammar rules
            )

            # Decode output
            generated_sql = tokenizer.batch_decode(output, skip_special_tokens=True)[0][len(prompt):]

            try:
                conn = sqlite3.connect(file_path)

                start_index = generated_sql.find('SQL:')
                if start_index != -1:
                    start_index = start_index + len('SQL:')
                else:
                    start_index = generated_sql.find('SELECT')
                    if start_index == -1:
                        print(generated_sql)
                        raise SyntaxError

                end_index = generated_sql.find(';', start_index)
                if end_index == -1:
                    end_index = generated_sql.find('\n', start_index)

                generated_sql = generated_sql[start_index:end_index].strip() + ' ;'

                print(generated_sql)

                generated_answer = conn.cursor().execute(generated_sql).fetchall()
                if not generated_answer:
                    raise ValueError
                example['generated_answers'] = [list(row) for row in generated_answer]

            except sqlite3.Error:
                generated_sql = 'The SQL cannot be executed.'
                example['generated_answers'] = [['SQL Error']]
            except SyntaxError:
                generated_sql = 'SQL syntax error.'
                example['generated_answers'] = [['SQL Error']]
            except ValueError:
                generated_sql = 'The search result does not exist.'
                example['generated_answers'] = [['']]
            finally:
                example['generated_sql'] = generated_sql
                if conn:
                    conn.close()

            temp_dataset.append(example)
            print(f'Generation {index + 1}: {generated_sql}\n')

    return temp_dataset

In [11]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import Dataset, DatasetDict, Features, Value, Sequence
from trl import SFTTrainer, SFTConfig

features = Features({
    'text': Value('string'),
    'question': Value('string'),
    'sql': Value('string'),
    'answers': Sequence(Sequence(Value('string'))),
    'generated_sql': Value('string'),
    'generated_answers': Sequence(Sequence(Value('string'))),
})

# Load your custom dataset
train_data = Dataset.from_list(geography_datasets['train'], features=features)
dev_data = Dataset.from_list(geography_datasets['dev'], features=features)
test_data = Dataset.from_list(geography_datasets['test'], features=features)

dataset = DatasetDict({'train': train_data, 'dev': dev_data, 'test': test_data})




In [12]:
# Load the model and tokenizer
model_name = 'HuggingFaceTB/SmolLM2-360M-Instruct'  # or any other model
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Add a padding token to the tokenizer
tokenizer.pad_token = tokenizer.eos_token  # Use eos_token as pad_token

In [13]:
print(dataset['dev'][0])

{'text': 'You are an SQL database query expert, specialized in generating correct and efficient SQL query statements for complex databases, particularly for the SQLite database. Your task is to generate SQL queries that you must follow the following requirements: 1. Follow the provided database schema; no assumptions about other tables. 2. Generate a single SQL query, strictly compatible with SQLite. 3. Use ANSI SQL syntax suitable for SQLite. 4. Output only the SQL query, prefixed with `SQL:` and ending with `;`. 5. Do not include any additional text or questions. Below is the schema of the database, including the structure and column descriptions of 7 tables storing geographic information of the United States. Table 1: state ( state_name, population, area, country_name, capital, density ) Table 2: city ( city_name, population, country_name, state_name ) Table 3: border_info ( state_name, border ) Table 4: highlow ( state_name, highest_elevation, lowest_point, highest_point, lowest_el

In [14]:
temp_dataset_dev = evaluate(dataset['dev'], model, tokenizer, cfg=False)

SELECT CITYalias0.NAME FROM CITY AS CITYalias0 WHERE CITYalias0.CITY_NAME = "albuquerque" ;
Generation 1: The SQL cannot be executed.

SELECT STATEalias0.POPULATION FROM STATE AS STATEalias0 WHERE STATEalias0.POPULATION = ( SELECT MAX( STATEalias1.POPULATION ) FROM STATE AS STATEalias1 WHERE STATEalias1.STATE_NAME = "texas" ) ;
Generation 2: SELECT STATEalias0.POPULATION FROM STATE AS STATEalias0 WHERE STATEalias0.POPULATION = ( SELECT MAX( STATEalias1.POPULATION ) FROM STATE AS STATEalias1 WHERE STATEalias1.STATE_NAME = "texas" ) ;

SELECT CITYalias0.NAME FROM CITY AS CITYalias0 WHERE CITYalias0.CITY_NAME = "mo" ;
Generation 3: The SQL cannot be executed.

SELECT RIVERalias0.RIVER_NAME FROM RIVER AS RIVERalias0 WHERE RIVERalias0.TRAVERSE IN ( SELECT CITYalias0.STATE_NAME FROM CITY AS CITYalias0 WHERE CITYalias0.POPULATION = ( SELECT MAX( CITYalias1.POPULATION ) FROM CITY AS CITYalias1 WHERE CITYalias1.STATE_NAME = "washington" ) ) ;
Generation 4: SELECT RIVERalias0.RIVER_NAME FROM RIV

In [15]:
dev_data = Dataset.from_list(temp_dataset_dev, features=features)
dataset = DatasetDict({'train': train_data, 'dev': dev_data, 'test': test_data})
tp, tn, fp, fn, exact_match, exact_match_rate = compare_results(dataset['dev'])
syntax_rate = sql_syntax_correct_rate(dataset['dev'])
precision, recall, f1 = calculate_metrics(tp, fp, fn)

print(f'tp: {tp}, tn: {tn}, fp: {fp}, fn: {fn}, exact match: {exact_match}, exact match rate: {exact_match_rate}')
print(f'sql syntax rate: {syntax_rate}')
print(f'precision: {precision}, recall: {recall}, f1 score: {f1}')

tp: 9, tn: 0, fp: 37, fn: 171, exact match: 8, exact match rate: 0.16666666666666666
sql syntax rate: 0.4375
precision: 0.1956521739130435, recall: 0.05, f1 score: 0.07964601769911504


In [19]:
temp_dataset_dev = evaluate(dataset['dev'], model, tokenizer)

SELECT	CITY.CITY_NAME	FROM	CITY	WHERE	CITY.CITY_NAME	=	"chances" ;
Generation 1: The search result does not exist.

SELECT	STATE.STATE_NAME	FROM	STATE	WHERE	STATE.STATE_NAME	=	"texas"	AND	STATE.POPULATION	=	(SELECT	MAX(	STATE.POPULATION)	FROM	STATE	WHERE	STATE.STATE_NAME	=	"texas"	)	AND	STATE.CAPITAL	=	"alexandria"	AND	STATE.AREA	=	(SELECT	STATE.AREA	FROM	STATE	WHERE	STATE.STATE_NAME	=	"te ;
Generation 2: The SQL cannot be executed.

SELECT	CITY.STATE_NAME	FROM	CITY	WHERE	CITY.POPULATION	>	(SELECT	MAX(CITY.POPULATION)	FROM	CITY	WHERE	CITY.STATE_NAME	=	"missouri") ;
Generation 3: SELECT	CITY.STATE_NAME	FROM	CITY	WHERE	CITY.POPULATION	>	(SELECT	MAX(CITY.POPULATION)	FROM	CITY	WHERE	CITY.STATE_NAME	=	"missouri") ;

SELECT	RIVER.RIVER_NAME	FROM	RIVER	WHERE	RIVER.TRAVERSE	=	"alaska" ;
Generation 4: The search result does not exist.

SELECT	STATE.STATE_NAME	,	STATE.AREA	FROM	STATE	WHERE	STATE.STATE_NAME	=	"texas" ;
Generation 5: SELECT	STATE.STATE_NAME	,	STATE.AREA	FROM	STATE	WHERE	STATE.STAT

In [20]:
dev_data = Dataset.from_list(temp_dataset_dev, features=features)
dataset = DatasetDict({'train': train_data, 'dev': dev_data, 'test': test_data})
tp, tn, fp, fn, exact_match, exact_match_rate = compare_results(dataset['dev'])
syntax_rate = sql_syntax_correct_rate(dataset['dev'])
precision, recall, f1 = calculate_metrics(tp, fp, fn)

print(f'tp: {tp}, tn: {tn}, fp: {fp}, fn: {fn}, exact match: {exact_match}, exact match rate: {exact_match_rate}')
print(f'sql syntax rate: {syntax_rate}')
print(f'precision: {precision}, recall: {recall}, f1 score: {f1}')

tp: 14, tn: 0, fp: 292, fn: 166, exact match: 4, exact match rate: 0.08333333333333333
sql syntax rate: 0.7916666666666666
precision: 0.0457516339869281, recall: 0.07777777777777778, f1 score: 0.05761316872427983


In [31]:
train_dataset = Dataset.from_list([{'text': making_prompt(example['question']) + example['sql']} for example in dataset['train']])
test_dataset = Dataset.from_list([{'text': making_prompt(example['question']) + example['sql']} for example in dataset['test']])
dev_dataset = Dataset.from_list([{'text': making_prompt(example['question']) + example['sql']} for example in dataset['dev']])

In [22]:
# try figure out how to set up the fine-tuning config
sft_config = SFTConfig(
    num_train_epochs=3,
    max_seq_length=1200,
    dataset_batch_size=2,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=1,
    fp16=True,
    logging_steps=100,
    logging_dir='./smo_fine_tuned_model',
    output_dir='./smo_fine_tuned_model',
)

# Create the SFT Trainer
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    args=sft_config,
)

# Fine-tune the model
trainer.train()

# Save the fine-tuned model
trainer.save_model('./smo_final_fine_tuned_model')

Map:   0%|          | 0/547 [00:00<?, ? examples/s]

Map:   0%|          | 0/277 [00:00<?, ? examples/s]

Step,Training Loss
100,0.1121
200,0.0198
300,0.0187
400,0.0186
500,0.0159
600,0.0155
700,0.0123
800,0.0138
900,0.0119
1000,0.0118


In [23]:
temp_dataset_dev = evaluate(dataset['dev'], model, tokenizer, cfg=False)

SELECT CITYalias0.CITY_NAME FROM CITY AS CITYalias0 WHERE CITYalias0.POPULATION = ( SELECT MAX( CITYalias1.POPULATION ) FROM CITY AS CITYalias1 WHERE CITYalias1.STATE_NAME = "arizona" ) AND CITYalias0.STATE_NAME = " Arizona" ;
Generation 1: The search result does not exist.

SELECT CITYalias0.CITY_NAME FROM CITY AS CITYalias0 WHERE CITYalias0.POPULATION = ( SELECT MAX( CITYalias1.POPULATION ) FROM CITY AS CITYalias1 WHERE CITYalias1.STATE_NAME = "texas" ) AND CITYalias0.STATE_NAME = "texas" ;
Generation 2: SELECT CITYalias0.CITY_NAME FROM CITY AS CITYalias0 WHERE CITYalias0.POPULATION = ( SELECT MAX( CITYalias1.POPULATION ) FROM CITY AS CITYalias1 WHERE CITYalias1.STATE_NAME = "texas" ) AND CITYalias0.STATE_NAME = "texas" ;

SELECT CITYalias0.CITY_NAME FROM CITY AS CITYalias0 WHERE CITYalias0.POPULATION = ( SELECT MAX( CITYalias1.POPULATION ) FROM CITY AS CITYalias1 WHERE CITYalias1.STATE_NAME = "missouri" ) AND CITYalias0.STATE_NAME = "missouri" ;
Generation 3: SELECT CITYalias0.CITY_

In [24]:
dev_data = Dataset.from_list(temp_dataset_dev, features=features)
dataset = DatasetDict({'train': train_data, 'dev': dev_data, 'test': test_data})
tp, tn, fp, fn, exact_match, exact_match_rate = compare_results(dataset['dev'])
syntax_rate = sql_syntax_correct_rate(dataset['dev'])
precision, recall, f1 = calculate_metrics(tp, fp, fn)

print(f'tp: {tp}, tn: {tn}, fp: {fp}, fn: {fn}, exact match: {exact_match}, exact match rate: {exact_match_rate}')
print(f'sql syntax rate: {syntax_rate}')
print(f'precision: {precision}, recall: {recall}, f1 score: {f1}')

tp: 142, tn: 0, fp: 16, fn: 30, exact match: 41, exact match rate: 0.8541666666666666
sql syntax rate: 0.9375
precision: 0.8987341772151899, recall: 0.8255813953488372, f1 score: 0.8606060606060606


In [25]:
temp_dataset_dev = evaluate(dataset['dev'], model, tokenizer)

SELECT	CITY.CITY_NAME	FROM	CITY		WHERE	CITY.POPULATION	=	(SELECT	MAX(CITY.POPULATION)	FROM	CITY		WHERE	CITY.STATE_NAME	=	"arizona"	) ;
Generation 1: SELECT	CITY.CITY_NAME	FROM	CITY		WHERE	CITY.POPULATION	=	(SELECT	MAX(CITY.POPULATION)	FROM	CITY		WHERE	CITY.STATE_NAME	=	"arizona"	) ;

SELECT	CITY.CITY_NAME	FROM	CITY		WHERE	CITY.POPULATION	=	(	SELECT	MAX(	CITY.POPULATION	)	FROM	CITY		WHERE	CITY.STATE_NAME	=	"texas"	) ;
Generation 2: SELECT	CITY.CITY_NAME	FROM	CITY		WHERE	CITY.POPULATION	=	(	SELECT	MAX(	CITY.POPULATION	)	FROM	CITY		WHERE	CITY.STATE_NAME	=	"texas"	) ;

SELECT	CITY.CITY_NAME	FROM	CITY		WHERE	CITY.POPULATION	=	(SELECT	MAX(CITY.POPULATION)	FROM	CITY		WHERE	CITY.STATE_NAME	=	"missouri"	) ;
Generation 3: SELECT	CITY.CITY_NAME	FROM	CITY		WHERE	CITY.POPULATION	=	(SELECT	MAX(CITY.POPULATION)	FROM	CITY		WHERE	CITY.STATE_NAME	=	"missouri"	) ;

SELECT	RIVER.RIVER_NAME	FROM	RIVER	WHERE	RIVER.TRAVERSE	=	(	SELECT	STATE.STATE_NAME	FROM	STATE	WHERE	STATE.POPULATION	=	150000	) ;
Generation

In [26]:
dev_data = Dataset.from_list(temp_dataset_dev, features=features)
dataset = DatasetDict({'train': train_data, 'dev': dev_data, 'test': test_data})
tp, tn, fp, fn, exact_match, exact_match_rate = compare_results(dataset['dev'])
syntax_rate = sql_syntax_correct_rate(dataset['dev'])
precision, recall, f1 = calculate_metrics(tp, fp, fn)

print(f'tp: {tp}, tn: {tn}, fp: {fp}, fn: {fn}, exact match: {exact_match}, exact match rate: {exact_match_rate}')
print(f'sql syntax rate: {syntax_rate}')
print(f'precision: {precision}, recall: {recall}, f1 score: {f1}')

tp: 113, tn: 0, fp: 376, fn: 66, exact match: 27, exact match rate: 0.5625
sql syntax rate: 0.9791666666666666
precision: 0.2310838445807771, recall: 0.6312849162011173, f1 score: 0.3383233532934131


In [12]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = 'Qwen/Qwen2.5-0.5B-Instruct'
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name, legacy=False)

# Add a padding token to the tokenizer
tokenizer.pad_token = tokenizer.eos_token  # Use eos_token as pad_token

In [16]:
temp_dataset_dev = evaluate(dataset['dev'], model, tokenizer)

SELECT CITYalias0.CITY_NAME FROM CITY AS CITYalias0 WHERE CITYalias0.STATE_NAME = "arizona" ORDER BY CITYalias0.POPULATION DESC LIMIT 1 ;
Generation 1: SELECT CITYalias0.CITY_NAME FROM CITY AS CITYalias0 WHERE CITYalias0.STATE_NAME = "arizona" ORDER BY CITYalias0.POPULATION DESC LIMIT 1 ;

SELECT T1.city_name FROM city AS T1 INNER JOIN state AS T2 ON T1.state_name = T2.state_name WHERE T2.state_name = "texas" ORDER BY T1.population DESC LIMIT 1 ;
Generation 2: SELECT T1.city_name FROM city AS T1 INNER JOIN state AS T2 ON T1.state_name = T2.state_name WHERE T2.state_name = "texas" ORDER BY T1.population DESC LIMIT 1 ;

SELECT CITYalias0.CITY_NAME FROM CITY AS CITYalias0 WHERE CITYalias0.STATE_NAME = "missouri" ORDER BY CITYalias0.POPULATION DESC LIMIT 1 ;
Generation 3: SELECT CITYalias0.CITY_NAME FROM CITY AS CITYalias0 WHERE CITYalias0.STATE_NAME = "missouri" ORDER BY CITYalias0.POPULATION DESC LIMIT 1 ;

SELECT RIVERalias0.RIVER_NAME FROM RIVER AS RIVERalias0 WHERE RIVERalias0.TRAVERS

In [27]:
dev_data = Dataset.from_list(temp_dataset_dev, features=features)
dataset = DatasetDict({'train': train_data, 'dev': dev_data, 'test': test_data})
tp, tn, fp, fn, exact_match, exact_match_rate = compare_results(dataset['dev'])
syntax_rate = sql_syntax_correct_rate(dataset['dev'])
precision, recall, f1 = calculate_metrics(tp, fp, fn)

print(f'tp: {tp}, tn: {tn}, fp: {fp}, fn: {fn}, exact match: {exact_match}, exact match rate: {exact_match_rate}')
print(f'sql syntax rate: {syntax_rate}')
print(f'precision: {precision}, recall: {recall}, f1 score: {f1}')

tp: 0, tn: 0, fp: 0, fn: 180, exact match: 0, exact match rate: 0.0
sql syntax rate: 0.0
precision: 0, recall: 0.0, f1 score: 0


In [19]:
# try figure out how to set up the fine-tuning config
sft_config = SFTConfig(
    num_train_epochs=3,
    max_seq_length=1024,
    dataset_batch_size=2,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=1,
    fp16=True,
    logging_steps=100,
    logging_dir='./qwen-05_fine_tuned_model',
    output_dir='./qwen-05_fine_tuned_model',
)

# Create the SFT Trainer
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    args=sft_config,
)

# Fine-tune the model
trainer.train()

# Save the fine-tuned model
trainer.save_model('./qwen-05_fine_tuned_model')

Map:   0%|          | 0/547 [00:00<?, ? examples/s]

Map:   0%|          | 0/277 [00:00<?, ? examples/s]

Step,Training Loss
100,0.094
200,0.0293
300,0.0281
400,0.028
500,0.0217
600,0.0218
700,0.0162
800,0.0186
900,0.0153
1000,0.015


In [20]:
temp_dataset_dev = evaluate(dataset['dev'], model, tokenizer)

SELECT CITYalias0.CITY_NAME FROM CITY AS CITYalias0 WHERE CITYalias0.POPULATION = ( SELECT MAX( CITYalias1.POPULATION ) FROM CITY AS CITYalias1 WHERE CITYalias1.STATE_NAME = "arizona" ) AND CITYalias0.STATE_NAME = "arizona" ;
Generation 1: SELECT CITYalias0.CITY_NAME FROM CITY AS CITYalias0 WHERE CITYalias0.POPULATION = ( SELECT MAX( CITYalias1.POPULATION ) FROM CITY AS CITYalias1 WHERE CITYalias1.STATE_NAME = "arizona" ) AND CITYalias0.STATE_NAME = "arizona" ;

SELECT CITYalias0.CITY_NAME FROM CITY AS CITYalias0 WHERE CITYalias0.POPULATION = ( SELECT MAX( CITYalias1.POPULATION ) FROM CITY AS CITYalias1 WHERE CITYalias1.STATE_NAME = "texas" ) AND CITYalias0.STATE_NAME = "texas" ;
Generation 2: SELECT CITYalias0.CITY_NAME FROM CITY AS CITYalias0 WHERE CITYalias0.POPULATION = ( SELECT MAX( CITYalias1.POPULATION ) FROM CITY AS CITYalias1 WHERE CITYalias1.STATE_NAME = "texas" ) AND CITYalias0.STATE_NAME = "texas" ;

SELECT CITYalias0.CITY_NAME FROM CITY AS CITYalias0 WHERE CITYalias0.POPUL

In [21]:
dev_data = Dataset.from_list(temp_dataset_dev, features=features)
dataset = DatasetDict({'train': train_data, 'dev': dev_data, 'test': test_data})
tp, tn, fp, fn, exact_match, exact_match_rate = compare_results(dataset['dev'])
syntax_rate = sql_syntax_correct_rate(dataset['dev'])
precision, recall, f1 = calculate_metrics(tp, fp, fn)

print(f'tp: {tp}, tn: {tn}, fp: {fp}, fn: {fn}, exact match: {exact_match}, exact match rate: {exact_match_rate}')
print(f'sql syntax rate: {syntax_rate}')
print(f'precision: {precision}, recall: {recall}, f1 score: {f1}')

tp: 145, tn: 0, fp: 2, fn: 27, exact match: 43, exact match rate: 0.8958333333333334
sql syntax rate: 0.9375
precision: 0.9863945578231292, recall: 0.8430232558139535, f1 score: 0.9090909090909091
