<a href="https://colab.research.google.com/github/YichengShen/cis5220-project/blob/main/evaluations/eval_modified_gpt3_5_turbo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Evaluation of predictions made by modified gpt-3.5-turbo model

To run evaluation, we need 2 files, `./labels.txt` and `./preds.txt`. To see their format, refer to https://github.com/taoyds/spider/tree/master/evaluation_examples.

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os
import shutil
import subprocess
import nltk

In [3]:
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [4]:
# Create data folder if not exist
!mkdir -p data

# Change this path to where you store spider.zip in your Drive
dataset_zip_path_in_drive = "/content/drive/Shareddrives/CIS 522/spider.zip"
dataset_zip_path_in_runtime = "/content/data/spider.zip"

shutil.copy(dataset_zip_path_in_drive, dataset_zip_path_in_runtime)

!unzip -q -o /content/data/spider.zip -d /content/data/

Copy GPT predictions to Colab

In [7]:
shutil.copy("/content/drive/Shareddrives/CIS 522/GPT_pred/dev_pred.sql", "/content")

'/content/dev_pred.sql'

Convert preds file type

In [8]:
# Convert .sql to .txt
input_file = "dev_pred.sql"
output_file = "preds.txt"

with open(input_file, "r") as infile, open(output_file, "w") as outfile:
    for line in infile:
        outfile.write(line)

Clean up ERROR rows in pred.txt

In [11]:
# labels.txt is uploaded to Colab

input_preds = "preds.txt"
input_labels = "labels.txt"

output_preds = "filtered_preds.txt"
output_labels = "filtered_labels.txt"

with open(input_preds, "r") as preds_file, open(input_labels, "r") as labels_file, \
     open(output_preds, "w") as filtered_preds, open(output_labels, "w") as filtered_labels:

    for pred_line, label_line in zip(preds_file, labels_file):
        if "ERROR" not in pred_line:
            filtered_preds.write(pred_line)
            filtered_labels.write(label_line)

Evaluation

In [5]:
scripts_path_in_drive = "/content/drive/Shareddrives/CIS 522/scripts"
scripts_path_in_runtime = "/content/scripts"

# Overrides previous scripts folder
if os.path.exists(scripts_path_in_runtime):
    shutil.rmtree(scripts_path_in_runtime)
shutil.copytree(scripts_path_in_drive, scripts_path_in_runtime)

'/content/scripts'

In [6]:
def evaluate(preds_file, labels_file, evaluation_type="all", 
             database_dir="./data/spider/database", 
             table_file="./data/spider/tables.json",
             verbose="False"):
    """
    Runs the evaluation script for the Spider dataset using the provided labels and predictions files.
    It prints the evaluation results to the console and returns the subprocess result object.

    Args:
        preds_file (str): Path to the predictions file. In this file, each line is `a ground-truth SQL \t db_id`.
        labels_file (str): Path to the labels (gold) file. In this file, each line is a predicted SQL.
        evaluation_type (str): Evaluation type, can be 'all', 'exec', or 'match'.
        database_dir (str): Path to the directory containing the Spider dataset's database files.
        table_file (str): Path to the tables.json file from the Spider dataset.
        verbose (str): Flag to trun on or off printing details.

    Returns:
        result (subprocess.CompletedProcess): A CompletedProcess instance representing the evaluation subprocess.
                                              It contains attributes like 'stdout' and 'stderr' to access the output
                                              and error messages respectively.
    """

    cmd = [
        "python3", "scripts/evaluation.py",
        "--gold", labels_file,
        "--pred", preds_file,
        "--etype", evaluation_type,
        "--db", database_dir,
        "--table", table_file,
        "--verbose", verbose
    ]

    result = subprocess.run(cmd, capture_output=True, text=True)

    print(result.stdout)

    return result

### Evaluations containing ERROR rows in preds.txt

labels.txt is uploaded to Colab

In [10]:
# evaluation_type="all" or "exec" might explode RAM, be careful
evaluation_eval = evaluate(preds_file="preds.txt", 
                          labels_file="labels.txt", 
                          evaluation_type="all", 
                          database_dir="./data/spider/database", 
                          table_file="./data/spider/tables.json",
                          verbose="False")

                     easy                 medium               hard                 extra                all                 
count                248                  446                  174                  166                  1034                
execution            0.690                0.509                0.299                0.187                0.465               

exact match          0.661                0.413                0.230                0.066                0.386               

---------------------PARTIAL MATCHING ACCURACY----------------------
select               0.874                0.887                0.884                0.831                0.877               
select(no AGG)       0.888                0.887                0.884                0.845                0.883               
where                0.863                0.730                0.611                0.393                0.691               
where(no OP)         0.863                0.730

### Evaluations NOT containing ERROR rows in preds.txt

In [12]:
# evaluation_type="all" or "exec" might explode RAM, be careful
evaluation_eval = evaluate(preds_file="filtered_preds.txt", 
                          labels_file="filtered_labels.txt", 
                          evaluation_type="all", 
                          database_dir="./data/spider/database", 
                          table_file="./data/spider/tables.json",
                          verbose="False")

                     easy                 medium               hard                 extra                all                 
count                238                  412                  165                  152                  967                 
execution            0.718                0.551                0.315                0.204                0.497               

exact match          0.689                0.447                0.242                0.072                0.413               

---------------------PARTIAL MATCHING ACCURACY----------------------
select               0.874                0.887                0.884                0.831                0.877               
select(no AGG)       0.888                0.887                0.884                0.845                0.883               
where                0.863                0.730                0.611                0.393                0.691               
where(no OP)         0.863                0.730