# Train Queryable Input Classifier
This notebook trains an LSTM model that classifies whether the user's input can be queryable, in that the model should extract keywords to look online via the citation fetcher class.

In [1]:
from datasets import Dataset, DatasetDict
from sklearn.utils.class_weight import compute_class_weight
from torchtext import data, datasets

import const
import datasets
import pandas as pd
import torchtext
import torch
import torch.nn as nn
import random
import spacy
import subprocess

In [2]:
TEXT = data.Field(tokenize="spacy", tokenizer_language="en_core_web_sm")

In [3]:
LABEL = data.LabelField(dtype= torch.long)

In [4]:
TRAINING_DS_PATH = const.DATASETS_FOLDER + "QI_training.csv"
TESTING_DS_PATH = const.DATASETS_FOLDER + "QI_testing.csv"

FEATURE_COL = "question"
LABEL_COL = "is_searchable"

In [5]:
TRAIN_SPLIT = 0.7

# MODEL PARAMETERS
CORPUS_SIZE = 25000
LEARNING_RATE = 1e-3
BATCH_SIZE = 128
EPOCHS = 50

EMBEDDING_DIM = 256
HIDDEN_DIM = 512
OUTPUT_DIM = 2
DEVICE = "auto"

In [6]:
MODEL_SAVE_PATH = (const.MODELS_FOLDER + 
       "aletheianomous_ai-QI_class-v0.1.4.pt")

In [7]:
if DEVICE == "auto":
    if torch.cuda.is_available():
        selected_device = "cuda"
    else:
        selected_device = "cpu"
else:
    selected_device = DEVICE

In [8]:
selected_device

'cuda'

In [9]:
training_df = pd.read_csv(TRAINING_DS_PATH)
testing_df = pd.read_csv(TESTING_DS_PATH)

  training_df = pd.read_csv(TRAINING_DS_PATH)


In [10]:
len(training_df)

204726

In [11]:
split_ind = int(len(training_df) * TRAIN_SPLIT)
training_df = training_df.sample(frac=1)
training_df = training_df.reset_index()

validation_df = training_df[split_ind:]
validation_df = validation_df.reset_index()
training_df = training_df[0:split_ind]

In [12]:
len(training_df)

143308

In [13]:
len(validation_df)

61418

In [14]:
training_df.tail()

Unnamed: 0.5,index,Unnamed: 0.4,Unnamed: 0,source,topic,paragraph,question,question_id,is_impossible,answers,expanded_answers,is_searchable,Unnamed: 0.3,Unnamed: 0.2,Unnamed: 0.1,data
143303,17724,17724,17724.0,Squad 2.0,Universal_Studios,"In 1928, Laemmle, Sr. made his son, Carl, Jr. ...","At what age did Carl Laemmle, Jr. become presi...",56e14c5fcd28a01900c67772,False,21,"r. being known around the studios as ""Uncle Ca...",1.0,,,,
143304,47709,47709,47709.0,Squad 2.0,Ashkenazi_Jews,The history of Jews in Greece goes back to at ...,The Greek historian Herodotus listed the Jews ...,571a87714faf5e1900b8aa17,False,the invading Persians,and listed them among the levied naval forces...,1.0,,,,
143305,10737,10737,10737.0,Squad 2.0,Human_Development_Index,Some countries were not included for various r...,What year were all countries included?,5ad0cd5f645df0001a2d03de,True,2014,and Tuvalu.,1.0,,,,
143306,111296,111296,111296.0,Squad 2.0,Premier_League,Despite significant European success during th...,Had the Football League First Division ever be...,572fbb02b2c2fd14005683bc,False,"The Football League First Division, which had ...",for five years following the Heysel Stadium di...,1.0,,,,
143307,119274,119274,119274.0,Squad 2.0,Great_power,"When World War II started in 1939, it divided ...",Who was part of the inclusion powers?,5a14aafca54d42001852930b,True,"Germany, Italy and Japan","and Japan.[nb 1] During World War II, the Uni...",1.0,,,,


In [15]:
validation_df.tail()

Unnamed: 0.5,level_0,index,Unnamed: 0.4,Unnamed: 0,source,topic,paragraph,question,question_id,is_impossible,answers,expanded_answers,is_searchable,Unnamed: 0.3,Unnamed: 0.2,Unnamed: 0.1,data
61413,204721,31311,31311,31311.0,Squad 2.0,Houston,Houston was founded in 1836 on land near the b...,For what battle was Sam Houston known as comma...,570a898d4103511400d59838,False,Battle of San Jacinto,he Battle of San Jacinto 25 miles (40 km) east...,1.0,,,,
61414,204722,125193,125193,125193.0,Squad 2.0,Financial_crisis_of_2007%E2%80%9308,"In the early and mid-2000s, the Bush administr...",Who rebuked OFHEO in 2003 for their attempt at...,5732b191d6dcfa19001e8a80,False,House Financial Services Committee,urging of the administration to assess safety...,1.0,,,,
61415,204723,103353,103353,103353.0,Squad 2.0,Neptune,The dipole component of the magnetic field at ...,Besides the geometrical constraints of Neptune...,572ea25103f9891900756870,False,the planet's centre,y that includes relatively large contributions...,1.0,,,,
61416,204724,153300,153300,22981.0,,,,Can you provide any safety tips for tourists v...,,,,,1.0,22981.0,22981.0,22981.0,['Can you provide any safety tips for tourists...
61417,204725,160203,160203,29884.0,,,,Are there any notable social or political move...,,,,,1.0,29884.0,29884.0,29884.0,['Are there any notable social or political mo...


In [16]:
validation_df.to_csv(const.DATASETS_FOLDER + "QI_validation.csv")

In [17]:
torch.backends.cudnn.deterministic=True

In [18]:
labels = training_df[LABEL_COL].unique()

In [19]:
weights = compute_class_weight(class_weight="balanced", classes=labels, y=training_df[LABEL_COL].to_numpy())

In [20]:
loss_weights = torch.Tensor(weights)
#loss_weights = None

In [21]:
testing_df.columns

Index(['Unnamed: 0.4', 'Unnamed: 0', 'source', 'topic', 'paragraph',
       'question', 'question_id', 'is_impossible', 'answers',
       'expanded_answers', 'is_searchable', 'Unnamed: 0.3', 'Unnamed: 0.2',
       'Unnamed: 0.1', 'data'],
      dtype='object')

In [22]:
training_df.columns

Index(['index', 'Unnamed: 0.4', 'Unnamed: 0', 'source', 'topic', 'paragraph',
       'question', 'question_id', 'is_impossible', 'answers',
       'expanded_answers', 'is_searchable', 'Unnamed: 0.3', 'Unnamed: 0.2',
       'Unnamed: 0.1', 'data'],
      dtype='object')

In [23]:
testing_df.head()

Unnamed: 0.5,Unnamed: 0.4,Unnamed: 0,source,topic,paragraph,question,question_id,is_impossible,answers,expanded_answers,is_searchable,Unnamed: 0.3,Unnamed: 0.2,Unnamed: 0.1,data
0,0,0.0,Squad 2.0,Normans,The Normans (Norman: Nourmands; French: Norman...,In what country is Normandy located?,56ddde6b9a695914005b9628,False,France,"and and Norway who, under their leader Rollo, ...",1.0,,,,
1,1,1.0,Squad 2.0,Normans,The Normans (Norman: Nourmands; French: Norman...,In what country is Normandy located?,56ddde6b9a695914005b9628,False,France,"and and Norway who, under their leader Rollo, ...",1.0,,,,
2,2,2.0,Squad 2.0,Normans,The Normans (Norman: Nourmands; French: Norman...,In what country is Normandy located?,56ddde6b9a695914005b9628,False,France,"and and Norway who, under their leader Rollo, ...",1.0,,,,
3,3,3.0,Squad 2.0,Normans,The Normans (Norman: Nourmands; French: Norman...,In what country is Normandy located?,56ddde6b9a695914005b9628,False,France,"and and Norway who, under their leader Rollo, ...",1.0,,,,
4,4,4.0,Squad 2.0,Normans,The Normans (Norman: Nourmands; French: Norman...,When were the Normans in Normandy?,56ddde6b9a695914005b9629,False,10th and 11th centuries,"and and Norway who, under their leader Rollo, ...",1.0,,,,


In [24]:
TEXT = data.Field(tokenize='spacy', tokenizer_language = "en_core_web_sm")
LABEL = data.LabelField(dtype = torch.float)

fields = {
    FEATURE_COL: (FEATURE_COL, TEXT),
    LABEL_COL: (LABEL_COL, LABEL)
}

In [25]:
training_ds, val_ds, testing_ds= data.TabularDataset.splits(
    path = const.DATASETS_FOLDER,
    train = "QI_training.csv",
    validation = "QI_validation.csv",
    test = "QI_testing.csv",
    format = 'csv',
    fields = fields
)

print(vars(training_ds[0]))

{'question': ['When', 'did', 'Beyonce', 'start', 'becoming', 'popular', '?'], 'is_searchable': '1.0'}


In [26]:
TEXT.build_vocab(training_ds, max_size=CORPUS_SIZE)
LABEL.build_vocab(training_ds)

In [27]:
training_dl, val_dl, testing_dl = data.BucketIterator.splits(
    (training_ds, val_ds, testing_ds),
    batch_size=BATCH_SIZE,
    device=selected_device,
    sort=False
)

In [28]:
class QIClassifier(nn.Module):
    
    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):
        super().__init__()

        self.embedding = nn.Embedding(input_dim, embedding_dim)
        self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=2,
              batch_first=False)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, text):
        embedded_text = self.embedding(text)
        output, hidden = self.rnn(embedded_text)
        logits = self.fc(output[-1, :, :])
        output = self.softmax(logits)
        return logits, output
        

In [29]:
def train(model, training_ds, validation_ds=None, epochs=10, class_weights=None, device="cpu", epoch_timestamp=1, lr=0.001):
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    epoch_modulus: int
    model.train()
    model = model.to(device)
    if class_weights is not None:
        class_weights = class_weights.to(device)
    if device == "cuda":
        loss = nn.CrossEntropyLoss(weight=class_weights).cuda()
    else:
        loss = nn.CrossEntropyLoss(weight=class_weights)

    ds_ls = {"training": training_ds, "validation": validation_ds}
    
    for i in range(epochs):            
        epoch_modulus = i % epoch_timestamp
        for phase in ds_ls.keys():
            ds = ds_ls[phase]
            
            if ds is None:
                print("Skipping validation")
                continue
            
            if phase == "training":
                model.train()
            elif phase == "validation":
                model.eval()
                
            for batch_id, batch_data in enumerate(ds):
                logits, out = model(batch_data.question.to(device))
                logits = logits.to(device)
                if phase == "training":
                    optimizer.zero_grad()
                label = batch_data.is_searchable.to(device)
                cost = loss(logits, label.to(torch.long))
                if phase == "training":
                    cost.backward()
                    optimizer.step()
                    cost = cost.cpu().item()
                elif phase == "validation":
                    cost = cost.cpu().item()
                    
                if epoch_timestamp == 1:
                    (print("Epoch " + str(i + 1) + "/" + str(epochs) + 
                           " " + phase + " loss: " + str(cost), 
                           end="                                \r"))
                elif epoch_timestamp > 1:
                    if epoch_modulus == epoch_timestamp - 1:
                        end_line = "\n"
                    else:
                        end_line = "                        \r"
                    print("Epoch " + str(i + 1) + "/" + str(epochs) + " " + phase +  " loss: " + str(cost), end=end_line)
                else:
                    raise ValueError("Expected epoch_timestamp parameter to be a non-negative number but got " + str(epoch_timestamp))
            print()

In [30]:
qi_classifier = QIClassifier(len(TEXT.vocab), EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM)

In [31]:
train(qi_classifier, training_dl, validation_ds = val_dl, epochs=EPOCHS, class_weights=loss_weights, device=selected_device, lr=LEARNING_RATE)

Epoch 1/50 training loss: 0.6934905648231506                                
Epoch 1/50 validation loss: 0.6918582320213318                                
Epoch 2/50 training loss: 0.6945582628250122                                
Epoch 2/50 validation loss: 0.6917373538017273                                
Epoch 3/50 training loss: 0.6911581754684448                                
Epoch 3/50 validation loss: 0.6905151009559631                                
Epoch 4/50 training loss: 0.6950238347053528                                
Epoch 4/50 validation loss: 0.6906713247299194                                
Epoch 5/50 training loss: 0.6928452849388123                                
Epoch 5/50 validation loss: 0.691433310508728                                 
Epoch 6/50 training loss: 0.6921548247337341                                
Epoch 6/50 validation loss: 0.6902626752853394                                
Epoch 7/50 training loss: 0.6920985579490662                    

In [32]:
torch.save(qi_classifier, MODEL_SAVE_PATH)

In [33]:
tokenizer = spacy.load("en_core_web_sm")

In [41]:
def predict(model, input_sentence, tokenizer):
    model.eval()
    token_sent = [tok.text for tok in tokenizer.tokenizer(input_sentence)]
    tok_index = [TEXT.vocab.stoi[t] for t in token_sent]
    tensor = torch.LongTensor(tok_index).to(selected_device)
    tensor = tensor.unsqueeze(1)
    logit, probas = model(tensor)
    output = torch.argmax(probas, dim=1)
    return probas.cpu(), output.cpu()

In [75]:
probas, output = predict(qi_classifier, "Who is Beyonce's daughter?", tokenizer)

In [76]:
output

tensor([0])

# IMPORTANT
The model may have trained to classify non-searchable terms as 1
while searchable terms is classified as 0.

In [62]:
probas

tensor([[0.2463, 0.7537]], grad_fn=<ToCopyBackward0>)