In [24]:
import json
import sqlparse
from sqlparse.sql import IdentifierList, Identifier
from sqlparse.tokens import Keyword, DML
import random

# Load the JSON files
with open('train.json') as f:
	train_data = json.load(f)

with open('train_tables.json') as f:
	train_tables = json.load(f)

# Function to check if a parsed token is a subselect
def is_subselect(parsed):
	if not parsed.is_group:
		return False
	for item in parsed.tokens:
		if item.ttype is DML and item.value.upper() == 'SELECT':
			return True
	return False

# Function to extract the FROM part of a SQL statement
def extract_from_part(parsed):
	from_seen = False
	for item in parsed.tokens:
		if from_seen:
			if is_subselect(item):
				for x in extract_from_part(item):
					yield x
			elif item.ttype is Keyword:
				return
			else:
				yield item
		elif item.ttype is Keyword and item.value.upper() == 'FROM':
			from_seen = True

# Function to extract table identifiers from a token stream
def extract_table_identifiers(token_stream):
	for token in token_stream:
		if isinstance(token, IdentifierList):
			for identifier in token.get_identifiers():
				yield identifier.get_real_name()
		elif isinstance(token, Identifier):
			yield token.get_real_name()

# Function to extract table names from a SQL query
def extract_tables(sql):
	parsed = sqlparse.parse(sql)
	for stmt in parsed:
		if stmt.get_type() == 'SELECT':
			stream = extract_from_part(stmt)
			return list(extract_table_identifiers(stream))

# Function to extract column names from a SQL query
def extract_column_names(sql):
	column_names = []
	parsed = sqlparse.parse(sql)
	for stmt in parsed:
		if stmt.get_type() == 'SELECT':
			for token in stmt.tokens:
				if isinstance(token, IdentifierList):
					for identifier in token.get_identifiers():
						column_names.append(identifier.get_real_name())
				elif isinstance(token, Identifier):
					column_names.append(token.get_real_name())
	return column_names

# Randomly select 20 SQL queries from the dataset
random_sql_queries = random.sample(train_data, 20)

# Extract and print tables and columns for each selected query
for i, data in enumerate(random_sql_queries):
	sql_query = data["SQL"]
	tables = extract_tables(sql_query)
	columns = extract_column_names(sql_query)
	print(f"SQL Query {i+1}:")
	print(sql_query)
	print("Tables:", tables)
	print("Columns:", columns)
	print("-" * 50)


SQL Query 1:
SELECT SUM(T2.Value) FROM Country AS T1 INNER JOIN Indicators AS T2 ON T1.CountryCode = T2.CountryCode WHERE T1.IncomeGroup LIKE '%middle income' AND T2.Year = 1960 AND T2.IndicatorName = 'Urban population'
Tables: ['Country']
Columns: ['Country', 'Indicators']
--------------------------------------------------
SQL Query 2:
SELECT `Sales Channel` FROM `Sales Orders` WHERE OrderDate LIKE '1/%/20' GROUP BY `Sales Channel` ORDER BY COUNT(`Sales Channel`) DESC LIMIT 1
Tables: ['Sales Orders']
Columns: ['Sales Channel', 'Sales Orders', 'Sales Channel']
--------------------------------------------------
SQL Query 3:
SELECT T3.`Sales Team`, T1.`City Name` FROM `Store Locations` AS T1 INNER JOIN `Sales Orders` AS T2 ON T2._StoreID = T1.StoreID INNER JOIN `Sales Team` AS T3 ON T3.SalesTeamID = T2._SalesTeamID WHERE T2.OrderNumber = 'SO - 0001004'
Tables: ['Store Locations']
Columns: ['Sales Team', 'City Name', 'Store Locations', 'Sales Orders', 'Sales Team']
-----------------------

In [88]:
import json
import sqlparse
from sqlparse.sql import IdentifierList, Identifier, Comparison, Where, Function, Parenthesis
from sqlparse.tokens import Keyword, DML, Name, String, Number, Whitespace
import re

# Load the JSON files
with open('train.json') as f:
	train_data = json.load(f)

with open('train_tables.json') as f:
	train_tables = json.load(f)

# List of SQL keywords
keywords_main_body = {'SELECT', 'FROM', 'WHERE', 'AND', 'OR', 'NOT', 'IN', 'EXISTS', 'IS', 'NULL', 'IIF', 'CASE', 'WHEN'}
keywords_join = {'INNER JOIN', 'LEFT JOIN', 'ON', 'AS'}
keywords_clause = {'BETWEEN', 'LIKE', 'LIMIT', 'ORDER BY', 'ASC', 'DESC', 'GROUP BY', 'HAVING', 'UNION', 'ALL', 'EXCEPT', 'PARTITION BY', 'OVER'}
keywords_aggregation = {'AVG', 'COUNT', 'MAX', 'MIN', 'ROUND', 'SUM'}
keywords_scalar = {'ABS', 'LENGTH', 'STRFTIME', 'JULIADAY', 'NOW', 'CAST', 'SUBSTR', 'INSTR'}
keywords_comparison = {'=', '>', '<', '>=', '<=', '!=', '<>'}
keywords_computing = {'-', '+', '*', '/'}

all_keywords = keywords_main_body | keywords_join | keywords_clause | keywords_aggregation | keywords_scalar | keywords_comparison | keywords_computing

def get_id(entity_type):
	if entity_type == 'Alias':
		return '<extra_id_0>'
	elif entity_type == 'Table':
		return '<extra_id_1>'
	elif entity_type == 'Column':
		return '<extra_id_2>'
	elif entity_type == 'Value':
		return '<extra_id_3>'
	
# Function to extract entities and process the SQL query
def extract_entities(sql):
	parsed = sqlparse.parse(sql)[0]
	entities = []
	processed_tokens = []

	def process_token(token):
		print("Processing token:",token)
		nonlocal processed_tokens

		if token.ttype in (Whitespace, Keyword.DML, Keyword):
			print(token,"is a whitespace/DML/Keyword, appending to result\n")
			processed_tokens.append(token.value)
		elif isinstance(token, Identifier):
			print(token,"is an Identifier (type=",token.ttype,")")
			if token.ttype in (Name, String, Number):
				print(token,"is a Name/String/Number")
				entity_value = token.value
				entity_type = 'Value'
				entity_id = get_id(entity_type)
				entities.append((entity_value, entity_type, entity_id))
				processed_tokens.append(entity_id)
			elif re.search(r'T\d+\..*', token.value):
				print(token,"is an Alias/Column pair T*.*")
				# match the T*.* pattern
				entitiy_split = token.value.split(".")
				entity_1_name = entitiy_split[0]
				entity_2_name = entitiy_split[1]
				entity_1_type = 'Alias'
				entity_2_type = 'Column'
				entity_1_id = get_id(entity_1_type)
				entity_2_id = get_id(entity_2_type)
				entities.append((entity_1_name, entity_1_type, entity_1_id))
				entities.append((entity_2_name, entity_2_type, entity_2_id))
				processed_tokens.append(entity_1_id)
				processed_tokens.append(".")
				processed_tokens.append(entity_2_id)
			else:
				print(token,"is a Column")
				entity_value = token.value
				entity_type = 'Column'
				entity_id = get_id(entity_type)
				entities.append((entity_value, entity_type, entity_id))
				processed_tokens.append(entity_id)
		elif isinstance(token, IdentifierList):
			print(token,"is in IdentifierList")
			for identifier in token.get_identifiers():
				process_token(identifier)
		elif isinstance(token, Comparison):
			print(token,"is a comparison token")
			for sub_token in token.tokens:
				process_token(sub_token)
		elif isinstance(token, (Function, Parenthesis, Where)):
			print(token,"is a function/parenthesis/Where")
			for sub_token in token.tokens:
				process_token(sub_token)
		else:
			print(token,"is unknown. Appending to result")
			processed_tokens.append(token.value)

	for token in parsed.tokens:
		process_token(token)
	# Combine Tokens
	processed_sql = ''.join(processed_tokens)
	return entities, processed_sql

# Randomly select 20 SQL queries from the dataset
sample = [train_data[i] for i in [0]]

# Extract entity and print results
for i, data in enumerate(sample):
	sql_query = data["SQL"]
	print(f"SQL Query {i+1}:")
	print(sql_query+"\n")
	entities, processed_sql = extract_entities(sql_query)

	# Entity Table
	headers = ["Entity", "Type", "ID"]
	print(f"{headers[0]:<20} {headers[1]:<20} {headers[2]:<20}")
	# Sort
	entities = sorted(entities, key=lambda x: x[2])
	for entity in entities:
		print(f"{entity[0]:<20} {entity[1]:<20} {entity[2]:<20}")
	# Masked SQL
	print("\nMasked SQL:")
	print(processed_sql)
	print("-" * 50)


SQL Query 1:
SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC LIMIT 1

Processing token: SELECT
SELECT is a whitespace/DML/Keyword, appending to result

Processing token:  
  is a whitespace/DML/Keyword, appending to result

Processing token: movie_title
movie_title is an Identifier (type= None )
movie_title is a Column
Processing token:  
  is a whitespace/DML/Keyword, appending to result

Processing token: FROM
FROM is a whitespace/DML/Keyword, appending to result

Processing token:  
  is a whitespace/DML/Keyword, appending to result

Processing token: movies
movies is an Identifier (type= None )
movies is a Column
Processing token:  
  is a whitespace/DML/Keyword, appending to result

Processing token: WHERE movie_release_year = 1945 
WHERE movie_release_year = 1945  is a function/parenthesis/Where
Processing token: WHERE
WHERE is a whitespace/DML/Keyword, appending to result

Processing token:  
  is a whitespace/DML/Keyword, appending 

In [76]:
with open('train.json') as f:
	train_data = json.load(f)
	
sql=train_data[15]["SQL"]
print(sql)
parsed = sqlparse.parse(sql)[0]

for token in parsed.tokens:
	print("token",token,"type",token.ttype,)
a1=parsed.tokens[8].tokens[2].tokens[4].ttype
a2=sqlparse.tokens.String
print("TARGET",parsed.tokens[8].tokens[2].tokens[4].ttype)
print(a1 in a2)

SELECT director_name FROM movies WHERE movie_title = 'Sex, Drink and Bloodshed'
token SELECT type Token.Keyword.DML
token   type Token.Text.Whitespace
token director_name type None
token   type Token.Text.Whitespace
token FROM type Token.Keyword
token   type Token.Text.Whitespace
token movies type None
token   type Token.Text.Whitespace
token WHERE movie_title = 'Sex, Drink and Bloodshed' type None
TARGET Token.Literal.String.Single
True
