In [1]:
import pandas as pd
import re
import math

In [2]:
df = pd.read_csv('SQL_codes_dataset.csv', sep=';', decimal=',', encoding='utf-8')
df.head()

Unnamed: 0,ID,TG,Student,Model,Variant,X,AXC
0,1,A,-1,ChatGPT5.2,0,1,CREATE OR REPLACE TABLE retail (\n invoice_...
1,2,A,-1,ChatGPT5.2,1,1,CREATE OR REPLACE TABLE retail (\n invoice_...
2,3,A,-1,ChatGPT5.2,2,1,CREATE OR REPLACE TABLE retail (\n invoice_...
3,4,A,-1,ChatGPT5.2,3,1,CREATE OR REPLACE TABLE retail (\n invoice_no...
4,5,B,-1,ChatGPT5.2,0,1,CREATE OR REPLACE TABLE games (\n id ...


In [3]:
SQL_OPERATORS = {'SELECT', 'FROM', 'WHERE', 'GROUP', 'BY', 'ORDER', 'HAVING', 'LIMIT', 'OFFSET', 'DISTINCT', 'AS', 'ON', 'USING', 'INSERT', 'UPDATE', 'DELETE', 'INTO', 'VALUES', 'SET', 'CREATE', 'ALTER', 'DROP', 'TABLE', 'VIEW', 'INDEX'}
SQL_JOIN_KEYWORDS = {'JOIN', 'INNER', 'LEFT', 'RIGHT', 'FULL', 'OUTER'} # not present in the examined dataset
SQL_LOGICAL_OPERATORS = {'AND', 'OR', 'NOT', 'IN', 'EXISTS', 'BETWEEN', 'LIKE', 'ILIKE', 'IS', 'NULL', 'TRUE', 'FALSE', 'ALL', 'ANY', 'SOME'}
SQL_COMPARISON_OPERATORS = {'=', '<>', '!=', '<', '>', '<=', '>='}
SQL_CONDITIONAL_KEYWORDS = {'CASE', 'WHEN', 'THEN', 'ELSE', 'END', 'IF', 'IFF'}
SQL_WINDOW_KEYWORDS = {'OVER', 'PARTITION', 'ROWS', 'RANGE', 'UNBOUNDED', 'PRECEDING', 'FOLLOWING', 'CURRENT', 'ROW'}
SQL_AGGREGATE_FUNCTIONS = {'SUM', 'COUNT', 'AVG', 'MIN', 'MAX', 'ARRAY_AGG', 'MEDIAN'}
SQL_WINDOW_FUNCTIONS = {'ROW_NUMBER', 'RANK', 'DENSE_RANK', 'NTILE', 'LAG', 'LEAD'}

ALL_SQL_OPERATORS = (SQL_OPERATORS | SQL_JOIN_KEYWORDS | SQL_LOGICAL_OPERATORS | SQL_CONDITIONAL_KEYWORDS | SQL_WINDOW_KEYWORDS | SQL_AGGREGATE_FUNCTIONS | SQL_WINDOW_FUNCTIONS | SQL_COMPARISON_OPERATORS |
  {'(', ')', ',', '.', ';', '*', '+', '-', '/', '%'}
)

# Weights for different SQL constructs 
SQL_COMMAND_WEIGHTS = {
  'SELECT': 1.0,
  'JOIN': 2.5, # not present in the examined dataset
  'SUBQUERY': 3.0,
  'WINDOW': 2.5,
  'AGGREGATE': 1.5,
  'CASE': 2.0, # not present in the examined dataset
  'CTE': 2.0,
  'UNION': 1.5 # not present in the examined dataset
}

In [4]:
def remove_comments(sql):
  sql = re.sub(r'--.*$', '', sql, flags=re.MULTILINE)
  sql = re.sub(r'/\*.*?\*/', '', sql, flags=re.DOTALL)
  return sql

def tokenize(sql):
  sql = remove_comments(sql)
  string_pattern = re.compile(r"'[^']*'|\"[^\"]*\"")
  symbol_pattern = re.compile(r'<>|!=|<=|>=|::|\|\||[(),.*+\-/%;=<>]')
  strings = string_pattern.findall(sql)
  sql_no_strings = string_pattern.sub(' __STRING__ ', sql)
  parts = symbol_pattern.split(sql_no_strings)
  symbols = symbol_pattern.findall(sql_no_strings)
    
  result = []
  for i, part in enumerate(parts):
    words = part.split()
    result.extend(words)
    if i<len(symbols):
      result.append(symbols[i])
    
  string_idx = 0
  for i, token in enumerate(result):
    if token == '__STRING__' and string_idx<len(strings):
      result[i] = strings[string_idx]
      string_idx += 1
    
  return [t for t in result if t.strip()]

In [5]:
def size_metrics(sql):
  lines = sql.split('\n')
  tokens = tokenize(sql)
  return {
    'lines_of_code': len(lines),
    'character_length': len(sql),
    'token_count': len(tokens)
  }

In [6]:
def divide_tokens(tokens):
  number_pattern = re.compile(r'\b\d+\.?\d*\b')
  operators = []
  operands = []
  for token in tokens:
    upper_token = token.upper()
    if upper_token in ALL_SQL_OPERATORS or token in ALL_SQL_OPERATORS:
      operators.append(upper_token if upper_token in ALL_SQL_OPERATORS else token)
    elif token.startswith("'") or token.startswith('"'):
      operands.append(token)
    elif number_pattern.fullmatch(token):
      operands.append(token)
    else:
      operands.append(token.lower())
    
  return {
    'unique_operators': set(operators),
    'unique_operands': set(operands),
    'total_operators': len(operators),
    'total_operands': len(operands)
  }

In [7]:
def halstead_metrics(sql):
  tokens = tokenize(sql)
  divided_tokens = divide_tokens(tokens)
  n1 = len(divided_tokens['unique_operators'])
  n2 = len(divided_tokens['unique_operands'])
  N1 = divided_tokens['total_operators']
  N2 = divided_tokens['total_operands']
  
  vocabulary = n1+n2
  program_length = N1+N2
  volume = program_length*math.log2(vocabulary) if vocabulary>1 else 0
  difficulty = (n1/2)*(N2/n2) if n1>0 and n2>0 else 0
  effort = difficulty*volume
  time_to_program = effort/18 
    
  return {
    'n1_distinct_operators': n1,
    'n2_distinct_operands': n2,
    'N1_total_operators': N1,
    'N2_total_operands': N2,
    'vocabulary': vocabulary,
    'program_length': program_length,
    'volume': round(volume, 2),
    'difficulty': round(difficulty, 2),
    'effort': round(effort, 2),
    'time_to_program': round(time_to_program, 2)
  }

In [8]:
def sql_constructs(sql):
  sql_upper = sql.upper()
  return {
    'subquery_count': max(0, len(re.findall(r'\bSELECT\b', sql_upper))-1),
    'window_function_count': len(re.findall(r'\bOVER\s*\(', sql_upper)),
    'aggregate_function_count': sum(len(re.findall(rf'\b{func}\s*\(', sql_upper)) for func in SQL_AGGREGATE_FUNCTIONS),
    'cte_count': len(re.findall(r'\bWITH\b', sql_upper))
  }

In [9]:
def nesting_depth(sql):
  max_depth = 0
  current_depth = 0
  for char in sql:
    if char == '(':
      current_depth += 1
      max_depth = max(max_depth, current_depth)
    elif char == ')':
      current_depth = max(0, current_depth-1)
  return max_depth

In [10]:
def cognitive_complexity(sql):
  sql_upper = sql.upper()
  complexity = 0
  complexity += len(re.findall(r'\(\s*SELECT\b', sql_upper))*3
  complexity += len(re.findall(r'\bAND\b', sql_upper))*1
  complexity += len(re.findall(r'\bOR\b', sql_upper))*1
  complexity += len(re.findall(r'\bOVER\s*\(', sql_upper))*2
  max_depth = nesting_depth(sql)
  if max_depth>2:
    complexity += (max_depth-2)*2
  return complexity

In [11]:
def cognitive_metrics(sql):
  return {
    'max_nesting_depth': nesting_depth(sql),
    'cognitive_complexity': cognitive_complexity(sql),
  }

In [12]:
def count_tables(sql):
  sql_upper = sql.upper()
  from_count = len(re.findall(r'\bFROM\b', sql_upper))
  join_count = len(re.findall(r'\bJOIN\b', sql_upper))
  return from_count + join_count

def count_columns(tokens):
  all_keywords = SQL_OPERATORS | SQL_JOIN_KEYWORDS | SQL_LOGICAL_OPERATORS | SQL_AGGREGATE_FUNCTIONS | SQL_WINDOW_FUNCTIONS | SQL_CONDITIONAL_KEYWORDS | SQL_WINDOW_KEYWORDS
  identifier_pattern = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
  count = 0
  for token in tokens:
    if identifier_pattern.match(token) and token.upper() not in all_keywords:
      count += 1
  return count

def count_expressions(sql):
  sql_upper = sql.upper()
  count = 0
  for op in ['=', '<>', '!=', '<=', '>=', '<', '>']:
    count += len(re.findall(re.escape(op), sql))
  for op in ['AND', 'OR', 'NOT', 'LIKE', 'ILIKE', 'IN', 'BETWEEN', 'EXISTS']:
    count += len(re.findall(rf'\b{op}\b', sql_upper))
  return count

In [13]:
def sqlshare_complexity(sql):
  tokens = tokenize(sql)
  structure = {
    'table_count': count_tables(sql),
    'column_count': count_columns(tokens),
    'operator_count': divide_tokens(tokens)['total_operators'],
    'character_length': len(sql),
    'expression_count': count_expressions(sql),
  }
  return round(0.12*structure.get('table_count', 0)+0.08*structure.get('column_count', 0)+0.002*structure.get('character_length', 0)+0.20*structure.get('operator_count', 0)+0.15*structure.get('expression_count', 0), 2)

In [14]:
def weighted_complexity(metrics):
  score = SQL_COMMAND_WEIGHTS['SELECT']*1+SQL_COMMAND_WEIGHTS['SUBQUERY']*metrics.get('subquery_count', 0)+SQL_COMMAND_WEIGHTS['WINDOW']*metrics.get('window_function_count', 0)+SQL_COMMAND_WEIGHTS['AGGREGATE']*metrics.get('aggregate_function_count', 0)+SQL_COMMAND_WEIGHTS['CTE']*metrics.get('cte_count', 0)
  return round(score, 2)

In [15]:
def evaluate_query(sql):
  if not sql or not sql.strip():
    return {}
  metrics = {}
  metrics.update(size_metrics(sql))
  metrics.update(halstead_metrics(sql))
  metrics.update(sql_constructs(sql))
  metrics.update(cognitive_metrics(sql))
  metrics['sqlshare_complexity'] = sqlshare_complexity(sql)
  metrics['weighted_complexity'] = weighted_complexity(metrics)
  return metrics

In [16]:
def analyze_dataset(df, col):
  metrics_list = []
  for sql in df[col]:
    if pd.notna(sql):
      metrics = evaluate_query(sql)
    else:
      metrics = {}
    metrics_list.append(metrics)
  metrics_df = pd.DataFrame(metrics_list)
  return pd.concat([df.reset_index(drop=True), metrics_df], axis=1)

In [17]:
results = analyze_dataset(df, 'AXC')

In [18]:
results.to_csv('metrics_dataset.csv', index=False, encoding='utf-8', sep=';', decimal=',')

# Abstract Syntax Tree (AST)

In [19]:
!pip install sqlglot -q

In [20]:
import sqlglot
from sqlglot import exp

In [21]:
def extract_ast_features(sql_code):
  try:
    ast = sqlglot.parse_one(sql_code, dialect='snowflake')
  except Exception as e:
    print(f"Parse error: {e}")
    return {
      'ast_node_count': 0,
      'ast_height': 0,
      'ast_branching_factor': 0.0,
      'leaf_node_ratio': 0.0
    }
  features = {
    'ast_node_count': len(list(ast.walk()))
  }
    
  def calc_height(node):
    if not hasattr(node, 'args') or not node.args:
      return 0
    return 1 + max((calc_height(child) for child in node.walk() if child is not node), default=0)
    
  features['ast_height'] = calc_height(ast)
    
  total_nodes = 0
  leaf_nodes = 0
  non_leaf_nodes = 0
  total_children = 0
    
  def analyze_tree_structure(node):
    nonlocal total_nodes, leaf_nodes, non_leaf_nodes, total_children
    total_nodes += 1
    children = []
    if hasattr(node, 'args'):
      for arg_value in node.args.values():
        if isinstance(arg_value, list):
          children.extend([item for item in arg_value if isinstance(item, exp.Expression)])
        elif isinstance(arg_value, exp.Expression):
          children.append(arg_value)
    num_children = len(children)
    if num_children == 0:
      leaf_nodes += 1
    else:
      non_leaf_nodes += 1
      total_children += num_children
      for child in children:
        analyze_tree_structure(child)
    
  analyze_tree_structure(ast)
    
  if non_leaf_nodes > 0:
    features['ast_branching_factor'] = round(total_children / non_leaf_nodes, 2)
  else:
    features['ast_branching_factor'] = 0.0
  if total_nodes > 0:
    features['leaf_node_ratio'] = round(leaf_nodes / total_nodes, 2)
  else:
    features['leaf_node_ratio'] = 0.0
   
  return features

In [None]:
ast_features = results['AXC'].apply(extract_ast_features)
ast_features_df = pd.DataFrame(ast_features.tolist())

final_dataset = pd.concat([results, ast_features_df], axis=1)
final_dataset.to_csv('metrics_dataset.csv', index=False, encoding='utf-8', sep=';', decimal=',')