In [1]:
import json
import re
import torch
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# read data from .json file

with open("plain_statement_5000.json") as f:
    json_data = json.load(f)

plain_sql = [item['sql'] for item in json_data]
plain_sql = [sql.lower() for sql in plain_sql]

In [3]:
# split data into tokens

pattern = r'[\s()\-,:;]'
string_literal_pattern = r"'([^']*)'"
placeholder = "<string>"

# replace content inside single quotes by <string>
plain_sql_ph = [re.sub(string_literal_pattern, placeholder, sql) for sql in plain_sql]

# split the statements with placeholder
tokenized_sql = [re.split(pattern, sql) for sql in plain_sql_ph]

# remove empty tokens
tokenized_sql = [token for token in tokenized_sql if token]

# replace numbers by placeholder
for sql in tokenized_sql:
    for i, token in enumerate(sql):
        # if re.match(r'^[\'\"].*[\'\"]$', token):  # Check if token is a string literal
        #     sql[i] = '<string>'
        if re.match(r'^[0-9]+(\.[0-9]+)?$', token):  # Check if token is a number
            sql[i] = '<number>'

# remove empty tokens
for i, sql in enumerate(tokenized_sql):
    tokenized_sql[i] = [token for token in tokenized_sql[i] if token]

In [4]:
# build the vocab
vocab_set = set()
for sql in tokenized_sql:
    vocab_set.update(sql)

vocab = {word: idx for idx, word in enumerate(vocab_set)}

In [5]:
print(vocab)

{'"lineitem"."l_tax"': 0, '"nation"."n_name"': 1, '"supplier"."s_nationkey"': 2, '<=': 3, 'count': 4, '"customer"."c_mktsegment"': 5, '"nation"."n_comment"': 6, '"orders"."o_orderpriority"': 7, '"orders"': 8, '"orders"."o_orderdate"': 9, '"part"."p_container"': 10, '"customer"."c_nationkey"': 11, '"partsupp"."ps_supplycost"': 12, '"lineitem"."l_orderkey"': 13, '+': 14, '"orders"."o_totalprice"': 15, '*': 16, '"lineitem"."l_commitdate"': 17, '"region"."r_regionkey"': 18, 'on': 19, '"lineitem"."l_receiptdate"': 20, 'sum': 21, '"supplier"."s_comment"': 22, '>=': 23, '"supplier"': 24, '"lineitem"."l_partkey"': 25, '"supplier"."s_phone"': 26, '"partsupp"."ps_suppkey"': 27, '"part"': 28, '"orders"."o_clerk"': 29, '"lineitem"': 30, '"orders"."o_shippriority"': 31, '"customer"."c_custkey"': 32, '"customer"': 33, '"nation"': 34, '=': 35, '"supplier"."s_suppkey"': 36, '"lineitem"."l_shipinstruct"': 37, '!=': 38, 'from': 39, 'avg': 40, '"partsupp"."ps_availqty"': 41, '"lineitem"."l_linestatus"': 

In [6]:
# convert tokens to indices for each sample
indices = [[vocab[token] for token in sql] for sql in tokenized_sql]

# define embedding layer
vocab_size = len(vocab)
embedding_dim = 10
embedding = nn.Embedding(vocab_size, embedding_dim)

# convert indices to PyTorch tensors
indices = [torch.LongTensor(index) for index in indices]
embedded_X = [embedding(index) for index in indices]
print(embedded_X)

In [None]:
# get the labels

label = [item['runtime_ms'] for item in json_data]