In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertModel
from datasets import load_dataset
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
class BLANC(nn.Module):
    def __init__(self, model_name="distilbert-base-uncased", freeze_bert=True, dropout_rate=0.3):
        super().__init__()
        self.bert = DistilBertModel.from_pretrained(model_name)
        hidden_size = self.bert.config.hidden_size  # 768 for BERT, 312 for TinyBERT

        # Dropout for regularization
        self.dropout = nn.Dropout(dropout_rate)

        # Answer-span prediction heads (start & end positions)
        self.qa_start = nn.Linear(hidden_size, 1)
        self.qa_end = nn.Linear(hidden_size, 1)

        # Context prediction head (binary classification per token)
        self.context_head = nn.Linear(hidden_size, 1)

        # Optionally freeze BERT's weights
        if freeze_bert:
            for param in self.bert.parameters():
                param.requires_grad = False  # Do not update BERT weights

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        sequence_output = self.dropout(outputs.last_hidden_state)  # Apply dropout

        # Predict start & end positions
        start_logits = self.qa_start(sequence_output).squeeze(-1)
        end_logits = self.qa_end(sequence_output).squeeze(-1)

        # Predict context word probabilities (sigmoid activation applied later)
        context_logits = self.context_head(sequence_output).squeeze(-1)

        return start_logits, end_logits, context_logits


In [5]:
def create_context_labels(input_ids, answer_start, answer_end, window_size=3, decay_factor=0.7):
    """
    Generate soft-labels for context prediction by assigning probabilities to words near the answer span.
    """
    context_labels = np.zeros(len(input_ids))

    # Assign 1.0 to words inside the answer span
    context_labels[answer_start:answer_end + 1] = 1.0

    # Assign decaying probabilities to surrounding words
    for i in range(1, window_size + 1):
        if answer_start - i >= 0:
            context_labels[answer_start - i] = decay_factor ** i
        if answer_end + i < len(input_ids):
            context_labels[answer_end + i] = decay_factor ** i

    return context_labels


In [6]:
def compute_loss(start_logits, end_logits, context_logits, start_positions, end_positions, context_labels, lambda_factor=0.8):
    """
    Compute combined loss: Answer span + Context prediction.
    """
    ce_loss = nn.CrossEntropyLoss()
    bce_loss = nn.BCEWithLogitsLoss()

    # Answer span loss
    start_loss = ce_loss(start_logits, start_positions)
    end_loss = ce_loss(end_logits, end_positions)
    answer_loss = (start_loss + end_loss) / 2

    # Context prediction loss
    context_loss = bce_loss(context_logits, context_labels.float())

    # Weighted sum of both losses
    total_loss = (1 - lambda_factor) * answer_loss + lambda_factor * context_loss
    return total_loss


In [7]:
from transformers import DistilBertTokenizerFast
dataset = load_dataset("squad")

# Load tokenizer
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")

# Tokenizing function
def tokenize_and_align(batch):
    """
    Tokenize question + context together and align answer spans.
    """
    inputs = tokenizer(batch["question"], batch["context"], truncation=True, padding="max_length", max_length=384, return_offsets_mapping=True)

    start_positions = []
    end_positions = []

    for i, offsets in enumerate(inputs["offset_mapping"]):
        # Extracting the answer details correctly for each sample in the batch
        if len(batch["answers"][i]["text"]) == 0:
            # No Answer Case
            start_positions.append(-100)
            end_positions.append(-100)
        else:
            answer_start = batch["answers"][i]["answer_start"][0]  # First answer occurrence
            answer_text = batch["answers"][i]["text"][0]  # First answer text

            # Locate token index corresponding to answer span
            start_index = end_index = None
            for idx, (start, end) in enumerate(offsets):
                if start <= answer_start < end:
                    start_index = idx
                if start < answer_start + len(answer_text) <= end:
                    end_index = idx

            # If valid indices were found, append them
            if start_index is not None and end_index is not None:
                start_positions.append(start_index)
                end_positions.append(end_index)
            else:
                start_positions.append(-100)  # Default to no answer
                end_positions.append(-100)

    # Store tokenized input and computed start/end positions
    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions

    # Remove offset mapping to avoid errors during training
    inputs.pop("offset_mapping")

    return inputs

# Apply tokenization with batching
train_dataset = dataset["train"].map(tokenize_and_align, batched=True)
val_dataset = dataset["validation"].map(tokenize_and_align, batched=True)


In [8]:
train_dataset

Dataset({
    features: ['id', 'title', 'context', 'question', 'answers', 'input_ids', 'attention_mask', 'start_positions', 'end_positions'],
    num_rows: 87599
})

In [9]:
val_dataset

Dataset({
    features: ['id', 'title', 'context', 'question', 'answers', 'input_ids', 'attention_mask', 'start_positions', 'end_positions'],
    num_rows: 10570
})

In [7]:
class QADataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        return {key: torch.tensor(item[key]) for key in ["input_ids", "attention_mask", "start_positions", "end_positions"]}

# Convert to PyTorch datasets
train_dataset = QADataset(train_dataset)
val_dataset = QADataset(val_dataset)

# Dataloaders
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)


In [8]:
def compute_loss(start_logits, end_logits, start_positions, end_positions):
    ce_loss = nn.CrossEntropyLoss(ignore_index=-100)  # Ignore samples with no answer
    start_loss = ce_loss(start_logits, start_positions)
    end_loss = ce_loss(end_logits, end_positions)
    return (start_loss + end_loss) / 2  # Average loss

In [9]:
from tqdm import tqdm

In [10]:
import string
import re

def normalize_answer(s):
    """
    Lowercases, removes punctuation, articles, and extra whitespace from answers.
    """
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punctuation(text):
        return text.translate(str.maketrans('', '', string.punctuation))

    return white_space_fix(remove_articles(remove_punctuation(s.lower())))

def exact_match(prediction, ground_truth):
    """
    Returns 1 if the predicted answer exactly matches the ground truth, otherwise 0.
    """
    return int(normalize_answer(prediction) == normalize_answer(ground_truth))

def f1_score(prediction, ground_truth):
    """
    Compute F1 score between predicted answer and ground truth.
    """
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()

    common_tokens = set(prediction_tokens) & set(ground_truth_tokens)
    if len(common_tokens) == 0:
        return 0.0

    precision = len(common_tokens) / len(prediction_tokens)
    recall = len(common_tokens) / len(ground_truth_tokens)
    return (2 * precision * recall) / (precision + recall)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model = BLANC(freeze_bert=False).to(device)

# optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-5, weight_decay=0.01)
optimizer = optim.AdamW(model.parameters(), lr=1e-5, weight_decay=0.01)

epochs = 10
patience = 3  # Stop training if no improvement after 3 epochs
best_val_loss = float("inf")
patience_counter = 0  # Early stopping counter

for epoch in range(epochs):
    model.train()
    total_train_loss = 0

    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} (Training)")

    for batch in progress_bar:
        batch = {k: v.to(device) for k, v in batch.items()}
        optimizer.zero_grad()

        start_logits, end_logits, _ = model(batch["input_ids"], batch["attention_mask"])  # Ignore context logits

        loss = compute_loss(start_logits, end_logits, batch["start_positions"], batch["end_positions"])

        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()

        progress_bar.set_postfix(loss=loss.item())

    avg_train_loss = total_train_loss / len(train_loader)

    # Validation 
    model.eval()
    total_val_loss = 0
    total_em, total_f1, num_samples = 0, 0, 0  # Initialize F1 & EM counters

    val_progress_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} (Validation)")

    with torch.no_grad():
        for batch in val_progress_bar:
            batch = {k: v.to(device) for k, v in batch.items()}

            start_logits, end_logits, _ = model(batch["input_ids"], batch["attention_mask"])

            loss = compute_loss(start_logits, end_logits, batch["start_positions"], batch["end_positions"])
            total_val_loss += loss.item()

            for i in range(batch["input_ids"].shape[0]):  
                input_ids = batch["input_ids"][i]
                start_pred = torch.argmax(start_logits[i]).item()
                end_pred = torch.argmax(end_logits[i]).item()

                # Ensure start ≤ end
                if start_pred > end_pred:
                    continue

                predicted_answer = tokenizer.decode(input_ids[start_pred:end_pred + 1])

                # Get ground truth answer
                start_true = batch["start_positions"][i].item()
                end_true = batch["end_positions"][i].item()

                if start_true == -100 or end_true == -100:
                    continue  

                true_answer = tokenizer.decode(input_ids[start_true:end_true + 1])

                # Computing F1 and EM
                total_em += exact_match(predicted_answer, true_answer)
                total_f1 += f1_score(predicted_answer, true_answer)
                num_samples += 1

            val_progress_bar.set_postfix(val_loss=loss.item())

    avg_val_loss = total_val_loss / len(val_loader)
    avg_f1 = total_f1 / num_samples
    avg_em = total_em / num_samples

    print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | EM: {avg_em:.4f} | F1: {avg_f1:.4f}")

    # Early stopping 
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0  # Reseting
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping triggered! Stopping training.")
            break  # Stop training


cuda


Epoch 1/10 (Training): 100%|██████████| 685/685 [10:06<00:00,  1.13it/s, loss=1.18]
Epoch 1/10 (Validation): 100%|██████████| 83/83 [00:31<00:00,  2.66it/s, val_loss=1.52]


Epoch 1 | Train Loss: 2.3471 | Val Loss: 1.4100 | EM: 0.5550 | F1: 0.7131


Epoch 2/10 (Training): 100%|██████████| 685/685 [10:06<00:00,  1.13it/s, loss=1.2]
Epoch 2/10 (Validation): 100%|██████████| 83/83 [00:31<00:00,  2.66it/s, val_loss=1.34]


Epoch 2 | Train Loss: 1.3751 | Val Loss: 1.2624 | EM: 0.5899 | F1: 0.7419


Epoch 3/10 (Training): 100%|██████████| 685/685 [10:07<00:00,  1.13it/s, loss=1.26]
Epoch 3/10 (Validation): 100%|██████████| 83/83 [00:31<00:00,  2.65it/s, val_loss=1.24]


Epoch 3 | Train Loss: 1.1879 | Val Loss: 1.1883 | EM: 0.6128 | F1: 0.7608


Epoch 4/10 (Training): 100%|██████████| 685/685 [10:06<00:00,  1.13it/s, loss=1.08]
Epoch 4/10 (Validation): 100%|██████████| 83/83 [00:31<00:00,  2.64it/s, val_loss=1.24]


Epoch 4 | Train Loss: 1.0646 | Val Loss: 1.1568 | EM: 0.6188 | F1: 0.7668


Epoch 5/10 (Training): 100%|██████████| 685/685 [10:06<00:00,  1.13it/s, loss=0.664]
Epoch 5/10 (Validation): 100%|██████████| 83/83 [00:31<00:00,  2.64it/s, val_loss=1.16]


Epoch 5 | Train Loss: 0.9651 | Val Loss: 1.1422 | EM: 0.6257 | F1: 0.7726


Epoch 6/10 (Training): 100%|██████████| 685/685 [10:06<00:00,  1.13it/s, loss=1.01]
Epoch 6/10 (Validation): 100%|██████████| 83/83 [00:31<00:00,  2.63it/s, val_loss=1.26]


Epoch 6 | Train Loss: 0.8812 | Val Loss: 1.1756 | EM: 0.6251 | F1: 0.7739


Epoch 7/10 (Training): 100%|██████████| 685/685 [10:06<00:00,  1.13it/s, loss=0.717]
Epoch 7/10 (Validation): 100%|██████████| 83/83 [00:31<00:00,  2.64it/s, val_loss=1.25]


Epoch 7 | Train Loss: 0.8070 | Val Loss: 1.1580 | EM: 0.6258 | F1: 0.7753


Epoch 8/10 (Training): 100%|██████████| 685/685 [10:06<00:00,  1.13it/s, loss=0.843]
Epoch 8/10 (Validation): 100%|██████████| 83/83 [00:31<00:00,  2.65it/s, val_loss=1.35]

Epoch 8 | Train Loss: 0.7405 | Val Loss: 1.2021 | EM: 0.6299 | F1: 0.7771
Early stopping triggered! Stopping training.





In [13]:
torch.save(model.state_dict(), "qa_squad2_final.pth")

In [None]:
def predict_answer(model, question, context):
    model.eval()

    # Tokenize input
    inputs = tokenizer(question, context, return_tensors="pt", truncation=True, padding=True).to(device)

    # Forward pass
    start_logits, end_logits, context_logits = model(inputs["input_ids"], inputs["attention_mask"])

    # Get predicted answer span
    start_idx = torch.argmax(start_logits)
    end_idx = torch.argmax(end_logits)

    # Extract predicted answer
    answer_tokens = inputs["input_ids"][0, start_idx:end_idx + 1]
    predicted_answer = tokenizer.decode(answer_tokens)

    return predicted_answer

print(predict_answer(model, "What is the capital of Azerbaijan?", "It's Baku, full of rich history"))


baku
