In [None]:
!pip install transformers datasets torch

import torch
from torch.utils.data import DataLoader
from transformers import BertTokenizerFast, BertForQuestionAnswering
from datasets import load_dataset
from tqdm import tqdm




In [None]:
model_name = "bert-base-uncased"
tokenizer = BertTokenizerFast.from_pretrained(model_name)
model = BertForQuestionAnswering.from_pretrained(model_name)


Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
dataset = load_dataset("squad")

train_data = dataset["train"].select(range(300))
val_data = dataset["validation"].select(range(100))


In [None]:
def preprocess(examples):
    inputs = tokenizer(
        examples["question"],
        examples["context"],
        truncation=True,
        padding="max_length",
        max_length=256,
        return_offsets_mapping=True
    )

    start_positions = []
    end_positions = []

    for i, offset in enumerate(inputs["offset_mapping"]):
        answer = examples["answers"][i]
        start_char = answer["answer_start"][0]
        end_char = start_char + len(answer["text"][0])

        sequence_ids = inputs.sequence_ids(i)
        context_start = sequence_ids.index(1)
        context_end = len(sequence_ids) - 1 - sequence_ids[::-1].index(1)

        start_token, end_token = 0, 0
        for j in range(context_start, context_end + 1):
            if offset[j][0] <= start_char and offset[j][1] > start_char:
                start_token = j
            if offset[j][0] < end_char and offset[j][1] >= end_char:
                end_token = j
        start_positions.append(start_token)
        end_positions.append(end_token)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    inputs.pop("offset_mapping")

    return inputs

train_dataset = train_data.map(preprocess, batched=True, remove_columns=train_data.column_names)
val_dataset = val_data.map(preprocess, batched=True, remove_columns=val_data.column_names)

train_dataset.set_format(type="torch")
val_dataset.set_format(type="torch")


In [None]:
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8)


In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

model.train()
for epoch in range(1):
    loop = tqdm(train_loader, leave=True)
    for batch in loop:
        optimizer.zero_grad()
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        start_positions = batch["start_positions"].to(device)
        end_positions = batch["end_positions"].to(device)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            start_positions=start_positions,
            end_positions=end_positions
        )
        loss = outputs.loss
        loss.backward()
        optimizer.step()

        loop.set_description(f"Epoch {epoch}")
        loop.set_postfix(loss=loss.item())


Epoch 0: 100%|██████████| 38/38 [13:15<00:00, 20.94s/it, loss=4.95]


In [None]:
model.eval()

context = """
The Amazon rainforest is the largest tropical rainforest in the world,
covering over five and a half million square kilometers.
It is known for its biodiversity and plays a key role in regulating the Earth's climate.
"""
question = "What is the largest tropical rainforest in the world?"

inputs = tokenizer(question, context, return_tensors="pt").to(device)
with torch.no_grad():
    outputs = model(**inputs)

start = torch.argmax(outputs.start_logits)
end = torch.argmax(outputs.end_logits) + 1
answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][start:end]))

print("Question:", question)
print("Answer:", answer)


Question: What is the largest tropical rainforest in the world?
Answer: five and a half million


In [None]:
context2 = """
Python is a programming language created by Guido van Rossum and first released in 1991.
It is known for its simple syntax and readability.
"""
question2 = "Who created Python?"

inputs2 = tokenizer(question2, context2, return_tensors="pt").to(device)
with torch.no_grad():
    outputs2 = model(**inputs2)

start2 = torch.argmax(outputs2.start_logits)
end2 = torch.argmax(outputs2.end_logits) + 1
answer2 = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs2["input_ids"][0][start2:end2]))

print("Question:", question2)
print("Answer:", answer2)


Question: Who created Python?
Answer: 1991
