In [43]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import lightning as L

import tiktoken

In [29]:
encodings = tiktoken.get_encoding('p50k_base')

In [51]:
# CONFIG

vocab_size = encodings.n_vocab
embed_dim = 256
max_seq_len = 512
heads = 4
stacks = 4

In [57]:
class SQLGen(L.LightningModule):
    def __init__(self, vocab_size, embed_dim, max_seq_len, num_heads, num_stacks):
        super().__init__()

        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        self.token_embeddings = nn.Embedding(vocab_size, embed_dim)
        self.position_embeddings = nn.Embedding(max_seq_len, embed_dim)

        self.query = nn.ModuleList([nn.Linear(embed_dim, embed_dim) for _ in range(num_stacks)])
        self.key = nn.ModuleList([nn.Linear(embed_dim, embed_dim) for _ in range(num_stacks)])
        self.value = nn.ModuleList([nn.Linear(embed_dim, embed_dim) for _ in range(num_stacks)])

        self.final_attn_linear = nn.ModuleList([nn.Linear(embed_dim, embed_dim) for _ in range(num_stacks)])

        self.pt_wise_ffn = nn.ModuleList([nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4), 
            nn.ReLU(), 
            nn.Linear(embed_dim * 4, embed_dim)
        ) for _ in range(num_stacks)])

        self.layer_norm_1 = nn.ModuleList([nn.LayerNorm(embed_dim) for _ in range(num_stacks)])
        self.layer_norm_2 = nn.ModuleList([nn.LayerNorm(embed_dim) for _ in range(num_stacks)])

        self.output_projection = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        token_embeds = self.token_embeddings(x)

        positions = torch.arange(0, x.size(1), device=x.device).expand(x.size(0), x.size(1))
        position_embeds = self.position_embeddings(positions)

        x = token_embeds + position_embeds

        for i in range(len(self.query)):
            x = self.layer_norm_1[i](embeds)
            residual = x

            batch_size = layer_norm_embeds.size(0)
            seq_length = layer_norm_embeds.size(1)

            Q = self.query[i](x).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
            K = self.key[i](x).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
            V = self.value[i](x).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)

            scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
            mask = torch.triu(torch.ones(scores.size(1), scores.size(1), device=x.device) * float('-inf'), diagonal=1)
            scores += mask.unsqueeze(0).unsqueeze(0)

            attn_weights = F.softmax(scores, dim=-1)
        
            x = torch.matmul(attn_weights, V)
            x = x.transpose(1, 2).contiguous().view(batch_size, seq_length, -1)
            x = self.final_attn_linear(x)

            x += residual
            x = self.layer_norm_2[i](x)

            residual = x
            x = self.pt_wise_ffn[i](x)
            x += residual

        return x