In [1]:
import torch
from torch.utils.data import DataLoader
from transformers import (
    BertTokenizerFast,
)
from datasets import load_dataset
import evaluate
from utils import preprocess_and_tokenize, collate_fn, post_processing_function, post_processing_function_with_softmax
from swag_transformers.swag_bert import SwagBertForQuestionAnswering
from collections import Counter
import numpy as np

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "bert-base-uncased"
tokenizer = BertTokenizerFast.from_pretrained(model_name) # This supports offsets mapping

In [3]:
# Load SQuAD 2.0 dataset from Hugging Face
squad_dataset = load_dataset("squad_v2")

# Because test set is hidden, make a new split
split = squad_dataset["train"].train_test_split(test_size=0.1, shuffle=False)
train_set = split["train"]
dev_set = split["test"]
test_set = squad_dataset["validation"]

# subsets for testing
train_set = train_set.select(range(1))  
dev_set = dev_set.select(range(1))     
test_set = test_set.select(range(1)) 

# Preprocess the custom splits
tokenized_train, tokenized_dev, tokenized_test = preprocess_and_tokenize(train_set, dev_set, test_set, tokenizer)

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

In [4]:
output_dir = "./model_1"               
tokenizer = BertTokenizerFast.from_pretrained(output_dir)
swag_checkpoint_dir = f"{output_dir}/checkpoint-swag-epoch-4" 
swag_model = SwagBertForQuestionAnswering.from_pretrained(swag_checkpoint_dir)
swag_model = swag_model.to(device)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
# Majority voting

def make_predictions(model, tokenized_dev, dev_set, batch_size=32, num_samples=2):
    model.eval()
    all_predictions = [[] for _ in range(len(dev_set))]

    dataloader = DataLoader(tokenized_dev, batch_size=batch_size, collate_fn=collate_fn)

    for i in range(num_samples):
        all_start_logits = []
        all_end_logits = []
        model.sample_parameters(scale=0.5, cov=True)

        for batch in dataloader:
            input_ids = batch['input_ids'].to(model.device)
            attention_mask = batch['attention_mask'].to(model.device)
            token_type_ids = batch['token_type_ids'].to(model.device) if batch['token_type_ids'] is not None else None

            # Run the model on the batch
            with torch.no_grad():
                outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

            # Extract logits
            all_start_logits.append(outputs.start_logits.cpu().numpy())
            all_end_logits.append(outputs.end_logits.cpu().numpy())

        # Concatenate all logits across batches
        start_logits = np.concatenate(all_start_logits, axis=0)
        end_logits = np.concatenate(all_end_logits, axis=0)

        # Post-process predictions
        predictions = (start_logits, end_logits)
        final_predictions = post_processing_function(dev_set, tokenized_dev, predictions)

        # Collect predictions in the correct format without modification
        for i, pred in enumerate(final_predictions.predictions):
            all_predictions[i].append(pred)  # Store the prediction directly (no modification)
        
    # Perform Majority Voting on the predictions
    majority_voted_predictions = []
    for preds in all_predictions:
        # Collect the predictions for this example and do majority voting
        pred_texts = [p['prediction_text'] for p in preds]
        pred_counts = Counter(pred_texts)
       
        # Majority voting
        most_common_answer = pred_counts.most_common(1)[0][0]
        prediction_id = preds[0]['id']

        majority_voted_predictions.append({
            "id": prediction_id, 
            "prediction_text": most_common_answer, 
            "no_answer_probability": 0.0,
        })
    return majority_voted_predictions

# Get majority-voted predictions
final_predictions = make_predictions(swag_model, tokenized_dev, dev_set)

# Evaluate using SQuAD metrics
metric = evaluate.load("squad_v2")
references = [{"id": ex['id'], "answers": ex['answers']} for ex in dev_set]

# Compute evaluation results (F1-score, Exact Match)
results = metric.compute(predictions=final_predictions, references=references)

# Print results
print(results)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

{'exact': 100.0, 'f1': 100.0, 'total': 1, 'HasAns_exact': 100.0, 'HasAns_f1': 100.0, 'HasAns_total': 1, 'best_exact': 100.0, 'best_exact_thresh': 0.0, 'best_f1': 100.0, 'best_f1_thresh': 0.0}


In [6]:
# Mean logits decoding

def make_predictions(model, tokenized_dev, batch_size=32, num_samples=2):
    model.eval()  
    all_start_logits = []
    all_end_logits = []

    dataloader = DataLoader(tokenized_dev, batch_size=batch_size, collate_fn=collate_fn)
    
    for _ in range(num_samples):
        model.sample_parameters(scale=0.5, cov=True)

        start_logits = []
        end_logits = []
    
        for batch in dataloader:
            input_ids = batch['input_ids'].to(model.device)
            attention_mask = batch['attention_mask'].to(model.device)
            token_type_ids = batch['token_type_ids'].to(model.device) if batch['token_type_ids'] is not None else None
    
            # Run the model on the batch
            with torch.no_grad():
                outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
    
            start_logits.append(outputs.start_logits.cpu().numpy())
            end_logits.append(outputs.end_logits.cpu().numpy())

        # Store logits for this stochastic pass
        all_start_logits.append(np.concatenate(start_logits, axis=0))
        all_end_logits.append(np.concatenate(end_logits, axis=0))

    # Compute the average logits across all stochastic passes
    avg_start_logits = np.mean(all_start_logits, axis=0)
    avg_end_logits = np.mean(all_end_logits, axis=0)

    return avg_start_logits, avg_end_logits
    
predictions = make_predictions(swag_model, tokenized_dev)

# Post-process the predictions
final_predictions = post_processing_function(dev_set, tokenized_dev, predictions)

# Evaluate using SQuAD metrics
metric = evaluate.load("squad_v2")
references = [{"id": ex['id'], "answers": ex['answers']} for ex in dev_set]

# Compute evaluation results (e.g., F1, EM)
results = metric.compute(predictions=final_predictions.predictions, references=references)

# Print results
print(results)

  0%|          | 0/1 [00:00<?, ?it/s]

{'exact': 100.0, 'f1': 100.0, 'total': 1, 'HasAns_exact': 100.0, 'HasAns_f1': 100.0, 'HasAns_total': 1, 'best_exact': 100.0, 'best_exact_thresh': 0.0, 'best_f1': 100.0, 'best_f1_thresh': 0.0}


In [7]:
# Max Sample decoding
    
def make_predictions(model, tokenized_dev, dev_set, batch_size=32, num_samples=2):
    model.eval()
    all_predictions = [[] for _ in range(len(dev_set))]  
    
    dataloader = DataLoader(tokenized_dev, batch_size=batch_size, collate_fn=collate_fn)

    for i in range(num_samples): 
        all_start_logits = []
        all_end_logits = []

        model.sample_parameters(scale=0.5, cov=True)

        for batch in dataloader:
            input_ids = batch['input_ids'].to(model.device)
            attention_mask = batch['attention_mask'].to(model.device)
            token_type_ids = batch['token_type_ids'].to(model.device) if batch['token_type_ids'] is not None else None

            # Run the model on the batch
            with torch.no_grad():
                outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

            # Extract logits
            all_start_logits.append(outputs.start_logits.cpu().numpy())
            all_end_logits.append(outputs.end_logits.cpu().numpy())

        # Concatenate all logits across batches
        start_logits = np.concatenate(all_start_logits, axis=0)
        end_logits = np.concatenate(all_end_logits, axis=0)

        # Post-process predictions
        predictions = (start_logits, end_logits)
        final_predictions = post_processing_function_with_softmax(dev_set, tokenized_dev, predictions)

        # Collect predictions in the correct format without modification
        for i, pred in enumerate(final_predictions.predictions):
            all_predictions[i].append(pred)  # Store the prediction directly (no modification)

    # Perform softmax voting
    softmax_voted_predictions = []
    for preds in all_predictions:

        # highest softmax
        highest_softmax_pred = max(preds, key=lambda x: x['softmax'])
        highest_softmax_answer = highest_softmax_pred['prediction_text']
        highest_softmax_value = highest_softmax_pred['softmax']
        prediction_id = preds[0]['id']

        softmax_voted_predictions.append({
            "id": prediction_id, 
            "prediction_text": highest_softmax_answer, 
            "no_answer_probability": 0.0,  
            "softmax": highest_softmax_value  # Include the softmax value here
        })

    return softmax_voted_predictions


# Get majority-voted predictions
final_predictions = make_predictions(swag_model, tokenized_dev, dev_set)

metric_predictions = [{'id': prediction['id'], 'prediction_text': prediction['prediction_text'], 'no_answer_probability': prediction['no_answer_probability']} 
                         for prediction in final_predictions]

# Evaluate using SQuAD metrics
metric = evaluate.load("squad_v2")
references = [{"id": ex['id'], "answers": ex['answers']} for ex in dev_set]

# Compute evaluation results (F1-score, Exact Match)
results = metric.compute(predictions=metric_predictions, references=references)

# Print results
print(results)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

{'exact': 100.0, 'f1': 100.0, 'total': 1, 'HasAns_exact': 100.0, 'HasAns_f1': 100.0, 'HasAns_total': 1, 'best_exact': 100.0, 'best_exact_thresh': 0.0, 'best_f1': 100.0, 'best_f1_thresh': 0.0}


In [8]:
# Minimum Bayes Risk decoding

def compute_f1(pred1, pred2):
    """
    Computes F1-score between two predictions based on token overlap.
    """
    def get_tokens(s):
        return s.lower().split()
    
    pred1_tokens = get_tokens(pred1)
    pred2_tokens = get_tokens(pred2)

    # Compute precision and recall
    common = set(pred1_tokens) & set(pred2_tokens)
    
    if len(common) == 0:
        return 0.0  # No overlap means F1 score is 0

    precision = len(common) / len(pred1_tokens)
    recall = len(common) / len(pred2_tokens)
    f1 = 2 * (precision * recall) / (precision + recall)
    
    return f1


def mbr_with_softmax_weighting(all_predictions):
    """
    Applies Minimal Bayes Risk (MBR) decoding using F1-score as a utility function,
    weighted by softmax probabilities.
    
    This method computes pairwise F1 scores between predictions and selects the
    prediction with the highest expected utility based on softmax probabilities.
    """
    mbr_predictions = []
    
    for preds in all_predictions:
        num_preds = len(preds)

        # Extract softmax probabilities 
        softmax_probs = np.array([p['softmax'] for p in preds])

        # Compute the pairwise F1 utilities (similarity) between predictions
        utility_matrix = np.zeros((num_preds, num_preds))

        for j, pred_j in enumerate(preds):
            for k, pred_k in enumerate(preds):
                # We compare each pair of predictions by their F1 score
                utility_matrix[j, k] = compute_f1(pred_j['prediction_text'], pred_k['prediction_text'])

        # Compute expected utility for each prediction (weighted sum of pairwise F1 scores)
        expected_utilities = np.dot(utility_matrix, softmax_probs)

        # Select the prediction with the highest expected utility
        best_idx = np.argmax(expected_utilities)
        best_prediction = preds[best_idx]

        # Store the MBR-decoded prediction
        mbr_predictions.append({
            "id": best_prediction["id"],
            "prediction_text": best_prediction["prediction_text"],
            "no_answer_probability": 0.0, 
        })
    
    return mbr_predictions


def make_predictions(model, tokenized_dev, dev_set, batch_size=32, num_samples=2):
    model.eval()
    all_predictions = [[] for _ in range(len(dev_set))]  
    
    dataloader = DataLoader(tokenized_dev, batch_size=batch_size, collate_fn=collate_fn)

    for _ in range(num_samples): 
        all_start_logits = []
        all_end_logits = []

        model.sample_parameters(scale=0.5, cov=True)

        for batch in dataloader:
            input_ids = batch['input_ids'].to(model.device)
            attention_mask = batch['attention_mask'].to(model.device)
            token_type_ids = batch['token_type_ids'].to(model.device) if batch['token_type_ids'] is not None else None

            # Run the model on the batch
            with torch.no_grad():
                outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

            # Extract logits
            all_start_logits.append(outputs.start_logits.cpu().numpy())
            all_end_logits.append(outputs.end_logits.cpu().numpy())

        # Concatenate all logits across batches
        start_logits = np.concatenate(all_start_logits, axis=0)
        end_logits = np.concatenate(all_end_logits, axis=0)

        # Post-process predictions
        predictions = (start_logits, end_logits)
        final_predictions = post_processing_function_with_softmax(dev_set, tokenized_dev, predictions)

        # Collect predictions in the correct format without modification
        for i, pred in enumerate(final_predictions.predictions):
            all_predictions[i].append(pred)  # Store the prediction directly (no modification)

    return all_predictions

all_predictions = make_predictions(swag_model, tokenized_dev, dev_set)

final_predictions = mbr_with_softmax_weighting(all_predictions)

# Prepare predictions for evaluation (e.g., in SQuAD format)
metric_predictions = [{'id': prediction['id'], 'prediction_text': prediction['prediction_text'], 'no_answer_probability': prediction['no_answer_probability']} 
                         for prediction in final_predictions]

# Evaluate using SQuAD metrics
metric = evaluate.load("squad_v2")
references = [{"id": ex['id'], "answers": ex['answers']} for ex in dev_set]

# Compute evaluation results (F1-score, Exact Match)
results = metric.compute(predictions=metric_predictions, references=references)

# Print results
print(results)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

{'exact': 100.0, 'f1': 100.0, 'total': 1, 'HasAns_exact': 100.0, 'HasAns_f1': 100.0, 'HasAns_total': 1, 'best_exact': 100.0, 'best_exact_thresh': 0.0, 'best_f1': 100.0, 'best_f1_thresh': 0.0}
