In [118]:
import json
import sqlparse
from sqlparse.sql import IdentifierList, Identifier, Comparison, Where, Function, Parenthesis
from sqlparse.tokens import Keyword,Whitespace,Operator, Punctuation
import re
# Load the JSON files
with open('train.json') as f:
	train_data = json.load(f)

with open('train_tables.json') as f:
	train_tables = json.load(f)

In [119]:
# List of SQL keywords
keywords_main_body = {'SELECT', 'FROM', 'WHERE', 'AND', 'OR', 'NOT', 'IN', 'EXISTS', 'IS', 'NULL', 'IIF', 'CASE', 'WHEN'}
keywords_join = {'INNER JOIN', 'LEFT JOIN', 'ON', 'AS'}
keywords_clause = {'BETWEEN', 'LIKE', 'LIMIT', 'ORDER BY', 'ASC', 'DESC', 'GROUP BY', 'HAVING', 'UNION', 'ALL', 'EXCEPT', 'PARTITION BY', 'OVER'}
keywords_aggregation = {'AVG', 'COUNT', 'MAX', 'MIN', 'ROUND', 'SUM'}
keywords_scalar = {'ABS', 'LENGTH', 'STRFTIME', 'JULIADAY', 'NOW', 'CAST', 'SUBSTR', 'INSTR'}
keywords_comparison = {'=', '>', '<', '>=', '<=', '!=', '<>'}
keywords_computing = {'-', '+', '*', '/'}

all_keywords = keywords_main_body | keywords_join | keywords_clause | keywords_aggregation | keywords_scalar | keywords_comparison | keywords_computing

def get_id(entity_type):
	if entity_type == 'Alias':
		return '<extra_id_0>'
	elif entity_type == 'Table':
		return '<extra_id_1>'
	elif entity_type == 'Column':
		return '<extra_id_2>'
	elif entity_type == 'Value':
		return '<extra_id_3>'

In [120]:
def extract_entities(sql):
	# Parsed tokens from sqlParse
	parsed = sqlparse.parse(sql)[0]
	# Entity table
	entities = []
	# Masked SQL Query
	processed_tokens = []

	# Function that process individual token
	def process_token(token):
		print("Processing token: '"+str(token)+"' ("+str(token.ttype)+")")
		nonlocal processed_tokens
		# Check for Whitespace or Keyword (SELECT)
		if token.ttype in (Whitespace, Keyword, Operator, Punctuation):
			processed_tokens.append(token.value)
		elif token.value.upper() in all_keywords:
			print(token,"is a keyword")
			processed_tokens.append(token.value)
		# Check for Identifier
		elif isinstance(token, Identifier):
			print(token,"is an Identifier")
			if re.search(r'T\d+\..*', token.value):
				print(token,"is an Alias/Column pair T*.*")
				# match the T*.* pattern
				entitiy_split = token.value.split(".")
				entity_1_value = entitiy_split[0]
				entity_2_value = entitiy_split[1]
				entity_1_type = 'Alias'
				entity_2_type = 'Column'
				entity_1_id = get_id(entity_1_type)
				entity_2_id = get_id(entity_2_type)
				entities.append((entity_1_value, entity_1_type, entity_1_id))
				entities.append((entity_2_value, entity_2_type, entity_2_id))
				processed_tokens.append(entity_1_id+"."+entity_2_id)

			# Check for Entity DESC pair because sqlparse will mistakenly parse them together
			elif re.search(r'\s+DESC$', token.value, re.IGNORECASE):
				# Spliting Entity DESC 
				split = token.value.rsplit(' ', 1)
				entity_value = split[0]
				identifier = split[1]
				print(token,"is a DESC pair, splitting and appending",entity_value,"as a Column")
				entity_type = 'Column'
				entity_id = get_id(entity_type)
				entities.append((entity_value, entity_type, entity_id))
				processed_tokens.append(entity_id+' DESC')
			elif re.search(r'\sAS\sT\d', token.value):
				entitiy_split = token.value.rsplit(' ', 2)
				entity_1_value = entitiy_split[0]
				entity_2_value = entitiy_split[2]
				entity_1_type = 'Column'
				entity_2_type = 'Alias'
				entity_1_id = get_id(entity_1_type)
				entity_2_id = get_id(entity_2_type)
				print(token,"is an \"AS T pair\", splitting and appending",entity_1_value,"as a Column")
				entities.append((entity_1_value, entity_1_type, entity_1_id))
				#entities.append((entity_2_value, entity_2_type, entity_2_id))
				processed_tokens.append(entity_1_id+" AS "+entity_2_id)

			else:
				print(token,"is a Column or Table")
				entity_value = token.value
				entity_type = 'Column'
				entity_id = get_id(entity_type)
				entities.append((entity_value, entity_type, entity_id))
				processed_tokens.append(entity_id)
		# Expand on these parent tokens 
		elif isinstance(token, (Function, Parenthesis, Where, Comparison,IdentifierList)):
			print(token,"is a function/parenthesis/Where")
			for sub_token in token.tokens:
				process_token(sub_token)
		else:
			print(token,"is a Name/String/Number")
			entity_value = token.value
			entity_type = 'Value'
			entity_id = get_id(entity_type)
			entities.append((entity_value, entity_type, entity_id))
			processed_tokens.append(entity_id)

	for token in parsed.tokens:
		process_token(token)
		print("-" * 50)

	# Combine Tokens
	processed_sql = ''.join(processed_tokens)
	return entities, processed_sql

In [121]:
sample = [train_data[i] for i in [0,5,10]]

# Extract entity and print results
for i, data in enumerate(sample):
	sql_query = data["SQL"]
	print(f"SQL Query {i+1}:")
	print(sql_query+"\n")
	entities, processed_sql = extract_entities(sql_query)

	# Entity Table
	headers = ["Entity", "Type", "ID"]
	print(f"{headers[0]:<20} {headers[1]:<20} {headers[2]:<20}")
	# Sort
	entities = sorted(entities, key=lambda x: x[2])
	for entity in entities:
		print(f"{entity[0]:<20} {entity[1]:<20} {entity[2]:<20}")
	# Masked SQL
	print("\nInitial SQL:",sql_query)
	print("\nMasked SQL:",processed_sql)
	print("=" * 50)
	print("=" * 50)
	

SQL Query 1:
SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC LIMIT 1

Processing token: 'SELECT' (Token.Keyword.DML)
SELECT is a keyword
--------------------------------------------------
Processing token: ' ' (Token.Text.Whitespace)
--------------------------------------------------
Processing token: 'movie_title' (None)
movie_title is an Identifier
movie_title is a Column or Table
--------------------------------------------------
Processing token: ' ' (Token.Text.Whitespace)
--------------------------------------------------
Processing token: 'FROM' (Token.Keyword)
--------------------------------------------------
Processing token: ' ' (Token.Text.Whitespace)
--------------------------------------------------
Processing token: 'movies' (None)
movies is an Identifier
movies is a Column or Table
--------------------------------------------------
Processing token: ' ' (Token.Text.Whitespace)
------------------------------------------------

In [122]:
with open('train.json') as f:
	train_data = json.load(f)
	
sql=train_data[15]["SQL"]
print(sql)
parsed = sqlparse.parse(sql)[0]
print(parsed.token_first())

a1=parsed.tokens[8].tokens[2].tokens[4].ttype
a2=sqlparse.tokens.String
print("TARGET",parsed.tokens[8].tokens[2].tokens[4].ttype)
print(a1 in a2)

SELECT director_name FROM movies WHERE movie_title = 'Sex, Drink and Bloodshed'
SELECT
TARGET Token.Literal.String.Single
True
