# Train Queryable Input Classifier
This notebook trains an LSTM model that classifies whether the user's input can be queryable, in that the model should extract keywords to look online via the citation fetcher class.

In [1]:
from datasets import Dataset, DatasetDict
from sklearn.utils.class_weight import compute_class_weight
from torchtext import data, datasets

import const
import datasets
import pandas as pd
import numpy as np
import torchtext
import torch
import torch.nn as nn
import random
import spacy
import subprocess

In [2]:
TEXT = data.Field(tokenize="spacy", tokenizer_language="en_core_web_sm")

In [3]:
LABEL = data.LabelField(dtype= torch.long)

In [4]:
TRAINING_DS_PATH = const.DATASETS_FOLDER + "QI_training.csv"
TESTING_DS_PATH = const.DATASETS_FOLDER + "QI_testing.csv"

FEATURE_COL = "question"
LABEL_COL = "is_searchable"

In [5]:
TRAIN_SPLIT = 0.7

# MODEL PARAMETERS
CORPUS_SIZE = 25000
LEARNING_RATE = 1e-3
BATCH_SIZE = 128
EPOCHS = 50

EMBEDDING_DIM = 256
HIDDEN_DIM = 512
OUTPUT_DIM = 2
DEVICE = "auto"

In [36]:
MODEL_NAME = "aletheianomous_ai-QI_class-v0.1.4" 
MODEL_SAVE_PATH = (const.MODELS_FOLDER + MODEL_NAME + "/"
       + "model.pt")
VOCAB_SAVE_PATH = (const.MODELS_FOLDER + MODEL_NAME + "/vocab.pt")

In [7]:
if DEVICE == "auto":
    if torch.cuda.is_available():
        selected_device = "cuda"
    else:
        selected_device = "cpu"
else:
    selected_device = DEVICE

In [8]:
selected_device

'cuda'

In [9]:
training_df = pd.read_csv(TRAINING_DS_PATH)
testing_df = pd.read_csv(TESTING_DS_PATH)

  training_df = pd.read_csv(TRAINING_DS_PATH)


In [10]:
len(training_df)

204726

In [11]:
split_ind = int(len(training_df) * TRAIN_SPLIT)
training_df = training_df.sample(frac=1)
training_df = training_df.reset_index()

validation_df = training_df[split_ind:]
validation_df = validation_df.reset_index()
training_df = training_df[0:split_ind]

In [12]:
len(training_df)

143308

In [13]:
len(validation_df)

61418

In [14]:
training_df.tail()

Unnamed: 0.5,index,Unnamed: 0.4,Unnamed: 0,source,topic,paragraph,question,question_id,is_impossible,answers,expanded_answers,is_searchable,Unnamed: 0.3,Unnamed: 0.2,Unnamed: 0.1,data
143303,118664,118664,118664.0,Squad 2.0,United_States_Air_Force,Specific concerns include a compounded inabili...,How did the USAF try to make these adjustments...,5730ed1ea5e9cc1400cdbaec,False,cutting the Air National Guard and Air Force R...,"eets and their associated manpower, but Congre...",1.0,,,,
143304,119739,119739,119739.0,Squad 2.0,Qing_dynasty,"However, the 18th century saw the European emp...",Which empires grew during the 18th century?,57313f7b497a881900248cd7,False,European states,opean trading posts expanded into territorial ...,1.0,,,,
143305,67588,67588,67588.0,Squad 2.0,Mali,"Mali (i/ˈmɑːli/; French: [maˈli]), officially ...",What is basal agriculture in deep-sea fishing?,5a27f540d1a287001a6d0a60,True,The country's economy,of Mali is 14.5 million. Its capital is Bamako...,1.0,,,,
143306,164169,164169,33850.0,,,,What are the main challenges facing the econom...,,,,,1.0,33850.0,33850.0,33850.0,['What are the main challenges facing the econ...
143307,185162,185162,,,,,10am on a thursday morning here. Just cooked o...,,,,,0.0,,,,


In [15]:
validation_df.tail()

Unnamed: 0.5,level_0,index,Unnamed: 0.4,Unnamed: 0,source,topic,paragraph,question,question_id,is_impossible,answers,expanded_answers,is_searchable,Unnamed: 0.3,Unnamed: 0.2,Unnamed: 0.1,data
61413,204721,132859,132859,2540.0,,,,What were the leading causes of the Trail of T...,,,,,1.0,2540.0,2540.0,2540.0,['What were the leading causes of the Trail of...
61414,204722,133391,133391,3072.0,,,,Can you suggest a list of top-rated restaurant...,,,,,1.0,3072.0,3072.0,3072.0,['Can you suggest a list of top-rated restaura...
61415,204723,173869,173869,,,,,together with a dragon,,,,,0.0,,,,
61416,204724,125177,125177,125177.0,Squad 2.0,Financial_crisis_of_2007%E2%80%9308,During a period of tough competition between m...,What years had the most intense competition be...,5732ac07328d981900602000,False,2004–2007,"e view of some analysts, the relatively conser...",1.0,,,,
61417,204725,132210,132210,1891.0,,,,Please describe the process for obtaining a dr...,,,,,1.0,1891.0,1891.0,1891.0,"[""Please describe the process for obtaining a ..."


In [16]:
validation_df.to_csv(const.DATASETS_FOLDER + "QI_validation.csv")

In [17]:
torch.backends.cudnn.deterministic=True

In [18]:
labels = training_df[LABEL_COL].unique()

In [19]:
weights = compute_class_weight(class_weight="balanced", classes=labels, y=training_df[LABEL_COL].to_numpy())

In [20]:
loss_weights = torch.Tensor(weights)
#loss_weights = None

In [21]:
testing_df.columns

Index(['Unnamed: 0.4', 'Unnamed: 0', 'source', 'topic', 'paragraph',
       'question', 'question_id', 'is_impossible', 'answers',
       'expanded_answers', 'is_searchable', 'Unnamed: 0.3', 'Unnamed: 0.2',
       'Unnamed: 0.1', 'data'],
      dtype='object')

In [22]:
training_df.columns

Index(['index', 'Unnamed: 0.4', 'Unnamed: 0', 'source', 'topic', 'paragraph',
       'question', 'question_id', 'is_impossible', 'answers',
       'expanded_answers', 'is_searchable', 'Unnamed: 0.3', 'Unnamed: 0.2',
       'Unnamed: 0.1', 'data'],
      dtype='object')

In [23]:
testing_df.head()

Unnamed: 0.5,Unnamed: 0.4,Unnamed: 0,source,topic,paragraph,question,question_id,is_impossible,answers,expanded_answers,is_searchable,Unnamed: 0.3,Unnamed: 0.2,Unnamed: 0.1,data
0,0,0.0,Squad 2.0,Normans,The Normans (Norman: Nourmands; French: Norman...,In what country is Normandy located?,56ddde6b9a695914005b9628,False,France,"and and Norway who, under their leader Rollo, ...",1.0,,,,
1,1,1.0,Squad 2.0,Normans,The Normans (Norman: Nourmands; French: Norman...,In what country is Normandy located?,56ddde6b9a695914005b9628,False,France,"and and Norway who, under their leader Rollo, ...",1.0,,,,
2,2,2.0,Squad 2.0,Normans,The Normans (Norman: Nourmands; French: Norman...,In what country is Normandy located?,56ddde6b9a695914005b9628,False,France,"and and Norway who, under their leader Rollo, ...",1.0,,,,
3,3,3.0,Squad 2.0,Normans,The Normans (Norman: Nourmands; French: Norman...,In what country is Normandy located?,56ddde6b9a695914005b9628,False,France,"and and Norway who, under their leader Rollo, ...",1.0,,,,
4,4,4.0,Squad 2.0,Normans,The Normans (Norman: Nourmands; French: Norman...,When were the Normans in Normandy?,56ddde6b9a695914005b9629,False,10th and 11th centuries,"and and Norway who, under their leader Rollo, ...",1.0,,,,


In [24]:
TEXT = data.Field(tokenize='spacy', tokenizer_language = "en_core_web_sm")
LABEL = data.LabelField(dtype = torch.float)

fields = {
    FEATURE_COL: (FEATURE_COL, TEXT),
    LABEL_COL: (LABEL_COL, LABEL)
}

In [25]:
training_ds, val_ds, testing_ds= data.TabularDataset.splits(
    path = const.DATASETS_FOLDER,
    train = "QI_training.csv",
    validation = "QI_validation.csv",
    test = "QI_testing.csv",
    format = 'csv',
    fields = fields
)

print(vars(training_ds[0]))

{'question': ['When', 'did', 'Beyonce', 'start', 'becoming', 'popular', '?'], 'is_searchable': '1.0'}


In [26]:
TEXT.build_vocab(training_ds, max_size=CORPUS_SIZE)
LABEL.build_vocab(training_ds)

In [28]:
training_dl, val_dl, testing_dl = data.BucketIterator.splits(
    (training_ds, val_ds, testing_ds),
    batch_size=BATCH_SIZE,
    device=selected_device,
    sort=False
)

In [27]:
class QIClassifier(nn.Module):
    
    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):
        super().__init__()

        self.embedding = nn.Embedding(input_dim, embedding_dim)
        self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=2,
              batch_first=False)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, text):
        embedded_text = self.embedding(text)
        output, hidden = self.rnn(embedded_text)
        logits = self.fc(output[-1, :, :])
        output = self.softmax(logits)
        return logits, output
        

In [None]:
def train(model, training_ds, validation_ds=None, epochs=10, class_weights=None, device="cpu", epoch_timestamp=1, lr=0.001):
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    epoch_modulus: int
    model.train()
    model = model.to(device)
    if class_weights is not None:
        class_weights = class_weights.to(device)
    if device == "cuda":
        loss = nn.CrossEntropyLoss(weight=class_weights).cuda()
    else:
        loss = nn.CrossEntropyLoss(weight=class_weights)

    ds_ls = {"training": training_ds, "validation": validation_ds}
    
    for i in range(epochs):            
        epoch_modulus = i % epoch_timestamp
        for phase in ds_ls.keys():
            ds = ds_ls[phase]
            
            if ds is None:
                print("Skipping validation")
                continue
            
            if phase == "training":
                model.train()
            elif phase == "validation":
                model.eval()
                
            for batch_id, batch_data in enumerate(ds):
                logits, out = model(batch_data.question.to(device))
                logits = logits.to(device)
                if phase == "training":
                    optimizer.zero_grad()
                label = batch_data.is_searchable.to(device)
                cost = loss(logits, label.to(torch.long))
                if phase == "training":
                    cost.backward()
                    optimizer.step()
                    cost = cost.cpu().item()
                elif phase == "validation":
                    cost = cost.cpu().item()
                    
                if epoch_timestamp == 1:
                    (print("Epoch " + str(i + 1) + "/" + str(epochs) + 
                           " " + phase + " loss: " + str(cost), 
                           end="                                \r"))
                elif epoch_timestamp > 1:
                    if epoch_modulus == epoch_timestamp - 1:
                        end_line = "\n"
                    else:
                        end_line = "                        \r"
                    print("Epoch " + str(i + 1) + "/" + str(epochs) + " " + phase +  " loss: " + str(cost), end=end_line)
                else:
                    raise ValueError("Expected epoch_timestamp parameter to be a non-negative number but got " + str(epoch_timestamp))
            print()

In [None]:
qi_classifier = QIClassifier(len(TEXT.vocab), EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM)

In [None]:
train(qi_classifier, training_dl, validation_ds = val_dl, epochs=EPOCHS, class_weights=loss_weights, device=selected_device, lr=LEARNING_RATE)

In [None]:
torch.save(qi_classifier, MODEL_SAVE_PATH)

In [None]:
tokenizer = spacy.load("en_core_web_sm")

In [None]:
def predict(model, input_sentence, tokenizer):
    model.eval()
    token_sent = [tok.text for tok in tokenizer.tokenizer(input_sentence)]
    tok_index = [TEXT.vocab.stoi[t] for t in token_sent]
    tensor = torch.LongTensor(tok_index).to(selected_device)
    tensor = tensor.unsqueeze(1)
    logit, probas = model(tensor)
    output = torch.argmax(probas, dim=1)
    return probas.cpu(), output.cpu()

In [None]:
probas, output = predict(qi_classifier, "Who is Beyonce's daughter?", tokenizer)

In [None]:
output

In [None]:
probas

# IMPORTANT
The model may have trained to classify non-searchable terms as 1
while searchable terms is classified as 0.

In [30]:
def calc_conf_matrix(outputs, labels):
    i: int = 0
    conf_matrix = pd.DataFrame(data={"true": [0,0], "false": [0,0]}, index=["true", "false"])
    conf_matr_row: int = 0
    conf_matr_col: int = 0
    if len(outputs) == len(labels):
        for i in range(len(outputs)):
            output = outputs[i]
            label = labels[i]
            if output == 0.0:
                conf_matr_row = 1
            elif output == 1.0:
                conf_matr_row = 0
            else:
                raise ValueError(output)

            if label == 0.0:
                conf_matr_col = 1
            elif label == 1.0:
                conf_matr_col = 0
            else:
                raise ValueError(output)

            conf_matrix.iloc[conf_matr_row, conf_matr_col] +=1
    else:
        raise AssertionError()
    return conf_matrix
        

In [31]:
def test_model(model, test_dl, device="cpu"):
    model = model.to(device)
    out_arr = np.array([])
    labels_arr = np.array([])
    ds_size: int = 0
    for batch_id, batch_data in enumerate(test_dl):
        label = batch_data.is_searchable.to(device)
        logits, proba = model(batch_data.question.to(device))
        ds_size += batch_data.is_searchable.size(0)
        out = torch.argmax(proba, dim=1)
        out_arr = np.append(out_arr, out.cpu().numpy())
        labels_arr = np.append(labels_arr, label.cpu().numpy())
    conf_matrix: pd.DataFrame = calc_conf_matrix(out_arr, labels_arr)
    precision: float = conf_matrix.iloc[0,0] / (conf_matrix.iloc[0,0] + conf_matrix.iloc[0,1])
    recall: float = conf_matrix.iloc[0,0] / (conf_matrix.iloc[0,0] + conf_matrix.iloc[1,0])
    accuracy: float = (conf_matrix.iloc[0,0] + conf_matrix.iloc[1,1]) / ds_size
    return conf_matrix, precision, recall, accuracy

In [32]:
conf_matrix, precision, recall, accuracy = test_model(qi_classifier, testing_dl, device=selected_device)

In [33]:
conf_matrix

Unnamed: 0,true,false
True,16521,548
False,369,40684


In [34]:
print("Precision: ", (precision*100), "%")
print("Recall: ", (recall*100), "%")
print("Accuracy: ", (accuracy*100), "%")

Precision:  96.78950143535063 %
Recall:  97.81527531083482 %
Accuracy:  98.42228416090293 %


In [37]:
torch.save(TEXT, VOCAB_SAVE_PATH)