In [1]:
import json
import transformers
from transformers import BertTokenizerFast, AutoModel

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from tqdm import tqdm
import copy

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'Running on {device}')

data_path = "./data"
model_path = "./models"

  from .autonotebook import tqdm as notebook_tqdm


Running on cuda:0


In [2]:
class BERTForSQuAD(nn.Module):
    def __init__(self, bert_model=None):
        super(BERTForSQuAD, self).__init__()
        self.bert = bert_model
        self.qa_outputs = None
        if bert_model:
            self.qa_outputs = nn.Linear(bert_model.config.hidden_size, 2, bias=True)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )
        last_hidden_state = outputs.last_hidden_state  # Shape: [batch_size, sequence_length, hidden_size]

        logits = self.qa_outputs(last_hidden_state)  # Shape: [batch_size, sequence_length, 2]
        start_logits, end_logits = logits.split(1, dim=-1)  # Split into start and end logits
        start_logits = start_logits.squeeze(-1)  # Shape: [batch_size, sequence_length]
        end_logits = end_logits.squeeze(-1)      # Shape: [batch_size, sequence_length]

        return start_logits, end_logits
    
    def save(self, save_path):
        self.bert.save_pretrained(f"{save_path}/bert-pod")
        torch.save(self.qa_outputs.state_dict(), f"{save_path}/linear_adapter.pth")
        torch.save(self.state_dict(), f"{save_path}/full_model.pth")
    
    def load(self, load_path):
        self.bert = AutoModel.from_pretrained(f"{load_path}/bert-pod")
        self.qa_outputs = nn.Linear(self.bert.config.hidden_size, 2, bias=True)
        self.qa_outputs.load_state_dict(torch.load(f"{load_path}/linear_adapter.pth", weights_only=True))
        self.load_state_dict(torch.load(f"{load_path}/full_model.pth", weights_only=True))

### Inference Function

In [3]:
def get_answer(model, tokenizer, question, context):
    # Preprocess
    tokenized = tokenizer(question, context, truncation=True)
    input_ids = tokenized['input_ids']
    token_type_ids = tokenized['token_type_ids']

    input_ids_tensor = torch.tensor([input_ids]).to(device)
    token_type_ids_tensor = torch.tensor([token_type_ids]).to(device)

    # Predict
    model.to(device)
    model.eval()
    output = model(input_ids_tensor, token_type_ids=token_type_ids_tensor)

    answer_start = torch.argmax(output[0])  
    answer_end = torch.argmax(output[1])

    # Generate answer text
    answer_tokens = tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end+1])
    answer = tokenizer.convert_tokens_to_string(answer_tokens)

    return answer, answer_start

### Load models

In [4]:
bert_squad_vanilla = BERTForSQuAD()
bert_squad_vanilla.load(f"{model_path}/vanilla_finetuning")

bert_squad_noised = BERTForSQuAD()
bert_squad_noised.load(f"{model_path}/noised_finetuning")


tokenizer = BertTokenizerFast.from_pretrained('csarron/bert-base-uncased-squad-v1')

### Try examples

In [5]:
def model_answers(question, context):
    vanilla_answer, vanilla_start = get_answer(bert_squad_vanilla, tokenizer, question, context)
    noised_answer, noised_start = get_answer(bert_squad_noised, tokenizer, question, context)

    print(f"Vanilla training {vanilla_start}: {vanilla_answer}")
    print(f"Noised training {noised_start}: {noised_answer}")

In [6]:
context = "Anne called Bill last night, but Bill did not reply."
question = "Who did not reply Anne's call?"

model_answers(question, context)

Vanilla training 13: bill last night, but bill
Noised training 18: 


*   Simple examples

In [7]:
context = "Anne likes to eat apples."
question = "What does Anne likes to eat?"

model_answers(question, context)

Vanilla training 13: apples
Noised training 13: apples


In [8]:
context = "Anne likes to eat apples."
question = "Who likes to eat apples?"

model_answers(question, context)

Vanilla training 8: anne
Noised training 14: 


In [9]:
context = "Anne likes to eat apples."
question = "What does Anne likes to consume?"

model_answers(question, context)

Vanilla training 13: apples
Noised training 13: apples


*   Queries not making much sense

In [10]:
context = "Anne likes to eat cars."
question = "What does Anne likes to consume?"

model_answers(question, context)

Vanilla training 12: eat cars
Noised training 13: cars


*   Noised training overfits to "be", "do", "has" and punctuation marks

In [11]:
context = "zzz is a food. xxx likes to eat zzz. Why yyy has zzz?"
question = "What does xxx likes to eat?"

model_answers(question, context)

Vanilla training 21: 
Noised training 21: zzz


In [12]:
context = "zzz is a food. xxx likes to eat zzz Why yyy has zzz?"
question = "What does xxx likes to eat?"

model_answers(question, context)

Vanilla training 21: 
Noised training 21: zzz why yyy has zzz
