# Train Queryable Input Classifier
This notebook trains an LSTM model that classifies whether the user's input can be queryable, in that the model should extract keywords to look online via the citation fetcher class.

In [None]:
from torchtext import data

import torch
import torch.nn as nn
import random
import spacy
import subprocess

In [None]:
TRAINING_DS_PATH = const.DATASETS_FOLDER + "tqi_training.csv"
TESTING_DS_PATH = const.DATASETS_FOLDER + "tqi_testing.csv"

FEATURE_COL = "input"
LABEL_COL = "is_queryable"

In [None]:
CORPUS_SIZE = 25000
LEARNING_RATE = 1e-4
BATCH_SIZE = 128
EPOCHS = 5

EMBEDDING_DIM = 256
HIDDEN_DIM = 512
OUTPUT_DIM = 1

In [None]:
print("Downloading Spacy Tokenizer...")
subprocess.run("python3 -m spacy download_en_core_web_sm")

In [None]:
training_df = pd.read_csv(TRAINING_DS_PATH)
testing_df = pd.read_csv(TESTING_DS_PATH)

training_df = training_df.sample(frac=1)

In [None]:
torch.backends.cudnn.deterministic=True

In [None]:
class TQIModel(nn.Module):
    
    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):
        super().__init__()

        self.embedding = nn.Embedding(input_dim, embedding_dim)
        self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=2)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, text):
        embedded_text = self.embedding(text)
        _, hidden = self.rnn(embedded_text)
        output = self.fc(hidden.squeeze(0)).view(-1)
        return output
        