In [1]:
# implement attention, attention rollout, attention flow, gradient, input x gradient, integrated gradients, deeplift, kernel shap explanations in a differientiable way

import torch

torch.backends.cuda.enable_mem_efficient_sdp(False)  # turn off the mem-efficient kernel
torch.backends.cuda.enable_flash_sdp(False)  

def get_embeddings(model, input_ids, attention_mask=None, token_type_ids=None, position_ids=None):
    if hasattr(model, "distilbert"):
        embeddings = model.distilbert.embeddings(input_ids=input_ids)
    elif hasattr(model, "roberta"):
        embeddings = model.roberta.embeddings(input_ids=input_ids, position_ids=None, token_type_ids=token_type_ids)
    elif hasattr(model, "bert"):
        embeddings = model.bert.embeddings(input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
    else:
        raise ValueError("Model not supported")
    embeddings.requires_grad_(True)
    return embeddings

def model_forward(model, embeddings, attention_mask=None):

    head_mask = model.get_head_mask(None, model.config.num_hidden_layers)
    #head_mask = [None] * self.model.config.num_hidden_layers

    if hasattr(model, "distilbert"):
        encoder_outputs = model.distilbert.transformer(
            embeddings,
            attn_mask=attention_mask,
            head_mask=head_mask,
        )
        hidden_state = encoder_outputs[0]
        pooled_output = hidden_state[:, 0]
        pooled_output = model.pre_classifier(pooled_output)
        pooled_output = model.dropout(pooled_output) 
        logits = model.classifier(pooled_output)

    elif hasattr(model, "roberta"):
        extended_attention_mask = model.get_extended_attention_mask(
            attention_mask, embeddings.shape[:2],
        )

        encoder_outputs = model.roberta.encoder(
            embeddings,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,
        )
        sequence_output = encoder_outputs[0]
        #sequence_output = self.model.roberta.pooler(sequence_output) if self.model.roberta.pooler is not None else None
        logits = model.classifier(sequence_output)

    elif hasattr(model, "bert"):
        extended_attention_mask = model.get_extended_attention_mask(
            attention_mask, embeddings.shape[:2], embeddings.device
        )

        encoder_outputs = model.bert.encoder(
            embeddings,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,
        )
        sequence_output = encoder_outputs[0]
        pooled_output = model.bert.pooler(sequence_output) if model.bert.pooler is not None else None
        pooled_output = model.dropout(pooled_output)
        logits = model.classifier(pooled_output)
    else:
        raise ValueError("Model not supported")

    return logits
    




In [2]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from explainer.Explainer_Encoder import *

model_name = "/scratch/yifwang/new_fairness_x_explainability/new_debiased_models_civil/bert_civil_race/no_debiasing"
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForSequenceClassification.from_pretrained(model_name, output_attentions=True).to(device)
model.eval()

text1 = "racist comment? of course not! blacks are incapable of being racists! maxine, time for you to swim back to africa! now that ' s not racist... it is simply honest."
text2 = "test test test"

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
explainer = GradientNPropabationExplainer(model, tokenizer, method="Saliency", baseline="pad")
expl = explainer.explain(
    texts = [text1],
    example_indices=[0],
    labels=[1],
    only_predicted_classes=True,
)
expl



{'Saliency_L2': [[{'index': 0,
    'text': "racist comment? of course not! blacks are incapable of being racists! maxine, time for you to swim back to africa! now that ' s not racist... it is simply honest.",
    'true_label': 1,
    'predicted_class': 1,
    'predicted_class_confidence': 0.64649498462677,
    'target_class': 1,
    'target_class_confidence': 0.64649498462677,
    'method': 'Saliency_L2',
    'attribution': [('[CLS]', 0.052127160131931305),
     ('racist', 0.036963626742362976),
     ('comment', 0.05362391099333763),
     ('?', 0.021845147013664246),
     ('of', 0.008625098504126072),
     ('course', 0.014305680058896542),
     ('not', 0.01100395992398262),
     ('!', 0.020923681557178497),
     ('blacks', 0.05848630517721176),
     ('are', 0.023970765992999077),
     ('incapable', 0.10321283340454102),
     ('of', 0.017269112169742584),
     ('being', 0.01954365149140358),
     ('racist', 0.04027537629008293),
     ('##s', 0.01688440516591072),
     ('!', 0.0227751862

In [6]:
attrs = [round(a[1], 4) for a in expl["IntegratedGradients_L2"][0][0]["attribution"]]
print(attrs)
sum(attrs)

[0.02, 0.0196, 0.0238, 0.017, 0.015, 0.0208, 0.0159, 0.0209, 0.0436, 0.0214, 0.0287, 0.0171, 0.0302, 0.0676, 0.0334, 0.0272, 0.0219, 0.0144, 0.0107, 0.0219, 0.0117, 0.0132, 0.0101, 0.0194, 0.0122, 0.0092, 0.0169, 0.0118, 0.0123, 0.0096, 0.0082, 0.0116, 0.0126, 0.0181, 0.0119, 0.0129, 0.0081, 0.0099, 0.0143, 0.0227, 0.032, 0.0139, 0.0957]


0.8894000000000001

In [21]:
def raw_attention_attr(model, input_ids, attention_mask, token_type_ids, position_ids, sensitive_token_mask, target_classes=None, aggregation="L1"):
    outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, output_attentions=True)
    attentions = outputs.attentions  # L x batch x heads x seq x seq
    all_attentions = torch.stack(attentions)  # L x batch x heads x seq x seq
    attention_mask_expanded = attention_mask.unsqueeze(1).unsqueeze(2)
    attention_mask_matrix = attention_mask_expanded * attention_mask_expanded.transpose(-1, -2)
    all_attentions = all_attentions * attention_mask_matrix.unsqueeze(0)
    # normalize attention weights
    attn_weights_sum = all_attentions.sum(dim=-1, keepdim=True) + 1e-9  # Add epsilon to avoid division by zero
    all_attentions = all_attentions / attn_weights_sum
    # mean over heads, mean over layers
    avg_attn_heads = all_attentions.mean(dim=2)
    avg_attn = avg_attn_heads.mean(dim=0)
    attr = avg_attn[:, 0, :]  # token contributions to [CLS] (position 0)

    # take only positions with sensitive token mask
    if sensitive_token_mask is None:
        sensitive_token_mask = attention_mask.clone().to(attention_mask.device)
    attr = attr * sensitive_token_mask

    # attr loss: mean over all sensitive token positions
    attr_loss = attr.sum()
    return attr_loss

In [24]:
inputs = tokenizer([text1, text2], return_tensors="pt", padding=True).to(device)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
position_ids = torch.arange(input_ids.size(1), dtype=torch.long, device=device).unsqueeze(0).repeat(input_ids.size(0), 1)
token_type_ids = inputs["token_type_ids"]
baseline_token_ids = tokenizer.pad_token_id  # Assuming 'pad' refers to the padding token ID
#attr, attr_loss = deeplift_attr(model, input_ids, attention_mask, token_type_ids, position_ids, None, baseline_token_ids, None, "L2")
attr_loss = raw_attention_attr(model, input_ids, attention_mask, token_type_ids, position_ids, None, None, "L2")
print(attr_loss)
# print(attr)
# # compute gradient of attr loss wrt model parameters
# model.zero_grad()
# attr_loss.backward()

# for name, param in model.named_parameters():
#     if param.grad is not None:
#         print(f"Gradient for {name}:")
#         print(param.grad)
#     else:
#         print(f"No gradient computed for {name}")

tensor(2.0000, device='cuda:0', grad_fn=<SumBackward0>)


In [23]:
from captum.attr import DeepLift, IntegratedGradients, Saliency, InputXGradient, KernelShap

class BertEmbeddingModelWrapper(torch.nn.Module):
    def __init__(self, model):
        super(BertEmbeddingModelWrapper, self).__init__()
        self.model = model

    def forward(self, embeddings, attention_mask=None):

        head_mask = self.model.get_head_mask(None, self.model.config.num_hidden_layers)
        #head_mask = [None] * self.model.config.num_hidden_layers

        if hasattr(self.model, "distilbert"):
            encoder_outputs = self.model.distilbert.transformer(
                embeddings,
                attn_mask=attention_mask,
                head_mask=head_mask,
            )
            hidden_state = encoder_outputs[0]
            pooled_output = hidden_state[:, 0]
            pooled_output = self.model.pre_classifier(pooled_output)
            if not hasattr(self.model, "bcos") or not self.model.bcos:
                pooled_output = torch.nn.ReLU()(pooled_output)
            pooled_output = self.model.dropout(pooled_output) 
            logits = self.model.classifier(pooled_output)

        elif hasattr(self.model, "roberta"):
            extended_attention_mask = self.model.get_extended_attention_mask(
                attention_mask, embeddings.shape[:2],
            )

            encoder_outputs = self.model.roberta.encoder(
                embeddings,
                attention_mask=extended_attention_mask,
                head_mask=head_mask,
            )
            sequence_output = encoder_outputs[0]
            #sequence_output = self.model.roberta.pooler(sequence_output) if self.model.roberta.pooler is not None else None
            logits = self.model.classifier(sequence_output)

        elif hasattr(self.model, "bert"):
            extended_attention_mask = self.model.get_extended_attention_mask(
                attention_mask, embeddings.shape[:2], embeddings.device
            )

            encoder_outputs = self.model.bert.encoder(
                embeddings,
                attention_mask=extended_attention_mask,
                head_mask=head_mask,
            )
            sequence_output = encoder_outputs[0]
            pooled_output = self.model.bert.pooler(sequence_output) if self.model.bert.pooler is not None else None
            pooled_output = self.model.dropout(pooled_output)
            logits = self.model.classifier(pooled_output)
        else:
            raise ValueError("Model not supported")

        return logits
    
model_wrapper = BertEmbeddingModelWrapper(model)
model_wrapper.to(device)
model_wrapper.eval()
embeddings = get_embeddings(model, input_ids, attention_mask, token_type_ids, position_ids)
embeddings.requires_grad_(True)
explainer = InputXGradient(model_wrapper)

attribution = explainer.attribute(
                    inputs=(embeddings),
                    target=[1],
                    additional_forward_args=(attention_mask,),
                )
attr = torch.norm(attribution, dim=-1)
print(attr)
attr_loss = attr.sum()
print(attr_loss)
model_wrapper.zero_grad()
model.zero_grad()
attr_loss.backward()

for name, param in model_wrapper.named_parameters():
    if param.grad is not None:
        print(f"Gradient for {name}:")
        print(param.grad)
    else:
        print(f"No gradient computed for {name}")

tensor([[0.0196, 0.0218, 0.0300, 0.0121, 0.0042, 0.0097, 0.0063, 0.0116, 0.0378,
         0.0121, 0.0608, 0.0115, 0.0114, 0.0235, 0.0090, 0.0124, 0.0235, 0.0092,
         0.0111, 0.0079, 0.0044, 0.0058, 0.0049, 0.0225, 0.0057, 0.0034, 0.0107,
         0.0086, 0.0055, 0.0052, 0.0023, 0.0040, 0.0049, 0.0155, 0.0062, 0.0038,
         0.0034, 0.0053, 0.0063, 0.0139, 0.0329, 0.0076, 0.0129]],
       device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)
tensor(0.5412, device='cuda:0', grad_fn=<SumBackward0>)
Gradient for model.bert.embeddings.word_embeddings.weight:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')
Gradient for model.bert.embeddings.position_embeddings.weight:
tensor([[ 1.4060e-04, -5.3115e-04, -2.9448e-05,  ...,  1.0426e-05,
          4.5507e-05,  3.5155e-0

