In [None]:
from datasets import load_dataset
from datasets import load_metric
from transformers import AutoTokenizer
from datasets import load_metric
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm


bert_map = {    
    'arabic': 'arabic_cls_bert_logreg_classifier.pt',
    'bengali': 'google/muril-base-cased', 
    'english': 'bert-base-uncased', 
    'indonesian': 'cahya/bert-base-indonesian-522M',
}
device = 'cuda'

compute_squad = load_metric("squad_v2")
dataset = load_dataset("copenlu/answerable_tydiqa")

for split in dataset.keys():
    dataset[split] = dataset[split].add_column('id', list(range(len(dataset[split]))))

language = 'english'
n = 15
stride = 10
max_length = 30
model = torch.load(f'{language}_xlm-roberta-base_classification.pt')
# model = AutoModelForQuestionAnswering.from_pretrained(bert_map[language]).to(device)
language_dataset = dataset.filter(lambda example: example['language'] == language)
tk = AutoTokenizer.from_pretrained('xlm-roberta-base', max_len=300, use_fast=True)

questions = language_dataset['validation'][200:500]['question_text']
documents = language_dataset['validation'][200:500]['document_plaintext']
annotations = language_dataset['validation'][200:500]['annotations']
# Iterate over each question-document pair
for i, (question, document, annotation) in enumerate(zip(questions, documents, annotations)):
    # Tokenize the input and send to CUDA )
    inputs = tk(question, document, return_tensors="pt", truncation="only_second", padding=True, max_length=max_length, stride=stride, return_overflowing_tokens=True)
    #inputs = {k: v.to(device) for k, v in inputs.items()}
    
    input_ids = inputs.input_ids.to(device)
    attention_mask = inputs.attention_mask.to(device)
    overflow_to_sample_mapping = inputs.overflow_to_sample_mapping.to(device)
    
    # Forward pass, get hidden states
    with torch.no_grad():
        outputs_span = model(input_ids, attention_mask=attention_mask, return_dict=True, output_attentions=True)
        # outputs = model(**inputs, output_attentions=True)
    
    # Find the best answer among the chunks
    start_logits = outputs.start_logits.cpu().numpy()
    end_logits = outputs.end_logits.cpu().numpy()
    answer_start = torch.argmax(torch.from_numpy(start_logits), dim=1).cpu()[0]
    answer_end = torch.argmax(torch.from_numpy(end_logits), dim=1).cpu()[0]
    scores = start_logits + end_logits
    best_chunk = torch.argmax(torch.from_numpy(scores).max(dim=1).values).item()

    # Convert token ids to tokens
    tokens = tk.convert_ids_to_tokens(input_ids[best_chunk])
    # Adjust answer_start and answer_end for long documents
    if best_chunk > 0:
        chunk_offset = stride * best_chunk
        answer_start += chunk_offset
        answer_end += chunk_offset
    
    answer_start = answer_start.item()    
    answer_end = answer_end.item()
    
    
    answer_tokens = tokens[answer_start: answer_end + 1]
    predicted_answer = tk.convert_tokens_to_string(answer_tokens)
    
    print("Predicted answer:", predicted_answer)
    print("True answer:", annotation['answer_text'][0])
    if predicted_answer.lower().strip() != annotation['answer_text'][0].lower().strip():
        continue
        
    attention_weights = outputs.attentions[-1]  # Get the last layer's attention weights

    # Select the best chunk
    chunk_attention = attention_weights[best_chunk]

    # Convert token ids to tokens for the entire sequence in the best chunk
    tokens = tk.convert_ids_to_tokens(input_ids[best_chunk].cpu().numpy())

    # Identify the position of the [SEP] token
    sep_pos = tokens.index('</s>')

    # Now, select the attention weights for the question tokens and the answer span
    # The shape of combined_attention will be (num_heads, combined_length, sequence_length)
    combined_attention = chunk_attention[:, :sep_pos+answer_end+2, :]

    # Visualize the attention weights
    for head in range(combined_attention.shape[0]):
        plt.figure(figsize=(10, 8))
        sns.heatmap(combined_attention[head].cpu().numpy(), xticklabels=tokens, yticklabels=tokens[:sep_pos+answer_end+2], cmap='Blues', vmin=0, vmax=1)
        plt.title(f'Head {head+1}, Question and Answer Span Attention')
        plt.xticks(rotation=90)
        plt.yticks(rotation=0)
        plt.show()
        
    break