# Models for Translating Flight Searches to SQL Queries, Matching the [ATIS Dataset Schema](https://github.com/jkkummerfeld/text2sql-data/blob/master/data/atis-schema.csv)



1.   Seq2Seq Model with Self-Attention
2.   Seq2Seq Model with Self-Attention and Cross Attention
3.   Fine-tuned BART Transformer




In [44]:
!pip install -r requirements.txt
import os



# Setup

In [8]:
!pip install datasets

import copy
import datetime
import math
import os
import re
import sys
import warnings

import wget
import nltk
import sqlite3
import csv
import torch
import torch.nn as nn
import datasets

from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers import Regex
from tokenizers.pre_tokenizers import WhitespaceSplit, Split
from tokenizers.processors import TemplateProcessing
from tokenizers import normalizers
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from transformers import PreTrainedTokenizerFast

from cryptography.fernet import Fernet
from func_timeout import func_set_timeout
from torch.nn.utils.rnn import pack_padded_sequence as pack
from torch.nn.utils.rnn import pad_packed_sequence as unpack
from tqdm import tqdm
from transformers import BartTokenizer, BartForConditionalGeneration

Collecting datasets
  Downloading datasets-2.18.0-py3-none-any.whl (510 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m510.5/510.5 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m16.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m14.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: xxhash, multiprocess, datasets
Successfully installed datasets-2.18.0 multiprocess-0.70.16 xxhash-3.4.1


In [9]:
# Set random seeds
seed = 1234
torch.manual_seed(seed)
# Set timeout for executing SQL
TIMEOUT = 3 # seconds

# GPU check: Set runtime type to use GPU where available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print (device)

cuda


In [10]:
## Download needed scripts and data
def download_if_needed(source, dest, filename):
    os.path.exists(f"./{dest}{filename}") or wget.download(source + filename, out=dest)

os.makedirs('data', exist_ok=True)
os.makedirs('scripts', exist_ok=True)
source_url = "https://raw.githubusercontent.com/nlp-course/data/master"

# Download scripts and ATIS database
download_if_needed(source_url, "scripts/", "/scripts/trees/transform.py")
download_if_needed(source_url, "data/", "/ATIS/atis_sqlite.db")

In [12]:
# Acquire the datasets - training, development, and test splits of the
# ATIS queries and corresponding SQL queries
for path in ["/ATIS/test_flightid.nl",  "/ATIS/test_flightid.sql",
             "/ATIS/dev_flightid.nl",   "/ATIS/dev_flightid.sql",
             "/ATIS/train_flightid.nl", "/ATIS/train_flightid.sql"]:
    download_if_needed(source_url, "data/", path)

In [13]:
# Process data into CSV files
for split in ['train', 'dev', 'test']:
    src_in_file = f'data/{split}_flightid.nl'
    tgt_in_file = f'data/{split}_flightid.sql'
    out_file = f'data/{split}.csv'

    with open(src_in_file, 'r') as f_src_in, open(tgt_in_file, 'r') as f_tgt_in:
        with open(out_file, 'w') as f_out:
            src, tgt= [], []
            writer = csv.writer(f_out)
            writer.writerow(('src','tgt'))
            for src_line, tgt_line in zip(f_src_in, f_tgt_in):
                writer.writerow((src_line.strip(), tgt_line.strip()))

## Corpus preprocessing

In [30]:
## NLTK Tokenizer
tokenizer_pattern = '\d+|st\.|[\w-]+|\$[\d\.]+|\S+'
nltk_tokenizer = nltk.tokenize.RegexpTokenizer(tokenizer_pattern)
def tokenize_nltk(string):
  return nltk_tokenizer.tokenize(string.lower())

In [16]:
dataset = load_dataset('csv', data_files={'train':f'data/train.csv', \
                                          'val': f'data/dev.csv', \
                                          'test': f'data/test.csv'})
dataset

Generating train split: 0 examples [00:00, ? examples/s]

Generating val split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['src', 'tgt'],
        num_rows: 3651
    })
    val: Dataset({
        features: ['src', 'tgt'],
        num_rows: 398
    })
    test: Dataset({
        features: ['src', 'tgt'],
        num_rows: 332
    })
})

In [17]:
train_data = dataset['train']
val_data = dataset['val']
test_data = dataset['test']

In [18]:
MIN_FREQ = 3
unk_token = '[UNK]'
pad_token = '[PAD]'
bos_token = '<bos>'
eos_token = '<eos>'

## source tokenizer
src_tokenizer = Tokenizer(WordLevel(unk_token=unk_token))
src_tokenizer.normalizer = normalizers.Lowercase()
src_tokenizer.pre_tokenizer = Split(Regex(tokenizer_pattern),
                                    behavior='removed',
                                    invert=True)

src_trainer = WordLevelTrainer(min_frequency=MIN_FREQ,
                               special_tokens=[pad_token, unk_token])
src_tokenizer.train_from_iterator(train_data['src'],
                                  trainer=src_trainer)

## target tokenizer
tgt_tokenizer = Tokenizer(WordLevel(unk_token=unk_token))
tgt_tokenizer.pre_tokenizer = WhitespaceSplit()

tgt_trainer = WordLevelTrainer(min_frequency=MIN_FREQ,
                               special_tokens=[pad_token, unk_token,
                                               bos_token, eos_token])
tgt_tokenizer.train_from_iterator(train_data['tgt'],
                                  trainer=tgt_trainer)
tgt_tokenizer.post_processor = \
    TemplateProcessing(single=f"{bos_token} $A {eos_token}",
                       special_tokens=[(bos_token,
                                        tgt_tokenizer.token_to_id(bos_token)),
                                       (eos_token,
                                        tgt_tokenizer.token_to_id(eos_token))])

In [19]:
hf_src_tokenizer = PreTrainedTokenizerFast(tokenizer_object=src_tokenizer,
                                           pad_token=pad_token)
hf_tgt_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tgt_tokenizer,
                                           pad_token=pad_token)

def encode(example):
    example['src_ids'] = hf_src_tokenizer(example['src']).input_ids
    example['tgt_ids'] = hf_tgt_tokenizer(example['tgt']).input_ids
    return example

train_data = train_data.map(encode)
val_data = val_data.map(encode)
test_data = test_data.map(encode)

Map:   0%|          | 0/3651 [00:00<?, ? examples/s]

Map:   0%|          | 0/398 [00:00<?, ? examples/s]

Map:   0%|          | 0/332 [00:00<?, ? examples/s]

In [20]:
# Compute size of vocabulary
src_vocab = src_tokenizer.get_vocab()
tgt_vocab = tgt_tokenizer.get_vocab()

print(f"Size of English vocab: {len(src_vocab)}")
print(f"Size of SQL vocab: {len(tgt_vocab)}")
print(f"Index for src padding: {src_vocab[pad_token]}")
print(f"Index for tgt padding: {tgt_vocab[pad_token]}")
print(f"Index for start of sequence token: {tgt_vocab[bos_token]}")
print(f"Index for end of sequence token: {tgt_vocab[eos_token]}")

Size of English vocab: 421
Size of SQL vocab: 392
Index for src padding: 0
Index for tgt padding: 0
Index for start of sequence token: 2
Index for end of sequence token: 3


Batching to facilitate GPU processing:

In [21]:
BATCH_SIZE = 16
TEST_BATCH_SIZE = 1

# Defines how to batch a list of examples together
def collate_fn(examples):
    batch = {}
    bsz = len(examples)
    src_ids, tgt_ids = [], []
    for example in examples:
        src_ids.append(example['src_ids'])
        tgt_ids.append(example['tgt_ids'])

    src_len = torch.LongTensor([len(word_ids) for word_ids in src_ids]).to(device)
    src_max_length = max(src_len)
    tgt_max_length = max([len(word_ids) for word_ids in tgt_ids])

    src_batch = torch.zeros(bsz, src_max_length).long().fill_(src_vocab[pad_token]).to(device)
    tgt_batch = torch.zeros(bsz, tgt_max_length).long().fill_(tgt_vocab[pad_token]).to(device)
    for b in range(bsz):
        src_batch[b][:len(src_ids[b])] = torch.LongTensor(src_ids[b]).to(device)
        tgt_batch[b][:len(tgt_ids[b])] = torch.LongTensor(tgt_ids[b]).to(device)

    batch['src_lengths'] = src_len
    batch['src_ids'] = src_batch
    batch['tgt_ids'] = tgt_batch
    return batch

train_iter = torch.utils.data.DataLoader(train_data,
                                         batch_size=BATCH_SIZE,
                                         shuffle=True,
                                         collate_fn=collate_fn)
val_iter = torch.utils.data.DataLoader(val_data,
                                       batch_size=BATCH_SIZE,
                                       shuffle=False,
                                       collate_fn=collate_fn)
test_iter = torch.utils.data.DataLoader(test_data,
                                        batch_size=TEST_BATCH_SIZE,
                                        shuffle=False,
                                        collate_fn=collate_fn)

Examining a single batch:



In [22]:
batch = next(iter(train_iter))
src_ids = batch['src_ids']
src_example = src_ids[2]
print (f"Size of text batch: {src_ids.size()}")
print (f"Third sentence in batch: {src_example}")
print (f"Length of the third sentence in batch: {len(src_example)}")
print (f"Converted back to string: {hf_src_tokenizer.decode(src_example)}")

tgt_ids = batch['tgt_ids']
tgt_example = tgt_ids[2]
print (f"Size of sql batch: {tgt_ids.size()}")
print (f"Third sql in batch: {tgt_example}")
print (f"Converted back to string: {hf_tgt_tokenizer.decode(tgt_example)}")

Size of text batch: torch.Size([16, 30])
Third sentence in batch: tensor([ 9,  7,  4,  3, 13, 16,  2, 11,  6, 69,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0], device='cuda:0')
Length of the third sentence in batch: 30
Converted back to string: show me flights from san francisco to boston on thursday [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
Size of sql batch: torch.Size([16, 153])
Third sql in batch: tensor([  2,  14,  31,  11,  13,  12,  16,   6,   7,  22,   6,   8,  23,   6,
          7,  29,   6,   8,  30,   6,  33,  40,   6,  38,  46,  15,  21,   4,
         18,   5,  19,   4,  17,   5,  20,   4,  54,  56,   5,   9,  24,   4,
         25,   5,  26,   4,  27,   5,  28,   4,  52,   5,  34,   4,  36,   5,
         37,   4,  41,   5,  44,   4,  35,   5,  43,   4, 103,   5,  42,   4,
        126,  10,   3,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
      

Sample question and corresponding SQL:

In [23]:
for _, example in zip(range(1), train_data):
  train_text_1 = example['src'] # detokenized question
  train_sql_1 = example['tgt']  # detokenized sql
  print (f"Question: {train_text_1}\n")
  print (f"SQL: {train_sql_1}")

Question: list all the flights that arrive at general mitchell international from various cities

SQL: SELECT DISTINCT flight_1.flight_id FROM flight flight_1 , airport airport_1 , airport_service airport_service_1 , city city_1 WHERE flight_1.to_airport = airport_1.airport_code AND airport_1.airport_code = 'MKE' AND flight_1.from_airport = airport_service_1.airport_code AND airport_service_1.city_code = city_1.city_code AND 1 = 1


## Establishing a SQL database for evaluating ATIS queries, using the [Python `sqlite3` module](https://docs.python.org/3.8/library/sqlite3.html).

In [24]:
@func_set_timeout(TIMEOUT)
def execute_sql(sql):
  conn = sqlite3.connect('data/atis_sqlite.db')
  c = conn.cursor()
  c.execute(sql)
  results = list(c.fetchall())
  c.close()
  conn.close()
  return results

#### Defining a function `verify` to compare the results from the generated SQL to the ground truth SQL:

In [None]:
def verify(predicted_sql, gold_sql, silent=True):
  """
  Compare the correctness of the generated SQL by executing on the
  ATIS database and comparing the returned results.
  """
  # Execute predicted SQL
  try:
    predicted_result = execute_sql(predicted_sql)
  except BaseException as e:
    if not silent:
      print(f"predicted sql exec failed: {e}")
    return False
  if not silent:
    print("Predicted DB result:\n\n", predicted_result[:10], "\n")

  # Execute gold SQL
  try:
    gold_result = execute_sql(gold_sql)
  except BaseException as e:
    if not silent:
      print(f"gold sql exec failed: {e}")
    return False
  if not silent:
    print("Gold DB result:\n\n", gold_result[:10], "\n")

  # Verify correctness
  if gold_result == predicted_result:
    return True

In [31]:
# Example 1
example_1 = 'flights from phoenix to milwaukee'
gold_sql_1 = """
  SELECT DISTINCT flight_1.flight_id
  FROM flight flight_1 ,
       airport_service airport_service_1 ,
       city city_1 ,
       airport_service airport_service_2 ,
       city city_2
  WHERE flight_1.from_airport = airport_service_1.airport_code
        AND airport_service_1.city_code = city_1.city_code
        AND city_1.city_name = 'PHOENIX'
        AND flight_1.to_airport = airport_service_2.airport_code
        AND airport_service_2.city_code = city_2.city_code
        AND city_2.city_name = 'MILWAUKEE'
  """

In [32]:
# Example 2
example_2 = 'i would like a united flight'
gold_sql_2 = """
  SELECT DISTINCT flight_1.flight_id
  FROM flight flight_1
  WHERE flight_1.airline_code = 'UA'
  """

In [33]:
# Example 3
example_3 = 'i would like a flight between boston and dallas'
gold_sql_3 = """
  SELECT DISTINCT flight_1.flight_id
  FROM flight flight_1 ,
       airport_service airport_service_1 ,
       city city_1 ,
       airport_service airport_service_2 ,
       city city_2
  WHERE flight_1.from_airport = airport_service_1.airport_code
        AND airport_service_1.city_code = city_1.city_code
        AND city_1.city_name = 'BOSTON'
        AND flight_1.to_airport = airport_service_2.airport_code
        AND airport_service_2.city_code = city_2.city_code
        AND city_2.city_name = 'DALLAS'
  """

In [34]:
# Example 4
example_4 = 'show me the united flights from denver to baltimore'
gold_sql_4 = """
  SELECT DISTINCT flight_1.flight_id
  FROM flight flight_1 ,
       airport_service airport_service_1 ,
       city city_1 ,
       airport_service airport_service_2 ,
       city city_2
  WHERE flight_1.airline_code = 'UA'
        AND ( flight_1.from_airport = airport_service_1.airport_code
              AND airport_service_1.city_code = city_1.city_code
              AND city_1.city_name = 'DENVER'
              AND flight_1.to_airport = airport_service_2.airport_code
              AND airport_service_2.city_code = city_2.city_code
              AND city_2.city_name = 'BALTIMORE' )

  """


In [35]:
# Example 5
example_5 = 'show flights from cleveland to miami that arrive before 4pm'
gold_sql_5 = """
  SELECT DISTINCT flight_1.flight_id
  FROM flight flight_1 ,
       airport_service airport_service_1 ,
       city city_1 ,
       airport_service airport_service_2 ,
       city city_2
  WHERE flight_1.from_airport = airport_service_1.airport_code
        AND airport_service_1.city_code = city_1.city_code
        AND city_1.city_name = 'CLEVELAND'
        AND ( flight_1.to_airport = airport_service_2.airport_code
              AND airport_service_2.city_code = city_2.city_code
              AND city_2.city_name = 'MIAMI'
              AND flight_1.arrival_time < 1600 )
  """


In [39]:
# Example 6
example_6 = 'okay how about a flight on sunday from tampa to charlotte'
gold_sql_6 = """
  SELECT DISTINCT flight_1.flight_id
  FROM flight flight_1 ,
       airport_service airport_service_1 ,
       city city_1 ,
       airport_service airport_service_2 ,
       city city_2 ,
       days days_1 ,
       date_day date_day_1
  WHERE flight_1.from_airport = airport_service_1.airport_code
        AND airport_service_1.city_code = city_1.city_code
        AND city_1.city_name = 'TAMPA'
        AND ( flight_1.to_airport = airport_service_2.airport_code
              AND airport_service_2.city_code = city_2.city_code
              AND city_2.city_name = 'CHARLOTTE'
              AND flight_1.flight_days = days_1.days_code
              AND days_1.day_name = date_day_1.day_name
              AND date_day_1.year = 1991
              AND date_day_1.month_number = 8
              AND date_day_1.day_number = 27 )
  """
# The gold answer above used the exact date, as opposed to the
# following approach:
gold_sql_6b = """
  SELECT DISTINCT flight.flight_id
  FROM flight
  WHERE ((((1
            AND flight.flight_days IN (SELECT days.days_code
                                       FROM days
                                       WHERE days.day_name = 'SUNDAY')
            )
           AND flight.from_airport IN (SELECT airport_service.airport_code
                                       FROM airport_service
                                       WHERE airport_service.city_code IN (SELECT city.city_code
                                                                           FROM city
                                                                           WHERE city.city_name = "TAMPA")))
          AND flight.to_airport IN (SELECT airport_service.airport_code
                                    FROM airport_service
                                    WHERE airport_service.city_code IN (SELECT city.city_code
                                                                        FROM city
                                                                        WHERE city.city_name = "CHARLOTTE"))))
  """


In [41]:
# Example 7
example_7 = 'list all flights going from boston to atlanta that leaves before 7 am on thursday'
gold_sql_7 = """
  SELECT DISTINCT flight_1.flight_id
  FROM flight flight_1 ,
       airport_service airport_service_1 ,
       city city_1 ,
       airport_service airport_service_2 ,
       city city_2 ,
       days days_1 ,
       date_day date_day_1
  WHERE flight_1.from_airport = airport_service_1.airport_code
        AND airport_service_1.city_code = city_1.city_code
        AND city_1.city_name = 'BOSTON'
        AND ( flight_1.to_airport = airport_service_2.airport_code
              AND airport_service_2.city_code = city_2.city_code
              AND city_2.city_name = 'ATLANTA'
              AND ( flight_1.flight_days = days_1.days_code
                    AND days_1.day_name = date_day_1.day_name
                    AND date_day_1.year = 1991
                    AND date_day_1.month_number = 5
                    AND date_day_1.day_number = 24
                    AND flight_1.departure_time < 700 ) )
  """

# Again, the gold answer above used the exact date, as opposed to the
# following approach:
gold_sql_7b = """
  SELECT DISTINCT flight.flight_id
  FROM flight
  WHERE ((1
          AND ((((1
                  AND flight.from_airport IN (SELECT airport_service.airport_code
                                              FROM airport_service
                                              WHERE airport_service.city_code IN (SELECT city.city_code
                                                                                  FROM city
                                                                                  WHERE city.city_name = "BOSTON")))
                 AND flight.to_airport IN (SELECT airport_service.airport_code
                                           FROM airport_service
                                           WHERE airport_service.city_code IN (SELECT city.city_code
                                                                               FROM city
                                                                               WHERE city.city_name = "ATLANTA")))
                AND flight.departure_time <= 0700)
               AND flight.flight_days IN (SELECT days.days_code
                                          FROM days
                                          WHERE days.day_name = 'THURSDAY'))))
  """


In [42]:
# Example 8
example_8 = 'list the flights from dallas to san francisco on american airlines'
gold_sql_8 = """
  SELECT DISTINCT flight_1.flight_id
  FROM flight flight_1 ,
       airport_service airport_service_1 ,
       city city_1 ,
       airport_service airport_service_2 ,
       city city_2
  WHERE flight_1.airline_code = 'AA'
        AND ( flight_1.from_airport = airport_service_1.airport_code
              AND airport_service_1.city_code = city_1.city_code
              AND city_1.city_name = 'DALLAS'
              AND flight_1.to_airport = airport_service_2.airport_code
              AND airport_service_2.city_code = city_2.city_code
              AND city_2.city_name = 'SAN FRANCISCO' )
  """


### Systematic evaluation on a test set:

In [43]:
def evaluate(predictor, dataset, num_examples=0, silent=True):
  """Evaluate accuracy of `predictor` by executing predictions on a
  SQL database and comparing returned results against those of gold queries.
  """
  # Prepare to count results
  if num_examples <= 0:
    num_examples = len(dataset)
  example_count = 0
  predicted_count = 0
  correct = 0
  incorrect = 0

  # Process the examples from the dataset
  for _, example in tqdm(zip(range(num_examples), dataset)):
    example_count += 1
    # obtain query SQL
    predicted_sql = predictor(example['src'])
    if predicted_sql == None:
      continue
    predicted_count += 1
    # obtain gold SQL
    gold_sql = example['tgt']

    # check that they're compatible
    if verify(predicted_sql, gold_sql):
      correct += 1
    else:
      incorrect += 1

  # Compute and return precision, recall, F1
  precision = correct / predicted_count if predicted_count > 0 else 0
  recall = correct / example_count
  f1 = (2 * precision * recall) / (precision + recall) if precision + recall > 0 else 0
  return precision, recall, f1

### Seq2seq model (with self-attention)

In [28]:
class Beam():
  """
  Helper class for storing a hypothesis, its score and its decoder hidden state.
  """
  def __init__(self, decoder_state, tokens, score):
    self.decoder_state = decoder_state
    self.tokens = tokens
    self.score = score

In [29]:
def attention(batched_Q, batched_K, batched_V, mask=None):
    """
    Performs the attention operation and returns the attention matrix
    `batched_A` and the context matrix `batched_C` using queries
    `batched_Q`, keys `batched_K`, and values `batched_V`.

    Arguments:
        batched_Q: (bsz, q_len, D)
        batched_K: (bsz, k_len, D)
        batched_V: (bsz, k_len, D)
        mask: (bsz, q_len, k_len). An optional boolean mask *disallowing*
              attentions where the mask value is *`False`*.
    Returns:
        batched_A: the normalized attention scores (bsz, q_len, k_len)
        batched_C: a tensor of size (bsz, q_len, D).
    """
    # Initialize and extract dimensions and confirm input dimensions
    D = batched_Q.size(-1)
    bsz = batched_Q.size(0)
    q_len = batched_Q.size(1)
    k_len = batched_K.size(1)
    assert batched_K.size(-1) == D and batched_V.size(-1) == D
    assert batched_K.size(0) == bsz and batched_V.size(0) == bsz
    assert batched_V.size(1) == k_len
    if mask is not None:
      assert mask.size() == torch.Size([bsz, q_len, k_len])

    # Compute attention scores
    mult = torch.bmm(batched_Q, batched_K.transpose(1, 2))
    if mask is not None:
      mult = mult.masked_fill((mask == False), float('-inf'))
    batched_A = torch.softmax(mult, dim=-1)

    # Compute context matrix
    batched_C = torch.bmm(batched_A,batched_V)

    # Ensure attention scores correctly sum to 1
    assert torch.all(torch.isclose(batched_A.sum(-1),
                                  torch.ones(bsz, q_len).to(device)))
    return batched_A, batched_C

class AttnEncoderDecoder(nn.Module):
  def __init__(self, hf_src_tokenizer, hf_tgt_tokenizer, hidden_size=64, layers=3):
    """
    Initializer. Creates network modules and loss function.
    Arguments:
        hf_src_tokenizer: hf src tokenizer
        hf_tgt_tokenizer: hf tgt tokenizer
        hidden_size: hidden layer size of both encoder and decoder
        layers: number of layers of both encoder and decoder
    """
    super().__init__()
    self.hf_src_tokenizer = hf_src_tokenizer
    self.hf_tgt_tokenizer = hf_tgt_tokenizer

    # Keep the vocabulary sizes available
    self.V_src = len(self.hf_src_tokenizer)
    self.V_tgt = len(self.hf_tgt_tokenizer)

    # Get special word ids
    self.padding_id_src = self.hf_src_tokenizer.pad_token_id
    self.padding_id_tgt = self.hf_tgt_tokenizer.pad_token_id
    self.bos_id = self.hf_tgt_tokenizer.get_vocab()[bos_token]
    self.eos_id = self.hf_tgt_tokenizer.get_vocab()[eos_token]

    # Keep hyper-parameters available
    self.embedding_size = hidden_size
    self.hidden_size = hidden_size
    self.layers = layers

    # Create essential modules
    self.word_embeddings_src = nn.Embedding(self.V_src, self.embedding_size)
    self.word_embeddings_tgt = nn.Embedding(self.V_tgt, self.embedding_size)

    # RNN cells
    self.encoder_rnn = nn.LSTM(
      input_size    = self.embedding_size,
      hidden_size   = hidden_size // 2, # to match decoder hidden size
      num_layers    = layers,
      batch_first=True,
      bidirectional = True              # bidirectional encoder
    )
    self.decoder_rnn = nn.LSTM(
      input_size    = self.embedding_size,
      hidden_size   = hidden_size,
      num_layers    = layers,
      batch_first=True,
      bidirectional = False             # unidirectional decoder
    )

    # Final projection layer
    self.hidden2output = nn.Linear(2*hidden_size, self.V_tgt) # project the concatenation to logits

    # Create loss function
    self.loss_function = nn.CrossEntropyLoss(reduction='sum',
                                             ignore_index=self.padding_id_tgt)

  def forward_encoder(self, src, src_lengths):
    """
    Encodes source words `src`.
    Arguments:
        src: src batch of size (bsz, max_src_len)
        src_lengths: src lengths of size (bsz)
    Returns:
        memory_bank: a tensor of size (bsz, src_len, hidden_size)
        (final_state, context): `final_state` is a tuple (h, c) where h/c is of size
                                (layers, bsz, hidden_size), and `context` is `None`.
    """
    # Embed and pack source tokens
    packed_embeddings = pack(
        self.word_embeddings_src(src).to(device),
        src_lengths.tolist(),
        batch_first=True,
        enforce_sorted=False
    )

    # Pass packed embeddings through encoder RNN
    encoder_outs, (h, c) = self.encoder_rnn(packed_embeddings)

    # Reshape and transpose hidden states to match decoder
    hsplit = h.reshape(self.layers, 2, len(src_lengths), self.hidden_size//2)
    hsplit = hsplit.transpose(1,2)
    finalh = hsplit.reshape(self.layers, len(src_lengths), -1)

    # Reshape and transpose cell states
    csplit = c.reshape(self.layers, 2, len(src_lengths), self.hidden_size//2)
    csplit = csplit.transpose(1,2)
    finalc = csplit.reshape(self.layers, len(src_lengths), -1)

    # Unpack encoder outputs into the memory bank
    memory_bank = unpack(encoder_outs, batch_first=True)[0]

    # Assign final state tuple and context, knowing context is set to None for the encoder
    final_state = (finalh, finalc)
    context = None
    return memory_bank, (final_state, context)

  def forward_decoder(self, encoder_final_state, tgt_in, memory_bank, src_mask):
    """
    Decodes based on encoder final state, memory bank, src_mask, and ground truth
    target words.
    Arguments:
        encoder_final_state: (final_state, None) where final_state is the encoder
                             final state used to initialize decoder. None is the
                             initial context (there's no previous context at the
                             first step).
        tgt_in: a tensor of size (bsz, tgt_len)
        memory_bank: a tensor of size (bsz, src_len, hidden_size), encoder outputs
                     at every position
        src_mask: a tensor of size (bsz, src_len): a boolean tensor, `False` where
                  src is padding (disallow decoder to attend to those places).
    Returns:
        Logits of size (bsz, tgt_len, V_tgt) (before the softmax operation)
    """
    max_tgt_length = tgt_in.size(1)

    # Initialize decoder state
    decoder_states = encoder_final_state

    # Iterate over each token and iteratively decode at each time step
    all_logits = []
    for i in range(max_tgt_length):
      logits, decoder_states, attn = \
        self.forward_decoder_incrementally(decoder_states,
                                           tgt_in[:, i],
                                           memory_bank,
                                           src_mask,
                                           normalize=False)
      all_logits.append(logits)             # list of bsz, vocab_tgt
    # Stack logits to generate 3-dimensional tensor
    all_logits = torch.stack(all_logits, 1) # bsz, tgt_len, vocab_tgt
    return all_logits

  def forward(self, src, src_lengths, tgt_in):
    """
    Performs forward computation, returns logits.
    Arguments:
        src: src batch of size (bsz, max_src_len)
        src_lengths: src lengths of size (bsz)
        tgt_in:  a tensor of size (bsz, tgt_len)
    """
    src = src.to(device)
    src_lengths = src_lengths.to(device)
    tgt_in = tgt_in.to(device)

    src_mask = src.ne(self.padding_id_src) # bsz, max_src_len
    # Forward encoder
    memory_bank, encoder_final_state = self.forward_encoder(src, src_lengths)
    # Forward decoder
    logits = self.forward_decoder(encoder_final_state, tgt_in, memory_bank, src_mask)
    return logits

  def forward_decoder_incrementally(self, prev_decoder_states, tgt_in_onestep,
                                    memory_bank, src_mask,
                                    normalize=True):
    """
    Forward the decoder for a single step with token `tgt_in_onestep`.
    This function is used both in `forward_decoder` and in beam search.
    Arguments:
        prev_decoder_states: a tuple (prev_decoder_state, prev_context). `prev_context`
                             is `None` for the first step
        tgt_in_onestep: a tensor of size (bsz), tokens at one step
        memory_bank: a tensor of size (bsz, src_len, hidden_size), encoder outputs
                     at every position
        src_mask: a tensor of size (bsz, src_len): a boolean tensor, `False` where
                  src is padding (we disallow decoder to attend to those places).
        normalize: use log_softmax to normalize or not.

    Returns:
        logits: log probabilities for `tgt_in_token` of size (bsz, V_tgt)
        decoder_states: (`decoder_state`, `context`) which is used for the
                        next incremental update
        attn: normalized attention scores at this step (bsz, src_len)
    """
    # Extract previous decoder state and context
    prev_decoder_state, prev_context = prev_decoder_states

    # Obtain previous decoder state's batch size
    bsz = prev_decoder_state[0].size(1)

    # Reshape and embed current input
    tgt_in_onestep = tgt_in_onestep.reshape(bsz, 1)
    tgt_in_onestep = self.word_embeddings_tgt(tgt_in_onestep)

    # Add previous context to embeddings if one exists
    if prev_context is not None:
        tgt_in_onestep = tgt_in_onestep + prev_context

    # Pass embedded input into the decoder rnn and extract output
    decoder_out, decoder_state = self.decoder_rnn(tgt_in_onestep, prev_decoder_state)
    src_mask = src_mask.unsqueeze(1)

    # Obtain attention and context vector from current decoder output and memory bank
    attn, context = attention(decoder_out, memory_bank, memory_bank, src_mask)

    # Append decoder output and context, and cast to size of output vocabulary
    logits = self.hidden2output(torch.cat((decoder_out, context), dim=-1))

    # Pack decoder state and context
    decoder_states = (decoder_state, context)

    # Apply softmax normalization if specified
    if normalize:
      logits = torch.log_softmax(logits, dim=-1)
    return logits, decoder_states, attn

  def evaluate_ppl(self, iterator):
    """Returns the model's perplexity on a given dataset `iterator`."""
    # Switch to eval mode
    self.eval()
    total_loss = 0
    total_words = 0
    for batch in iterator:
      # Input and target
      src = batch['src_ids']              # bsz, max_src_len
      src_lengths = batch['src_lengths']  # bsz
      tgt_in = batch['tgt_ids'][:, :-1] # Remove <eos> for decode input (y_0=<bos>, y_1, y_2)
      tgt_out = batch['tgt_ids'][:, 1:] # Remove <bos> as target        (y_1, y_2, y_3=<eos>)
      # Forward to get logits
      logits = self.forward(src, src_lengths, tgt_in) # bsz, tgt_len, V_tgt
      # Compute cross entropy loss
      loss = self.loss_function(logits.reshape(-1, self.V_tgt), tgt_out.reshape(-1))
      total_loss += loss.item()
      total_words += tgt_out.ne(self.padding_id_tgt).float().sum().item()
    return math.exp(total_loss/total_words)

  def train_all(self, train_iter, val_iter, epochs=10, learning_rate=0.001):
    """Train the model."""
    # Switch the module to training mode
    self.train()
    # Use Adam to optimize the parameters
    optim = torch.optim.Adam(self.parameters(), lr=learning_rate)
    best_validation_ppl = float('inf')
    best_model = None
    # Run the optimization for multiple epochs
    for epoch in range(epochs):
      total_words = 0
      total_loss = 0.0
      for batch in tqdm(train_iter):
        # Zero the parameter gradients
        self.zero_grad()
        # Input and target
        tgt = batch['tgt_ids']              # bsz, max_tgt_len
        src = batch['src_ids']              # bsz, max_src_len
        src_lengths = batch['src_lengths']  # bsz
        tgt_in = tgt[:, :-1].contiguous() # Remove <eos> for decode input (y_0=<bos>, y_1, y_2)
        tgt_out = tgt[:, 1:].contiguous() # Remove <bos> as target        (y_1, y_2, y_3=<eos>)
        bsz = tgt.size(0)
        # Run forward pass and compute loss along the way.
        logits = self.forward(src, src_lengths, tgt_in)
        loss = self.loss_function(logits.view(-1, self.V_tgt), tgt_out.view(-1))
        # Training stats
        num_tgt_words = tgt_out.ne(self.padding_id_tgt).float().sum().item()
        total_words += num_tgt_words
        total_loss += loss.item()
        # Perform backpropagation
        loss.div(bsz).backward()
        optim.step()

      # Evaluate and track improvements on the validation dataset
      validation_ppl = self.evaluate_ppl(val_iter)
      self.train()
      if validation_ppl < best_validation_ppl:
        best_validation_ppl = validation_ppl
        self.best_model = copy.deepcopy(self.state_dict())
      epoch_loss = total_loss / total_words
      print (f'Epoch: {epoch} Training Perplexity: {math.exp(epoch_loss):.4f} '
             f'Validation Perplexity: {validation_ppl:.4f}')

  def predict(self, src, K, max_T):
    """
    Generates the target sequence given a list of source tokens using beam search decoding.
    Assumes batch size of 1 for simplicity.
    Arguments:
      src: Source sequence in string format
      K: Beam size
      max_T: Max length of the target sequence
    Returns:
      The generated target sequence
    """
    # Set model to evaluation mode
    self.eval()
    finished = []

    # Tokenize source sequence for encoding
    src_tokens = self.hf_src_tokenizer(src, return_tensors="pt")
    src_ids = src_tokens["input_ids"].to(device)

    # Encode source sequence
    src_lengths = torch.tensor([src_ids.size(1)]).to(device)
    memory_bank, encoder_final_state = self.forward_encoder(src_ids, src_lengths)

    # Initialize beams
    beams = [Beam(encoder_final_state, torch.tensor([self.bos_id]).to(device), 0)]

    with torch.no_grad():
      # Complete beam search over time steps
      for t in range(max_T):
        all_total_scores = []

        # Iterate over all beams
        for beam in beams:

          # Isolate present beam state
          y_1_to_t, score, decoder_state = beam.tokens, beam.score, beam.decoder_state
          y_t = y_1_to_t[-1]

          # Generate source mask and decode iteratively
          src_mask = src_ids.ne(self.padding_id_src).to(device)
          logits, decoder_state, _ = self.forward_decoder_incrementally(
              decoder_state, y_t.unsqueeze(-1), memory_bank, src_mask, normalize=True)

          # Update cumulative scores and states
          cumu_scores = score + logits
          all_total_scores.append(cumu_scores)
          beam.decoder_state = decoder_state
        all_total_scores = torch.stack(all_total_scores)

        # Find the K best next beams
        all_scores_flattened = all_total_scores.view(-1) # K*V when t>0, 1*V when t=0
        topk_scores, topk_ids = all_scores_flattened.topk(K, 0)
        beam_ids = topk_ids.div(self.V_tgt, rounding_mode='floor')
        next_tokens = topk_ids - beam_ids * self.V_tgt

        # Generate new beams based on the top K choices
        new_beams = []
        for k in range(K):
          beam_id = beam_ids[k]       # which beam it comes from
          y_t_plus_1 = next_tokens[k] # which y_{t+1}
          score = topk_scores[k]
          beam = beams[beam_id]
          decoder_state = beam.decoder_state
          y_1_to_t = beam.tokens
          new_beam = Beam(decoder_state,
                          torch.cat((y_1_to_t, torch.tensor([y_t_plus_1]).to(device))), score)
          new_beams.append(new_beam)
        beams = new_beams

        # Set aside completed beams
        for beam in beams:
           if beam.tokens[-1] == self.eos_id:
              finished.append(beam)
              beams.remove(beam)

        # Break the loop if all beams are completed
        if len(beams) == 0:
            break

      # Return the best hypothesis
      output_tokens = []
      if len(finished) > 0:
        finished = sorted(finished, key=lambda beam: -beam.score)
        output_tokens = finished[0].tokens
      else: # When no hypothesis is finished, return the best unfinished hypothesis
        output_tokens = beams[0].tokens

      # Decode output from tensor of tokens to string
      output_string = self.hf_tgt_tokenizer.decode(output_tokens.tolist(),
                                                   skip_special_tokens=True)

      return output_string

In [None]:
EPOCHS = 50

LEARNING_RATE = 1e-4

model = AttnEncoderDecoder(hf_src_tokenizer, hf_tgt_tokenizer,
  hidden_size    = 1024,
  layers         = 1,
).to(device)

model.train_all(train_iter, val_iter, epochs=EPOCHS, learning_rate=LEARNING_RATE)

100%|██████████| 229/229 [02:02<00:00,  1.87it/s]


Epoch: 0 Training Perplexity: 4.1049 Validation Perplexity: 1.7253


100%|██████████| 229/229 [01:59<00:00,  1.92it/s]


Epoch: 1 Training Perplexity: 1.4688 Validation Perplexity: 1.3745


100%|██████████| 229/229 [01:59<00:00,  1.92it/s]


Epoch: 2 Training Perplexity: 1.2766 Validation Perplexity: 1.2602


100%|██████████| 229/229 [01:58<00:00,  1.93it/s]


Epoch: 3 Training Perplexity: 1.2008 Validation Perplexity: 1.2078


100%|██████████| 229/229 [02:00<00:00,  1.89it/s]


Epoch: 4 Training Perplexity: 1.1556 Validation Perplexity: 1.1724


100%|██████████| 229/229 [02:01<00:00,  1.89it/s]


Epoch: 5 Training Perplexity: 1.1219 Validation Perplexity: 1.1488


100%|██████████| 229/229 [02:01<00:00,  1.89it/s]


Epoch: 6 Training Perplexity: 1.0997 Validation Perplexity: 1.1375


100%|██████████| 229/229 [02:01<00:00,  1.89it/s]


Epoch: 7 Training Perplexity: 1.0841 Validation Perplexity: 1.1289


100%|██████████| 229/229 [02:00<00:00,  1.90it/s]


Epoch: 8 Training Perplexity: 1.0711 Validation Perplexity: 1.1197


100%|██████████| 229/229 [01:57<00:00,  1.95it/s]


Epoch: 9 Training Perplexity: 1.0590 Validation Perplexity: 1.1139


100%|██████████| 229/229 [02:02<00:00,  1.88it/s]


Epoch: 10 Training Perplexity: 1.0542 Validation Perplexity: 1.1074


100%|██████████| 229/229 [02:01<00:00,  1.89it/s]


Epoch: 11 Training Perplexity: 1.0442 Validation Perplexity: 1.1033


100%|██████████| 229/229 [01:59<00:00,  1.92it/s]


Epoch: 12 Training Perplexity: 1.0385 Validation Perplexity: 1.0994


100%|██████████| 229/229 [02:04<00:00,  1.84it/s]


Epoch: 13 Training Perplexity: 1.0333 Validation Perplexity: 1.0990


100%|██████████| 229/229 [02:01<00:00,  1.89it/s]


Epoch: 14 Training Perplexity: 1.0276 Validation Perplexity: 1.0982


100%|██████████| 229/229 [01:59<00:00,  1.92it/s]


Epoch: 15 Training Perplexity: 1.0246 Validation Perplexity: 1.0927


100%|██████████| 229/229 [01:59<00:00,  1.91it/s]


Epoch: 16 Training Perplexity: 1.0195 Validation Perplexity: 1.0948


100%|██████████| 229/229 [02:03<00:00,  1.85it/s]


Epoch: 17 Training Perplexity: 1.0164 Validation Perplexity: 1.0970


100%|██████████| 229/229 [02:00<00:00,  1.90it/s]


Epoch: 18 Training Perplexity: 1.0171 Validation Perplexity: 1.0989


100%|██████████| 229/229 [02:01<00:00,  1.89it/s]


Epoch: 19 Training Perplexity: 1.0159 Validation Perplexity: 1.0915


100%|██████████| 229/229 [02:03<00:00,  1.85it/s]


Epoch: 20 Training Perplexity: 1.0140 Validation Perplexity: 1.0993


100%|██████████| 229/229 [02:02<00:00,  1.88it/s]


Epoch: 21 Training Perplexity: 1.0136 Validation Perplexity: 1.0907


100%|██████████| 229/229 [02:06<00:00,  1.81it/s]


Epoch: 22 Training Perplexity: 1.0146 Validation Perplexity: 1.0944


100%|██████████| 229/229 [02:04<00:00,  1.85it/s]


Epoch: 23 Training Perplexity: 1.0135 Validation Perplexity: 1.0927


100%|██████████| 229/229 [02:01<00:00,  1.89it/s]


Epoch: 24 Training Perplexity: 1.0093 Validation Perplexity: 1.0922


100%|██████████| 229/229 [02:06<00:00,  1.81it/s]


Epoch: 25 Training Perplexity: 1.0081 Validation Perplexity: 1.0959


100%|██████████| 229/229 [02:05<00:00,  1.82it/s]


Epoch: 26 Training Perplexity: 1.0067 Validation Perplexity: 1.0918


100%|██████████| 229/229 [02:00<00:00,  1.90it/s]


Epoch: 27 Training Perplexity: 1.0058 Validation Perplexity: 1.0992


100%|██████████| 229/229 [02:03<00:00,  1.85it/s]


Epoch: 28 Training Perplexity: 1.0053 Validation Perplexity: 1.1027


100%|██████████| 229/229 [02:00<00:00,  1.89it/s]


Epoch: 29 Training Perplexity: 1.0059 Validation Perplexity: 1.0989


100%|██████████| 229/229 [02:01<00:00,  1.88it/s]


Epoch: 30 Training Perplexity: 1.0072 Validation Perplexity: 1.0988


100%|██████████| 229/229 [02:01<00:00,  1.89it/s]


Epoch: 31 Training Perplexity: 1.0067 Validation Perplexity: 1.0961


100%|██████████| 229/229 [02:00<00:00,  1.90it/s]


Epoch: 32 Training Perplexity: 1.0047 Validation Perplexity: 1.1039


100%|██████████| 229/229 [02:01<00:00,  1.88it/s]


Epoch: 33 Training Perplexity: 1.0052 Validation Perplexity: 1.1013


100%|██████████| 229/229 [01:57<00:00,  1.95it/s]


Epoch: 34 Training Perplexity: 1.0085 Validation Perplexity: 1.0998


100%|██████████| 229/229 [01:58<00:00,  1.94it/s]


Epoch: 35 Training Perplexity: 1.0070 Validation Perplexity: 1.0977


100%|██████████| 229/229 [02:03<00:00,  1.85it/s]


Epoch: 36 Training Perplexity: 1.0039 Validation Perplexity: 1.0984


100%|██████████| 229/229 [01:58<00:00,  1.93it/s]


Epoch: 37 Training Perplexity: 1.0029 Validation Perplexity: 1.1019


100%|██████████| 229/229 [02:01<00:00,  1.89it/s]


Epoch: 38 Training Perplexity: 1.0023 Validation Perplexity: 1.1012


100%|██████████| 229/229 [02:02<00:00,  1.88it/s]


Epoch: 39 Training Perplexity: 1.0020 Validation Perplexity: 1.1014


100%|██████████| 229/229 [02:00<00:00,  1.89it/s]


Epoch: 40 Training Perplexity: 1.0021 Validation Perplexity: 1.1024


100%|██████████| 229/229 [02:02<00:00,  1.87it/s]


Epoch: 41 Training Perplexity: 1.0033 Validation Perplexity: 1.1049


100%|██████████| 229/229 [01:58<00:00,  1.93it/s]


Epoch: 42 Training Perplexity: 1.0061 Validation Perplexity: 1.1136


100%|██████████| 229/229 [02:01<00:00,  1.89it/s]


Epoch: 43 Training Perplexity: 1.0081 Validation Perplexity: 1.1027


100%|██████████| 229/229 [02:03<00:00,  1.86it/s]


Epoch: 44 Training Perplexity: 1.0078 Validation Perplexity: 1.0999


100%|██████████| 229/229 [02:00<00:00,  1.90it/s]


Epoch: 45 Training Perplexity: 1.0039 Validation Perplexity: 1.1015


100%|██████████| 229/229 [02:02<00:00,  1.87it/s]


Epoch: 46 Training Perplexity: 1.0033 Validation Perplexity: 1.0972


100%|██████████| 229/229 [01:59<00:00,  1.92it/s]


Epoch: 47 Training Perplexity: 1.0026 Validation Perplexity: 1.1000


100%|██████████| 229/229 [02:03<00:00,  1.85it/s]


Epoch: 48 Training Perplexity: 1.0016 Validation Perplexity: 1.1001


100%|██████████| 229/229 [02:02<00:00,  1.88it/s]


Epoch: 49 Training Perplexity: 1.0014 Validation Perplexity: 1.1020


In [None]:
# Evaluate model performance on the validation set using the best model found.
model.load_state_dict(model.best_model)
print (f'Validation perplexity: {model.evaluate_ppl(tqdm(val_iter)):.3f}')

100%|██████████| 25/25 [00:06<00:00,  3.70it/s]

Validation perplexity: 1.091





####Define function to test seq2seq model outputs:

In [None]:
def seq2seq_trial(sentence, gold_sql):
  print("Sentence: ", sentence, "\n")

  predicted_sql = model.predict(sentence, K=1, max_T=400)
  print("Predicted SQL:\n\n", predicted_sql, "\n")

  if verify(predicted_sql, gold_sql, silent=False):
    print ('Correct!')
  else:
    print ('Incorrect!')

In [None]:
print(torch.cuda.is_available())


True


In [None]:
seq2seq_trial(example_1, gold_sql_1)

Sentence:  flights from phoenix to milwaukee 

Predicted SQL:

 SELECT DISTINCT flight_1.flight_id FROM flight flight_1, airport_service airport_service_1, city city_1, airport_service airport_service_2, city city_2 WHERE flight_1.from_airport = airport_service_1.airport_code AND airport_service_1.city_code = city_1.city_code AND city_1.city_name = 'PHOENIX' AND flight_1.to_airport = airport_service_2.airport_code AND airport_service_2.city_code = city_2.city_code AND city_2.city_name = 'MILWAUKEE' 

Predicted DB result:

 [(108086,), (108087,), (301763,), (301764,), (301765,), (301766,), (302323,), (304881,), (310619,), (310620,)] 

Gold DB result:

 [(108086,), (108087,), (301763,), (301764,), (301765,), (301766,), (302323,), (304881,), (310619,), (310620,)] 

Correct!


In [None]:
seq2seq_trial(example_2, gold_sql_2)

Sentence:  i would like a united flight 

Predicted SQL:

 SELECT DISTINCT flight_1.flight_id FROM flight flight_1, airport airport_1 WHERE flight_1.airline_code = 'UA' AND flight_1.flight_id = flight_fare_1.flight_id AND flight_fare_1.fare_id = fare_1.fare_id AND 1 = 1 

predicted sql exec failed: no such column: flight_fare_1.flight_id
Incorrect!


In [None]:
seq2seq_trial(example_3, gold_sql_3)

Sentence:  i would like a flight between boston and dallas 

Predicted SQL:

 SELECT DISTINCT flight_1.flight_id FROM flight flight_1, airport_service airport_service_1, city city_1, airport_service airport_service_2, city city_2 WHERE flight_1.from_airport = airport_service_1.airport_code AND airport_service_1.city_code = city_1.city_code AND city_1.city_name = 'BOSTON' AND flight_1.to_airport = airport_service_2.airport_code AND airport_service_2.city_code = city_2.city_code AND city_2.city_name = 'DALLAS' 

Predicted DB result:

 [(103171,), (103172,), (103173,), (103174,), (103175,), (103176,), (103177,), (103178,), (103179,), (103180,)] 

Gold DB result:

 [(103171,), (103172,), (103173,), (103174,), (103175,), (103176,), (103177,), (103178,), (103179,), (103180,)] 

Correct!


In [None]:
seq2seq_trial(example_4, gold_sql_4)

Sentence:  show me the united flights from denver to baltimore 

Predicted SQL:

 SELECT DISTINCT flight_1.flight_id FROM flight flight_1, airport_service airport_service_1, city city_1, airport_service airport_service_2, city city_2 WHERE flight_1.airline_code = 'UA' AND ( flight_1.from_airport = airport_service_1.airport_code AND airport_service_1.city_code = city_1.city_code AND city_1.city_name = 'DENVER' AND flight_1.to_airport = airport_service_2.airport_code AND airport_service_2.city_code = city_2.city_code AND city_2.city_name = 'BALTIMORE' ) 

Predicted DB result:

 [(101231,), (101233,), (305983,)] 

Gold DB result:

 [(101231,), (101233,), (305983,)] 

Correct!


In [None]:
seq2seq_trial(example_5, gold_sql_5)

Sentence:  show flights from cleveland to miami that arrive before 4pm 

Predicted SQL:

 SELECT DISTINCT flight_1.flight_id FROM flight flight_1, airport_service airport_service_1, city city_1, airport_service airport_service_2, city city_2 WHERE flight_1.from_airport = airport_service_1.airport_code AND airport_service_1.city_code = city_1.city_code AND city_1.city_name = 'CLEVELAND' AND ( flight_1.to_airport = airport_service_2.airport_code AND airport_service_2.city_code = city_2.city_code AND city_2.city_name = 'MIAMI' AND flight_1.arrival_time < 1600 ) 

Predicted DB result:

 [(107698,), (301117,)] 

Gold DB result:

 [(107698,), (301117,)] 

Correct!


In [None]:
seq2seq_trial(example_6, gold_sql_6b)

Sentence:  okay how about a flight on sunday from tampa to charlotte 

Predicted SQL:

 SELECT DISTINCT flight_1.flight_id FROM flight flight_1, airport_service airport_service_1, city city_1, airport_service airport_service_2, city city_2, days days_1, date_day date_day_1 WHERE flight_1.from_airport = airport_service_1.airport_code AND airport_service_1.city_code = city_1.city_code AND city_1.city_name = 'TAMPA' AND ( flight_1.to_airport = airport_service_2.airport_code AND airport_service_2.city_code = city_2.city_code AND city_2.city_name = 'CHARLOTTE' AND flight_1.flight_days = days_1.days_code AND days_1.day_name = date_day_1.day_name AND date_day_1.year = 1991 AND date_day_1.month_number = 8 AND date_day_1.day_number = 27 ) 

Predicted DB result:

 [(101860,), (101861,), (101862,), (101863,), (101864,), (101865,), (305231,)] 

Gold DB result:

 [(101860,), (101861,), (101862,), (101863,), (101864,), (101865,), (305231,)] 

Correct!


In [None]:
seq2seq_trial(example_7, gold_sql_7b)

Sentence:  list all flights going from boston to atlanta that leaves before 7 am on thursday 

Predicted SQL:

 SELECT DISTINCT flight_1.flight_id FROM flight flight_1, airport_service airport_service_1, city city_1, airport_service airport_service_2, city city_2, days days_1, date_day date_day_1 WHERE flight_1.from_airport = airport_service_1.airport_code AND airport_service_1.city_code = city_1.city_code AND city_1.city_name = 'BOSTON' AND ( flight_1.to_airport = airport_service_2.airport_code AND airport_service_2.city_code = city_2.city_code AND city_2.city_name = 'ATLANTA' AND ( flight_1.flight_days = days_1.days_code AND days_1.day_name = date_day_1.day_name AND date_day_1.year = 1991 AND date_day_1.month_number = 5 AND date_day_1.day_number = 24 AND flight_1.departure_time < 700 ) ) 

Predicted DB result:

 [(100014,)] 

Gold DB result:

 [(100014,)] 

Correct!


In [None]:
seq2seq_trial(example_8, gold_sql_8)

Sentence:  list the flights from dallas to san francisco on american airlines 

Predicted SQL:

 SELECT DISTINCT flight_1.flight_id FROM flight flight_1, airport_service airport_service_1, city city_1, airport_service airport_service_2, city city_2 WHERE flight_1.airline_code = 'AA' AND ( flight_1.from_airport = airport_service_1.airport_code AND airport_service_1.city_code = city_1.city_code AND city_1.city_name = 'DALLAS' AND flight_1.to_airport = airport_service_2.airport_code AND airport_service_2.city_code = city_2.city_code AND city_2.city_name = 'SAN FRANCISCO' ) 

Predicted DB result:

 [(108452,), (108454,), (108456,), (111083,), (111085,), (111086,), (111090,), (111091,), (111092,), (111094,)] 

Gold DB result:

 [(108452,), (108454,), (108456,), (111083,), (111085,), (111086,), (111090,), (111091,), (111092,), (111094,)] 

Correct!


### Running Full Evaluation

In [None]:
def seq2seq_predictor(tokens):
  prediction = model.predict(tokens, K=1, max_T=400)
  return prediction

In [None]:
precision, recall, f1 = evaluate(seq2seq_predictor, test_data, num_examples=0)
print(f"precision: {precision:3.2f}")
print(f"recall:    {recall:3.2f}")
print(f"F1:        {f1:3.2f}")

332it [03:28,  1.59it/s]

precision: 0.41
recall:    0.41
F1:        0.41





## Seq2seq model (with cross attention and self attention)


In [None]:
class AttnEncoderDecoder2(nn.Module):
  def __init__(self, hf_src_tokenizer, hf_tgt_tokenizer, hidden_size=64, layers=3):
    """
    Initializer. Creates network modules and loss function.
    Arguments:
        hf_src_tokenizer: hf src tokenizer
        hf_tgt_tokenizer: hf tgt tokenizer
        hidden_size: hidden layer size of both encoder and decoder
        layers: number of layers of both encoder and decoder
    """
    super().__init__()
    self.hf_src_tokenizer = hf_src_tokenizer
    self.hf_tgt_tokenizer = hf_tgt_tokenizer

    # Keep the vocabulary sizes available
    self.V_src = len(self.hf_src_tokenizer)
    self.V_tgt = len(self.hf_tgt_tokenizer)

    # Get special word ids
    self.padding_id_src = self.hf_src_tokenizer.pad_token_id
    self.padding_id_tgt = self.hf_tgt_tokenizer.pad_token_id
    self.bos_id = self.hf_tgt_tokenizer.get_vocab()[bos_token]
    self.eos_id = self.hf_tgt_tokenizer.get_vocab()[eos_token]

    # Keep hyper-parameters available
    self.embedding_size = hidden_size
    self.hidden_size = hidden_size
    self.layers = layers

    # Create essential modules
    self.word_embeddings_src = nn.Embedding(self.V_src, self.embedding_size)
    self.word_embeddings_tgt = nn.Embedding(self.V_tgt, self.embedding_size)

    # RNN cells
    self.encoder_rnn = nn.LSTM(
      input_size    = self.embedding_size,
      hidden_size   = hidden_size // 2, # to match decoder hidden size
      num_layers    = layers,
      batch_first=True,
      bidirectional = True              # bidirectional encoder
    )
    self.decoder_rnn = nn.LSTM(
      input_size    = self.embedding_size,
      hidden_size   = hidden_size,
      num_layers    = layers,
      batch_first=True,
      bidirectional = False             # unidirectional decoder
    )

    # Final projection layer
    self.hidden2output = nn.Linear(2*hidden_size, self.V_tgt) # project the concatenation to logits

    # Create loss function
    self.loss_function = nn.CrossEntropyLoss(reduction='sum',
                                             ignore_index=self.padding_id_tgt)

  def forward_encoder(self, src, src_lengths):
    """
    Encodes source words `src`.
    Arguments:
        src: src batch of size (bsz, max_src_len)
        src_lengths: src lengths of size (bsz)
    Returns:
        memory_bank: a tensor of size (bsz, src_len, hidden_size)
        (final_state, context, dec_out): `final_state` is a tuple (h, c)
                                where h/c is of size (layers, bsz, hidden_size),
                                `context` is `None`, and 'dec_out' (decoder output)
                                is 'None.'
    """
    # Embed and pack source tokens
    packed_embeddings = pack(
        self.word_embeddings_src(src).to(device),
        src_lengths.tolist(),
        batch_first=True,
        enforce_sorted=False
    )

    # Pass packed embeddings through encoder RNN
    encoder_outs, (h, c) = self.encoder_rnn(packed_embeddings)

    # Reshape and transpose hidden states to match decoder
    hsplit = h.reshape(self.layers, 2, len(src_lengths), self.hidden_size//2)
    hsplit = hsplit.transpose(1,2)
    finalh = hsplit.reshape(self.layers, len(src_lengths), -1)

    # Reshape and transpose cell states
    csplit = c.reshape(self.layers, 2, len(src_lengths), self.hidden_size//2)
    csplit = csplit.transpose(1,2)
    finalc = csplit.reshape(self.layers, len(src_lengths), -1)

    # Unpack encoder outputs into the memory bank
    memory_bank = unpack(encoder_outs, batch_first=True)[0]

    # Assign final state tuple, context, and decoder output, knowing context
    # and decoder output are set to None for the encoder
    final_state = (finalh, finalc)
    context = None
    dec_out = None

    return memory_bank, (final_state, context, dec_out)

  def forward_decoder(self, encoder_final_state, tgt_in, memory_bank, src_mask):
    """
    Decodes based on encoder final state, memory bank, src_mask, and ground truth
    target words.
    Arguments:
        encoder_final_state: (final_state, None, None) where final_state is the encoder
                             final state used to initialize decoder. None is the
                             initial context (there's no previous context at the
                             first step) and the decoder output (there is no previous
                             decoder output at the first time step).
        tgt_in: a tensor of size (bsz, tgt_len)
        memory_bank: a tensor of size (bsz, src_len, hidden_size), encoder outputs
                     at every position
        src_mask: a tensor of size (bsz, src_len): a boolean tensor, `False` where
                  src is padding (disallow decoder to attend to those places).
    Returns:
        Logits of size (bsz, tgt_len, V_tgt) (before the softmax operation)
    """
    max_tgt_length = tgt_in.size(1)

    # Initialize decoder state, which is a tuple (state, context, decoder_output) here
    decoder_states = encoder_final_state

    # Iterate over each token and iteratively decode at each time step
    all_logits = []
    for i in range(max_tgt_length):
      logits, decoder_states, attn = \
        self.forward_decoder_incrementally(decoder_states,
                                           tgt_in[:, i],
                                           memory_bank,
                                           src_mask,
                                           normalize=False)
      all_logits.append(logits)             # list of bsz, vocab_tgt
    # Stack logits to generate 3-dimensional tensor
    all_logits = torch.stack(all_logits, 1) # bsz, tgt_len, vocab_tgt
    return all_logits

  def forward(self, src, src_lengths, tgt_in):
    """
    Performs forward computation, returns logits.
    Arguments:
        src: src batch of size (bsz, max_src_len)
        src_lengths: src lengths of size (bsz)
        tgt_in:  a tensor of size (bsz, tgt_len)
    """
    src = src.to(device)
    src_lengths = src_lengths.to(device)
    tgt_in = tgt_in.to(device)

    src_mask = src.ne(self.padding_id_src) # bsz, max_src_len
    # Forward encoder
    memory_bank, encoder_final_state = self.forward_encoder(src, src_lengths)
    # Forward decoder
    logits = self.forward_decoder(encoder_final_state, tgt_in, memory_bank, src_mask)
    return logits

  def forward_decoder_incrementally(self, prev_decoder_states, tgt_in_onestep,
                                    memory_bank, src_mask, previous_states = None,
                                    normalize=True):
    """
    Forward the decoder for a single step with token `tgt_in_onestep`.
    This function is used both in `forward_decoder` and in beam search.
    Arguments:
        prev_decoder_states: a tuple (prev_decoder_state, prev_context, prev_decoder_outputs).
                             `prev_context` and 'prev_decoder_outputs' are `None` for the
                             first step
        tgt_in_onestep: a tensor of size (bsz), tokens at one step
        memory_bank: a tensor of size (bsz, src_len, hidden_size), encoder outputs
                     at every position
        src_mask: a tensor of size (bsz, src_len): a boolean tensor, `False` where
                  src is padding (disallow decoder to attend to those places).
        normalize: use log_softmax to normalize or not. Beam search needs to normalize,
                   while `forward_decoder` does not
    Returns:
        logits: log probabilities for `tgt_in_token` of size (bsz, V_tgt)
        decoder_states: (`decoder_state`, `context`, 'prev_decoder_outputs') which
                        is used for the next incremental update
        attn: normalized attention scores at this step (bsz, src_len)
    """
    # Extract previous decoder state, context, and decoder outputs
    prev_decoder_state, prev_context, prev_decoder_outputs = prev_decoder_states

    # Obtain previous decoder state's batch size
    bsz = prev_decoder_state[0].size(1)

    # Reshape and embed current input
    tgt_in_onestep = tgt_in_onestep.reshape(bsz, 1)
    tgt_in_onestep = self.word_embeddings_tgt(tgt_in_onestep)

    # Add previous context to embeddings if one exists
    if prev_context is not None:
        tgt_in_onestep = tgt_in_onestep + prev_context

    # Pass embedded input into the decoder rnn and extract output
    decoder_out, decoder_state = self.decoder_rnn(tgt_in_onestep, prev_decoder_state)

    # Apply self-attention if there are previous states
    if previous_states is not None and previous_states.size(1) > 0:
        self_attn_output, _ = attention(decoder_out, prev_decoder_outputs, prev_decoder_outputs)
        decoder_out += self_attn_output

    # Initialize previous decoder outputs if they do not exist
    if prev_decoder_outputs is None:
        prev_decoder_outputs = torch.zeros_like(decoder_out)

    # Concatenate current outputs and previous outputs
    prev_decoder_outputs = torch.cat((prev_decoder_outputs, decoder_out), dim = 1)

    # Unsqueeze source mask for attention using the memory bank
    src_mask = src_mask.unsqueeze(1)

    # Obtain cross-attention and context vector from current decoder output and memory bank
    attn, context = attention(decoder_out, memory_bank, memory_bank, src_mask)

    # Append decoder output and context, and cast to size of output vocabulary
    logits = self.hidden2output(torch.cat((decoder_out, context), dim=-1))

    # Pack decoder state, context, and previous outputs
    decoder_states = (decoder_state, context, prev_decoder_outputs)

    # Apply softmax normalization if specified
    if normalize:
      logits = torch.log_softmax(logits, dim=-1)
    return logits, decoder_states, attn

  def evaluate_ppl(self, iterator):
    """Returns the model's perplexity on a given dataset `iterator`."""
    # Switch to eval mode
    self.eval()
    total_loss = 0
    total_words = 0
    for batch in iterator:
      # Input and target
      src = batch['src_ids']              # bsz, max_src_len
      src_lengths = batch['src_lengths']  # bsz
      tgt_in = batch['tgt_ids'][:, :-1] # Remove <eos> for decode input (y_0=<bos>, y_1, y_2)
      tgt_out = batch['tgt_ids'][:, 1:] # Remove <bos> as target        (y_1, y_2, y_3=<eos>)
      # Forward to get logits
      logits = self.forward(src, src_lengths, tgt_in) # bsz, tgt_len, V_tgt
      # Compute cross entropy loss
      loss = self.loss_function(logits.reshape(-1, self.V_tgt), tgt_out.reshape(-1))
      total_loss += loss.item()
      total_words += tgt_out.ne(self.padding_id_tgt).float().sum().item()
    return math.exp(total_loss/total_words)

  def train_all(self, train_iter, val_iter, epochs=10, learning_rate=0.001):
    """Train the model."""
    # Switch the module to training mode
    self.train()
    # Use Adam to optimize the parameters
    optim = torch.optim.Adam(self.parameters(), lr=learning_rate)
    best_validation_ppl = float('inf')
    best_model = None
    # Run the optimization for multiple epochs
    for epoch in range(epochs):
      total_words = 0
      total_loss = 0.0
      for batch in tqdm(train_iter):
        # Zero the parameter gradients
        self.zero_grad()
        # Input and target
        tgt = batch['tgt_ids']              # bsz, max_tgt_len
        src = batch['src_ids']              # bsz, max_src_len
        src_lengths = batch['src_lengths']  # bsz
        tgt_in = tgt[:, :-1].contiguous() # Remove <eos> for decode input (y_0=<bos>, y_1, y_2)
        tgt_out = tgt[:, 1:].contiguous() # Remove <bos> as target        (y_1, y_2, y_3=<eos>)
        bsz = tgt.size(0)
        # Run forward pass and compute loss along the way.
        logits = self.forward(src, src_lengths, tgt_in)
        loss = self.loss_function(logits.view(-1, self.V_tgt), tgt_out.view(-1))
        # Training stats
        num_tgt_words = tgt_out.ne(self.padding_id_tgt).float().sum().item()
        total_words += num_tgt_words
        total_loss += loss.item()
        # Perform backpropagation
        loss.div(bsz).backward()
        optim.step()

      # Evaluate and track improvements on the validation dataset
      validation_ppl = self.evaluate_ppl(val_iter)
      self.train()
      if validation_ppl < best_validation_ppl:
        best_validation_ppl = validation_ppl
        self.best_model = copy.deepcopy(self.state_dict())
      epoch_loss = total_loss / total_words
      print (f'Epoch: {epoch} Training Perplexity: {math.exp(epoch_loss):.4f} '
             f'Validation Perplexity: {validation_ppl:.4f}')

  def predict(self, src, K, max_T):
    """
    Generates the target sequence given a list of source tokens using beam search decoding.
    Assumes batch size is 1 for simplicity.
    Arguments:
      src: Source sequence in string format
      K: Beam size
      max_T: Max length of the target sequence
    Returns:
      The generated target sequence
    """
    # Set model to evaluation mode
    self.eval()
    finished = []

    # Tokenize source sequence for encoding
    src_tokens = self.hf_src_tokenizer(src, return_tensors="pt")
    src_ids = src_tokens["input_ids"].to(device)

    # Encode source sequence
    src_lengths = torch.tensor([src_ids.size(1)]).to(device)
    memory_bank, encoder_final_state = self.forward_encoder(src_ids, src_lengths)

    # Initialize beams
    beams = [Beam(encoder_final_state, torch.tensor([self.bos_id]).to(device), 0)]

    with torch.no_grad():
      # Complete beam search over time steps
      for t in range(max_T):
        all_total_scores = []

        # Iterate over all beams
        for beam in beams:

          # Isolate present beam state
          y_1_to_t, score, decoder_state = beam.tokens, beam.score, beam.decoder_state
          y_t = y_1_to_t[-1]

          # Generate source mask and decode iteratively
          src_mask = src_ids.ne(self.padding_id_src).to(device)
          logits, decoder_state, _ = self.forward_decoder_incrementally(
              decoder_state, y_t.unsqueeze(-1), memory_bank, src_mask, normalize=True)

          # Update cumulative scores and states
          cumu_scores = score + logits
          all_total_scores.append(cumu_scores)
          beam.decoder_state = decoder_state
        all_total_scores = torch.stack(all_total_scores)

        # Find the K best next beams
        all_scores_flattened = all_total_scores.view(-1) # K*V when t>0, 1*V when t=0
        topk_scores, topk_ids = all_scores_flattened.topk(K, 0)
        beam_ids = topk_ids.div(self.V_tgt, rounding_mode='floor')
        next_tokens = topk_ids - beam_ids * self.V_tgt

        # Generate new beams based on the top K choices
        new_beams = []
        for k in range(K):
          beam_id = beam_ids[k]       # which beam it comes from
          y_t_plus_1 = next_tokens[k] # which y_{t+1}
          score = topk_scores[k]
          beam = beams[beam_id]
          decoder_state = beam.decoder_state
          y_1_to_t = beam.tokens
          new_beam = Beam(decoder_state,
                          torch.cat((y_1_to_t, torch.tensor([y_t_plus_1]).to(device))), score)
          new_beams.append(new_beam)
        beams = new_beams

        # Set aside completed beams
        for beam in beams:
           y_t_plus_1 = beam.tokens[-1]
           if y_t_plus_1 == self.eos_id:
              finished.append(beam)
              beams.remove(beam)

        # Break the loop if all beams are completed
        if len(beams) == 0:
            break

      # Return the best hypothesis
      output_tokens = []
      if len(finished) > 0:
        finished = sorted(finished, key=lambda beam: -beam.score)
        output_tokens = finished[0].tokens
      else: # When no hypothesis is finished, return the best unfinished hypothesis
        output_tokens = beams[0].tokens

      # Decode output from tensor of tokens to string
      output_string = self.hf_tgt_tokenizer.decode(output_tokens.tolist(),
                                                   skip_special_tokens=True)

      return output_string

In [None]:
EPOCHS = 50
LEARNING_RATE = 1e-4

model2 = AttnEncoderDecoder2(hf_src_tokenizer, hf_tgt_tokenizer,
  hidden_size    = 1024,
  layers         = 1,
).to(device)

In [None]:
# Evaluate model performance on the validation set using the best model found.
model2.train_all(train_iter, val_iter, epochs=EPOCHS, learning_rate=LEARNING_RATE)
model2.load_state_dict(model2.best_model)

print (f'Validation perplexity: {model2.evaluate_ppl(val_iter):.3f}')

100%|██████████| 229/229 [01:38<00:00,  2.32it/s]


Epoch: 0 Training Perplexity: 4.1484 Validation Perplexity: 1.7135


100%|██████████| 229/229 [01:39<00:00,  2.31it/s]


Epoch: 1 Training Perplexity: 1.4609 Validation Perplexity: 1.3743


100%|██████████| 229/229 [01:38<00:00,  2.32it/s]


Epoch: 2 Training Perplexity: 1.2792 Validation Perplexity: 1.2675


100%|██████████| 229/229 [01:37<00:00,  2.35it/s]


Epoch: 3 Training Perplexity: 1.2041 Validation Perplexity: 1.2116


100%|██████████| 229/229 [01:38<00:00,  2.32it/s]


Epoch: 4 Training Perplexity: 1.1572 Validation Perplexity: 1.1769


100%|██████████| 229/229 [01:37<00:00,  2.34it/s]


Epoch: 5 Training Perplexity: 1.1269 Validation Perplexity: 1.1559


100%|██████████| 229/229 [01:37<00:00,  2.35it/s]


Epoch: 6 Training Perplexity: 1.1038 Validation Perplexity: 1.1430


100%|██████████| 229/229 [01:37<00:00,  2.34it/s]


Epoch: 7 Training Perplexity: 1.0863 Validation Perplexity: 1.1292


100%|██████████| 229/229 [01:37<00:00,  2.34it/s]


Epoch: 8 Training Perplexity: 1.0725 Validation Perplexity: 1.1210


100%|██████████| 229/229 [01:38<00:00,  2.34it/s]


Epoch: 9 Training Perplexity: 1.0631 Validation Perplexity: 1.1227


100%|██████████| 229/229 [01:37<00:00,  2.35it/s]


Epoch: 10 Training Perplexity: 1.0545 Validation Perplexity: 1.1096


100%|██████████| 229/229 [01:39<00:00,  2.31it/s]


Epoch: 11 Training Perplexity: 1.0475 Validation Perplexity: 1.1060


100%|██████████| 229/229 [01:39<00:00,  2.31it/s]


Epoch: 12 Training Perplexity: 1.0400 Validation Perplexity: 1.1051


100%|██████████| 229/229 [01:38<00:00,  2.32it/s]


Epoch: 13 Training Perplexity: 1.0350 Validation Perplexity: 1.1058


100%|██████████| 229/229 [01:38<00:00,  2.32it/s]


Epoch: 14 Training Perplexity: 1.0314 Validation Perplexity: 1.0983


100%|██████████| 229/229 [01:38<00:00,  2.32it/s]


Epoch: 15 Training Perplexity: 1.0278 Validation Perplexity: 1.1013


100%|██████████| 229/229 [01:37<00:00,  2.35it/s]


Epoch: 16 Training Perplexity: 1.0247 Validation Perplexity: 1.0976


100%|██████████| 229/229 [01:37<00:00,  2.34it/s]


Epoch: 17 Training Perplexity: 1.0205 Validation Perplexity: 1.0972


100%|██████████| 229/229 [01:38<00:00,  2.33it/s]


Epoch: 18 Training Perplexity: 1.0171 Validation Perplexity: 1.0955


100%|██████████| 229/229 [01:38<00:00,  2.33it/s]


Epoch: 19 Training Perplexity: 1.0161 Validation Perplexity: 1.1009


100%|██████████| 229/229 [01:38<00:00,  2.32it/s]


Epoch: 20 Training Perplexity: 1.0157 Validation Perplexity: 1.0969


100%|██████████| 229/229 [01:39<00:00,  2.31it/s]


Epoch: 21 Training Perplexity: 1.0137 Validation Perplexity: 1.1000


100%|██████████| 229/229 [01:36<00:00,  2.37it/s]


Epoch: 22 Training Perplexity: 1.0120 Validation Perplexity: 1.0963


100%|██████████| 229/229 [01:39<00:00,  2.30it/s]


Epoch: 23 Training Perplexity: 1.0119 Validation Perplexity: 1.0981


100%|██████████| 229/229 [01:37<00:00,  2.35it/s]


Epoch: 24 Training Perplexity: 1.0111 Validation Perplexity: 1.0983


100%|██████████| 229/229 [01:39<00:00,  2.30it/s]


Epoch: 25 Training Perplexity: 1.0092 Validation Perplexity: 1.1005


100%|██████████| 229/229 [01:37<00:00,  2.36it/s]


Epoch: 26 Training Perplexity: 1.0100 Validation Perplexity: 1.0948


100%|██████████| 229/229 [01:39<00:00,  2.31it/s]


Epoch: 27 Training Perplexity: 1.0077 Validation Perplexity: 1.0995


100%|██████████| 229/229 [01:38<00:00,  2.32it/s]


Epoch: 28 Training Perplexity: 1.0083 Validation Perplexity: 1.0979


100%|██████████| 229/229 [01:38<00:00,  2.32it/s]


Epoch: 29 Training Perplexity: 1.0099 Validation Perplexity: 1.1082


100%|██████████| 229/229 [01:38<00:00,  2.32it/s]


Epoch: 30 Training Perplexity: 1.0111 Validation Perplexity: 1.1034


100%|██████████| 229/229 [01:37<00:00,  2.35it/s]


Epoch: 31 Training Perplexity: 1.0073 Validation Perplexity: 1.0972


100%|██████████| 229/229 [01:39<00:00,  2.30it/s]


Epoch: 32 Training Perplexity: 1.0060 Validation Perplexity: 1.0982


100%|██████████| 229/229 [01:38<00:00,  2.33it/s]


Epoch: 33 Training Perplexity: 1.0040 Validation Perplexity: 1.0970


100%|██████████| 229/229 [01:37<00:00,  2.34it/s]


Epoch: 34 Training Perplexity: 1.0029 Validation Perplexity: 1.1030


100%|██████████| 229/229 [01:38<00:00,  2.32it/s]


Epoch: 35 Training Perplexity: 1.0037 Validation Perplexity: 1.1012


100%|██████████| 229/229 [01:39<00:00,  2.31it/s]


Epoch: 36 Training Perplexity: 1.0046 Validation Perplexity: 1.1094


100%|██████████| 229/229 [01:38<00:00,  2.33it/s]


Epoch: 37 Training Perplexity: 1.0067 Validation Perplexity: 1.1050


100%|██████████| 229/229 [01:37<00:00,  2.34it/s]


Epoch: 38 Training Perplexity: 1.0064 Validation Perplexity: 1.1002


100%|██████████| 229/229 [01:37<00:00,  2.35it/s]


Epoch: 39 Training Perplexity: 1.0044 Validation Perplexity: 1.1010


100%|██████████| 229/229 [01:39<00:00,  2.30it/s]


Epoch: 40 Training Perplexity: 1.0038 Validation Perplexity: 1.0998


100%|██████████| 229/229 [01:37<00:00,  2.35it/s]


Epoch: 41 Training Perplexity: 1.0035 Validation Perplexity: 1.1003


100%|██████████| 229/229 [01:37<00:00,  2.35it/s]


Epoch: 42 Training Perplexity: 1.0027 Validation Perplexity: 1.1011


100%|██████████| 229/229 [01:39<00:00,  2.31it/s]


Epoch: 43 Training Perplexity: 1.0021 Validation Perplexity: 1.1025


100%|██████████| 229/229 [01:37<00:00,  2.34it/s]


Epoch: 44 Training Perplexity: 1.0021 Validation Perplexity: 1.1009


100%|██████████| 229/229 [01:38<00:00,  2.32it/s]


Epoch: 45 Training Perplexity: 1.0057 Validation Perplexity: 1.1184


100%|██████████| 229/229 [01:37<00:00,  2.35it/s]


Epoch: 46 Training Perplexity: 1.0065 Validation Perplexity: 1.1073


100%|██████████| 229/229 [01:38<00:00,  2.31it/s]


Epoch: 47 Training Perplexity: 1.0048 Validation Perplexity: 1.1000


100%|██████████| 229/229 [01:35<00:00,  2.39it/s]


Epoch: 48 Training Perplexity: 1.0032 Validation Perplexity: 1.1075


100%|██████████| 229/229 [01:36<00:00,  2.37it/s]


Epoch: 49 Training Perplexity: 1.0038 Validation Perplexity: 1.1037
Validation perplexity: 1.095


### Running Full Evaluation

In [None]:
def seq2seq_predictor2(tokens):
  prediction = model2.predict(tokens, K=1, max_T=400)
  return prediction

In [None]:
precision, recall, f1 = evaluate(seq2seq_predictor2, test_iter.dataset, num_examples=0)
print(f"precision: {precision:3.2f}")
print(f"recall:    {recall:3.2f}")
print(f"F1:        {f1:3.2f}")

332it [01:00,  5.48it/s]

precision: 0.42
recall:    0.42
F1:        0.42





## Using BART

In [None]:
pretrained_bart = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
bart_tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

config.json:   0%|          | 0.00/1.72k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/558M [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [None]:
def bart_encode(example):
    example['src_ids'] = bart_tokenizer(example['src']).input_ids[:1024] # BART model can process at most 1024 tokens
    example['tgt_ids'] = bart_tokenizer(example['tgt']).input_ids[:1024]
    return example

train_bart_data = dataset['train'].map(bart_encode)
val_bart_data = dataset['val'].map(bart_encode)
test_bart_data = dataset['test'].map(bart_encode)

BATCH_SIZE = 1 # batch size for training/validation
TEST_BATCH_SIZE = 1 # batch size for test, using 1 to make beam search implementation easier

train_iter_bart = torch.utils.data.DataLoader(train_bart_data,
                                         batch_size=BATCH_SIZE,
                                         shuffle=True,
                                         collate_fn=collate_fn)
val_iter_bart = torch.utils.data.DataLoader(val_bart_data,
                                       batch_size=BATCH_SIZE,
                                       shuffle=False,
                                       collate_fn=collate_fn)
test_iter_bart = torch.utils.data.DataLoader(test_bart_data,
                                        batch_size=TEST_BATCH_SIZE,
                                        shuffle=False,
                                        collate_fn=collate_fn)

Map:   0%|          | 0/3651 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1135 > 1024). Running this sequence through the model will result in indexing errors


Map:   0%|          | 0/398 [00:00<?, ? examples/s]

Map:   0%|          | 0/332 [00:00<?, ? examples/s]

In [None]:
batch = next(iter(train_iter_bart))
src_ids = batch['src_ids']
src_example = src_ids[0]
print (f"Size of text batch: {src_ids.size()}")
print (f"First sentence in batch: {src_example}")
print (f"Length of the third sentence in batch: {len(src_example)}")
print (f"Converted back to string: {bart_tokenizer.decode(src_example)}")

tgt_ids = batch['tgt_ids']
tgt_example = tgt_ids[0]
print (f"Size of sql batch: {tgt_ids.size()}")
print (f"First sql in batch: {tgt_example}")
print (f"Converted back to string: {bart_tokenizer.decode(tgt_example)}")

Size of text batch: torch.Size([1, 14])
First sentence in batch: tensor([    0,  4825,  6183,    31,  7843, 42956,     7,  9209, 37444,    15,
        18862, 46836,  1559,     2], device='cuda:0')
Length of the third sentence in batch: 14
Converted back to string: <s>flights from phoenix to milwaukee on wednesday evening</s>
Size of sql batch: torch.Size([1, 257])
First sql in batch: tensor([    0, 49179,   211, 11595,  2444,  7164,  2524,  1215,   134,     4,
        15801,  1215,   808, 11974,  2524,  2524,  1215,   134,  2156,  3062,
         1215, 11131,  3062,  1215, 11131,  1215,   134,  2156,   343,   343,
         1215,   134,  2156,  3062,  1215, 11131,  3062,  1215, 11131,  1215,
          176,  2156,   343,   343,  1215,   176,  2156,   360,   360,  1215,
          134,  2156,  1248,  1215,  1208,  1248,  1215,  1208,  1215,   134,
        29919,  2524,  1215,   134,     4, 17272,  2013,  2407,  1215,   958,
        24844,  9112,  2796, 18360,  4248,   132,  2619,  4248,    

In [None]:
class BART(nn.Module):
  def __init__(self, tokenizer, pretrained_bart):
    """
    Initializer. Creates network modules and loss function.
    Arguments:
        tokenizer: BART tokenizer
        pretrained_bart: pretrained BART
    """
    super(BART, self).__init__()

    self.tokenizer = tokenizer

    self.V_tgt = len(tokenizer)

    # Get special word ids
    self.padding_id_tgt = tokenizer.pad_token_id

    # Create essential modules
    self.bart = pretrained_bart

    # Create loss function
    self.loss_function = nn.CrossEntropyLoss(reduction="sum",
                                             ignore_index=self.padding_id_tgt)

  def forward(self, src, src_lengths, tgt_in):
    """
    Performs forward computation, returns logits.
    Arguments:
        src: src batch of size (batch_size, max_src_len)
        src_lengths: src lengths of size (batch_size)
        tgt_in:  a tensor of size (batch_size, tgt_len)
    """
    logits = self.bart(input_ids=src,
                       decoder_input_ids=tgt_in,
                       use_cache=False
                      ).logits
    return logits

  def evaluate_ppl(self, iterator):
    """Returns the model's perplexity on a given dataset `iterator`."""
    self.eval()
    total_loss = 0
    total_words = 0
    for batch in iterator:
      # Input and target
      src = batch['src_ids']              # bsz, max_src_len
      src_lengths = batch['src_lengths']  # bsz
      tgt_in = batch['tgt_ids'][:, :-1]   # Remove <eos> for decode input (y_0=<bos>, y_1, y_2)
      tgt_out = batch['tgt_ids'][:, 1:]   # Remove <bos> as target        (y_1, y_2, y_3=<eos>)
      # Forward to get logits
      logits = self.forward(src, src_lengths, tgt_in) # bsz, tgt_len, V_tgt
      # Compute cross entropy loss
      loss = self.loss_function(logits.reshape(-1, self.V_tgt), tgt_out.reshape(-1))
      total_loss += loss.item()
      total_words += tgt_out.ne(self.padding_id_tgt).float().sum().item()
    return math.exp(total_loss/total_words)

  def train_all(self, train_iter, val_iter, epochs=10, first_n_batches=None, learning_rate=0.001):
    """Train the model."""
    # Switch the module to training mode
    self.train()
    # Use Adam to optimize the parameters
    optim = torch.optim.Adam(self.parameters(), lr=learning_rate)
    best_validation_ppl = float('inf')
    best_model = None
    # Run the optimization for multiple epochs
    for epoch in range(epochs):
      total_words = 0
      total_loss = 0.0
      i = 0
      # Iterate over each batch in train data iterator
      for batch in tqdm(train_iter):
        # If maximum number of batches has been set and
        # that number of batches has been processed,
        # stop processing further batches
        if first_n_batches and i >= first_n_batches:
          break
        # Zero the parameter gradients
        self.zero_grad()
        # Input and target
        tgt = batch['tgt_ids']              # bsz, max_tgt_len
        src = batch['src_ids']              # bsz, max_src_len
        src_lengths = batch['src_lengths']  # bsz
        tgt_in = tgt[:, :-1].contiguous()   # Remove <eos> for decode input (y_0=<bos>, y_1, y_2)
        tgt_out = tgt[:, 1:].contiguous()   # Remove <bos> as target        (y_1, y_2, y_3=<eos>)
        bsz = tgt.size(0)
        # Run forward pass and compute loss along the way.
        logits = self.forward(src, src_lengths, tgt_in)
        loss = self.loss_function(logits.view(-1, self.V_tgt), tgt_out.view(-1))
        # Training stats
        num_tgt_words = tgt_out.ne(self.padding_id_tgt).float().sum().item()
        total_words += num_tgt_words
        total_loss += loss.item()
        # Perform backpropagation
        loss.div(bsz).backward()
        optim.step()
        i += 1

      # Evaluate and track improvements on the validation dataset
      validation_ppl = self.evaluate_ppl(val_iter)
      self.train()
      if validation_ppl < best_validation_ppl:
        best_validation_ppl = validation_ppl
        self.best_model = copy.deepcopy(self.state_dict())
      epoch_loss = total_loss / total_words
      print (f'Epoch: {epoch} Training Perplexity: {math.exp(epoch_loss):.4f} '
             f'Validation Perplexity: {validation_ppl:.4f}')

  def predict(self, tokens, K=1, max_T=400):
    """
    Generates the target sequence given the source sequence using beam search decoding.
    Only use batch size 1 for simplicity.
    Arguments:
        tokens: the source sentence.
        max_T: at most proceed this many steps of decoding
    Returns:
        a string of the generated target sentence.
    """
    # Tokenize and map to a list of word ids
    inputs = torch.LongTensor(self.tokenizer([tokens])['input_ids'][:1024]).to(device)
    # The `transformers` package provides built-in beam search support
    prediction = self.bart.generate(inputs,
                                    num_beams=K,
                                    max_length=max_T,
                                    early_stopping=True,
                                    no_repeat_ngram_size=0,
                                    decoder_start_token_id=49179,
                                    use_cache=True)[0]
    return self.tokenizer.decode(prediction, skip_special_tokens=True)

In [None]:
EPOCHS = 5

LEARNING_RATE = 1e-5

bart_model = BART(bart_tokenizer,
                 pretrained_bart
).to(device)

bart_model.train_all(train_iter_bart, val_iter_bart, epochs=EPOCHS, learning_rate=LEARNING_RATE)
bart_model.load_state_dict(bart_model.best_model)

print (f'Validation perplexity: {bart_model.evaluate_ppl(val_iter_bart):.3f}')

100%|██████████| 3651/3651 [10:27<00:00,  5.82it/s]


Epoch: 0 Training Perplexity: 1.0173 Validation Perplexity: 1.0156


100%|██████████| 3651/3651 [10:29<00:00,  5.80it/s]


Epoch: 1 Training Perplexity: 1.0149 Validation Perplexity: 1.0158


100%|██████████| 3651/3651 [12:09<00:00,  5.01it/s]


Epoch: 2 Training Perplexity: 1.0128 Validation Perplexity: 1.0152


100%|██████████| 3651/3651 [10:16<00:00,  5.92it/s]


Epoch: 3 Training Perplexity: 1.0115 Validation Perplexity: 1.0146


100%|██████████| 3651/3651 [09:32<00:00,  6.38it/s]


Epoch: 4 Training Perplexity: 1.0099 Validation Perplexity: 1.0147
Validation perplexity: 1.015


In [None]:
def bart_trial(sentence, gold_sql):
  predicted_sql = bart_model.predict(sentence, K=1, max_T=300)
  print("Predicted SQL:\n\n", predicted_sql, "\n")

  if verify(predicted_sql, gold_sql, silent=False):
    print ('Correct!')
  else:
    print ('Incorrect!')

In [None]:
bart_trial(example_1, gold_sql_1)



Predicted SQL:

 SELECT DISTINCT flight_1.flight_id FROM flight flight_1, airport_service airport_service_1, city city_1, airport_service airport_service_2, city city_2 WHERE flight_1.from_airport = airport_service_1.airport_code AND airport_service_1.city_code = city_1.city_code AND city_1.city_name = 'PHOENIX' AND flight_1.to_airport = airport_service_2.airport_code AND airport_service_2.city_code = city_2.city_code AND city_2.city_name = 'MILWAUKEE' 

Predicted DB result:

 [(108086,), (108087,), (301763,), (301764,), (301765,), (301766,), (302323,), (304881,), (310619,), (310620,)] 

Gold DB result:

 [(108086,), (108087,), (301763,), (301764,), (301765,), (301766,), (302323,), (304881,), (310619,), (310620,)] 

Correct!


In [None]:
bart_trial(example_2, gold_sql_2)

Predicted SQL:

 SELECT DISTINCT flight_1.flight_id FROM flight flight_1 WHERE flight_1.airline_code = 'UA' AND 1 = 1 

Predicted DB result:

 [(100094,), (100099,), (100145,), (100158,), (100164,), (100167,), (100169,), (100203,), (100204,), (100296,)] 

Gold DB result:

 [(100094,), (100099,), (100145,), (100158,), (100164,), (100167,), (100169,), (100203,), (100204,), (100296,)] 

Correct!


In [None]:
bart_trial(example_3, gold_sql_3)

Predicted SQL:

 SELECT DISTINCT flight_1.flight_id FROM flight flight_1, airport_service airport_service_1, city city_1, airport_service airport_service_2, city city_2 WHERE flight_1.from_airport = airport_service_1.airport_code AND airport_service_1.city_code = city_1.city_code AND city_1.city_name = 'BOSTON' AND flight_1.to_airport = airport_service_2.airport_code AND airport_service_2.city_code = city_2.city_code AND city_2.city_name = 'DALLAS' 

Predicted DB result:

 [(103171,), (103172,), (103173,), (103174,), (103175,), (103176,), (103177,), (103178,), (103179,), (103180,)] 

Gold DB result:

 [(103171,), (103172,), (103173,), (103174,), (103175,), (103176,), (103177,), (103178,), (103179,), (103180,)] 

Correct!


In [None]:
bart_trial(example_4, gold_sql_4)

Predicted SQL:

 SELECT DISTINCT flight_1.flight_id FROM flight flight_1, airport_service airport_service_1, city city_1, airport_service airport_service_2, city city_2 WHERE flight_1.airline_code = 'UA' AND ( flight_1.from_airport = airport_service_1.airport_code AND airport_service_1.city_code = city_1.city_code AND city_1.city_name = 'DENVER' AND flight_1.to_airport = airport_service_2.airport_code AND airport_service_2.city_code = city_2.city_code AND city_2.city_name = 'BALTIMORE' ) 

Predicted DB result:

 [(101231,), (101233,), (305983,)] 

Gold DB result:

 [(101231,), (101233,), (305983,)] 

Correct!


In [None]:
bart_trial(example_5, gold_sql_5)

Predicted SQL:

 SELECT DISTINCT flight_1.flight_id FROM flight flight_1, airport_service airport_service_1, city city_1, airport_service airport_service_2, city city_2 WHERE flight_1.from_airport = airport_service_1.airport_code AND airport_service_1.city_code = city_1.city_code AND city_1.city_name = 'CLEVELAND' AND ( flight_1.to_airport = airport_service_2.airport_code AND airport_service_2.city_code = city_2.city_code AND city_2.city_name = 'MIAMI' AND flight_1.arrival_time < 1600 ) 

Predicted DB result:

 [(107698,), (301117,)] 

Gold DB result:

 [(107698,), (301117,)] 

Correct!


In [None]:
bart_trial(example_6, gold_sql_6b)

Predicted SQL:

 SELECT DISTINCT flight_1.flight_id FROM flight flight_1, airport_service airport_service_1, city city_1, airport_service airport_service_2, city city_2, days days_1, date_day date_day_1 WHERE flight_1.from_airport = airport_service_1.airport_code AND airport_service_1.city_code = city_1.city_code AND city_1.city_name = 'TAMPA' AND ( flight_1.to_airport = airport_service_2.airport_code AND airport_service_2.city_code = city_2.city_code AND city_2.city_name = 'CHARLOTTE' AND flight_1.flight_days = days_1.days_code AND days_1.day_name = date_day_1.day_name AND date_day_1.year = 1991 AND date_day_1.month_number = 8 AND date_day_1.day_number = 27 ) 

Predicted DB result:

 [(101860,), (101861,), (101862,), (101863,), (101864,), (101865,), (305231,)] 

Gold DB result:

 [(101860,), (101861,), (101862,), (101863,), (101864,), (101865,), (305231,)] 

Correct!


In [None]:
bart_trial(example_7, gold_sql_7b)

Predicted SQL:

 SELECT DISTINCT flight_1.flight_id FROM flight flight_1, airport_service airport_service_1, city city_1, airport_service airport_service_2, city city_2, days days_1, date_day date_day_1 WHERE flight_1.from_airport = airport_service_1.airport_code AND airport_service_1.city_code = city_1.city_code AND city_1.city_name = 'BOSTON' AND ( flight_1.to_airport = airport_service_2.airport_code AND airport_service_2.city_code = city_2.city_code AND city_2.city_name = 'ATLANTA' AND ( flight_1.flight_days = days_1.days_code AND days_1.day_name = date_day_1.day_name AND date_day_1.year = 1991 AND date_day_1.month_number = 5 AND date_day_1.day_number = 24 AND flight_1.departure_time < 700 ) ) 

Predicted DB result:

 [(100014,)] 

Gold DB result:

 [(100014,)] 

Correct!


In [None]:
bart_trial(example_8, gold_sql_8)

Predicted SQL:

 SELECT DISTINCT flight_1.flight_id FROM flight flight_1, airport_service airport_service_1, city city_1, airport_service airport_service_2, city city_2 WHERE flight_1.airline_code = 'AA' AND ( flight_1.from_airport = airport_service_1.airport_code AND airport_service_1.city_code = city_1.city_code AND city_1.city_name = 'DALLAS' AND flight_1.to_airport = airport_service_2.airport_code AND airport_service_2.city_code = city_2.city_code AND city_2.city_name = 'SAN FRANCISCO' ) 

Predicted DB result:

 [(108452,), (108454,), (108456,), (111083,), (111085,), (111086,), (111090,), (111091,), (111092,), (111094,)] 

Gold DB result:

 [(108452,), (108454,), (108456,), (111083,), (111085,), (111086,), (111090,), (111091,), (111092,), (111094,)] 

Correct!


### Running Full Evaluation

In [None]:
def seq2seq_predictor_bart(tokens):
  prediction = bart_model.predict(tokens, K=4, max_T=400)
  return prediction

In [None]:
precision, recall, f1 = evaluate(seq2seq_predictor_bart, test_bart_data, num_examples=0)
print(f"precision: {precision:3.2f}")
print(f"recall:    {recall:3.2f}")
print(f"F1:        {f1:3.2f}")

236it [21:05,  6.59s/it]Exception ignored in: <function _xla_gc_callback at 0x7fa65e8d6680>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/lib/__init__.py", line 101, in _xla_gc_callback
    def _xla_gc_callback(*args):
func_timeout.dafunc.FunctionTimedOut3722789218968015863: Function execute_sql (args=('SELECT DISTINCT flight_1.flight_id FROM flight flight_1, airport_service airport_service_1, city city_1 WHERE flight_1.arrival_time >= 800 AND flight_1.arrival_time <= 2100',)) (kwargs={}) timed out after 3.000000 seconds.

Exception ignored in: <function _xla_gc_callback at 0x7fa65e8d6680>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/lib/__init__.py", line 101, in _xla_gc_callback
    def _xla_gc_callback(*args):
func_timeout.dafunc.FunctionTimedOut3722789218968015863: Function execute_sql (args=('SELECT DISTINCT flight_1.flight_id FROM flight flight_1, airport_service airport_service_1, city c

precision: 0.51
recall:    0.51
F1:        0.51



