### 1. Data

In [33]:
qa_dataset = [
    {
        'context': 'My name is AIVN and I am from Vietnam.',
        'question': 'What is my name?',
        'answer': 'AIVN'
    },
    {
        'context': 'I love painting and my favorite artist is Vincent Van Gogh.',
        'question': 'What is my favorite activity?',
        'answer': 'painting'
    },
    {
        'context': 'I am studying computer science at the University of Tokyo.',
        'question': 'What am I studying?',
        'answer': 'computer science'
    },
    {
        'context': 'My favorite book is "To Kill a Mockingbird" by Harper Lee.',
        'question': 'What is my favorite book?', 
        'answer': '"To Kill a Mockingbird"'
    }
]

### 2. Vectorization

In [34]:
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator


tokenier = get_tokenizer("basic_english")

def yield_tokens(examples: list):
    for item in examples:
        yield tokenier("<cls> " + item["context"] + " <sep> " + item["question"])
        # token cls dành cho các câu không có đáp án, ứng với st_pos = end_pos = 0

vocab = build_vocab_from_iterator(
    iterator=yield_tokens(qa_dataset), 
    specials=["<unk>", "<pad>", "<bos>", "<eos>", "<sep>", "<cls>"]
)
vocab.set_default_index(vocab["<unk>"])
vocab.get_stoi()

{'vincent': 41,
 'vietnam': 40,
 'university': 38,
 'to': 36,
 'the': 35,
 'painting': 33,
 'of': 32,
 'mockingbird': 31,
 'love': 30,
 'lee': 29,
 'what': 12,
 '<sep>': 4,
 '<bos>': 2,
 'science': 34,
 '?': 9,
 'my': 7,
 'is': 6,
 'at': 22,
 '<cls>': 5,
 'gogh': 26,
 '.': 8,
 '<eos>': 3,
 '<pad>': 1,
 'computer': 24,
 'artist': 21,
 'favorite': 10,
 'harper': 27,
 '<unk>': 0,
 'and': 14,
 'studying': 17,
 'i': 11,
 'aivn': 20,
 'am': 13,
 'van': 39,
 'book': 15,
 'tokyo': 37,
 'name': 16,
 'kill': 28,
 'by': 23,
 'a': 18,
 'from': 25,
 'activity': 19}

In [35]:
PAD_IDX = vocab["<pad>"]

def pad_and_truncate(input_ids: list[int], max_seq_len: int):
    if len(input_ids) > max_seq_len:
        input_ids = input_ids[:max_seq_len]
    else:
        input_ids = input_ids + [PAD_IDX] * (max_seq_len - len(input_ids))
    return input_ids


MAX_SEQ_LEN = 22
text = "I love AIVN"
tokenized_text = tokenier(text)
tokens = [vocab[token] for token in tokenized_text]
print(tokens)
tokens = pad_and_truncate(input_ids=tokens, max_seq_len=MAX_SEQ_LEN)
print(tokens)

[11, 30, 20]
[11, 30, 20, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


In [36]:
import torch


def vectorize(question: str, context: str, answer: str):
    
    input_text = question + " <sep> " + context
    input_ids = [vocab[token] for token in tokenier(input_text)]
    input_ids = pad_and_truncate(input_ids=input_ids, max_seq_len=MAX_SEQ_LEN)

    answer_ids = [vocab[token] for token in tokenier(answer)]
    st_pos = input_ids.index(answer_ids[0])
    end_pos = st_pos + len(answer_ids) - 1

    input_ids = torch.tensor(input_ids, dtype=torch.long)
    st_pos = torch.tensor(st_pos, dtype=torch.long)
    end_pos = torch.tensor(end_pos, dtype=torch.long)
    return input_ids, st_pos, end_pos

input_ids, st_pos, end_pos = vectorize(
    question=qa_dataset[0]['question'],
    context=qa_dataset[0]['context'], 
    answer=qa_dataset[0]['answer']
)
print(input_ids)
print(st_pos)
print(end_pos)

tensor([12,  6,  7, 16,  9,  4,  7, 16,  6, 20, 14, 11, 13, 25, 40,  8,  1,  1,
         1,  1,  1,  1])
tensor(9)
tensor(9)


In [37]:
id2token = {id: label for label, id in vocab.get_stoi().items()}
for token in input_ids.numpy():
    print(id2token[token], end= ' ')

what is my name ? <sep> my name is aivn and i am from vietnam . <pad> <pad> <pad> <pad> <pad> <pad> 

### 3. Create dataset

In [38]:

from torch.utils.data import Dataset

class QADataset(Dataset):
    def __init__(self, data: list[dict]):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        item = self.data[index]
        question_text = item['question']
        context_text = item['context']
        answer_text = item['answer']

        input_ids, st_pos, end_pos = vectorize(
            question=question_text,
            context=context_text,
            answer=answer_text
        )
        return input_ids, st_pos, end_pos


In [39]:
from torch.utils.data import DataLoader


train_dataset = QADataset(data=qa_dataset)
train_dataloader = DataLoader(dataset=train_dataset, 
                              batch_size=1, 
                              shuffle=True)

for batch in train_dataloader:
    input_ids, st_pos, end_pos = batch
    print(input_ids)
    print(st_pos)
    print(end_pos)
    print("=" * 100)

tensor([[12,  6,  7, 10, 19,  9,  4, 11, 30, 33, 14,  7, 10, 21,  6, 41, 39, 26,
          8,  1,  1,  1]])
tensor([9])
tensor([9])
tensor([[12,  6,  7, 16,  9,  4,  7, 16,  6, 20, 14, 11, 13, 25, 40,  8,  1,  1,
          1,  1,  1,  1]])
tensor([9])
tensor([9])
tensor([[12, 13, 11, 17,  9,  4, 11, 13, 17, 24, 34, 22, 35, 38, 32, 37,  8,  1,
          1,  1,  1,  1]])
tensor([9])
tensor([10])
tensor([[12,  6,  7, 10, 15,  9,  4,  7, 10, 15,  6, 36, 28, 18, 31, 23, 27, 29,
          8,  1,  1,  1]])
tensor([11])
tensor([14])


### 4. Model

In [40]:
import torch 
import torch.nn as nn

class TransformerBlock(nn.Module):
    def __init__(self, 
                 embed_dim: int, 
                 num_heads: int, 
                 ff_dim: int,
                 dropout_prob: float = 0.1):
        super().__init__()
        # https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html
        self.attn = nn.MultiheadAttention(
            embed_dim=embed_dim,
            num_heads=num_heads, # create num_heads attention
            batch_first=True
        )

        self.ffn = nn.Sequential(
            nn.Linear(in_features=embed_dim, out_features=ff_dim, bias=True),
            nn.ReLU(),
            nn.Linear(in_features=ff_dim, out_features=embed_dim)
        )

        self.layernorm1 = nn.LayerNorm(normalized_shape=embed_dim)
        self.layernorm2 = nn.LayerNorm(normalized_shape=embed_dim)
        self.dropout1 = nn.Dropout(p=dropout_prob)
        self.dropout2 = nn.Dropout(p=dropout_prob)
    
    def forward(self, query, key, value): 
        # query, key, value: [N, seq_len, embed_dim]
        attn_output, attn_output_weights = self.attn(query, key, value)
        # print("attn_output: ", attn_output.size()) # => output model same input: [N, seq_len, embed_dim]
        # print("attn_output_weights: ", attn_output_weights.size()) => softmax(Q@K.T): [N, seq_len, seq_len]

        attn_output = self.dropout1(attn_output) # [N, seq_len, embed_dim]
        out_1 = self.layernorm1(query + attn_output) # [N, seq_len, embed_dim]
        ffn_output = self.ffn(out_1) # [N, seq_len, embed_dim]
        ffn_output = self.dropout2(ffn_output) # [N, seq_len, embed_dim]
        out_2 = self.layernorm2(out_1 + ffn_output) # [N seq_len, embed_dim]
        return out_2 # [N,seq_len, embedim]

In [41]:
class TokenAndPositionEmbedding(nn.Module):
    def __init__(self, 
                 embed_dim: int, 
                 vocab_size: int, 
                 max_length: int):
        super().__init__()
        self.embed_model = nn.Embedding(
            num_embeddings=vocab_size, 
            embedding_dim=embed_dim
        )

        self.pos_embed = nn.Embedding(
            num_embeddings=max_length,
            embedding_dim=embed_dim
        )

    def forward(self, x):
        N, seq_len = x.size() # 32, 128
        positions = torch.arange(0, seq_len).expand(N, seq_len) # [N, seq_len]
        token_embed = self.embed_model(x) # [N, seq_len, embed_dim]
        position_embed = self.pos_embed(positions) # [N, seq_len, embed_dim]
        return token_embed + position_embed # [N, seq_len, embed_dim]

In [42]:
class QAModel(nn.Module):
    def __init__(self, 
                 vocab_size: int, 
                 embed_dim: int, 
                 n_heads: int, 
                 ff_dim: int,
                 seq_len: int):
        super().__init__()

        self.embed_model = TokenAndPositionEmbedding(embed_dim=embed_dim, 
                                                     vocab_size=vocab_size, 
                                                     max_length=seq_len)
        self.transformer = TransformerBlock(embed_dim=embed_dim,
                                            num_heads=n_heads,
                                            ff_dim=ff_dim)
        
        self.st_linear = nn.Linear(in_features=embed_dim,
                                   out_features=1)
        self.end_linear = nn.Linear(in_features=embed_dim, 
                                    out_features=1)
    def forward(self, input): # input: [N, seq_len]
        embedding = self.embed_model(input) # [N, seq_len, embed_dim]
        transformer_out = self.transformer(embedding, embedding, embedding) # [N, seq_len, embed_dim]
        st_logits = self.st_linear(transformer_out).squeeze(-1) # [N, seq_len]
        end_logits = self.end_linear(transformer_out).squeeze(-1) # [N, seq_len]
        return st_logits, end_logits


In [43]:
EMBEDDING_DIM = 64
FF_DIM = 128
VOCAB_SIZE = len(vocab)
N_HEADS  = 2


model = QAModel(vocab_size=VOCAB_SIZE, 
                embed_dim=EMBEDDING_DIM,
                n_heads=N_HEADS, 
                ff_dim=FF_DIM, 
                seq_len=MAX_SEQ_LEN)
mock_data = torch.randint(low=0, high=10, size=(1, MAX_SEQ_LEN))

output = model(mock_data)
print(output[0].shape)
print(output[1].shape)


torch.Size([1, 22])
torch.Size([1, 22])


### 5. Training

In [44]:
LR = 1e-3
EPOCHS = 20
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

criterion = nn.CrossEntropyLoss()

model.train()
for _ in range(EPOCHS):
    for idx, (input_ids, st_pos, end_pos) in enumerate(train_dataloader):
        optimizer.zero_grad()

        st_pos_logits, end_pos_logits = model(input_ids)
        st_loss = criterion(st_pos_logits, st_pos)
        end_loss = criterion(end_pos_logits, end_pos)
        loss = (st_loss + end_loss) / 2
        
        loss.backward()
        optimizer.step()

        print(loss.item())
        
        

3.254849910736084
3.3437860012054443
3.2557053565979004
3.444972038269043
2.617384910583496
2.982722282409668
3.0575757026672363
2.661027193069458
2.593743324279785
2.0395679473876953
2.6776223182678223
2.4687552452087402
2.0641984939575195
1.7262616157531738
2.2824065685272217
2.056992530822754
1.9079394340515137
1.2548178434371948
1.7481648921966553
1.4543797969818115
1.541587471961975
0.8851906061172485
1.221709132194519
1.0393528938293457
0.8276312351226807
0.8561490178108215
0.45124197006225586
0.8200867176055908
0.6931801438331604
0.33383098244667053
0.5316500663757324
0.4451179504394531
0.2661314904689789
0.397219181060791
0.38671767711639404
0.2752888798713684
0.2576621174812317
0.2204362154006958
0.11286113411188126
0.2247752845287323
0.09133392572402954
0.11935053765773773
0.14225052297115326
0.1153278797864914
0.06740006804466248
0.08350738883018494
0.10964974015951157
0.0869818702340126
0.09620773792266846
0.07457447797060013
0.05715733766555786
0.04342472925782204
0.070986

In [46]:
model.eval()
with torch.no_grad():
    sample = qa_dataset[3]
    context, question, answer = sample.values()
    input_ids, st_pos, end_pos = vectorize(question=question, 
                                           context=context, 
                                           answer=answer)
    input_ids = input_ids.unsqueeze(0) # add batch dimention

    st_logits, end_logits = model(input_ids)

    offset = len(tokenier(question)) + 1
    st_pos = torch.argmax(st_logits, dim=1).numpy()[0]
    end_pos = torch.argmax(end_logits, dim=1).numpy()[0]
    
    st_pos -= offset
    end_pos -= offset

    st_pos = max(st_pos, 0)
    end_pos = min(end_pos, len(tokenier(context)) -1) 

    if end_pos >= st_pos:
        context_tokens = tokenier(context)
        predicted_answer_tokens = context_tokens[st_pos: end_pos +1]
        predicted_answer = " ".join(predicted_answer_tokens)
    else:
        predicted_answer = " "
    
    print(f"Context: {context}")
    print(f"Question: {question}")
    print(f"Prediction: {predicted_answer}")
    


Context: My favorite book is "To Kill a Mockingbird" by Harper Lee.
Question: What is my favorite book?
Prediction: to kill a mockingbird
