<a href="https://colab.research.google.com/github/andrew98450/StableGAN-LM/blob/main/StableGAN_LM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from tokenizers import Tokenizer, models, pre_tokenizers, decoders, trainers, processors
from transformers import GPT2TokenizerFast
import pandas as pd
import torch

In [2]:
# Install dependencies as needed:
# pip install kagglehub[pandas-datasets]
import kagglehub
from kagglehub import KaggleDatasetAdapter

# Set the path to the file you'd like to load
file_path = "spider_text_sql.csv"

# Load the latest version
df = kagglehub.load_dataset(
  KaggleDatasetAdapter.PANDAS,
  "mohammadnouralawad/spider-text-sql",
  file_path,
  # Provide any additional arguments like
  # sql_query or pandas_kwargs. See the
  # documenation for more information:
  # https://github.com/Kaggle/kagglehub/blob/main/README.md#kaggledatasetadapterpandas
)

print("First 5 records:", df.head())

  df = kagglehub.load_dataset(


First 5 records:                                           text_query  \
0  How many heads of the departments are older th...   
1  List the name, born state and age of the heads...   
2  List the creation year, name and budget of eac...   
3  What are the maximum and minimum budget of the...   
4  What is the average number of employees of the...   

                                         sql_command  
0         SELECT count(*) FROM head WHERE age  >  56  
1  SELECT name ,  born_state ,  age FROM head ORD...  
2  SELECT creation ,  name ,  budget_in_billions ...  
3  SELECT max(budget_in_billions) ,  min(budget_i...  
4  SELECT avg(num_employees) FROM department WHER...  


In [None]:
# Initialize a tokenizer
tokenizer = Tokenizer(models.BPE())

# Customize pre-tokenization and decoding
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=True)
tokenizer.decoder = decoders.ByteLevel()
tokenizer.post_processor = processors.ByteLevel(trim_offsets=True)

# And then train
trainer = trainers.BpeTrainer(
    vocab_size=1000)
tokenizer.train_from_iterator(df["text_query"], trainer=trainer)
tokenizer.train_from_iterator(df["sql_command"], trainer=trainer)

# And Save it
tokenizer.save("text2sql.json", pretty=True)

In [3]:
tokenizer = GPT2TokenizerFast(tokenizer_file="text2sql.json")
tokenizer.pad_token = tokenizer.eos_token
prompt = tokenizer.batch_encode_plus(df["text_query"].to_list(), max_length=60, truncation=True, padding="max_length", return_tensors="pt")
target = tokenizer.batch_encode_plus(df["sql_command"].to_list(), max_length=60, truncation=True, padding="max_length", return_tensors="pt")
prompt = prompt["input_ids"]
target = target["input_ids"]

In [4]:
class Encoder(torch.nn.Module):

    def __init__(self, vocab_size, seq_len) -> None:
        super(Encoder, self).__init__()

        self.embedding_layer = torch.nn.Embedding(vocab_size, 64, padding_idx=1000)

        self.mha_layer = torch.nn.MultiheadAttention(64, num_heads=8, dropout=0.1, batch_first=True)

        self.fc_mean_layer = torch.nn.Sequential(
            torch.nn.Flatten(),
            torch.nn.Linear(seq_len * 64, 64))
        self.fc_std_layer = torch.nn.Sequential(
            torch.nn.Flatten(),
            torch.nn.Linear(seq_len * 64, 64))

    def forward(self, inputs):
        outputs = self.embedding_layer(inputs)
        outputs, _ = self.mha_layer(outputs, outputs, outputs)
        mean = self.fc_mean_layer(outputs)
        std = torch.exp(self.fc_std_layer(outputs))
        eps = torch.randn_like(mean)
        z = mean + (std * eps)
        return z

class Decoder(torch.nn.Module):

    def __init__(self, vocab_size, seq_len) -> None:
        super(Decoder, self).__init__()
        self.seq_len = seq_len
        self.latent_layer = torch.nn.Linear(64, seq_len * 64)

        self.mha_layer = torch.nn.MultiheadAttention(64, num_heads=8, dropout=0.1, batch_first=True)

        self.fc_layer = torch.nn.Linear(64, vocab_size)

    def forward(self, inputs):
        outputs = self.latent_layer(inputs)
        outputs = torch.reshape(outputs, (-1, self.seq_len, 64))
        outputs, _ = self.mha_layer(outputs, outputs, outputs)
        outputs = self.fc_layer(outputs)
        return outputs

class VAE(torch.nn.Module):

    def __init__(self, vocab_size, seq_len) -> None:
        super(VAE, self).__init__()
        self.seq_len = seq_len
        self.encoder_layer = Encoder(vocab_size, seq_len)
        self.decoder_layer = Decoder(vocab_size, seq_len)

    def forward(self, inputs):
        encoded = self.encoder_layer(inputs)
        outputs = self.decoder_layer(encoded)
        return outputs

In [None]:
epochs = 50
lr = 0.001
batch_size = 64
sample_size = 5
vocab_size = 1001
seq_len = 60
model = VAE(vocab_size, seq_len)

model.cuda()

optim = torch.optim.Adam(model.parameters(), lr=lr)

from torch.utils.data import TensorDataset, DataLoader

datasets = TensorDataset(target, prompt)
dataloader = DataLoader(datasets, batch_size=batch_size, shuffle=True, drop_last=True)

ce_loss = torch.nn.CrossEntropyLoss()

In [None]:
for epoch in range(1, epochs+1):
  for iters, (x, _) in enumerate(dataloader, 0):
      x = x.long().cuda()
      outputs = model(x)
      loss = ce_loss(outputs.view(-1, vocab_size), x.view(-1))
      optim.zero_grad()
      loss.backward()
      optim.step()
      if iters % 50 == 0:
        print(f"[+] Epoch: [{epoch}/{epochs}] Loss: {loss.item()}")
        with torch.no_grad():
          output = torch.distributions.Categorical(logits=outputs).sample()
          print(f"Output: \n {tokenizer.batch_decode(output, skip_special_tokens=True)[0]}")
model.cpu()
torch.save(model.state_dict(), 'vae_model.pt')

In [5]:
vocab_size = 1001
seq_len = 60
vae_model = VAE(vocab_size, seq_len)
vae_model.load_state_dict(torch.load("./vae_model.pt"))
vae_model.cuda()

for parm in vae_model.parameters():
    parm.requires_grad = False

def encoder(inputs):
    with torch.no_grad():
        outputs = vae_model.encoder_layer(inputs)
    return outputs

def decoder(inputs):
    with torch.no_grad():
        outputs = vae_model.decoder_layer(inputs)
    return outputs

In [6]:
class ResLinear(torch.nn.Module):

    def __init__(self, feature) -> None:
        super(ResLinear, self).__init__()
        self.fc_layer = torch.nn.Sequential(
            torch.nn.Linear(feature, 256),
            torch.nn.SiLU(),
            torch.nn.Linear(256, feature))
        self.layer_norm = torch.nn.LayerNorm(feature)

    def forward(self, inputs):
        outputs = self.fc_layer(inputs)
        outputs = self.layer_norm(inputs + outputs)
        return outputs

class Generator(torch.nn.Module):
    def __init__(self, vocab_size, seq_len) -> None:
      super(Generator, self).__init__()
      self.seq_len = seq_len
      self.latent_embedding = torch.nn.Linear(128, 128)
      self.cond_embedding = torch.nn.Sequential(
          torch.nn.Embedding(vocab_size, 128, padding_idx=1000),
          torch.nn.Flatten(),
          torch.nn.Linear(seq_len * 128, 128))
      self.res_fc = torch.nn.Sequential(
          ResLinear(256),
          ResLinear(256),
          ResLinear(256))
      self.latent_fc = torch.nn.Linear(256, 64)

    def forward(self, z, c):
      z = self.latent_embedding(z)
      c = self.cond_embedding(c)
      x = torch.cat((z, c), dim=-1)
      x = self.res_fc(x)
      x = self.latent_fc(x)
      return x

class Discriminator(torch.nn.Module):
    def __init__(self, vocab_size, seq_len = 60) -> None:
      super(Discriminator, self).__init__()
      self.seq_len = seq_len
      self.latent_embedding = torch.nn.Linear(64, 128)
      self.cond_embedding = torch.nn.Sequential(
          torch.nn.Embedding(vocab_size, 128, padding_idx=1000),
          torch.nn.Flatten(),
          torch.nn.Linear(seq_len * 128, 128))
      self.res_fc = torch.nn.Sequential(
          ResLinear(256),
          ResLinear(256),
          ResLinear(256))
      self.prob_fc = torch.nn.Sequential(
          torch.nn.Linear(256, 1),
          torch.nn.Sigmoid())

    def forward(self, x, c):
      x = self.latent_embedding(x)
      c = self.cond_embedding(c)
      x = torch.cat((x, c), dim=-1)
      x = self.res_fc(x)
      x = self.prob_fc(x)
      return x

In [7]:
epochs = 300
lr = 0.0002
batch_size = 64
seq_len = 60
sample_size = 5
vocab_size = 1001
G = Generator(vocab_size, seq_len)
D = Discriminator(vocab_size, seq_len)

G.cuda()
D.cuda()

g_optim = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
d_optim = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

from torch.utils.data import TensorDataset, DataLoader

datasets = TensorDataset(target, prompt)
dataloader = DataLoader(datasets, batch_size=batch_size, shuffle=True, drop_last=True)

bce_loss = torch.nn.BCELoss()

In [8]:
for epoch in range(1, epochs+1):
  for iters, (x, c) in enumerate(dataloader, 0):
      x_real = x.long().cuda()
      z_real = encoder(x_real)

      z = torch.randn(batch_size, 128).cuda()
      c = c.long().cuda()
      with torch.no_grad():
        z_fake = G(z, c)
      d_real = D(z_real, c)
      d_fake = D(z_fake, c)
      d_loss = bce_loss(d_real, torch.ones_like(d_real)) + bce_loss(d_fake, torch.zeros_like(d_fake))
      d_optim.zero_grad()
      d_loss.backward()
      d_optim.step()

      z_fake = G(z, c)
      d_fake = D(z_fake, c)
      g_loss = bce_loss(d_fake, torch.ones_like(d_fake))
      g_optim.zero_grad()
      g_loss.backward()
      g_optim.step()

      if iters % 50 == 0:
        print(f"[+] Epoch: [{epoch}/{epochs}] G_Loss: {g_loss.item()} D_Loss: {d_loss.item()}")
        z = torch.randn(batch_size, 128).cuda()
        with torch.no_grad():
          output = G(z, c)
          output = torch.distributions.Categorical(logits=decoder(output)).sample()


          prompts = tokenizer.batch_decode(c, skip_special_tokens=True)[0]
          outputs = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
          print(f"Prompt: \n {prompts}")
          print(f"Output: \n {outputs}")
G.cpu()
D.cpu()

torch.save(G.state_dict(), 'ganlm_modelG.pt')
torch.save(D.state_dict(), 'ganlm_modelD.pt')

[+] Epoch: [1/300] G_Loss: 1.974495768547058 D_Loss: 1.6751960515975952
Prompt: 
 How many distinct programs are broadcast at "Night" time
Output: 
  countces Room Col BETWEEN0ntinentours chitionert T 19 1ct1z LIMITHamountCEPTm ,Code:0=en8utes NOTork4k UNIONuidTS 10ex DESCq districthereERE"'3 3 <uilding namedateB raprice01act0h
[+] Epoch: [1/300] G_Loss: 2.681807279586792 D_Loss: 0.6898761987686157
Prompt: 
 What is the date of birth of every customer whose status code is 'Good Customer'
Output: 
 SELECT deillSelcocolchcoNnCSonKs
[+] Epoch: [1/300] G_Loss: 1.5329492092132568 D_Loss: 0.6683480739593506
Prompt: 
 Find the name all districts with city area greater than 10 or population larger than 100000
Output: 
 SELECT cnamel authorssla date DESC  LOCATION DESCea DESC > ASC
[+] Epoch: [2/300] G_Loss: 3.152301788330078 D_Loss: 0.693906843662262
Prompt: 
 Find out the top 10 customers by total number of orders. List customers' first and last name and the number of total orders.
Output: 
 

In [74]:
tokenizer = GPT2TokenizerFast(tokenizer_file="text2sql.json")
tokenizer.pad_token = tokenizer.eos_token
c = tokenizer.encode("How many heads of the departments are older than 56 ?", truncation=True, max_length=60, padding="max_length", return_tensors="pt")
c = c.long().cuda()
G = Generator(vocab_size, seq_len)
G.load_state_dict(torch.load("./ganlm_modelG.pt"))
G.cuda()
G.eval()

z = torch.randn(1, 128).cuda()
with torch.no_grad():
  output = G(z, c)
  output = torch.distributions.Categorical(logits=decoder(output)).sample()

output = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
print(f"Output: \n {output}")


Output: 
 SELECT count(*) FROM aiream WHERE age  >  12
