# Train Queryable Input Classifier
This notebook trains an LSTM model that classifies whether the user's input can be queryable, in that the model should extract keywords to look online via the citation fetcher class.

In [1]:
from datasets import Dataset, DatasetDict
from torchtext import data, datasets

import const
import datasets
import pandas as pd
import torchtext
import torch
import torch.nn as nn
import random
import spacy
import subprocess

In [2]:
TEXT = data.Field(tokenize="spacy", tokenizer_language="en_core_web_sm")

In [3]:
LABEL = data.LabelField(dtype= torch.long)

In [4]:
TRAINING_DS_PATH = const.DATASETS_FOLDER + "QI_training.csv"
TESTING_DS_PATH = const.DATASETS_FOLDER + "QI_testing.csv"

FEATURE_COL = "question"
LABEL_COL = "is_searchable"

In [5]:
# MODEL PARAMETERS
CORPUS_SIZE = 25000
LEARNING_RATE = 1e-4
BATCH_SIZE = 128
EPOCHS = 10

EMBEDDING_DIM = 256
HIDDEN_DIM = 512
OUTPUT_DIM = 2
DEVICE = "auto"

In [6]:
MODEL_SAVE_PATH = (const.MODELS_FOLDER + 
       "aletheianomous_ai-QI_class-v0.1.pt")

In [7]:
if DEVICE == "auto":
    if torch.cuda.is_available():
        selected_device = "cuda"
    else:
        selected_device = "cpu"
else:
    selected_device = DEVICE

In [8]:
training_df = pd.read_csv(TRAINING_DS_PATH)
testing_df = pd.read_csv(TESTING_DS_PATH)

training_df = training_df.sample(frac=1)

  training_df = pd.read_csv(TRAINING_DS_PATH)


In [9]:
training_ds = Dataset.from_pandas(training_df)
validation_ds = Dataset.from_pandas(testing_df)

In [10]:
torch.backends.cudnn.deterministic=True

In [11]:
positive_cl_count = len(training_df[training_df[LABEL_COL] == True])
negative_cl_count = len(training_df[training_df[LABEL_COL] == False])

In [12]:
print(positive_cl_count)
print(negative_cl_count)

165319
39407


In [13]:
training_len = len(training_df)
positive_weight = positive_cl_count / training_len
negative_weight = negative_cl_count / training_len
loss_weights = [positive_weight, negative_weight]

In [14]:
print(loss_weights)
loss_weights = torch.Tensor(loss_weights)

[0.807513457010834, 0.192486542989166]


In [15]:
testing_df.columns

Index(['Unnamed: 0.4', 'Unnamed: 0', 'source', 'topic', 'paragraph',
       'question', 'question_id', 'is_impossible', 'answers',
       'expanded_answers', 'is_searchable', 'Unnamed: 0.3', 'Unnamed: 0.2',
       'Unnamed: 0.1', 'data'],
      dtype='object')

In [16]:
training_df.columns

Index(['Unnamed: 0.4', 'Unnamed: 0', 'source', 'topic', 'paragraph',
       'question', 'question_id', 'is_impossible', 'answers',
       'expanded_answers', 'is_searchable', 'Unnamed: 0.3', 'Unnamed: 0.2',
       'Unnamed: 0.1', 'data'],
      dtype='object')

In [17]:
testing_df.head()

Unnamed: 0.5,Unnamed: 0.4,Unnamed: 0,source,topic,paragraph,question,question_id,is_impossible,answers,expanded_answers,is_searchable,Unnamed: 0.3,Unnamed: 0.2,Unnamed: 0.1,data
0,0,0.0,Squad 2.0,Normans,The Normans (Norman: Nourmands; French: Norman...,In what country is Normandy located?,56ddde6b9a695914005b9628,False,France,"and and Norway who, under their leader Rollo, ...",1.0,,,,
1,1,1.0,Squad 2.0,Normans,The Normans (Norman: Nourmands; French: Norman...,In what country is Normandy located?,56ddde6b9a695914005b9628,False,France,"and and Norway who, under their leader Rollo, ...",1.0,,,,
2,2,2.0,Squad 2.0,Normans,The Normans (Norman: Nourmands; French: Norman...,In what country is Normandy located?,56ddde6b9a695914005b9628,False,France,"and and Norway who, under their leader Rollo, ...",1.0,,,,
3,3,3.0,Squad 2.0,Normans,The Normans (Norman: Nourmands; French: Norman...,In what country is Normandy located?,56ddde6b9a695914005b9628,False,France,"and and Norway who, under their leader Rollo, ...",1.0,,,,
4,4,4.0,Squad 2.0,Normans,The Normans (Norman: Nourmands; French: Norman...,When were the Normans in Normandy?,56ddde6b9a695914005b9629,False,10th and 11th centuries,"and and Norway who, under their leader Rollo, ...",1.0,,,,


In [18]:
TEXT = data.Field(tokenize='spacy', tokenizer_language = "en_core_web_sm")
LABEL = data.LabelField(dtype = torch.float)

fields = {
    FEATURE_COL: (FEATURE_COL, TEXT),
    LABEL_COL: (LABEL_COL, LABEL)
}

training_ds, testing_ds= data.TabularDataset.splits(
    path = const.DATASETS_FOLDER,
    train = "QI_training.csv",
    test = "QI_testing.csv",
    format = 'csv',
    fields = fields
)

print(vars(training_ds[0]))

{'question': ['When', 'did', 'Beyonce', 'start', 'becoming', 'popular', '?'], 'is_searchable': '1.0'}


In [19]:
TEXT.build_vocab(training_ds, max_size=CORPUS_SIZE)
LABEL.build_vocab(training_ds)

In [20]:
training_dl, testing_dl = data.BucketIterator.splits(
    (training_ds, testing_ds),
    batch_size=BATCH_SIZE,
    device=selected_device
)

In [21]:
class QIClassifier(nn.Module):
    
    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):
        super().__init__()

        self.embedding = nn.Embedding(input_dim, embedding_dim)
        self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=2,
              batch_first=False)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.sigmoid = nn.Sigmoid()

    def forward(self, text):
        embedded_text = self.embedding(text)
        output, hidden = self.rnn(embedded_text)
        #print(output[-1, :, :].shape)
        logits = self.fc(output[-1, :, :])
        output = self.sigmoid(logits)
        return logits, output
        

In [22]:
def train(model, training_ds, epochs=10, class_weights=None, device="cpu", epoch_timestamp=1, lr=0.001):
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    epoch_modulus: int
    model.train()
    model = model.to(device)
    class_weights = class_weights.to(device)
    if device == "cuda":
        loss = nn.CrossEntropyLoss(weight=class_weights)
    else:
        loss = nn.CrossEntropyLoss(weight=class_weights)
    loss = nn.CrossEntropyLoss(weight=class_weights)
    for i in range(epochs):
        epoch_modulus = i % epoch_timestamp
        for batch_id, batch_data in enumerate(training_ds):
            logits, out = model(batch_data.question.to(device))
            logits = logits.to(device)
            optimizer.zero_grad()
            label = batch_data.is_searchable.to(device)
            cost = loss(logits, label.to(torch.long))
            cost.backward()
            optimizer.step()
            if epoch_timestamp == 1:
                (print("Epoch " + str(i + 1) + "/" + str(epochs) + 
                       " loss: " + str(cost.cpu().item()), 
                       end="                                \r"))
            elif epoch_timestamp > 1:
                if epoch_modulus == epoch_timestamp - 1:
                    end_line = "\n"
                else:
                    end_line = "                        \r"
                (print("Epoch " + str(i + 1) + "/" + str(epochs) 
                       + " loss: " + str(cost.cpu().item()), end=end_line))
            else:
                raise ValueError("Expected epoch_timestamp parameter to be a non-negative number but got " + str(epoch_timestamp))
        print()

In [23]:
qi_classifier = QIClassifier(len(TEXT.vocab), EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM)

In [24]:
train(qi_classifier, training_dl, EPOCHS, loss_weights, selected_device, lr=LEARNING_RATE)

Epoch 1/10 loss: 0.41167721152305603                                
Epoch 2/10 loss: 0.31453752517700195                                
Epoch 3/10 loss: 0.2886545658111572                                 
Epoch 4/10 loss: 0.25289446115493774                                
Epoch 5/10 loss: 0.23246945440769196                                
Epoch 6/10 loss: 0.2088184803724289                                 
Epoch 7/10 loss: 0.20366403460502625                                
Epoch 8/10 loss: 0.211162731051445                                  
Epoch 9/10 loss: 0.21772532165050507                                
Epoch 10/10 loss: 0.2606121599674225                                 


In [26]:
torch.save(qi_classifier, MODEL_SAVE_PATH)