In [None]:
import torch
from torch.utils.data import DataLoader
from transformers import (
    BertTokenizerFast,
    BertForQuestionAnswering,
)
from datasets import load_dataset
from utils_qa import preprocess_and_tokenize
from swag_transformers.swag_bert import SwagBertForQuestionAnswering
from utils import preprocess_and_tokenize, post_processing_function, collate_fn
import numpy as np
import evaluate
import json
from Levenshtein import distance as levenshtein_distance
from scipy.spatial import distance

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "bert-base-uncased"
tokenizer = BertTokenizerFast.from_pretrained(model_name) # This supports offsets mapping

In [None]:
# Load SQuAD 2.0 dataset from Hugging Face
squad_dataset = load_dataset("squad_v2")

# Because test set is hidden, make a new split
split = squad_dataset["train"].train_test_split(test_size=0.1, shuffle=False)
train_set = split["train"]
dev_set = split["test"]
test_set = squad_dataset["validation"]

# subsets for testing
#train_set = train_set.select(range(1))  
#dev_set = dev_set.select(range(10))     
#test_set = test_set.select(range(1)) 

# Preprocess the custom splits
tokenized_train, tokenized_dev, tokenized_test = preprocess_and_tokenize(train_set, dev_set, test_set, tokenizer)

In [None]:
output_dir = "./model_1"               
tokenizer = BertTokenizerFast.from_pretrained(output_dir)

checkpoint_dir = f"{output_dir}/checkpoint-59300" 
model = BertForQuestionAnswering.from_pretrained(checkpoint_dir)
model = model.to(device)

swag_checkpoint_dir = f"{output_dir}/checkpoint-swag-epoch-4" 
swag_model = SwagBertForQuestionAnswering.from_pretrained(swag_checkpoint_dir)
swag_model = swag_model.to(device)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
def compute_f1(pred1, pred2):
    if pred1.strip() == "" and pred2.strip() == "":
        return 1.0  # perfect match for no-answer
    if pred1.strip() == "" or pred2.strip() == "":
        return 0.0  # mismatch if one is no-answer

    def get_tokens(s):
        return s.lower().split()
    
    pred1_tokens = get_tokens(pred1)
    pred2_tokens = get_tokens(pred2)

    # Compute precision and recall
    common = set(pred1_tokens) & set(pred2_tokens)
    
    if len(common) == 0:
        return 0.0  # No overlap means F1 score is 0

    precision = len(common) / len(pred1_tokens)
    recall = len(common) / len(pred2_tokens)
    f1 = 2 * (precision * recall) / (precision + recall)
    
    return f1
    
def mbr_with_softmax_weighting(all_predictions):
    mbr_predictions = []
    
    for preds in all_predictions:
        num_preds = len(preds)

        # Extract softmax probabilities 
        softmax_probs = np.array([p['softmax'] for p in preds])

        # Compute the pairwise F1 utilities (similarity) between predictions
        utility_matrix = np.zeros((num_preds, num_preds))

        for j, pred_j in enumerate(preds):
            for k, pred_k in enumerate(preds):
                # We compare each pair of predictions by their F1 score
                utility_matrix[j, k] = compute_f1(pred_j['prediction_text'], pred_k['prediction_text'])

        # Compute expected utility for each prediction (weighted sum of pairwise F1 scores)
        expected_utilities = np.dot(utility_matrix, softmax_probs)

        # Select the prediction with the highest expected utility
        best_idx = np.argmax(expected_utilities)
        best_prediction = preds[best_idx]

        # Store the MBR-decoded prediction
        mbr_predictions.append({
            "id": best_prediction["id"],
            "prediction_text": best_prediction["prediction_text"],
            "no_answer_probability": 0.0, 
            "softmax": best_prediction["softmax"]
        })
    
    return mbr_predictions


def jsd(P, Q):
    # Jensen-Shannon divergence using scipy (returns sqrt(JSD), so we square it)
    return distance.jensenshannon(P, Q) ** 2


def calculate_jsd(logits_list):
    logits_tensor = torch.tensor(logits_list)
    
    # Stability trick: subtract max logits for numerical stability
    logits_tensor = logits_tensor - logits_tensor.max(dim=-1, keepdim=True).values

    # Softmax and clamp to avoid log(0) issues in distance metric
    probs = torch.softmax(logits_tensor, dim=-1).clamp(min=1e-12)
    probs = probs.detach().cpu().numpy()

    # Validate probabilities are normalized
    if not np.allclose(probs.sum(axis=1), 1, atol=1e-3):
        raise ValueError("Probabilities do not sum to 1.")

    # Mean probability distribution across runs
    mean_probs = probs.mean(axis=0)

    # JSD between each run's prob and the mean
    jsd_scores = [jsd(probs[i], mean_probs) for i in range(probs.shape[0])]

    return np.mean(jsd_scores)

    
def compute_uncertainty(start_logits_list, end_logits_list):
    jsd_start = calculate_jsd(start_logits_list)
    jsd_end = calculate_jsd(end_logits_list)
    
    # Final uncertainty as the average of start and end JSD
    return (jsd_start + jsd_end) / 2


def compute_levenshtein(pred_texts):
    # Normalize no-answers to "no answer" just for this calculation
    normalized_texts = ["no answer" if t.strip() == "" else t for t in pred_texts]
    
    num_preds = len(normalized_texts)
    if num_preds <= 1:
        return 0.0  # no distance if only one prediction

    total_dist = 0.0
    num_pairs = 0

    for i in range(num_preds):
        for j in range(i + 1, num_preds):
            dist = levenshtein_distance(normalized_texts[i], normalized_texts[j])
            max_len = max(len(normalized_texts[i]), len(normalized_texts[j]), 1)
            norm_dist = dist / max_len
            total_dist += norm_dist
            num_pairs += 1

    avg_norm_dist = total_dist / num_pairs if num_pairs > 0 else 0.0
    return avg_norm_dist


def make_predictions(model, tokenized_dev, dev_set, batch_size=32, num_samples=40):
    model.eval()
    all_predictions = [[] for _ in range(len(dev_set))]  
    
    dataloader = DataLoader(tokenized_dev, batch_size=batch_size, collate_fn=collate_fn)

    all_start_logits_passes = []  # Store logits across all forward passes
    all_end_logits_passes = []

    for _ in range(num_samples): 
        batch_start_logits = []
        batch_end_logits = []

        model.sample_parameters(scale=0.5, cov=True)

        for batch in dataloader:
            input_ids = batch['input_ids'].to(model.device)
            attention_mask = batch['attention_mask'].to(model.device)
            token_type_ids = batch['token_type_ids'].to(model.device) if batch['token_type_ids'] is not None else None

            with torch.no_grad():
                outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

            batch_start_logits.append(outputs.start_logits.cpu().numpy())
            batch_end_logits.append(outputs.end_logits.cpu().numpy())

        # Concatenate all batches
        start_logits = np.concatenate(batch_start_logits, axis=0)
        end_logits = np.concatenate(batch_end_logits, axis=0)

        # Store per-pass logits
        all_start_logits_passes.append(start_logits)
        all_end_logits_passes.append(end_logits)

        # Post-process predictions
        predictions = (start_logits, end_logits)
        final_predictions = post_processing_function(dev_set, tokenized_dev, predictions)

        for i, pred in enumerate(final_predictions.predictions):
            all_predictions[i].append(pred)

    # Compute Levenshtein distances
    levenshtein_values = {}
    for idx, pred in enumerate(all_predictions): 
        prediction_id = pred[0]['id']
        pred_texts = [p['prediction_text'] for p in pred]
        levenshtein_value = compute_levenshtein(pred_texts)
        levenshtein_values[prediction_id] = levenshtein_value

    # Apply MBR decoding
    mbr_decoded_predictions = mbr_with_softmax_weighting(all_predictions)

    # Final predictions
    final_predictions = []
    references = [{"id": ex['id'], "answers": ex['answers']} for ex in dev_set]

    for idx, pred in enumerate(mbr_decoded_predictions):
        prediction_id = pred['id']
        pred_text = pred['prediction_text']
        softmax = pred['softmax']

        ground_truth_answers = [
            ans for ref in references if ref['id'] == prediction_id for ans in ref['answers']['text']
        ]

        # Stack logits across passes for this sample
        start_logits_per_sample = np.array([logits[idx] for logits in all_start_logits_passes])
        end_logits_per_sample = np.array([logits[idx] for logits in all_end_logits_passes])

        # Compute uncertainty
        jsd_value = compute_uncertainty(start_logits_per_sample, end_logits_per_sample)

        # Evaluate correctness
        if not pred_text and not ground_truth_answers:
            correct = 1
        else:
            correct = 1 if pred_text in ground_truth_answers else 0

        no_answer = 1 if not ground_truth_answers else 0
        incorrect_noans = int((not ground_truth_answers) and (not correct))

        final_predictions.append({
            "id": prediction_id,
            "prediction_text": pred_text,
            "no_answer_probability": 0.0,
            "correct": correct,
            "no_ans": no_answer,
            "incorrect_noans": incorrect_noans,
            "jsd": jsd_value,
            "levenshtein": levenshtein_values[prediction_id],
            "softmax": softmax
        })

    return final_predictions


# Get majority-voted predictions
final_predictions = make_predictions(swag_model, tokenized_test, test_set)

metric_predictions = [{'id': prediction['id'], 'prediction_text': prediction['prediction_text'], 'no_answer_probability': prediction['no_answer_probability']} 
                         for prediction in final_predictions]

# Evaluate using SQuAD metrics
metric = evaluate.load("squad_v2")
references = [{"id": ex['id'], "answers": ex['answers']} for ex in test_set]

# Compute evaluation results (F1-score, Exact Match)
results = metric.compute(predictions=metric_predictions, references=references)

# Print results
print(results)

with open("swag_predictions_test.json", "w") as f:
    json.dump(final_predictions, f)

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

  0%|          | 0/11873 [00:00<?, ?it/s]

Downloading builder script:   0%|          | 0.00/6.47k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/11.3k [00:00<?, ?B/s]

{'exact': 73.08178219489598, 'f1': 76.64982612375834, 'total': 11873, 'HasAns_exact': 72.57085020242916, 'HasAns_f1': 79.71717030488921, 'HasAns_total': 5928, 'NoAns_exact': 73.59125315391086, 'NoAns_f1': 73.59125315391086, 'NoAns_total': 5945, 'best_exact': 73.08178219489598, 'best_exact_thresh': 0.0, 'best_f1': 76.64982612375836, 'best_f1_thresh': 0.0}
