In [26]:
import json
import re
import sqlparse
from sqlparse.sql import Identifier, Function, Parenthesis, Where, Comparison, IdentifierList, Operation, If, Case
from sqlparse.tokens import Whitespace, Keyword, Operator, Punctuation

# Load the JSON files
with open('train.json') as f:
	train_data = json.load(f)

with open('train_tables.json') as f:
	train_tables = json.load(f)

In [27]:
# List of SQL keywords
keywords_main_body = {'SELECT', 'FROM', 'WHERE', 'AND', 'OR', 'NOT', 'IN', 'EXISTS', 'IS', 'NULL', 'IIF', 'CASE', 'WHEN'}
keywords_join = {'INNER JOIN', 'LEFT JOIN', 'ON', 'AS'}
keywords_clause = {'BETWEEN', 'LIKE', 'LIMIT', 'ORDER BY', 'ASC', 'DESC', 'GROUP BY', 'HAVING', 'UNION', 'ALL', 'EXCEPT', 'PARTITION BY', 'OVER'}
keywords_aggregation = {'AVG', 'COUNT', 'MAX', 'MIN', 'ROUND', 'SUM'}
keywords_scalar = {'ABS', 'LENGTH', 'STRFTIME', 'JULIADAY', 'NOW', 'CAST', 'SUBSTR', 'INSTR'}
keywords_comparison = {'=', '>', '<', '>=', '<=', '!=', '<>'}
keywords_computing = {'-', '+', '*', '/'}

all_keywords = keywords_main_body | keywords_join | keywords_clause | keywords_aggregation | keywords_scalar | keywords_comparison | keywords_computing

def get_id(entity_type):
	if entity_type == 'Alias':
		return '<extra_id_0>'
	elif entity_type == 'Table':
		return '<extra_id_1>'
	elif entity_type == 'Column':
		return '<extra_id_2>'
	elif entity_type == 'Value':
		return '<extra_id_3>'

### Process Token

In [38]:
def extract_entities(sql_id,sql, verbose=False):
	# Parse the SQL query and take the first statement
	parsed = sqlparse.parse(sql)[0]
	entities = []  # List to store identified entities
	processed_tokens = []  # List to store processed tokens for reconstructed SQL

	# Helper function to process each token in the parsed SQL
	def process_token(token):
		nonlocal processed_tokens
		if verbose: print(f"Processing token: '{token}' ({token.ttype})")
		
		# Processing based on token type or value
		if token.ttype in (Whitespace, Keyword, Operator, Punctuation):
			processed_tokens.append(token.value)
		elif token.value.upper() in all_keywords:
			if verbose:
				print(f"    {token} is a keyword")
			processed_tokens.append(token.value)
		elif isinstance(token, Identifier):
			if verbose: print(f"    {token} is an Identifier")
			process_identifier(token)
		elif isinstance(token, (Function, Parenthesis, Operation, Where, Comparison, IdentifierList, If , Case)):
			if verbose: print(f"    {token} is a function/parenthesis/where/comparison/identifierlist")
			for sub_token in token.tokens:
				process_token(sub_token)
		else:
			if verbose: print(f"    {token} is a Value")
			# print(type(token).__name__)
			append_entity(token.value, 'Value')

	# Helper function to process identifiers in SQL
	def process_identifier(token):
		alias_pattern = r'T\d+\..*'
		desc_pattern = r'\s+DESC$'
		as_t_pattern = r'\sAS\sT\d'
		
		if re.search(alias_pattern, token.value):
			if verbose: print(f"    {token} is an Alias/Column pair T*.*")
			process_alias_identifier(token)
		elif re.search(desc_pattern, token.value):
			if verbose: print(f"    {token} is a DESC pair")
			process_desc_identifier(token)
		elif re.search(as_t_pattern, token.value):
			if verbose: print(f"    {token} is an 'AS T' pair")
			process_as_t_identifier(token)

		else:
			type = find_name_in_db(sql_id, token.value)
			if verbose: print(f"    {token} is a {type}")
			append_entity(token.value, type)

	# Append entity to the list with its type and id
	def append_entity(value, type):
		entity_id = get_id(type)
		entities.append((value, type, entity_id))
		processed_tokens.append(entity_id)
		if verbose: print(f"        Appending entity: {value} as {type} with id {entity_id}")

	# Specific processing for alias-pattern identifiers
	def process_alias_identifier(token):
		entitiy_split = token.value.split(".")
		entity_1_id = get_id('Alias')
		type = find_name_in_db(sql_id, entitiy_split[1])
		entity_2_id = get_id(type)
		entities.append((entitiy_split[0], 'Alias', entity_1_id))
		entities.append((entitiy_split[1], type, entity_2_id))
		processed_tokens.append(f"{entity_1_id}.{entity_2_id}")

	# Processing identifiers ending with 'DESC'
	def process_desc_identifier(token):
		split = token.value.rsplit(' ', 1)
		type = find_name_in_db(sql_id, split[0])
		entity_id = get_id(type)
		entities.append((split[0], type, entity_id))
		processed_tokens.append(f"{entity_id} DESC")

	# Processing identifiers with 'AS T' pattern
	def process_as_t_identifier(token):
		entity_split = token.value.rsplit(' ', 2)
		type = find_name_in_db(sql_id, entity_split[0])
		entity_1_id = get_id(type)
		entity_2_id = get_id('Alias')
		entities.append((entity_split[0], type, entity_1_id))
		entities.append((entity_split[2], 'Alias', entity_2_id))
		processed_tokens.append(f"{entity_1_id} AS {entity_2_id}")
	
	def find_name_in_db(sql_id, var):
		for db in train_tables:
			if db["db_id"] == sql_id:
				if var in db["table_names"]:
					return 'Table'
				else:
					return 'Column'
			

	# Process each token in the parsed tokens
	for token in parsed.tokens:
		process_token(token)
		if verbose: print("-" * 50)

	# Combine tokens to reconstruct the processed SQL
	processed_sql = ''.join(processed_tokens)
	return entities, processed_sql

### Output Result

In [39]:

sample = [train_data[i] for i in [0, 5, 10]]

# Extract entity and print results
for i, data in enumerate(sample):
	sql_query = data["SQL"]
	sql_id = data["db_id"]
	print(f"Query {i+1}: <{sql_query}>\n")
	print(sql_id+"\n")
	entities, processed_sql = extract_entities(sql_id,sql_query,verbose=True)

	# Entity Table
	headers = ["Entity", "Type", "ID"]
	print(f"{headers[0]:<30} {headers[1]:<15} {headers[2]:<20}")
	
	# Filter out duplicates
	seen = set()
	unique_entities = []
	for entity in entities:
		if entity not in seen:
			seen.add(entity)
			unique_entities.append(entity)
	
	# Sort and print unique entities
	unique_entities = sorted(unique_entities, key=lambda x: x[2])
	for entity in unique_entities:
		print(f"{entity[0]:<30} {entity[1]:<15} {entity[2]:<20}")
	# Masked SQL
	print("\nInitial SQL:",sql_query)
	print("Masked SQL:",processed_sql)
	print("=" * 50)
	print("=" * 50)
	

Query 1: <SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC LIMIT 1>

movie_platform

Processing token: 'SELECT' (Token.Keyword.DML)
    SELECT is a keyword
--------------------------------------------------
Processing token: ' ' (Token.Text.Whitespace)
--------------------------------------------------
Processing token: 'movie_title' (None)
    movie_title is an Identifier
    movie_title is a Column
        Appending entity: movie_title as Column with id <extra_id_2>
--------------------------------------------------
Processing token: ' ' (Token.Text.Whitespace)
--------------------------------------------------
Processing token: 'FROM' (Token.Keyword)
--------------------------------------------------
Processing token: ' ' (Token.Text.Whitespace)
--------------------------------------------------
Processing token: 'movies' (None)
    movies is an Identifier
    movies is a Table
        Appending entity: movies as Table with id <extra_id_1

### Testing

In [35]:
with open('train.json') as f:
	train_data = json.load(f)
	
sql=train_data[15]["SQL"]
print(sql)
parsed = sqlparse.parse(sql)[0]
print(parsed.token_first())

a1=parsed.tokens[8].tokens[2].tokens[4].ttype
a2=sqlparse.tokens.String
print("TARGET",parsed.tokens[8].tokens[2].tokens[4].ttype)
print(a1 in a2)

SELECT director_name FROM movies WHERE movie_title = 'Sex, Drink and Bloodshed'
SELECT
TARGET Token.Literal.String.Single
True
