In [1]:
# !pip install -r requirements.txt
# !pip install torch torchvision
# !pip install transformers
from model import RegBertForQA
# !pip install captum

In [3]:
import torch
from transformers import AutoTokenizer
from captum.attr import IntegratedGradients
import numpy as np


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

tokenizer = AutoTokenizer.from_pretrained('fine_tuned_model_registers_Nov17')
# tokenizer = AutoTokenizer.from_pretrained('fine_tuned_model_orig')

model_path = 'fine_tuned_model_registers_Nov17'  # Replace with your actual model path
# model_path = 'fine_tuned_model_orig'  # Replace with your actual model path

model = RegBertForQA.from_pretrained(model_path)
model.to(device)
model.eval()

question, context = "What is a way to increase your wound healing speed?", \
"Wound care encourages and speeds wound healing via cleaning and protection from reinjury or infection. Depending on each patient's needs, it can range from the simplest first aid to entire nursing specialties such as wound, ostomy, and continence nursing and burn center care."


inputs = tokenizer.encode_plus(question, context, return_tensors='pt')
input_ids = inputs['input_ids'].to(device)            
token_type_ids = inputs['token_type_ids'].to(device)  
attention_mask = inputs['attention_mask'].to(device)  

batch_size = input_ids.shape[0]
seq_length = input_ids.shape[1]
num_registers = model.bert.num_registers

inputs_embeds = model.bert.embeddings(
    input_ids=input_ids,
    token_type_ids=token_type_ids
)
# raise Exception
# register_tokens = model.bert.reg_tokens.expand(batch_size, -1, -1).to(device)
# register_pos = model.bert.reg_pos.to(device)
# register_embeddings = register_tokens + register_pos  

# inputs_embeds_with_registers = torch.cat((register_embeddings, inputs_embeds), dim=1)  # Shape: [batch_size, num_registers + seq_length, embedding_dim]
# inputs_embeds_with_registers.requires_grad_()

# register_attention_mask = torch.ones((batch_size, num_registers), device=device)
# attention_mask_with_registers = torch.cat((register_attention_mask, attention_mask), dim=1)  # Shape: [batch_size, num_registers + seq_length]

# register_token_type_ids = torch.zeros((batch_size, num_registers), dtype=token_type_ids.dtype, device=device)
# token_type_ids_with_registers = torch.cat((register_token_type_ids, token_type_ids), dim=1)  # Shape: [batch_size, num_registers + seq_length]

baseline_embeds = torch.zeros_like(inputs_embeds, device=device)

def forward_func(inputs_embeds_with_registers, attention_mask_with_registers=None, token_type_ids_with_registers=None):
    outputs = model(
        inputs_embeds=inputs_embeds_with_registers,
        attention_mask=attention_mask_with_registers,
        token_type_ids=token_type_ids_with_registers,
        input_ids=input_ids,  
    )
    start_logits = outputs.start_logits  
    end_logits = outputs.end_logits     

    start_probs = torch.softmax(start_logits, dim=-1)
    end_probs = torch.softmax(end_logits, dim=-1)

    start_index = torch.argmax(start_probs, dim=-1) 
    end_index = torch.argmax(end_probs, dim=-1)     
    batch_size = start_logits.shape[0]

    start_logit = start_logits[torch.arange(batch_size), start_index]  
    end_logit = end_logits[torch.arange(batch_size), end_index]        

    return start_logit + end_logit 

ig = IntegratedGradients(forward_func)

attributions, delta = ig.attribute(
    inputs_embeds,
    baselines=baseline_embeds,
    additional_forward_args=(attention_mask, token_type_ids),
    return_convergence_delta=True
)

attributions_sum = attributions.sum(dim=-1)[0]  

register_token_ids = ['[REG{}]'.format(i) for i in range(num_registers)]

input_tokens = tokenizer.convert_ids_to_tokens(input_ids.squeeze(0))

all_tokens =  register_token_ids + input_tokens

print("Number of tokens:", len(all_tokens))
print("Number of attributions:", len(attributions_sum))

for token, attribution in zip(all_tokens, attributions_sum):
    print(f'{token}: {attribution.item()}')


from regbertfor QA, num_reg= 50


Some weights of RegBert were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.embeddings.position_ids', 'bert.reg_pos', 'bert.reg_tokens']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


input_shape: torch.Size([1, 71])
Position_ids: tensor([[  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
          14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
          28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
          42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
          56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
          70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
          84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
          98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
         112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
         126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
         140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,
         154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167,
     

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 1 but got size 50 for tensor number 1 in the list.

In [None]:
from captum.attr import visualization as viz

attributions_np = attributions_sum.detach().cpu().numpy()
attributions_norm = attributions_np / np.linalg.norm(attributions_np)

outputs = model(
    inputs_embeds=inputs_embeds_with_registers,
    attention_mask=attention_mask_with_registers,
    token_type_ids=token_type_ids_with_registers,
    input_ids=None,
)

start_logits = outputs.start_logits
end_logits = outputs.end_logits

start_probs = torch.softmax(start_logits, dim=-1)
end_probs = torch.softmax(end_logits, dim=-1)

start_index = torch.argmax(start_probs, dim=-1).item()
end_index = torch.argmax(end_probs, dim=-1).item()

adjusted_start_index = max(0, start_index - num_registers)
adjusted_end_index = max(0, end_index - num_registers)

pred_prob = ((start_probs[0, start_index] + end_probs[0, end_index]) / 2).item()

answer_ids = input_ids[0, adjusted_start_index: adjusted_end_index + 1]
answer_tokens = tokenizer.convert_ids_to_tokens(answer_ids)
predicted_answer = tokenizer.convert_tokens_to_string(answer_tokens)

vis_data_record = viz.VisualizationDataRecord(
    word_attributions=attributions_norm,
    pred_prob=pred_prob,
    pred_class=predicted_answer,
    true_class="",  
    attr_class="",
    attr_score=attributions_norm.sum(),
    raw_input_ids=all_tokens,
    convergence_score=delta.item()
)

# Visualize attributions
viz.visualize_text([vis_data_record])


In [None]:
import matplotlib.pyplot as plt
# Obtain attention weights from the model
outputs = model(
    input_ids=input_ids,
    attention_mask=attention_mask,
    token_type_ids=token_type_ids,
    output_attentions=True
)

attentions = outputs.attentions  # A tuple of attention matrices

# Visualize attention for a specific layer and head
import seaborn as sns

layer = 0  # Choose the layer to analyze
head = 0   # Choose the head

attention_weights = attentions[layer][0][head].detach().cpu().numpy()
sns.heatmap(attention_weights)
plt.title(f'Attention Weights - Layer {layer+1}, Head {head+1}')
plt.xlabel('Tokens')
plt.ylabel('Tokens')
plt.show()


In [None]:
import torch
from transformers import BertTokenizer
from captum.attr import IntegratedGradients
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

tokenizer = BertTokenizer.from_pretrained('fine_tuned_model')

model_path = 'fine_tuned_model'  
model = RegBertForQA.from_pretrained(model_path)
model.to(device)
model.eval()

question, context = (
    "How long is the Great Wall of China?",
    "The Great Wall of China is an ancient wall in China. The wall was built to protect the northern borders of the Chinese Empire from invading nomadic tribes. The Great Wall stretches from Dandong in the east to Lop Lake in the west, covering about 13,000 miles."
)

inputs = tokenizer.encode_plus(question, context, return_tensors='pt')
input_ids = inputs['input_ids'].to(device)
token_type_ids = inputs['token_type_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)

# Prepare inputs_embeds with register tokens
batch_size = input_ids.shape[0]
seq_length = input_ids.shape[1]
num_registers = model.bert.num_registers

inputs_embeds = model.bert.embeddings(
    input_ids=input_ids,
    token_type_ids=token_type_ids
)

register_tokens = model.bert.reg_tokens.expand(batch_size, -1, -1).to(device)
register_pos = model.bert.reg_pos.to(device)
register_embeddings = register_tokens + register_pos

inputs_embeds_with_registers = torch.cat((register_embeddings, inputs_embeds), dim=1)  # Shape: [batch_size, num_registers + seq_length, embedding_dim]
inputs_embeds_with_registers.requires_grad_()

# Extend attention masks
register_attention_mask = torch.ones((batch_size, num_registers), device=device)
attention_mask_with_registers = torch.cat((register_attention_mask, attention_mask), dim=1)

register_token_type_ids = torch.zeros((batch_size, num_registers), dtype=token_type_ids.dtype, device=device)
token_type_ids_with_registers = torch.cat((register_token_type_ids, token_type_ids), dim=1)

# Baseline embeddings (zeros or padding embeddings)
baseline_embeds = torch.zeros_like(inputs_embeds_with_registers, device=device)

# Define forward function
def forward_func(inputs_embeds_with_registers, attention_mask_with_registers=None, token_type_ids_with_registers=None, target=None):
    outputs = model(
        inputs_embeds=inputs_embeds_with_registers,
        attention_mask=attention_mask_with_registers,
        token_type_ids=token_type_ids_with_registers,
        input_ids=None,  # Not passing input_ids since embeddings are provided
    )
    start_logits = outputs.start_logits  # Shape: [batch_size, seq_length]
    end_logits = outputs.end_logits      # Shape: [batch_size, seq_length]
    
    # Return logits for the specific target index
    if target is not None:
        return start_logits[:, target] if target == 0 else end_logits[:, target]
    else:
        raise ValueError("Invalid target value. Must be 0 (start logits) or 1 (end logits).")

# Define Integrated Gradients
ig = IntegratedGradients(forward_func)

# Get predicted start and end indices
start_index = torch.argmax(model(input_ids=input_ids).start_logits, dim=-1)
end_index = torch.argmax(model(input_ids=input_ids).end_logits, dim=-1)

# Compute attributions for start logits
start_attributions, start_delta = ig.attribute(
    inputs_embeds_with_registers,
    baselines=baseline_embeds,
    additional_forward_args=(attention_mask_with_registers, token_type_ids_with_registers, start_index.item()),  # Target = predicted start index
    return_convergence_delta=True
)

# Compute attributions for end logits
end_attributions, end_delta = ig.attribute(
    inputs_embeds_with_registers,
    baselines=baseline_embeds,
    additional_forward_args=(attention_mask_with_registers, token_type_ids_with_registers, end_index.item()),  # Target = predicted end index
    return_convergence_delta=True
)

# Sum attributions across embedding dimensions
start_attributions_sum = start_attributions.sum(dim=-1)[0]
end_attributions_sum = end_attributions.sum(dim=-1)[0]

# Combine tokens (register + input tokens)
register_token_ids = ['[REG{}]'.format(i) for i in range(num_registers)]
input_tokens = tokenizer.convert_ids_to_tokens(input_ids.squeeze(0))
all_tokens = register_token_ids + input_tokens

# Print attributions
print("Token-Level Attributions for Start Logits:")
for token, attribution in zip(all_tokens, start_attributions_sum):
    print(f"{token}: {attribution.item()}")

print("\nToken-Level Attributions for End Logits:")
for token, attribution in zip(all_tokens, end_attributions_sum):
    print(f"{token}: {attribution.item()}")


In [None]:
from captum.attr import visualization as viz

# Convert attributions to numpy and normalize
start_attributions_np = start_attributions_sum.detach().cpu().numpy()
end_attributions_np = end_attributions_sum.detach().cpu().numpy()

start_attributions_norm = start_attributions_np / np.linalg.norm(start_attributions_np)
end_attributions_norm = end_attributions_np / np.linalg.norm(end_attributions_np)

# Get model outputs
outputs = model(
    inputs_embeds=inputs_embeds_with_registers,
    attention_mask=attention_mask_with_registers,
    token_type_ids=token_type_ids_with_registers,
    input_ids=None,
)

start_logits = outputs.start_logits
end_logits = outputs.end_logits

start_probs = torch.softmax(start_logits, dim=-1)
end_probs = torch.softmax(end_logits, dim=-1)

start_index = torch.argmax(start_probs, dim=-1).item()
end_index = torch.argmax(end_probs, dim=-1).item()

# Adjust indices for register tokens
adjusted_start_index = max(0, start_index - num_registers)
adjusted_end_index = max(0, end_index - num_registers)

# Get predicted answer
answer_ids = input_ids[0, adjusted_start_index: adjusted_end_index + 1]
answer_tokens = tokenizer.convert_ids_to_tokens(answer_ids)
predicted_answer = tokenizer.convert_tokens_to_string(answer_tokens)

# Visualize attributions for start and end logits
start_vis_data_record = viz.VisualizationDataRecord(
    word_attributions=start_attributions_norm,
    pred_prob=start_probs[0, start_index].item(),
    pred_class=predicted_answer,
    true_class="",  # Provide true answer if available
    attr_class="Start Logits",
    attr_score=start_attributions_norm.sum(),
    raw_input_ids=all_tokens,
    convergence_score=start_delta.item()
)

end_vis_data_record = viz.VisualizationDataRecord(
    word_attributions=end_attributions_norm,
    pred_prob=end_probs[0, end_index].item(),
    pred_class=predicted_answer,
    true_class="",  # Provide true answer if available
    attr_class="End Logits",
    attr_score=end_attributions_norm.sum(),
    raw_input_ids=all_tokens,
    convergence_score=end_delta.item()
)

# Visualize both start and end attributions
viz.visualize_text([start_vis_data_record, end_vis_data_record])


In [None]:
start_index

In [None]:
end_index