# A Simple transformer question and answer model needs:
* Tokinizer (Here i am using BERT's way to tokenize beginning and end of sentences)
* Transformer encoder
* QA head (predict start and end position)
* Training and inference logic



# Use self-built attention head

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


In [None]:
config = {
        "vocab_size": 45,
        "hidden_size": 64,
        "max_position_embeddings": 64,
        "num_attention_heads": 4,
        'intermediate_size':10,
        'hidden_dropout_prob':0.01,
        "num_hidden_layers": 12,
        'mask':None
    }

In [None]:
## Tokinizer

class SimpleTokinizer:
  def __init__(self):
    self.vocab = {"[PAD]":0, "[CLS]":1, "[SEP]":2,"[UNK]":3}
    self.reverse_vocab = {0:"[PAD]", 1:"[CLS]", 2:"[SEP]",3:"[UNK]"}
    self.idx = 4

  def build_vocab(self,texts):
    for text in texts:
      for word in text.lower().split():
        if word not in self.vocab:
          self.vocab[word] = self.idx
          self.reverse_vocab[self.idx] = word
          self.idx+=1
  def encode(self, question, contaxt, max_len = 64):
    ## for each QA, input takes format of [CLS] question tokens [SEP] context tokens [SEP]
    tokens = ["[CLS]"]+question.lower().split()+["[SEP]"]+contaxt.lower().split()+["[SEP]"]
    token_ids = [self.vocab.get(token, self.vocab["[UNK]"]) for token in tokens]
    token_type_ids = [1]*(len(question.split())+2)+[2]*(len(context.split())+1)
    attention_mask = [1] * len(token_ids)
    padding = [0]*(max_len - len(token_ids))
    # print(token_type_ids)
    return {
        'input_ids':torch.tensor(token_ids + padding[:max_len - len(token_ids)]),
        'attention_mask':torch.tensor(attention_mask+padding[:max_len - len(token_ids)]),
        'token':tokens+['[PAD]']*len(padding),
        'token_type_ids':torch.tensor(token_type_ids+padding[:max_len - len(token_ids)]),

    }

In [None]:
## Sample
question = "What is KNN?"
context = '''KNN, or k-Nearest Neighbors, is a supervised machine learning algorithm used for both classification and regression tasks. It classifies new data points by finding the "k" most similar data points (neighbors) in the training data and assigning the new data point to the majority class among those neighbors.'''


In [None]:
## Tokenizer
tokenizer = SimpleTokinizer()
tokenizer.build_vocab([question, context])
# input_ids, attention_mask = tokenizer.encode(question, context)
inputs = tokenizer.encode(question, context)

In [None]:
# inputs

In [None]:
# input_ids = input_ids.unsqueeze(0)
# attention_mask = attention_mask.unsqueeze(0)

In [None]:
# config['mask'] = attention_mask

In [None]:
# input_ids.shape, attention_mask.shape

In [None]:
## Atttention head

def scaled_dot_product_attention(q,k,v, mask = None):
  # print(q.shape,k.shape,v.shape)
  dim_k = k.size(-1) ## embedding size
  # print(dim_k)
  # print(k.transpose(1,2).shape)
  scores = torch.bmm(q,k.transpose(1,2)) / math.sqrt(dim_k)
  if mask is not None:
    scores = scores.masked_fill(mask==0, -float('inf'))
  weights = F.softmax(scores, dim=1)
  attention_outputs = torch.bmm(weights, v)
  return attention_outputs


class AttentionHead(nn.Module):
  def __init__(self, embed_dim, head_dim, mask=None):
    super().__init__()
    self.q = nn.Linear(embed_dim, head_dim)
    self.k = nn.Linear(embed_dim, head_dim)
    self.v = nn.Linear(embed_dim, head_dim)
    self.mask = mask

  def forward(self,hidden_state):
    attention_outputs = scaled_dot_product_attention(self.q(hidden_state),self.k(hidden_state),self.v(hidden_state), mask = self.mask)

    return attention_outputs


class MultiHeadAttention(nn.Module):
  def __init__(self,config):
    super().__init__()
    embed_dim = config['hidden_size']
    num_heads = config['num_attention_heads']
    head_dim = embed_dim // num_heads
    mask = config['mask']
    self.heads = nn.ModuleList(
        [AttentionHead(embed_dim, head_dim, mask) for _ in range(num_heads)]
    )
    self.output_linear = nn.Linear(embed_dim, embed_dim)

  def forward(self,hidden_state):
    # print(hidden_state.shape)
    # for h in self.heads:
    #   print(h(hidden_state)[0][0].shape)
    # print(self.heads)
    x = torch.cat([h(hidden_state) for h in self.heads], dim = -1)
    x = self.output_linear(x)
    return x



In [None]:
AttentionHead(config['hidden_size'], config['num_attention_heads'], config['mask'])

In [None]:
multihead_attn = MultiHeadAttention(config)
token_emb = nn.Embedding(config['vocab_size'], config['hidden_size'])
input_embeds = token_emb(inputs['input_ids'])
input_embeds = input_embeds.unsqueeze(0)
attn_output = multihead_attn(input_embeds)

In [None]:
class FeedForward(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.linear1 = nn.Linear(config['hidden_size'], config['intermediate_size'])
    self.linear2 = nn.Linear(config['intermediate_size'], config['hidden_size'])
    self.gelu = nn.GELU()
    self.dropout = nn.Dropout(config['hidden_dropout_prob'])

  def forward(self,x):
    x = self.linear1(x)
    x = self.gelu(x)
    x = self.linear2(x)
    x = self.dropout(x)
    return x


In [None]:
feed_forward = FeedForward(config)

In [None]:
ff_outputs = feed_forward(attn_output)

In [None]:
ff_outputs.shape

In [None]:
class TransformerEncoderLayer(nn.Module):
  def __init__(self,config):
    super().__init__()
    self.layer_norm1 = nn.LayerNorm(config['hidden_size'])
    self.layer_norm2 = nn.LayerNorm(config['hidden_size'])
    self.attention = MultiHeadAttention(config)
    self.feedforward = FeedForward(config)
  def forward(self, x):
    hidden_state = self.layer_norm1(x)
    atten_output =  self.attention(hidden_state)
    x+=atten_output
    x += self.feedforward(self.layer_norm2(x))
    return x


In [None]:
encoder_layer = TransformerEncoderLayer(config)
print(input_embeds.shape)
encoder_layer(input_embeds).shape

In [None]:
class Embeddings(nn.Module):
  def __init__(self,config):
    super().__init__()
    self.token_embeddings = nn.Embedding(config['vocab_size'], config['hidden_size'])
    self.position_embeddings = nn.Embedding(config['max_position_embeddings'], config['hidden_size'])
    self.layer_norm = nn.LayerNorm(config['hidden_size'], eps = 1e-12)
    self.dropout = nn.Dropout()

  def forward(self, input_ids):
    seq_length = input_ids.unsqueeze(0).size(1)
    position_ids = torch.arange(seq_length, dtype = torch.long).unsqueeze(0)
    # print(input_ids)
    token_embeddings = self.token_embeddings(input_ids)
    position_embeddings = self.position_embeddings(position_ids)
    embeddings = token_embeddings+position_embeddings
    embeddings = self.layer_norm(embeddings)
    embeddings = self.dropout(embeddings)
    return embeddings



In [None]:
embedding_layer = Embeddings(config)
embedding_layer(inputs['input_ids'])#.size()

In [None]:
class TransformerEncoder(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.embeddings = Embeddings(config)
    self.layers = nn.ModuleList(
        [TransformerEncoderLayer(config) for _ in range(config['num_hidden_layers'])]
    )
  def forward(self,x):
    x = self.embeddings(x)
    for layer in self.layers:
      # print(layer)
      x = layer(x)
    return x

In [None]:
encoder = TransformerEncoder(config)
encoder(inputs['input_ids']).size()

In [None]:
## Add a QA head
class QA_Transformer(nn.Module):
  def __init__(self, vocab_size, d_model = 64, max_len = 64, heads = 4):
    super().__init__()
    self.config = {
        "vocab_size": vocab_size,
        "hidden_size": d_model,
        "max_position_embeddings": max_len,
        "num_attention_heads": heads,
        'intermediate_size':10,
        'hidden_dropout_prob':0.01,
        "num_hidden_layers": 12,
        'mask':None
    }
    # self.embedding = nn.Embedding(self.config['vocab_size'], self.config['hidden_size'])
    self.encoder = TransformerEncoder(self.config)
    # self.position_embeddings = nn.Parameter(torch.randn(1, self.config['max_position_embeddings'], self.config['hidden_size']))
    # self.position_embeddings = nn.Parameter(torch.randint(1, self.config['max_position_embeddings'], self.config['hidden_size']))
    self.qa_outputs = nn.Linear(self.config['hidden_size'], 2)

  def forward(self, input_ids):#, attention_mask):
    input_ids = input_ids.unsqueeze(0)
    # attention_mask = attention_mask.unsqueeze(0)
    # x = self.embedding(input_ids)+self.position_embeddings[:,:input_ids.size(1)]#.long()
    x = self.encoder(input_ids)#, attention_mask)
    logits = self.qa_outputs(x)
    start_logits, end_logits = logits.split(1,dim=-1)
    return start_logits.squeeze(-1), end_logits.squeeze(-1)




In [None]:
## Sample
question = "What is KNN?"
context = '''KNN, or k-Nearest Neighbors, is a supervised machine learning algorithm used for both classification and regression tasks. It classifies new data points by finding the "k" most similar data points (neighbors) in the training data and assigning the new data point to the majority class among those neighbors.'''


In [None]:
## Tokenizer
tokenizer = SimpleTokinizer()
tokenizer.build_vocab([question, context])
# input_ids, attention_mask = tokenizer.encode(question, context)
# input_ids = input_ids.unsqueeze(0)
# attention_mask = attention_mask.unsqueeze(0)
input = tokenizer.encode(question, context)

In [None]:
# input['input_ids']

In [None]:
input['attention_mask'].shape, input['input_ids'].shape, len(tokenizer.vocab)

In [None]:
input['input_ids'].shape

In [None]:
## Model
model = QA_Transformer(vocab_size =len(tokenizer.vocab), d_model = 64, max_len = 64, heads = 4)
# start_logits, end_logits = model(input_ids.unsqueeze(0), attention_mask.unsqueeze(0))
start_logits, end_logits = model(input['input_ids'])#, input['attention_mask'])

In [None]:
# start_logits, end_logits

In [None]:
# Get answer span
start_idx = torch.argmax(start_logits, dim=1).item()
end_idx = torch.argmax(end_logits, dim=1).item()
print(start_idx, end_idx)
tokens = inputs['input_ids'].tolist()
# print(tokens)
answer = [tokenizer.reverse_vocab.get(t, '[UNK]') for t in tokens[start_idx:end_idx+1]]
print("Predicted answer:", " ".join(answer))

## Model has not been trained yet and no seed is set so output is very unstable

### Output from the above:
46 58
Predicted answer: to the majority class among those neighbors. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD]

### But sometimes it could be empty.

# Existing attention head from pytorch

In [None]:
class SimpleTokenizer:
    def __init__(self):
        self.vocab = {"[PAD]": 0, "[CLS]": 1, "[SEP]": 2, "[UNK]": 3}
        self.reverse_vocab = {0: "[PAD]", 1: "[CLS]", 2: "[SEP]", 3: "[UNK]"}
        self.idx = 4

    def build_vocab(self, texts):
        for text in texts:
            for word in text.lower().split():
                if word not in self.vocab:
                    self.vocab[word] = self.idx
                    self.reverse_vocab[self.idx] = word
                    self.idx += 1

    def encode(self, question, context, max_len=64):
        tokens = ["[CLS]"] + question.lower().split() + ["[SEP]"] + context.lower().split() + ["[SEP]"]
        token_ids = [self.vocab.get(token, self.vocab["[UNK]"]) for token in tokens]
        attention_mask = [1] * len(token_ids)
        padding = [0] * (max_len - len(token_ids))
        return (
            torch.tensor(token_ids + padding[:max_len - len(token_ids)]),
            torch.tensor(attention_mask + padding[:max_len - len(token_ids)])
        )


In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, dim, heads):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=heads, batch_first=True)
        self.norm1 = nn.LayerNorm(dim)
        self.ff = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.ReLU(),
            nn.Linear(dim * 4, dim)
        )
        self.norm2 = nn.LayerNorm(dim)

    def forward(self, x, mask):
        # print(x)
        attn_output, _ = self.attn(x, x, x, key_padding_mask=~mask.bool())
        # print(attn_output)
        x = self.norm1(x + attn_output)
        ff_output = self.ff(x)
        x = self.norm2(x + ff_output)
        return x


In [None]:
class QA_Transformer(nn.Module):
    def __init__(self, vocab_size, d_model=64, max_len=64, heads=4):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embedding = nn.Parameter(torch.randn(1, max_len, d_model))
        self.encoder = TransformerBlock(d_model, heads)
        self.qa_outputs = nn.Linear(d_model, 2)

    def forward(self, input_ids, attention_mask):
        x = self.embedding(input_ids) + self.pos_embedding[:, :input_ids.size(1)]
        # print(x.shape)
        x = self.encoder(x, attention_mask)
        # print(x.shape)
        logits = self.qa_outputs(x)  # [batch, seq_len, 2]
        start_logits, end_logits = logits.split(1, dim=-1)
        return start_logits.squeeze(-1), end_logits.squeeze(-1)


In [None]:
# Sample data
question = "What is KNN?"
context = '''KNN, or k-Nearest Neighbors, is a supervised machine learning algorithm used for both classification and regression tasks. It classifies new data points by finding the "k" most similar data points (neighbors) in the training data and assigning the new data point to the majority class among those neighbors.'''


# Tokenizer
tokenizer = SimpleTokenizer()
tokenizer.build_vocab([question, context])
input_ids, attention_mask = tokenizer.encode(question, context)

# Model
model = QA_Transformer(vocab_size=len(tokenizer.vocab))
start_logits, end_logits = model(input_ids.unsqueeze(0), attention_mask.unsqueeze(0))

# Get answer span
start_idx = torch.argmax(start_logits, dim=1).item()
end_idx = torch.argmax(end_logits, dim=1).item()
tokens = input_ids.tolist()
answer = [tokenizer.reverse_vocab.get(t, '[UNK]') for t in tokens[start_idx:end_idx+1]]
print("Predicted answer:", " ".join(answer))


In [None]:
attention_mask

In [None]:
tokens[start_idx:end_idx+1]

## Model has not been trained yet and no seed is set so output is very unstable

### output from above:
Predicted answer: is a supervised machine learning algorithm used for both classification and regression tasks. it classifies new data points

### But another run will be different
### Model needs to be trained and set seeds

# Train model with SQuAD data

In [None]:
!pip install datasets

In [None]:
from datasets import load_dataset

## Load SQuAD data
squad = load_dataset('squad')

In [None]:
squad

In [None]:
train_data = squad['train']
val_data =squad['validation']

In [None]:
train_data

In [None]:
val_data

In [None]:
train_data[200]

In [None]:
val_data[100]['question']

In [None]:
def char_to_token_span(context, answer_start, answer_text, tokenizer):
  words = context.lower().split()
  char_idx = 0
  token_start = token_end = -1
  for i, word in enumerate(words):
    if context.lower().find(answer_text.lower(), char_idx) != -1:
      char_idx = context.lower().find(answer_text.lower(),char_idx)
      token_start = len(context[:char_idx].split())
      token_end = token_start+len(answer_text.split()) - 1
      break
  return token_start, token_end

In [None]:
class QADataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer, max_len=64):
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.samples = []

        for item in data:
            q = item["question"]
            c = item["context"]
            a = item["answers"]["text"][0]
            a_start = item["answers"]["answer_start"][0]
            self.tokenizer.build_vocab([q, c])
            input_ids, attn_mask = tokenizer.encode(q, c, max_len)
            start, end = char_to_token_span(c, a_start, a, tokenizer)

            # Adjust for [CLS] and question tokens
            offset = 1 + len(q.split()) + 1
            start += offset
            end += offset
            if end >= max_len: continue  # discard too long

            self.samples.append((input_ids, attn_mask, start, end))

    def __len__(self): return len(self.samples)
    def __getitem__(self, idx):
        input_ids, mask, start, end = self.samples[idx]
        input_ids = F.pad(input_ids, (0, self.max_len - input_ids.shape[0]), value=0)
        mask = F.pad(mask, (0, self.max_len - mask.shape[0]), value=0)
        return {
            "input_ids": input_ids,
            "attention_mask": mask,
            "start_pos": torch.tensor(start),
            "end_pos": torch.tensor(end)
        }


In [None]:
torch.manual_seed(42)

In [None]:
# def train(model, dataset, epochs=25, batch_size=16, lr=5e-4):
tokenizer = SimpleTokenizer()
train_dataset = QADataset(train_data.select(range(2000)), tokenizer)  # Use a small subset
model = QA_Transformer(vocab_size=len(tokenizer.vocab))

initial_weights = model.embedding.weight.clone().detach() # Store initial weights

epochs=42
batch_size=16
lr=5e-4
dataset = train_dataset
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()
model.train()

loss_hist = []

for epoch in range(epochs):
    total_loss = 0
    for batch in dataloader:
        input_ids = batch["input_ids"]
        mask = batch["attention_mask"]
        start_pos = batch["start_pos"]
        end_pos = batch["end_pos"]

        start_logits, end_logits = model(input_ids, mask)
        loss = loss_fn(start_logits, start_pos) + loss_fn(end_logits, end_pos)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}: Loss = {total_loss / len(dataloader):.4f}")
    loss_hist.append(total_loss / len(dataloader))


In [None]:
updated_weights = model.embedding.weight.clone().detach()

In [None]:
initial_weights

In [None]:
updated_weights

In [None]:
import matplotlib.pyplot as plt
plt.plot(loss_hist)

In [None]:
# # Step-by-step
# tokenizer = SimpleTokenizer()
# train_dataset = QADataset(train_data.select(range(2000)), tokenizer)  # Use a small subset
# model = QA_Transformer(vocab_size=len(tokenizer.vocab))

# train(model, train_dataset)


In [None]:
# updated_weights = model.embedding.weight.clone().detach()

In [None]:
# initial_weights

In [None]:
# updated_weights

In [None]:
def predict_answer(model, tokenizer, input_ids, attention_mask):
  model.eval()
  with torch.no_grad():
    start_logits, end_logits = model(input_ids.unsqueeze(0), attention_mask.unsqueeze(0))
    start = torch.argmax(start_logits, dim=1).item()
    end = torch.argmax(end_logits, dim=1).item()

    max_len = input_ids.size(0)
    if start>end:
      return ""
    tokens = input_ids[start:end+1].tolist()
    words = [tokenizer.reverse_vocab.get(t,'[UNK]') for t in tokens]
    return " ".join(words)


In [None]:
import string
def normalize(text):
  def remove_punc(s):
    return "".join(c for c in s if c not in string.punctuation)

  def remove_articles(s):
    return " ".join([w for w in s.split() if w not in ["a",'an','the']])
  return remove_articles(remove_punc(text.lower())).strip()



In [None]:
def compute_f1(pred, truth):
  pred_tokens = normalize(pred).split()
  truth_tokens = normalize(truth).split()
  common = set(pred_tokens) & set(truth_tokens)
  if len(common) == 0: return 0
  precision = len(common) / len(pred_tokens)
  recall = len(common) / len(truth_tokens)
  return 2*(precision*recall) / (precision+recall)

In [None]:
def compute_em(pred,truth):
  return int(normalize(pred) == normalize(truth))

In [None]:
def evaluate(model, dataset, tokenizer, num_samples=100):
  em_scores = []
  f1_scores = []
  for i in range(num_samples):
    sample = dataset.data[i]
    question = sample['question']
    context = sample['context']
    gt_answer = sample['answers']['text'][0]
    input_ids, attention_mask = tokenizer.encode(question, context, max_len = 64)
    input_ids = input_ids[:64]
    attention_mask = attention_mask[:64]
    pred_answer = predict_answer(model, tokenizer, input_ids, attention_mask)

    em_scores.append(compute_em(pred_answer, gt_answer))
    f1_scores.append(compute_f1(pred_answer, gt_answer))

  avg_em = sum(em_scores) / len(em_scores)
  avg_f1 = sum(f1_scores) / len(f1_scores)

  print(f"evaluate on {str(num_samples)} \n")
  print(f'exact match: {avg_em:.2%}')
  print(f'f1 score: {avg_f1:.2%}')
  return avg_em, avg_f1

In [None]:
## Val dataset has tokens that are not included in training dataset, hence causing errors
## Might need to train data with both val and train
## Or lemmentization or or better tokenization methods
# val_subset = val_data.select(range(100))
# val_dataset = QADataset(val_subset, tokenizer)
# evaluate(model, val_dataset, tokenizer)

train_subset = train_data.select(range(100))
train_dataset = QADataset(train_subset, tokenizer)
evaluate(model, train_dataset, tokenizer)

In [None]:
train_dataset.data[1]

In [None]:
input_ids, attention_mask = tokenizer.encode(train_dataset.data[1]['question'], train_dataset.data[1]['context'], max_len = 64)
input_ids = input_ids[:64]
attention_mask = attention_mask[:64]
pred_answer = predict_answer(model, tokenizer, input_ids, attention_mask)

In [None]:
pred_answer