In [1]:
import pickle
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

from Environment.SQL.SQL import SQL
from Environment.Tokens_Actions.Basic_Block.KeywordRepresentation_Token import (
    KeywordRepresentation_Token,
)
from Environment.Tokens_Actions.Basic_Block.Comma_Token import Comma_Token
from Environment.Tokens_Actions.Basic_Block.Comment_Token import Comment_Token
from Environment.Tokens_Actions.Basic_Block.Paranthesis_Token import Paranthesis_Token
from Environment.Tokens_Actions.Basic_Block.Quote_Token import Quote_Token
from Environment.Tokens_Actions.Basic_Block.Whitespace_Token import Whitespace_Token

In [2]:
generic_stmt_path = "generic_synatx_statment.dataset"

with open(generic_stmt_path, "rb") as f:
    generic_sqls = pickle.load(f)

In [3]:
keywords = KeywordRepresentation_Token.reserved_keywords
# operators
operators = ["op"]
# comma
comma = Comma_Token.comma_types
# comment
comments = list(Comment_Token.comment_type_mapping.values())
# hex
hex = ["hex"]
# string
string = ["str"]
# id
id = ["id"]
# number
number = ["num"]
# parenthesis
paranthesis = list(Paranthesis_Token.paranthesis_type_mapping.values())
# quotes
quotes = list(Quote_Token.quote_type_mapping.values())
# whitespace
whitespace = list(Whitespace_Token.whitespace_type_mapping.values())
# padding
pad_token = ["pad"]
# starting and ending token
start_end_token = ["sos", "eos"]

tokens_embedding = [
    *pad_token,
    *keywords,
    *operators,
    *comma,
    *comments,
    *hex,
    *string,
    *id,
    *number,
    *paranthesis,
    *quotes,
    *whitespace,
    *start_end_token,
]
token_to_idx = {c: i for i, c in enumerate(tokens_embedding)}
idx_to_token = {i: c for i, c in enumerate(tokens_embedding)}



def tokenizer(sql: SQL):
    result = []

    for current_token in sql.get_tokens().flat_idx_tokens_list():
        result.append(
            token_to_idx[str(current_token).casefold()]
        )
    return result

In [4]:
# Hyperparameters
MAX_SEQ_LEN = 41    # Maximum length of SQL queries
EMBEDDING_DIM = 128  # Embedding dimension
HIDDEN_DIM = 1024     # Hidden layer size in the Transformer
NUM_HEADS = 4        # Number of attention heads
NUM_LAYERS = 2       # Number of Transformer layers
DROPOUT = 0.1        # Dropout rate
EPOCHS = 10          # Number of epochs
BATCH_SIZE = 64      # Batch size
LEARNING_RATE = 1e-4 # Learning rate

In [5]:
class SQLDataset(Dataset):
    def __init__(self, queries, tokenizer, max_len=MAX_SEQ_LEN):
        self.queries = queries
        self.tokenizer = tokenizer
        self.max_len = max_len

        embeddings = []
        for query in queries:
            tokenized_query = self.tokenizer(query)
            tokenized_query = tokenized_query[:self.max_len]
            padding = [0] * (self.max_len - len(tokenized_query))
            input_sequence = torch.tensor(tokenized_query + padding)
            embeddings.append(input_sequence)

        self.embeddings = torch.stack(embeddings)

    def __len__(self):
        return len(self.queries)

    def __getitem__(self, idx):
        input_sequence = self.embeddings[idx]
        return input_sequence, input_sequence

In [6]:
class TransformerAutoencoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_heads, num_layers, max_len, dropout):
        super(TransformerAutoencoder, self).__init__()

        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.positional_encoding = self._generate_positional_encoding(max_len, embedding_dim)
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(embedding_dim, num_heads, hidden_dim, dropout),
            num_layers
        )
        self.decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(embedding_dim, num_heads, hidden_dim, dropout),
            num_layers
        )
        self.output_layer = nn.Linear(embedding_dim, vocab_size)

    def _generate_positional_encoding(self, max_len, embedding_dim):
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embedding_dim, 2) * (-torch.log(torch.tensor(10000.0)) / embedding_dim))
        pe = torch.zeros(max_len, embedding_dim)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe.unsqueeze(0)

    def forward(self, src, tgt):
        src_embedding = self.embedding(src) + self.positional_encoding[:, :src.size(1), :]
        tgt_embedding = self.embedding(tgt) + self.positional_encoding[:, :tgt.size(1), :]

        encoded = self.encoder(src_embedding)
        decoded = self.decoder(tgt_embedding, encoded)

        return self.output_layer(decoded)

In [7]:
epoch_losses = []
def train_model(model, dataloader, criterion, optimizer, epochs):
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        pbar = tqdm(dataloader, desc=f'Epoch {epoch + 1}:')
        for input_seq, target_seq in pbar:
            optimizer.zero_grad()
            output = model(input_seq, input_seq)
            loss = criterion(output.view(-1, output.size(-1)), target_seq.view(-1))
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            pbar.set_postfix(loss=loss.item())
        pbar.set_postfix(epoch_loss=epoch_loss)
        epoch_losses.append(epoch_loss)

In [None]:
dataset = SQLDataset(generic_sqls, tokenizer)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# Initialize model, optimizer, and loss function
VOCAB_SIZE = len(tokens_embedding)  # Example vocabulary size (depends on your tokenizer)
model = TransformerAutoencoder(VOCAB_SIZE, EMBEDDING_DIM, HIDDEN_DIM, NUM_HEADS, NUM_LAYERS, MAX_SEQ_LEN, DROPOUT)

criterion = nn.CrossEntropyLoss(ignore_index=0)  # Ignore padding tokens in the loss
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Train the model
train_model(model, dataloader, criterion, optimizer, EPOCHS)

Epoch 1:: 100%|███████████████| 1563/1563 [05:23<00:00,  4.84it/s, loss=0.00912]
Epoch 2:: 100%|███████████████| 1563/1563 [05:27<00:00,  4.77it/s, loss=0.00218]
Epoch 3:: 100%|██████████████| 1563/1563 [05:30<00:00,  4.73it/s, loss=0.000789]
Epoch 4::  49%|███████▍       | 772/1563 [02:49<03:07,  4.22it/s, loss=0.000472]