In [203]:
import json
import re
import sqlparse
from sqlparse.sql import TokenList
from sqlparse.tokens import Whitespace, Keyword, Operator, Punctuation, Name, Literal

# Load the JSON files
with open('train.json') as f:
	train_data = json.load(f)

with open('train_tables.json') as f:
	train_tables = json.load(f)

In [204]:
# List of SQL keywords
keywords_main_body = {'SELECT', 'FROM', 'WHERE', 'AND', 'OR', 'NOT', 'IN', 'EXISTS', 'IS', 'NULL', 'IIF', 'CASE', 'WHEN'}
keywords_join = {'INNER JOIN', 'LEFT JOIN', 'ON', 'AS'}
keywords_clause = {'BETWEEN', 'LIKE', 'LIMIT', 'ORDER BY', 'ASC', 'DESC', 'GROUP BY', 'HAVING', 'UNION', 'ALL', 'EXCEPT', 'PARTITION BY', 'OVER'}
keywords_aggregation = {'AVG', 'COUNT', 'MAX', 'MIN', 'ROUND', 'SUM'}
keywords_scalar = {'ABS', 'LENGTH', 'STRFTIME', 'JULIADAY', 'NOW', 'CAST', 'SUBSTR', 'INSTR'}
keywords_comparison = {'=', '>', '<', '>=', '<=', '!=', '<>'}
keywords_computing = {'-', '+', '*', '/'}

all_keywords = keywords_main_body | keywords_join | keywords_clause | keywords_aggregation | keywords_scalar | keywords_comparison | keywords_computing

def get_id(entity_type):
	if entity_type == 'Alias':
		return '<extra_id_0>'
	elif entity_type == 'Table':
		return '<extra_id_1>'
	elif entity_type == 'Column':
		return '<extra_id_2>'
	elif entity_type == 'Value':
		return '<extra_id_3>'

### Process Token

In [205]:
def extract_entities(sql_id,sql, verbose=False):
    # Parse the SQL query and take the first statement
    parsed = sqlparse.parse(sql)[0]
    entities = []  # List to store identified entities
    processed_tokens = []  # List to store processed tokens for reconstructed SQL

    # Helper function to process each token in the parsed SQL
    def process_token(token):
        nonlocal processed_tokens
        if verbose: print(f"Processing token: '{token}' ({token.ttype})")

        # Processing based on token type
        if isinstance(token, TokenList):
            for sub_token in token.tokens:
                process_token(sub_token)
        elif token.ttype in (Whitespace, Keyword, Operator, Punctuation):
            processed_tokens.append(token.value)
        elif token.value.upper() in all_keywords:
            if verbose: print(f"        {token} is a keyword")
            processed_tokens.append(token.value)
        elif token.ttype in (Name):
            token_type = get_type(sql_id, token.value)
            if token_type != 'None': append_entity(token.value, token_type)
            if verbose: print(f"        {token} is {token_type}")
            if token.ttype == Name.Builtin:
                processed_tokens.append(token.value)
        elif token.ttype in Literal:
            if verbose: print(f"        {token} is a Value")
            append_entity(token.value, "Value")

    # Append entity to the list with its type and id
    def append_entity(value, type):
        entity_id = get_id(type)
        entities.append((value, type, entity_id))
        processed_tokens.append(entity_id)
        if verbose: print(f"        Appending entity: {value} as {type} with id {entity_id}")

    def get_type(sql_id, value):
        alias_pattern = r"T\d+$" # match "T" followed by any number of digits
        if re.match(alias_pattern, value): return "Alias"
        else:
            for db in train_tables:
                if db["db_id"] == sql_id:
                    if value in db["table_names_original"]:
                        return 'Table'
                    elif value in [col[1] for col in db["column_names_original"]]:
                        return 'Column'
                    else:
                        return 'None'
            return 'None'

    # Process each token in the parsed tokens
    for token in parsed.tokens:
        process_token(token)
        if verbose: print("-" * 50)

    # Combine tokens to reconstruct the processed SQL
    processed_sql = ''.join(processed_tokens)
    return entities, processed_sql

### Output Result

In [208]:
sample = [train_data[i] for i in [11,33,5,250]]

# Extract entity and print results
for i, data in enumerate(sample):
	sql_query = data["SQL"]
	sql_id = data["db_id"]
	print(f"Query {i+1}: <{sql_query}>\n")
	print(sql_id+"\n")
	entities, processed_sql = extract_entities(sql_id,sql_query,verbose=False)

	# Entity Table
	headers = ["Entity", "Type", "ID"]
	print(f"{headers[0]:<30} {headers[1]:<15} {headers[2]:<20}")
	
	# Filter out duplicates
	seen = set()
	unique_entities = []
	for entity in entities:
		if entity not in seen:
			seen.add(entity)
			unique_entities.append(entity)
	
	# Sort and print unique entities
	unique_entities = sorted(unique_entities, key=lambda x: x[2])
	for entity in unique_entities:
		print(f"{entity[0]:<30} {entity[1]:<15} {entity[2]:<20}")
	# Masked SQL
	print("\nInitial SQL:",sql_query)
	print("Masked SQL:",processed_sql)
	print("=" * 50)
	print("=" * 50)

Query 1: <SELECT T2.movie_title FROM ratings AS T1 INNER JOIN movies AS T2 ON T1.movie_id = T2.movie_id WHERE CAST(SUBSTR(T1.rating_timestamp_utc, 1, 4) AS INTEGER) = 2020 AND CAST(SUBSTR(T1.rating_timestamp_utc, 6, 2) AS INTEGER) > 4>

movie_platform

Entity                         Type            ID                  
T2                             Alias           <extra_id_0>        
T1                             Alias           <extra_id_0>        
ratings                        Table           <extra_id_1>        
movies                         Table           <extra_id_1>        
movie_title                    Column          <extra_id_2>        
movie_id                       Column          <extra_id_2>        
rating_timestamp_utc           Column          <extra_id_2>        
1                              Value           <extra_id_3>        
4                              Value           <extra_id_3>        
2020                           Value           <extra_id_3>        

### Testing

In [207]:
with open('train.json') as f:
	train_data = json.load(f)
	
sql=train_data[15]["SQL"]
print(sql)
parsed = sqlparse.parse(sql)[0]
print(parsed.token_first())

a1=parsed.tokens[8].tokens[2].tokens[4].ttype
a2=sqlparse.tokens.String
print("TARGET",parsed.tokens[8].tokens[2].tokens[4].ttype)
print(a1 in a2)

SELECT director_name FROM movies WHERE movie_title = 'Sex, Drink and Bloodshed'
SELECT
TARGET Token.Literal.String.Single
True
