In [1]:
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict
from copy import deepcopy

Mounted at /content/drive


In [2]:
def get_tags(text, variable_map):
    tokens = text.split()
    tags = []
    for token in tokens:
        matched = False
        for var in variable_map:
            if var in token:
                tags.append(var)
                matched = True
                break
        if not matched:
            tags.append("O")
    return tags

def preprocess(data):
    examples = []
    template_map = {}
    template_id = 0

    for entry in data:
        query_split = entry["query-split"]
        sql_list = entry["sql"]
        shortest_template = min(sql_list, key=len).replace("\n", " ")

        if shortest_template not in template_map:
            template_map[shortest_template] = template_id
            template_id += 1

        for sent in entry["sentences"]:
            filled_question = sent["text"]
            for k, v in sent["variables"].items():
                filled_question = filled_question.replace(k, v)

            tag_seq = get_tags(sent["text"], sent["variables"])

            examples.append({
                "question": filled_question,
                "tokens": sent["text"].split(),
                "tags": tag_seq,
                "template_id": template_map[shortest_template],
                "template_text": shortest_template,
                "correct_sqls": sql_list,
                "variables": sent["variables"],
                "question_split": sent["question-split"]
            })
    return examples, template_map

In [3]:
class ATISDataset(Dataset):
    def __init__(self, data, word_encoder, tag_encoder):
        self.data = data
        self.word_encoder = word_encoder
        self.tag_encoder = tag_encoder

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        word_ids = [self.word_encoder.get(w.lower(), 0) for w in item["tokens"]]
        tag_ids = [self.tag_encoder.get(t, 0) for t in item["tags"]]
        return {
            "tokens": torch.tensor(word_ids),
            "tags": torch.tensor(tag_ids),
            "template": item["template_id"]
        }

def collate_fn(batch):
    tokens = nn.utils.rnn.pad_sequence([b["tokens"] for b in batch], batch_first=True)
    tags = nn.utils.rnn.pad_sequence([b["tags"] for b in batch], batch_first=True)
    template = torch.tensor([b["template"] for b in batch])
    return {"tokens": tokens, "tags": tags, "template": template}

In [4]:
class LSTMTaggerwithAttention(nn.Module):
    def __init__(self, vocab_size, tag_size):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, 128, padding_idx=0)
        self.lstm = nn.LSTM(128, 128, batch_first=True, num_layers=1, dropout=0, bidirectional=False)
        self.dropout = nn.Dropout(0)
        self.tagger = nn.Linear(128, tag_size)

    def forward(self, x):
        emb = self.emb(x)
        out, _ = self.lstm(emb)
        return self.tagger(self.dropout(out))

class LSTMClassifierwithAttention(nn.Module):
    def __init__(self, vocab_size, num_templates):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, 128, padding_idx=0)
        self.lstm = nn.LSTM(128, 128, batch_first=True, num_layers=1, dropout=0, bidirectional=False)
        self.dropout = nn.Dropout(0)
        self.attn = nn.Linear(128, 1)
        self.classifier = nn.Linear(128, num_templates)

    def forward(self, x):
        emb = self.emb(x)
        out, _ = self.lstm(emb)
        out = self.dropout(out)
        attn_weights = torch.softmax(self.attn(out), dim=1)
        return self.classifier(torch.sum(attn_weights * out, dim=1))

In [5]:
def train(models, dataloader, num_epochs):
    tagger, classifier = models
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(list(tagger.parameters()) + list(classifier.parameters()), lr=1e-3)

    for epoch in range(num_epochs):
        tagger.train()
        classifier.train()
        total_loss = 0

        for batch in dataloader:
            tokens, tags, templates = batch["tokens"], batch["tags"], batch["template"]
            tag_loss = criterion(tagger(tokens).view(-1, tagger(tokens).size(-1)), tags.view(-1))
            cls_loss = criterion(classifier(tokens), templates)

            loss = tag_loss + cls_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}: Loss = {total_loss/len(dataloader):.4f}")

In [6]:
def generate_sql(tokens, tag_probs, template_id, tag_vocab, template_map):
    tag_rev = {v: k for k, v in tag_vocab.items()}
    template_map_rev = {v: k for k, v in template_map.items()}
    tags = [tag_rev[i] for i in tag_probs.argmax(axis=1)]

    var_map = {}
    for i, tag in enumerate(tags):
        if tag != "O" and tag not in var_map:
            var_map[tag] = tokens[i]

    template = template_map_rev[template_id]
    for var, val in var_map.items():
        template = template.replace(f'"{var}"', f'"{val}"')
    return template

def normalize_sql(sql):
    return " ".join(sql.strip().split())

def is_sql_correct(predicted_sql, gold_sql_list):
    return normalize_sql(predicted_sql) in [normalize_sql(sql) for sql in gold_sql_list]

In [7]:
def evaluate(name, tagger, classifier, data, vocab, tag_vocab, template_map, quiet):
    correct = 0
    total = len(data)
    tagger.eval()
    classifier.eval()

    with torch.no_grad():
        for i, ex in enumerate(data, 1):
            token_ids = torch.tensor([vocab.get(w.lower(), 0) for w in ex["tokens"]]).unsqueeze(0)
            tag_probs = tagger(token_ids).squeeze(0).softmax(dim=-1).cpu().numpy()
            template_pred = classifier(token_ids).argmax(dim=1).item()

            pred_sql = generate_sql(ex["tokens"], tag_probs, template_pred, tag_vocab, template_map)
            if not quiet:
                print(f"Question {i}: {ex['question']}")
                print("Predicted SQL:", normalize_sql(pred_sql))
                print("Correct SQLs:")
                for sql in ex["correct_sqls"]:
                    print("       ", normalize_sql(sql))
            if is_sql_correct(pred_sql, ex["correct_sqls"]):
                correct += 1

    accuracy = correct / total

    if not quiet:
        print(f"\nAccuracy on \"{name}\": {correct}/{total} = {accuracy:.5f}")

    return accuracy

In [8]:
with open("atis.json") as dataset:
    raw_data = json.load(dataset)

examples, template_map = preprocess(raw_data)
vocab = {"<PAD>": 0}
tag_vocab = {"O": 0}
for ex in examples:
    for tok in ex["tokens"]:
        vocab.setdefault(tok.lower(), len(vocab))
    for tag in ex["tags"]:
        tag_vocab.setdefault(tag, len(tag_vocab))

train_data = [ex for ex in examples if ex.get("question_split") == "train"]
dev_data = [ex for ex in examples if ex.get("question_split") == "dev"]
test_data = [ex for ex in examples if ex.get("question_split") == "test"]

train_loader = DataLoader(ATISDataset(train_data, vocab, tag_vocab), batch_size=16, shuffle=True, collate_fn=collate_fn)

tagger = LSTMTaggerwithAttention(len(vocab), len(tag_vocab))
classifier = LSTMClassifierwithAttention(len(vocab), len(template_map))

print("Training starts...")
train((tagger, classifier), train_loader, num_epochs=20)

print("Evaluation on \"train\":")
evaluate("train", tagger, classifier, train_data, vocab, tag_vocab, template_map, quiet=False)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Question 3181: show me all flights from SAN FRANCISCO to PITTSBURGH which arrive in PITTSBURGH before 900 o'clock am tomorrow
Predicted SQL: SELECT DISTINCT FLIGHTalias0.FLIGHT_ID FROM AIRPORT_SERVICE AS AIRPORT_SERVICEalias0 , AIRPORT_SERVICE AS AIRPORT_SERVICEalias1 , CITY AS CITYalias0 , CITY AS CITYalias1 , DATE_DAY AS DATE_DAYalias0 , DAYS AS DAYSalias0 , FLIGHT AS FLIGHTalias0 WHERE ( ( ( ( ( FLIGHTalias0.ARRIVAL_TIME < FLIGHTalias0.DEPARTURE_TIME ) AND DATE_DAYalias0.DAY_NUMBER = day_number0 AND DATE_DAYalias0.MONTH_NUMBER = month_number0 AND DATE_DAYalias0.YEAR = year0 AND DAYSalias0.DAY_NAME = DATE_DAYalias0.DAY_NAME AND FLIGHTalias0.FLIGHT_DAYS = DAYSalias0.DAYS_CODE ) OR ( DATE_DAYalias0.DAY_NUMBER = day_number0 AND DATE_DAYalias0.MONTH_NUMBER = month_number0 AND DATE_DAYalias0.YEAR = year0 AND DAYSalias0.DAY_NAME = DATE_DAYalias0.DAY_NAME AND FLIGHTalias0.FLIGHT_DAYS = DAYSalias0.DAYS_CODE AND NOT ( FLIGHTalia

0.9592822636300897

In [9]:
print("Evaluation on \"dev\" with early stopping starts...")

def train_with_early_stopping(models, train_loader, dev_data, vocab, tag_vocab, template_map, max_epochs=30, patience=5):
    tagger, classifier = models
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(list(tagger.parameters()) + list(classifier.parameters()), lr=1e-5)

    best_dev_accuracy = 0
    best_tagger = deepcopy(tagger)
    best_classifier = deepcopy(classifier)
    patience_counter = 0

    for epoch in range(max_epochs):
        tagger.train()
        classifier.train()
        total_loss = 0

        for batch in train_loader:
            tokens, tags, templates = batch["tokens"], batch["tags"], batch["template"]
            tag_loss = criterion(tagger(tokens).view(-1, tagger(tokens).size(-1)), tags.view(-1))
            cls_loss = criterion(classifier(tokens), templates)

            loss = tag_loss + cls_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        dev_accuracy = evaluate("dev", tagger, classifier, dev_data, vocab, tag_vocab, template_map, quiet=True)
        print(f"Epoch {epoch+1}: Loss = {total_loss/len(train_loader):.4f} Accuracy = {dev_accuracy:.5f}")

        if dev_accuracy > best_dev_accuracy:
            best_dev_accuracy = dev_accuracy
            best_tagger = deepcopy(tagger)
            best_classifier = deepcopy(classifier)
        else:
            patience_counter += 1
            print(f"Patience Counter: {patience_counter}/{patience}")

        if patience_counter == patience:
            print("Early stopping triggered. Training finishes.")
            break

    print(f"\nBest Accuracy on \"dev\": {round(best_dev_accuracy * len(dev_data))}/{len(dev_data)} = {best_dev_accuracy:.5f}")
    return best_tagger, best_classifier

best_tagger, best_classifier = train_with_early_stopping(
        (tagger, classifier), train_loader, dev_data, vocab, tag_vocab, template_map
    )

Evaluation on "dev" with early stopping starts...
Epoch 1: Loss = 0.0335 Accuracy = 0.70988
Epoch 2: Loss = 0.0310 Accuracy = 0.70988
Patience Counter: 1/5
Epoch 3: Loss = 0.0296 Accuracy = 0.70782
Patience Counter: 2/5
Epoch 4: Loss = 0.0290 Accuracy = 0.70782
Patience Counter: 3/5
Epoch 5: Loss = 0.0279 Accuracy = 0.70782
Patience Counter: 4/5
Epoch 6: Loss = 0.0268 Accuracy = 0.70782
Patience Counter: 5/5
Early stopping triggered. Training finishes.
Best Accuracy on "dev": 0.70988


In [10]:
print("Evaluation on \"test\":")
evaluate("test", best_tagger, best_classifier, test_data, vocab, tag_vocab, template_map, quiet=False)

Evaluation on "test":
Question 1: i need a flight from DENVER to SALT LAKE CITY on monday
Predicted SQL: SELECT DISTINCT FLIGHTalias0.FLIGHT_ID FROM AIRPORT_SERVICE AS AIRPORT_SERVICEalias0 , AIRPORT_SERVICE AS AIRPORT_SERVICEalias1 , CITY AS CITYalias0 , CITY AS CITYalias1 , DATE_DAY AS DATE_DAYalias0 , DAYS AS DAYSalias0 , FLIGHT AS FLIGHTalias0 WHERE ( CITYalias1.CITY_CODE = AIRPORT_SERVICEalias1.CITY_CODE AND CITYalias1.CITY_NAME = "city_name0" AND DATE_DAYalias0.DAY_NUMBER = day_number0 AND DATE_DAYalias0.MONTH_NUMBER = month_number0 AND DATE_DAYalias0.YEAR = year0 AND DAYSalias0.DAY_NAME = DATE_DAYalias0.DAY_NAME AND FLIGHTalias0.FLIGHT_DAYS = DAYSalias0.DAYS_CODE AND FLIGHTalias0.TO_AIRPORT = AIRPORT_SERVICEalias1.AIRPORT_CODE ) AND CITYalias0.CITY_CODE = AIRPORT_SERVICEalias0.CITY_CODE AND CITYalias0.CITY_NAME = "city_name1" AND FLIGHTalias0.FROM_AIRPORT = AIRPORT_SERVICEalias0.AIRPORT_CODE ;
Correct SQLs:
        SELECT DISTINCT FLIGHTalias0.FLIGHT_ID FROM AIRPORT_SERVICE AS A

0.5391498881431768