In [3]:
import torch
import torch.nn as nn
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader


Input:
- Context : paragraph that contains answer
- Question : the query answerable question

Output:
- Answer: answer span

In [6]:
qa_dataset = [
    {
        'context': 'My name is Thang and I am from Vietnam',
        'question': 'What is my name?',
        'answer': 'Thang'
    },
    {
        'context': 'I love painting and my favorite artist is Vincent Van Gough',
        'question': 'What is my favorite activity?',
        'answer' : 'painting'
    },
    {
        'context' : 'I am studying Computer Science at Nanyang Technological University',
        'question' : 'What am I studying?',
        'answer' : 'Computer Science'
    }
]

In [7]:
#define tokenizer
tokenizer = get_tokenizer('basic_english')

text = 'I love cats'
tokenizer(text)

['i', 'love', 'cats']

In [8]:
def yield_tokens(data):
    for item in data:
        yield(tokenizer(item['context'] + ' ' + item['question']))

#create vocab
vocab = build_vocab_from_iterator(
    yield_tokens(qa_dataset),
    specials=['<unk>', '<pad>', '<bos>', '<eos>', '<sep>']
)
vocab.set_default_index(vocab['<unk>'])
vocab.get_stoi()

{'<sep>': 4,
 'technological': 25,
 '<bos>': 2,
 '<unk>': 0,
 'artist': 16,
 '<eos>': 3,
 '<pad>': 1,
 'i': 5,
 'is': 6,
 'am': 9,
 'my': 7,
 '?': 8,
 'what': 10,
 'name': 13,
 'and': 11,
 'favorite': 12,
 'studying': 14,
 'activity': 15,
 'at': 17,
 'computer': 18,
 'from': 19,
 'gough': 20,
 'love': 21,
 'nanyang': 22,
 'painting': 23,
 'science': 24,
 'thang': 26,
 'university': 27,
 'van': 28,
 'vietnam': 29,
 'vincent': 30}

In [9]:
classes = set([item['answer'] for item in qa_dataset])
classes_to_idx = {
    cls_name: idx for idx, cls_name in enumerate(classes)
}
idx_to_classes = {
    idx: cls_name for idx, cls_name in enumerate(classes)
}
print(idx_to_classes)
print(classes_to_idx)

{0: 'Thang', 1: 'painting', 2: 'Computer Science'}
{'Thang': 0, 'painting': 1, 'Computer Science': 2}


In [10]:
PAD_IDX = 1
def pad_and_truncate(input_ids, max_seq_len):
    if len(input_ids) > max_seq_len:
        input_ids = input_ids[:max_seq_len]
    elif len(input_ids) < max_seq_len:
        input_ids += [PAD_IDX] * (max_seq_len - len(input_ids))

    return input_ids

In [11]:
def vectorize(question, context, max_seq_len):
    input_question_ids = [vocab[token] for token in tokenizer(question)]
    input_context_ids = [vocab[token] for token in tokenizer(context)]

    input_question_ids = pad_and_truncate(input_question_ids, max_seq_len)
    input_context_ids = pad_and_truncate(input_context_ids, max_seq_len)

    input_question_ids = torch.tensor(input_question_ids, dtype=torch.long)
    input_context_ids = torch.tensor(input_context_ids, dtype=torch.long)

    return input_question_ids, input_context_ids

In [14]:
input_question_ids, input_context_ids = vectorize(
    qa_dataset[0]['question'],
    qa_dataset[0]['context'],
    20
)
print(input_context_ids)
print(input_question_ids)

tensor([ 7, 13,  6, 26, 11,  5,  9, 19, 29,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1])
tensor([10,  6,  7, 13,  8,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1])


In [15]:
class QADataset(nn.Module):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        question_text = item['question']
        context_text = item['context']

        input_question_ids, input_context_ids = vectorize(
            question_text, context_text, max_seq_len=20
        )

        answer_text = item['answer']
        answer_id = classes_to_idx[answer_text]
        answer_id = torch.tensor(answer_id, dtype=torch.long)

        return input_question_ids, input_context_ids, answer_id

In [16]:
def decode(input_ids):
    return ' '.join([vocab.lookup_token(token) for token in input_ids])



In [20]:
train_dataset = QADataset(qa_dataset)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)

for batch in train_loader:
    q, c, a = batch
    print(q, c, a)
   

tensor([[10,  6,  7, 12, 15,  8,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1],
        [10,  6,  7, 13,  8,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1]]) tensor([[ 5, 21, 23, 11,  7, 12, 16,  6, 30, 28, 20,  1,  1,  1,  1,  1,  1,  1,
          1,  1],
        [ 7, 13,  6, 26, 11,  5,  9, 19, 29,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1]]) tensor([1, 0])
tensor([[10,  9,  5, 14,  8,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1]]) tensor([[ 5,  9, 14, 18, 24, 17, 22, 25, 27,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1]]) tensor([2])


In [21]:
# Model: bidirectional lstm
class QAModel(nn.Module):
    def __init__(self, 
                 vocab_size, embedding_dim, hidden_size,
                 n_layers, n_classes):
        super().__init__()
        self.questsion_embedding = nn.Embedding(
            vocab_size, embedding_dim
        )
        self.context_embedding = nn.Embedding(
            vocab_size, embedding_dim
        )

        self.lstm = nn.LSTM(
            embedding_dim*2, hidden_size,
            num_layers=n_layers,
            batch_first=True,
            bidirectional=True
        )
        self.fc = nn.Linear(hidden_size * 2, n_classes)

    def forward(self, question, context):
        question_embed = self.questsion_embedding(question)
        context_embed = self.context_embedding(context)

        question_n_context = torch.cat(
            (question_embed, context_embed),
            dim=1
        )

        lstm_out, _ = self.lstm(question_n_context)
        lstm_out = lstm_out[:, -1, :]

        out = self.fc(lstm_out)

        return out

In [22]:
#Model parameters
EMBEDDING_DIM = 32
HIDDEN_SIZE = 128
VOCAB_SIZE = len(vocab)
N_LAYERS = 2
N_CLASSES = len(classes)
MAX_CONTEXT_LEN = 20 # max_seq_len = 20
model = QAModel(
    VOCAB_SIZE, EMBEDDING_DIM, HIDDEN_SIZE, N_LAYERS, N_CLASSES
)

input_context = torch.randint(0, 10, size=(1, MAX_CONTEXT_LEN))
model.eval()
with torch.no_grad():
    logits = model(input_context, input_context)

print(logits.shape)

torch.Size([1, 3])


In [24]:
LR = 1e-3
EPOCHS = 20
optimizer = torch.optim.Adam(
    model.parameters(), lr=LR
)
criterion = nn.CrossEntropyLoss()

model.train()
for _ in range(EPOCHS):
    for idx, (input_question_ids, input_context_ids, answer_id) in enumerate(train_loader):
        optimizer.zero_grad()
        ouputs = model(input_question_ids, input_context_ids)
        loss = criterion(ouputs, answer_id)
        loss.backward()
        optimizer.step()
        print(loss.item())

1.1375741958618164
1.0970170497894287
1.1093568801879883
1.0720338821411133
1.037621021270752
1.2168563604354858
1.1056877374649048
1.0685009956359863
1.0238081216812134
1.171850323677063
1.0436911582946777
1.0676417350769043
1.075181007385254
0.8446162939071655
0.8897265195846558
1.024154543876648
0.9852239489555359
0.4325188994407654
0.9271750450134277
0.2334124594926834
0.491590678691864
0.9268525242805481
0.8360755443572998
0.04568187892436981
0.3978568911552429
0.836230993270874
0.40441057085990906
0.8005020022392273
0.36667779088020325
0.8077053427696228
0.39257487654685974
0.7396137118339539
0.368089884519577
0.7426681518554688
0.7229405641555786
0.0021495348773896694
0.382950097322464
0.6704594492912292
0.7136815786361694
0.001401038491167128


In [26]:
model.eval()
with torch.no_grad():
    sample = qa_dataset[0]
    context, question, answer = sample.values()
    question_ids, context_ids = vectorize(question, context, MAX_CONTEXT_LEN)
    question_ids = question_ids.unsqueeze(0)
    context_ids = context_ids.unsqueeze(0)
    outputs = model(question_ids, context_ids)
    _, predicted = torch.max(outputs.data, 1)
    print(f'Context: {context}')
    print(f'Question: {question}')
    print(f'Prediction: {idx_to_classes[predicted.numpy()[0]]}')

Context: My name is Thang and I am from Vietnam
Question: What is my name?
Prediction: Thang
