In [43]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import lightning as L

import tiktoken

In [29]:
encodings = tiktoken.get_encoding('p50k_base')

In [38]:
vocab_size = encodings.n_vocab
embed_dim = 256
max_seq_len = 512

In [45]:
class SQLGen(L.LightningModule):
    def __init__(self, vocab_size, embed_dim, max_seq_len):
        super().__init__()
        
        self.token_embeddings = nn.Embedding(vocab_size, embed_dim)
        self.position_embeddings = nn.Embedding(max_seq_len, embed_dim)

        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        token_embeds = self.token_embeddings(x)

        positions = torch.arange(0, x.size(1)).expand(x.size(0), x.size(1))
        position_embeds = self.position_embeddings(positions)

        embeds = token_embeds + position_embeds

        Q = self.query(embeds)
        K = self.key(embeds)
        V = self.value(embeds)

        scores = torch.matmul(Q, K.transpose(-2, -1))
        mask = torch.triu(torch.ones(scores.size(1), scores.size(1)) * float('-inf'), diagonal=1)
        scores += mask.unsqueeze(0)

        attn_weights = F.softmax(scores, dim=-1)
        attn_V = torch.matmul(attn_weights, V)


