In [1]:
import json
import sqlite3
import pandas as pd
import numpy as np
import requests
import tqdm
import re

In [26]:
# Arguments
db_path = '../Database/Aminer_Simplified-small.sqlite'
url = 'openai-url'
api_key = 'your-api-key'
headers = {
    'Content-Type': 'application/json',
    'Authorization': f'Bearer {api_key}'
}
seed_data_path = '../Annotation/train.xlsx'
question_prompt_path = './QuestionGenerationTemplate.txt'
text_to_sql_prompt_path = './Text2SqlTemplate.txt'

In [3]:
def get_database_prompt(db_path) -> str:

    stmt = ''

    conn = sqlite3.connect(db_path)
    cur = conn.cursor()

    # Fetch names of all tables
    cur.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cur.fetchall()

    # Fech create statements for all tables
    for table in tables:
        table_name = table[0]
        cur.execute(f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table_name}';")
        create_statement = cur.fetchone()[0]

        stmt += create_statement + '\n\n'

    conn.close()
    return stmt

print(get_database_prompt(db_path=db_path))

CREATE TABLE Venue(
  id TEXT, -- id
  DisplayName TEXT, -- name of the conferenece/joural
  PRIMARY KEY (id)
)

CREATE TABLE Affiliation(
  id TEXT, -- id
  DisplayName TEXT, -- name of the orgnization
  type TEXT, -- orgnization type
  url TEXT, -- link of the orgnization's homepage
  PRIMARY KEY (id)
)

CREATE TABLE Author(
  id TEXT, -- id
  name TEXT, -- name
  org TEXT, -- author's current orgnization
  position TEXT, -- position
  n_pubs INTEGER, -- number of paper publication
  n_citation INTEGER, -- number of total citation
  h_index INTEGER, -- h-index
  PRIMARY KEY (id)
)

CREATE TABLE Paper(
  id TEXT, -- id
  title TEXT, -- title
  year INTEGER, -- publication year
  n_citation INTEGER, -- number of citation
  page_start TEXT, -- start page on the publication
  page_end TEXT, -- end page on the publication
  lang TEXT, -- language
  volume TEXT, -- volume of the publicaiton
  doi TEXT, -- digital object unique identifier
  pdf TEXT, -- pdf view link of the paper
  abstract

In [4]:
train = pd.read_excel(seed_data_path)
train.head()

Unnamed: 0,question,query
0,Show the different keywords articles which has...,SELECT DISTINCT Paper_Keywords.keyword\nFROM P...
1,find researcher who published all his paper af...,SELECT Author.name\nFROM Author\nWHERE Author....
2,Where to find the pdf file of the paper 'Femto...,SELECT Paper.pdf\nFROM Paper\nWHERE Paper.titl...
3,what orgnization has most researchers once bel...,SELECT Orgnization_Researchers.affiliation_nam...
4,"Among all institutions, which one has the most...","SELECT Author_Interested_In_Algorithms.org, CO..."


In [13]:
def extract(response:str) -> str : # extract response from formatted string
    text = ''
    try :
        text = re.search(r'\{(.+?)\}', response, re.DOTALL).groups(0)[-1] 
    except : 
        return ''
    return text

def generate_new_questions(db_path, train, n_demo=30, n_ques=20) -> list :
    
    database_prompt = get_database_prompt(db_path)
    question_prompt = open(question_prompt_path, 'r').read()
    questions = []
    for _ in tqdm.tqdm(range(n_ques)):
        shots = np.random.choice(train['question'], n_demo, replace=False)
        prompt = question_prompt.replace('{SCHEMA_SLOT}', database_prompt)
        for shot in shots :
            prompt += f'{{{shot}}}\n\n'
        params = {
                    "model": "gpt-3.5-turbo-16k",
                    "messages": [{"role":"user", "content":prompt}],
                    "temperature": 1.0,
        }
        response = requests.post(url, headers=headers, data=json.dumps(params))
        generated_question = extract(response.json().get('choices')[0].get('message').get('content'))
        if questions != '' :
            print(generated_question)
            questions.append(generated_question)
    return questions

questions = generate_new_questions(db_path=db_path, train=train, n_demo=5, n_ques=10)

 10%|█         | 1/10 [00:03<00:32,  3.62s/it]

Which authors have a higher h-index than the average h-index of all authors?


 20%|██        | 2/10 [00:06<00:23,  2.98s/it]

Which papers have been published by the conference 'International Conference on Machine Learning'?


 30%|███       | 3/10 [00:08<00:19,  2.82s/it]

What are the research interests of the authors with the highest h-index?


 40%|████      | 4/10 [00:12<00:18,  3.07s/it]

Which authors have the highest number of paper publications?


 50%|█████     | 5/10 [00:15<00:15,  3.02s/it]

Which conference has the highest number of papers published in the year 2020?


 60%|██████    | 6/10 [00:20<00:15,  3.91s/it]




 70%|███████   | 7/10 [00:23<00:10,  3.36s/it]

How many papers are published in each year?


 80%|████████  | 8/10 [00:29<00:08,  4.23s/it]

Which papers are written by authors who have published more than 50 papers?


 90%|█████████ | 9/10 [00:31<00:03,  3.63s/it]

Which authors have published more than 50 papers?


100%|██████████| 10/10 [00:36<00:00,  3.67s/it]

Can you provide a list of papers that were published in conferences with their conference names and years?





In [29]:
def text2sql(db_path, train, questions, n_demo=30) -> list :
    augmented_data = []
    text_to_sql_prompt = open(text_to_sql_prompt_path, 'r').read()
    for question in tqdm.tqdm(questions):
        random_indices = np.random.choice(len(train), n_demo, replace=False)
        shots_ques = list(train['question'].iloc[random_indices])
        shots_query = list(train['query'].iloc[random_indices])

        prompt = text_to_sql_prompt.replace('{SCHEMA_SLOT}', get_database_prompt(db_path))
        slot = 'Q: {QUES}\n{SQL}\n\n'
        shots = ''.join([slot.replace('QUES', ques).replace('SQL', sql) for ques, sql in zip(shots_ques, shots_query)])
        prompt = re.sub(r'\{SLOTS\}', shots, prompt)
        prompt = prompt.replace('NATURAL_LANGUAGE_QUESTION', question)

        params = {
                    "model": "gpt-3.5-turbo-16k",
                    "messages": [{"role":"user", "content":prompt}],
                    "temperature": 0,
        }
        response = requests.post(url, headers=headers, data=json.dumps(params))
        generated_sql = extract(response.json().get('choices')[0].get('message').get('content'))
        if generated_sql != '' :
            augmented_data.append({'question': question, 'SQL': generated_sql})
            print(question)
            print(generated_sql)

    return augmented_data

augmented_data = text2sql(db_path=db_path, train=train, questions=questions, n_demo=30)
augmented_data

 10%|█         | 1/10 [00:03<00:29,  3.29s/it]

Which authors have a higher h-index than the average h-index of all authors?
SELECT Author.name
FROM Author
WHERE Author.h_index > (SELECT AVG(h_index) FROM Author);


 20%|██        | 2/10 [00:05<00:22,  2.86s/it]

Which papers have been published by the conference 'International Conference on Machine Learning'?
SELECT Paper.title
FROM Paper
JOIN Venue_Papers ON Paper.id = Venue_Papers.paper_id
JOIN Venue ON Venue_Papers.venue_id = Venue.id
WHERE Venue.DisplayName = 'International Conference on Machine Learning';


 30%|███       | 3/10 [00:08<00:18,  2.69s/it]

What are the research interests of the authors with the highest h-index?
SELECT Researcher_Interests.tag
FROM Researcher_Interests
WHERE Researcher_Interests.author_id IN (SELECT Author.id FROM Author WHERE Author.h_index = (SELECT MAX(h_index) FROM Author));


 40%|████      | 4/10 [00:10<00:15,  2.62s/it]

Which authors have the highest number of paper publications?
SELECT Author.name
FROM Author
WHERE Author.n_pubs = (SELECT MAX(n_pubs) FROM Author);


 50%|█████     | 5/10 [00:12<00:11,  2.32s/it]

Which conference has the highest number of papers published in the year 2020?
SELECT Venue.DisplayName
FROM Venue
JOIN Venue_Papers ON Venue.id = Venue_Papers.venue_id
JOIN Paper ON Venue_Papers.paper_id = Paper.id
WHERE Paper.year = 2020
GROUP BY Venue.DisplayName
ORDER BY COUNT(*) DESC
LIMIT 1;


 70%|███████   | 7/10 [00:18<00:07,  2.63s/it]

How many papers are published in each year?
SELECT year, COUNT(*) AS num_papers
FROM Paper
GROUP BY year;


 80%|████████  | 8/10 [00:20<00:05,  2.59s/it]

Which papers are written by authors who have published more than 50 papers?
SELECT Paper.title
FROM Paper
JOIN Paper_Authors ON Paper.id = Paper_Authors.paper_id
JOIN (SELECT Author.id FROM Author WHERE Author.n_pubs > 50) AS Prolific_Authors
ON Paper_Authors.author_id = Prolific_Authors.id;


 90%|█████████ | 9/10 [00:22<00:02,  2.44s/it]

Which authors have published more than 50 papers?
SELECT Author.name
FROM Author
WHERE Author.n_pubs > 50;


100%|██████████| 10/10 [00:25<00:00,  2.53s/it]

Can you provide a list of papers that were published in conferences with their conference names and years?
SELECT Paper.title, Venue.DisplayName, Paper.year
FROM Paper
JOIN Venue_Papers ON Paper.id = Venue_Papers.paper_id
JOIN Venue ON Venue_Papers.venue_id = Venue.id
WHERE Venue.type = 'conference';





[{'question': 'Which authors have a higher h-index than the average h-index of all authors?',
  'SQL': 'SELECT Author.name\nFROM Author\nWHERE Author.h_index > (SELECT AVG(h_index) FROM Author);'},
 {'question': "Which papers have been published by the conference 'International Conference on Machine Learning'?",
  'SQL': "SELECT Paper.title\nFROM Paper\nJOIN Venue_Papers ON Paper.id = Venue_Papers.paper_id\nJOIN Venue ON Venue_Papers.venue_id = Venue.id\nWHERE Venue.DisplayName = 'International Conference on Machine Learning';"},
 {'question': 'What are the research interests of the authors with the highest h-index?',
  'SQL': 'SELECT Researcher_Interests.tag\nFROM Researcher_Interests\nWHERE Researcher_Interests.author_id IN (SELECT Author.id FROM Author WHERE Author.h_index = (SELECT MAX(h_index) FROM Author));'},
 {'question': 'Which authors have the highest number of paper publications?',
  'SQL': 'SELECT Author.name\nFROM Author\nWHERE Author.n_pubs = (SELECT MAX(n_pubs) FROM Auth

In [34]:
def filter_generated_data(data, db_path) -> list :
    conn = sqlite3.connect(db_path)
    cur = conn.cursor()
    filtered = []
    for item in tqdm.tqdm(data) :
        try :
            cur.execute(item['SQL'])
            result = cur.fetchall()
            # collect non-trivial query
            #if len(result) > 0 : 
            #    filtered.append(item)
            filtered.append(item)
        except Exception as e :
            print(e)
            print('Error in SQL:', item['SQL'])
    conn.close()
    return filtered

data = filter_generated_data(data=augmented_data, db_path=db_path)
data

100%|██████████| 9/9 [00:04<00:00,  1.96it/s]

no such column: Venue.type
Error in SQL: SELECT Paper.title, Venue.DisplayName, Paper.year
FROM Paper
JOIN Venue_Papers ON Paper.id = Venue_Papers.paper_id
JOIN Venue ON Venue_Papers.venue_id = Venue.id
WHERE Venue.type = 'conference';





[{'question': 'Which authors have a higher h-index than the average h-index of all authors?',
  'SQL': 'SELECT Author.name\nFROM Author\nWHERE Author.h_index > (SELECT AVG(h_index) FROM Author);'},
 {'question': "Which papers have been published by the conference 'International Conference on Machine Learning'?",
  'SQL': "SELECT Paper.title\nFROM Paper\nJOIN Venue_Papers ON Paper.id = Venue_Papers.paper_id\nJOIN Venue ON Venue_Papers.venue_id = Venue.id\nWHERE Venue.DisplayName = 'International Conference on Machine Learning';"},
 {'question': 'What are the research interests of the authors with the highest h-index?',
  'SQL': 'SELECT Researcher_Interests.tag\nFROM Researcher_Interests\nWHERE Researcher_Interests.author_id IN (SELECT Author.id FROM Author WHERE Author.h_index = (SELECT MAX(h_index) FROM Author));'},
 {'question': 'Which authors have the highest number of paper publications?',
  'SQL': 'SELECT Author.name\nFROM Author\nWHERE Author.n_pubs = (SELECT MAX(n_pubs) FROM Auth