In [1]:
import torch
import torch.nn as nn
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader

Input:
- Question + <'pad'> + context

Output:
- The start index and end index of the extracted information from context

In [4]:
qa_dataset = [
    {
        'context': 'My name is Thang and I am from Vietnam',
        'question': 'What is my name?',
        'answer': 'Thang'
    },
    {
        'context': 'I love painting and my favorite artist is Vincent Van Gough',
        'question': 'What is my favorite activity?',
        'answer' : 'painting'
    },
    {
        'context' : 'I am studying Computer Science at Nanyang Technological University',
        'question' : 'What am I studying?',
        'answer' : 'Computer Science'
    }
]

In [3]:
#define tokenizer
tokenizer = get_tokenizer('basic_english')

text = 'I love cats'
tokenizer(text)

['i', 'love', 'cats']

In [5]:
def yield_tokens(data):
    for item in data:
        yield(tokenizer(item['context'] + '<sep>' + item['question']))

#create vocab
vocab = build_vocab_from_iterator(
    yield_tokens(qa_dataset),
    specials=['<unk>', '<pad>', '<bos>', '<eos>', '<sep>']
)
vocab.set_default_index(vocab['<unk>'])
vocab.get_stoi()

{'<sep>': 4,
 'vietnam<sep>what': 28,
 'technological': 24,
 '<bos>': 2,
 '<unk>': 0,
 'artist': 15,
 '<eos>': 3,
 '<pad>': 1,
 'i': 5,
 'is': 6,
 'am': 9,
 'my': 7,
 '?': 8,
 'name': 12,
 'and': 10,
 'favorite': 11,
 'studying': 13,
 'activity': 14,
 'at': 16,
 'computer': 17,
 'from': 18,
 'gough<sep>what': 19,
 'love': 20,
 'nanyang': 21,
 'painting': 22,
 'science': 23,
 'thang': 25,
 'university<sep>what': 26,
 'van': 27,
 'vincent': 29}

In [6]:
PAD_IDX = 1
def pad_and_truncate(input_ids, max_seq_len):
    if len(input_ids) > max_seq_len:
        input_ids = input_ids[:max_seq_len]
    elif len(input_ids) < max_seq_len:
        input_ids += [PAD_IDX] * (max_seq_len - len(input_ids))

    return input_ids

In [16]:
MAX_SEQ_LENGTH = 22
def vectorize(question, context, answer, max_seq_len):
    input_text = question + '<sep>' + context
    input_ids = [vocab[token] for token in tokenizer(input_text)]
    input_ids = pad_and_truncate(input_ids, MAX_SEQ_LENGTH)

    answer_ids = [vocab[token] for token in tokenizer(answer)]
    start_positions = input_ids.index(answer_ids[0])
    end_positions = start_positions + len(answer_ids) - 1

    input_ids = torch.tensor(input_ids, dtype=torch.long)
    start_positions = torch.tensor(start_positions, dtype=torch.long)
    end_positions = torch.tensor(end_positions, dtype=torch.long)

    return input_ids, start_positions, end_positions 

In [20]:
class QADataset(nn.Module):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        question_text = item['question']
        context_text = item['context']
        answer_text = item['answer']

        input_ids, start_positions, end_positions = vectorize(
            question_text, context_text, answer_text, 22
        )

        return input_ids, start_positions, end_positions

In [21]:
def decode(input_ids):
    return ' '.join([vocab.lookup_token(token) for token in input_ids])

In [22]:
train_dataset = QADataset(qa_dataset)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)

for batch in train_loader:
    q, c, a = batch
    print(q, c, a)
   

tensor([[ 0,  9,  5, 13,  8,  0,  9, 13, 17, 23, 16, 21, 24,  0,  1,  1,  1,  1,
          1,  1,  1,  1],
        [ 0,  6,  7, 11, 14,  8,  0, 20, 22, 10,  7, 11, 15,  6, 29, 27,  0,  1,
          1,  1,  1,  1]]) tensor([8, 8]) tensor([9, 8])
tensor([[ 0,  6,  7, 12,  8,  0, 12,  6, 25, 10,  5,  9, 18,  0,  1,  1,  1,  1,
          1,  1,  1,  1]]) tensor([8]) tensor([8])


In [24]:
# Model: bidirectional lstm
class QAModel(nn.Module):
    def __init__(self, 
                 vocab_size, embedding_dim, hidden_size,
                 n_layers):
        super().__init__()
        self.input_embedding = nn.Embedding(
            vocab_size, embedding_dim
        )
        

        self.lstm = nn.LSTM(
            embedding_dim, hidden_size,
            num_layers=n_layers,
            batch_first=True,
            bidirectional=True
        )
        self.start_linear = nn.Linear(hidden_size*2, 1)
        self.end_linear = nn.Linear(hidden_size*2, 1)

    def forward(self, text):
        input_embedded = self.input_embedding(text)

        lstm_out, _ = self.lstm(input_embedded)
        start_logits = self.start_linear(lstm_out).squeeze(-1)
        end_logits = self.end_linear(lstm_out).squeeze(-1)

    

        return start_logits, end_logits

In [27]:
#Model parameters
EMBEDDING_DIM = 32
HIDDEN_SIZE = 128
VOCAB_SIZE = len(vocab)
N_LAYERS = 2

model = QAModel(
    VOCAB_SIZE, EMBEDDING_DIM, HIDDEN_SIZE, N_LAYERS
)

input_context = torch.randint(0, 10, size=(1, 10))
model.eval()
with torch.no_grad():
    start_logits, end_logits = model(input_context)

print(start_logits.shape)

torch.Size([1, 10])


In [29]:
LR = 1e-3
EPOCHS = 20
optimizer = torch.optim.Adam(
    model.parameters(), lr=LR
)
criterion = nn.CrossEntropyLoss()

model.train()
for _ in range(EPOCHS):
    for idx, (input_ids, start_positions, end_positions) in enumerate(train_loader):
        optimizer.zero_grad()
        start_logits, end_logits = model(input_ids)
        start_loss = criterion(start_logits, start_positions)
        end_loss = criterion(end_logits, end_positions)
        loss = start_loss + end_loss
        loss.backward()
        optimizer.step()
        print(loss.item())

0.11022062599658966
0.4110000431537628
0.13101322948932648
0.6656321287155151
0.12594982981681824
0.21103407442569733
0.01870870217680931
0.49638795852661133
0.1484886258840561
0.019184071570634842
0.019021131098270416
0.4201688766479492
0.15264801681041718
0.02022932469844818
0.022151129320263863
0.13079124689102173
0.08015856146812439
0.028210598975419998
0.0825396478176117
0.027810359373688698
0.07384565472602844
0.024248000234365463
0.05756865441799164
0.023093003779649734
0.046066176146268845
0.015439847484230995
0.03565260395407677
0.01469854824244976
0.012510443106293678
0.04787231981754303
0.02847272902727127
0.009645873680710793
0.026761554181575775
0.008145080879330635
0.008858866058290005
0.039425767958164215
0.021974265575408936
0.00964362546801567
0.007578282617032528
0.031262706965208054


In [31]:
model.eval()
with torch.no_grad():
    sample = qa_dataset[1]
    context, question, answer = sample.values()
    input_ids, start_position, end_position = vectorize(context, question, answer, MAX_SEQ_LENGTH)
    input_ids = input_ids.unsqueeze(0)
    start_logits, end_logits = model(input_ids)

    offset = len(tokenizer(question)) + 2
    start_position = torch.argmax(start_logits, dim=1).numpy()[0]
    end_position = torch.argmax(end_logits, dim=1).numpy()[0]

    start_postion -= offset
    end_position -= offset

    start_position = max(start_position, 0)
    end_position = min(end_position, len(tokenizer(context)) - 1)

    if end_position >= start_postion:
        #extract the predicted answer span
        context_tokens = tokenizer(context)
        predicted_answer_tokens = context_tokens[start_position:end_postion+1]
        predicted_answer = ' '.join(predicted_answer_tokens)
        print(context)
        print(question)
        print(predicted_answer)

I love painting and my favorite artist is Vincent Van Gough
What is my favorite activity?
painting
