# Building a Custom Question-Answering Model with PyTorch and SQuAD

In [1]:
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from transformers import BertTokenizer

In [3]:
# Device configuration
# Device configuration - enable multiple GPUs if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_count = torch.cuda.device_count()
print(f"Using {device_count} GPUs")

Using 2 GPUs


In [4]:
# Load the SQuAD dataset
dataset = load_dataset("squad")

README.md:   0%|          | 0.00/7.62k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/14.5M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/1.82M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/87599 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10570 [00:00<?, ? examples/s]

In [5]:
dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 87599
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10570
    })
})

In [6]:
# Load BERT tokenizer
from transformers import BertTokenizerFast

# Load the fast tokenizer
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]



In [7]:
print(tokenizer)

BertTokenizerFast(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}


In [8]:
def preprocess_function(examples):
    # Tokenize the inputs
    tokenized = tokenizer(
        examples["question"],
        examples["context"],
        truncation="only_second",  # Truncate the context if it's too long
        max_length=512,
        padding="max_length",
        return_offsets_mapping=True,  # Get the mapping between tokens and original text
    )

    # Process start and end positions
    start_positions = []
    end_positions = []

    for i, offsets in enumerate(tokenized["offset_mapping"]):
        answer = examples["answers"][i]
        start_char = answer["answer_start"][0]
        end_char = start_char + len(answer["text"][0])

        # Find the token indices corresponding to the answer span
        sequence_ids = tokenized.sequence_ids(i)
        context_start = sequence_ids.index(1)
        context_end = len(sequence_ids) - sequence_ids[::-1].index(1)

        token_start = context_start
        token_end = context_end - 1

        for idx, (start, end) in enumerate(offsets):
            if start <= start_char < end:
                token_start = idx
            if start < end_char <= end:
                token_end = idx
                break

        # Set to CLS token index if the answer is not fully inside context
        if token_start < context_start or token_end >= context_end:
            token_start = token_end = tokenizer.cls_token_id

        start_positions.append(token_start)
        end_positions.append(token_end)

    # Add positions to the tokenized output
    tokenized["start_positions"] = start_positions
    tokenized["end_positions"] = end_positions

    # Remove unnecessary fields
    tokenized.pop("offset_mapping", None)

    return tokenized


In [9]:
# Preprocess the data
tokenized_datasets = dataset.map(preprocess_function, batched=True, remove_columns=dataset["train"].column_names)

Map:   0%|          | 0/87599 [00:00<?, ? examples/s]

Map:   0%|          | 0/10570 [00:00<?, ? examples/s]

In [11]:
class SquadDataset(Dataset):
    """PyTorch Dataset for the SQuAD data."""
    def __init__(self, encodings):
        # Convert the Hugging Face dataset columns to torch tensors
        self.encodings = {key: torch.tensor(value) for key, value in encodings.items()}  # Ensure the values are tensors

    def __len__(self):
        return len(self.encodings["input_ids"])

    def __getitem__(self, idx):
        return {key: tensor[idx] for key, tensor in self.encodings.items()}

In [12]:
# Convert dataset to PyTorch Dataset
train_dataset = SquadDataset(tokenized_datasets["train"][:10000])
valid_dataset = SquadDataset(tokenized_datasets["validation"][:1000])

In [13]:
train_dataset[0]

{'input_ids': tensor([  101,  2000,  3183,  2106,  1996,  6261,  2984,  9382,  3711,  1999,
          8517,  1999, 10223, 26371,  2605,  1029,   102,  6549,  2135,  1010,
          1996,  2082,  2038,  1037,  3234,  2839,  1012, 10234,  1996,  2364,
          2311,  1005,  1055,  2751,  8514,  2003,  1037,  3585,  6231,  1997,
          1996,  6261,  2984,  1012,  3202,  1999,  2392,  1997,  1996,  2364,
          2311,  1998,  5307,  2009,  1010,  2003,  1037,  6967,  6231,  1997,
          4828,  2007,  2608,  2039, 14995,  6924,  2007,  1996,  5722,  1000,
          2310,  3490,  2618,  4748,  2033, 18168,  5267,  1000,  1012,  2279,
          2000,  1996,  2364,  2311,  2003,  1996, 13546,  1997,  1996,  6730,
          2540,  1012,  3202,  2369,  1996, 13546,  2003,  1996, 24665, 23052,
          1010,  1037, 14042,  2173,  1997,  7083,  1998,  9185,  1012,  2009,
          2003,  1037, 15059,  1997,  1996, 24665, 23052,  2012, 10223, 26371,
          1010,  2605,  2073,  1996,  6

In [14]:
len(train_dataset)

10000

In [15]:
# Create DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=8)

In [16]:
class BERTEmbeddings(nn.Module):
    def __init__(self, vocab_size, hidden_size, max_position_embeddings):
        super(BERTEmbeddings, self).__init__()
        self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
        self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
        self.token_type_embeddings = nn.Embedding(2, hidden_size)
        self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(0.1)

    def forward(self, input_ids, token_type_ids=None, position_ids=None):
        seq_length = input_ids.size(1)
        if position_ids is None:
            position_ids = torch.arange(seq_length, device=input_ids.device).unsqueeze(0).expand(input_ids.size(0), -1)
        word_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids if token_type_ids is not None else torch.zeros_like(input_ids))
        embeddings = word_embeddings + position_embeddings + token_type_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

In [17]:
class SelfAttention(nn.Module):
    def __init__(self, hidden_size, num_attention_heads):
        super(SelfAttention, self).__init__()
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = hidden_size // num_attention_heads
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        self.query = nn.Linear(hidden_size, self.all_head_size)
        self.key = nn.Linear(hidden_size, self.all_head_size)
        self.value = nn.Linear(hidden_size, self.all_head_size)
        self.dense = nn.Linear(hidden_size, hidden_size)

    def transpose_for_scores(self, x):
        new_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        return x.view(*new_shape).permute(0, 2, 1, 3)

    def forward(self, hidden_states):
        query_layer = self.transpose_for_scores(self.query(hidden_states))
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) / (self.attention_head_size ** 0.5)
        attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1)
        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_shape = context_layer.size()[:-2] + (self.all_head_size,)
        return self.dense(context_layer.view(*new_shape))

In [18]:
class TransformerLayer(nn.Module):
    def __init__(self, hidden_size, num_attention_heads, intermediate_size):
        super(TransformerLayer, self).__init__()
        self.attention = SelfAttention(hidden_size, num_attention_heads)
        self.attention_output = nn.Sequential(
            nn.LayerNorm(hidden_size),
            nn.Dropout(0.1)
        )
        self.intermediate = nn.Linear(hidden_size, intermediate_size)
        self.output = nn.Sequential(
            nn.LayerNorm(intermediate_size),
            nn.ReLU(),
            nn.Linear(intermediate_size, hidden_size),
            nn.Dropout(0.1)
        )

    def forward(self, hidden_states):
        attention_output = self.attention(hidden_states)
        hidden_states = self.attention_output(attention_output + hidden_states)
        intermediate_output = self.intermediate(hidden_states)
        
        # Project intermediate_output to match hidden_states' size
        hidden_states = self.output(intermediate_output) + hidden_states  # Ensure same dimensions for addition
        return hidden_states


In [19]:
class BERTModel(nn.Module):
    def __init__(self, vocab_size, hidden_size, num_attention_heads, num_encoder_layers, intermediate_size, max_position_embeddings):
        super(BERTModel, self).__init__()
        self.embeddings = BERTEmbeddings(vocab_size, hidden_size, max_position_embeddings)
        self.encoder_layers = nn.ModuleList([
            TransformerLayer(hidden_size, num_attention_heads, intermediate_size) for _ in range(num_encoder_layers)
        ])

    def forward(self, input_ids, token_type_ids=None, position_ids=None):
        hidden_states = self.embeddings(input_ids, token_type_ids, position_ids)
        for layer in self.encoder_layers:
            hidden_states = layer(hidden_states)
        return hidden_states

In [20]:
class BERTForQA(nn.Module):
    def __init__(self, bert_model):
        super(BERTForQA, self).__init__()
        self.bert = bert_model
        self.start_logits = nn.Linear(bert_model.embeddings.word_embeddings.embedding_dim, 1)
        self.end_logits = nn.Linear(bert_model.embeddings.word_embeddings.embedding_dim, 1)

    def forward(self, input_ids, token_type_ids=None, position_ids=None):
        hidden_states = self.bert(input_ids, token_type_ids, position_ids)
        # print(f"Shape of hidden states: {hidden_states.shape}")  # Debugging line
        start_logits = self.start_logits(hidden_states).squeeze(-1)
        end_logits = self.end_logits(hidden_states).squeeze(-1)
        return start_logits, end_logits

In [21]:
# Model initialization
model = BERTForQA(
    BERTModel(vocab_size=30522, hidden_size=768, num_attention_heads=12, num_encoder_layers=12, intermediate_size=3072, max_position_embeddings=512)
).to(device)

In [22]:
if device_count > 1:
    model = nn.DataParallel(model)

In [23]:
print(model)

DataParallel(
  (module): BERTForQA(
    (bert): BERTModel(
      (embeddings): BERTEmbeddings(
        (word_embeddings): Embedding(30522, 768)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder_layers): ModuleList(
        (0-11): 12 x TransformerLayer(
          (attention): SelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dense): Linear(in_features=768, out_features=768, bias=True)
          )
          (attention_output): Sequential(
            (0): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (1): Dropout(p=0.1, inplace=False)
          )
          (intermediate): Line

In [24]:
optimizer = AdamW(model.parameters(), lr=5e-5)
loss_fn = nn.CrossEntropyLoss()

In [25]:
from torch.amp import autocast, GradScaler

In [26]:
scaler = GradScaler()

In [47]:
# Training loop with AMP and device/stream management
for epoch in range(5):
    model.train()
    for batch in train_dataloader:
        input_ids = batch["input_ids"].to(device)
        token_type_ids = batch["token_type_ids"].to(device)
        start_positions = batch["start_positions"].to(device)
        end_positions = batch["end_positions"].to(device)

        optimizer.zero_grad()

        # With autocast and stream management for mixed precision
        with autocast('cuda'):   # Mixed precision with AMP
            start_logits, end_logits = model(input_ids, token_type_ids)
            start_loss = loss_fn(start_logits, start_positions)
            end_loss = loss_fn(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

        # Scale the loss and backpropagate
        scaler.scale(total_loss).backward()

        # Step the optimizer
        scaler.step(optimizer)

        # Update the scaler for the next step
        scaler.update()

    print(f"Epoch {epoch + 1}: Loss = {total_loss.item():.4f}")

Epoch 1: Loss = 3.9257
Epoch 2: Loss = 4.7504
Epoch 3: Loss = 3.9344
Epoch 4: Loss = 6.2700
Epoch 5: Loss = 6.3367


In [50]:
def predict_answer(model, question, context, tokenizer, device):
    # Tokenize the input question and context pair
    inputs = tokenizer(question, context, return_tensors="pt", truncation=True, padding=True, max_length=512)

    # Move the input tensors to the correct device (GPU or CPU)
    input_ids = inputs["input_ids"].to(device)
    token_type_ids = inputs["token_type_ids"].to(device)

    # Get predictions from the model
    model.eval()
    with torch.no_grad():
        start_logits, end_logits = model(input_ids, token_type_ids)

    # Get the predicted start and end token indices
    start_pred = torch.argmax(start_logits, dim=-1)
    end_pred = torch.argmax(end_logits, dim=-1)

    # Convert token indices to words (answer span)
    answer_tokens = inputs["input_ids"][0][start_pred:end_pred+1]
    answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)
    
    return answer# Sample question and context

In [None]:
question = "What is the capital of France?"
context = "France, located in Western Europe, has Paris as its capital and largest city."

answer = predict_answer(model, question, context, tokenizer, device)

print(f"Predicted Answer: {answer}")

Predicted Answer: paris


In [None]:
question = "Who wrote the novel '1984'?"
context = "'1984' is a dystopian social science fiction novel and cautionary tale, written by the English writer George Orwell in 1949."

answer = predict_answer(model, question, context, tokenizer, device)

print(f"Predicted Answer: {answer}")

Predicted Answer: george


In [None]:
question = "Who wrote the novel 'Pride and Prejudice'?"
context = "'Pride and Prejudice' is a romantic novel of manners written by Jane Austen in 1813. It explores the emotional development of the protagonist, Elizabeth Bennet, who learns the error of making hasty judgments."

answer = predict_answer(model, question, context, tokenizer, device)

print(f"Predicted Answer: {answer}")

Predicted Answer: jane
