In [1]:
import pandas as pd

df = pd.read_csv('/content/100_Unique_QA_Dataset.csv')

df.head()

Unnamed: 0,question,answer
0,What is the capital of France?,Paris
1,What is the capital of Germany?,Berlin
2,Who wrote 'To Kill a Mockingbird'?,Harper-Lee
3,What is the largest planet in our solar system?,Jupiter
4,What is the boiling point of water in Celsius?,100


In [10]:
# tokenize

def tokenize(text):
  text = text.lower()
  text = text.replace('?', '')
  text = text.replace("'", '')
  return text.split()

In [11]:
tokenize(df['question'][0])

['what', 'is', 'the', 'capital', 'of', 'france']

In [16]:
# Vocab

vocab = {'<UNK>':0}

def build_vocab(row):

  tokenized_question = tokenize(row['question'])
  tokenized_answer = tokenize(row['answer'])

  merged_tokens = tokenized_question + tokenized_answer

  for token in merged_tokens:
    if token not in vocab:
      vocab[token] = len(vocab)

In [17]:
df.apply(build_vocab, axis=1)

Unnamed: 0,0
0,
1,
2,
3,
4,
...,...
85,
86,
87,
88,


In [19]:
len(vocab)

324

In [20]:
# text to index

def text2index(text,vocab):

  indexed_text = []

  for token in tokenize(text):
    if token in vocab:
      indexed_text.append(vocab[token])
    else:
      indexed_text.append(vocab['<UNK>'])

  return indexed_text

In [23]:
text2index(df['question'][2], vocab)

[10, 11, 12, 13, 14, 15]

In [24]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn

In [27]:
class dataset(Dataset):

  def __init__(self, df, vocab):
    self.df = df
    self.vocab = vocab

  def __len__(self):
    return len(self.df)

  def __getitem__(self, idx):

    question = self.df['question'][idx]
    answer = self.df['answer'][idx]

    indexed_question = text2index(question, self.vocab)
    indexed_answer = text2index(answer, self.vocab)

    return torch.tensor(indexed_question), torch.tensor(indexed_answer)

In [28]:
data = dataset(df, vocab)

In [29]:
data[0]

(tensor([1, 2, 3, 4, 5, 6]), tensor([7]))

In [30]:
dataloader = DataLoader(data,batch_size=1, shuffle = True)

In [45]:
class SimpleRNN(nn.Module):

  def __init__(self, vocab_size, embedding_size, hidden_size):
    super().__init__()

    self.embedding = nn.Embedding(vocab_size, embedding_size)
    self.rnn = nn.RNN(embedding_size, hidden_size, batch_first=True) # RNN gives two output i.e. one array that contains all transition state and other containing the final output
    self.fc = nn.Linear(hidden_size, vocab_size)

  def forward(self, x):
    x = self.embedding(x)
    h,x = self.rnn(x)
    x = self.fc(x.squeeze(0))
    return x

In [46]:
vocab_size = len(vocab)
embedding_size = 50
hidden_size = 64

lr = 0.001
epochs = 20

In [47]:
model = SimpleRNN(vocab_size, embedding_size, hidden_size)

lossfn = nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr=lr)

In [48]:
for epoch in range(epochs):

  total_loss = 0

  for question, answer in dataloader:

    opt.zero_grad()

    pred = model(question)
    loss = lossfn(pred,answer[0])

    loss.backward()

    opt.step()

    total_loss += loss.item()

  print(f'Epoch: {epoch+1}, Loss: {total_loss:3f}')

Epoch: 1, Loss: 528.510429
Epoch: 2, Loss: 455.599044
Epoch: 3, Loss: 378.057846
Epoch: 4, Loss: 319.420903
Epoch: 5, Loss: 270.325090
Epoch: 6, Loss: 224.151557
Epoch: 7, Loss: 180.780272
Epoch: 8, Loss: 142.463642
Epoch: 9, Loss: 110.490960
Epoch: 10, Loss: 85.373004
Epoch: 11, Loss: 66.007126
Epoch: 12, Loss: 51.762779
Epoch: 13, Loss: 41.465474
Epoch: 14, Loss: 33.616107
Epoch: 15, Loss: 27.635295
Epoch: 16, Loss: 23.077622
Epoch: 17, Loss: 19.481783
Epoch: 18, Loss: 16.625527
Epoch: 19, Loss: 14.236510
Epoch: 20, Loss: 12.299620


In [54]:
def predict(model,question, threshold=0.5):

  num_ques = torch.tensor(text2index(question,vocab)).unsqueeze(0)

  output = model(num_ques)

  probs = torch.nn.functional.softmax(output, dim=1)

  value,idx = torch.max(probs, dim=1)

  if value < threshold: print('I do not know')

  else:
    print(list(vocab.keys())[idx])


In [56]:
ques = " What's the Captial of France ? "
predict(model,ques)

paris
