# FinBERT Sequence Classification Explanation using Captum

Based on the following notebook: <https://gist.github.com/davidefiocco/3e1a0ed030792230a33c726c61f6b3a5>

We can use token-level explanation methods like **Integrated Gradients (IG)** for fairness evaluation. An intuitive approach is analyzing how much each token contributes to a model's decision, particularly in protected characteristics. 

1. Identify Sensitive Tokens: Use IG to compute attribution scores for each token in input text. Focus on tokens related to protected characteristics (e.g., gendered words like "he/she," names associated with specific races, etc.).
2. Measure Disparity in Attributions: Compare the IG scores for sensitive tokens across different demographic groups. If certain tokens consistently have higher attributions in one group but not another, it may indicate biased decision-making.
3. Evaluate Decision Flip with Token Removal: Remove or replace high-attribution sensitive tokens and re-run the model. If the decision significantly changes, the model is heavily relying on those tokens, which may signal bias.
4. Fairness Metrics from Attributions:
   - **Disparate Impact**: Compare average attribution scores for different groups.
   - **Counterfactual Fairness**: Assess if similar inputs with different demographic tokens lead to different outcomes.
   - **Bias Amplification**: If sensitive tokens have consistently high attributions, check whether they are reinforcing stereotypes.

In this notebook, we use **captum library** to run IG on FinBERT model.

In [1]:
from transformers import BertTokenizer, BertForSequenceClassification
from captum.attr import visualization as viz
from captum.attr import LayerIntegratedGradients
import torch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
model = BertForSequenceClassification.from_pretrained('yiyanghkust/finbert-tone',num_labels=3)
model.to(device)
model.eval()
model.zero_grad()

tokenizer = BertTokenizer.from_pretrained('yiyanghkust/finbert-tone')

In [3]:
def predict(inputs):
    return model(inputs)[0]

ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence

def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):

    text_ids = tokenizer.encode(text, add_special_tokens=False)
    # construct input token ids
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)

def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
    seq_len = input_ids.size(1)
    token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)# * -1
    return token_type_ids, ref_token_type_ids

def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)

    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids
    
def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

def custom_forward(inputs):
    preds = predict(inputs)
    return torch.softmax(preds, dim = 1)[0][0].unsqueeze(-1)

def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

In [4]:
lig = LayerIntegratedGradients(custom_forward, model.bert.embeddings)

# text = "These tests do not work as expected."
text = "sugar mills have paid 30% of their total cane dues. the government has announced a soft loan of Rs 4,000 crore plus a subsidy of Rs 4.50/quintal to clear the dues of the farmers by November 30. the total cane crushed this season was 1,111.90 tonnes. the total cane arrears for 2017-18 now stand at Rs 9,770 crores."

input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)
token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
attention_mask = construct_attention_mask(input_ids)

indices = input_ids[0].detach().tolist()
all_tokens = tokenizer.convert_ids_to_tokens(indices)

attributions, delta = lig.attribute(inputs=input_ids,
                                    baselines=ref_input_ids,
                                    return_convergence_delta=True)


In [5]:
score = predict(input_ids)

print('Question: ', text)
print('Predicted Answer: ' + str(torch.argmax(score[0]).numpy()) + ', prob ungrammatical: ' + str(torch.softmax(score, dim = 1)[0][0].detach().numpy()))

Question:  sugar mills have paid 30% of their total cane dues. the government has announced a soft loan of Rs 4,000 crore plus a subsidy of Rs 4.50/quintal to clear the dues of the farmers by November 30. the total cane crushed this season was 1,111.90 tonnes. the total cane arrears for 2017-18 now stand at Rs 9,770 crores.
Predicted Answer: 0, prob ungrammatical: 0.99995935


In [6]:
attributions_sum = summarize_attributions(attributions)
# storing couple samples in an array for visualization purposes
record = viz.VisualizationDataRecord(
                        attributions_sum,
                        torch.softmax(score, dim = 1)[0][0],
                        torch.argmax(torch.softmax(score, dim = 1)[0]),
                        1, # Positive Sentiment
                        text,
                        attributions_sum.sum(),       
                        all_tokens,
                        delta)

print('\033[1m', 'Visualization For Score', '\033[0m')
viz.visualize_text([record])

[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,0 (1.00),"sugar mills have paid 30% of their total cane dues. the government has announced a soft loan of Rs 4,000 crore plus a subsidy of Rs 4.50/quintal to clear the dues of the farmers by November 30. the total cane crushed this season was 1,111.90 tonnes. the total cane arrears for 2017-18 now stand at Rs 9,770 crores.",1.2,"[CLS] sugar mills have paid 30 % of their total can ##e due ##s . the government has announced a soft loan of rs 4 , 000 cro ##re plus a subsidy of rs 4 . 50 / quint ##al to clear the due ##s of the farmers by november 30 . the total can ##e crush ##ed this season was 1 , 111 . 90 tonnes . the total can ##e arrears for 2017 - 18 now stand at rs 9 , 770 cro ##res . [SEP]"
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,0 (1.00),"sugar mills have paid 30% of their total cane dues. the government has announced a soft loan of Rs 4,000 crore plus a subsidy of Rs 4.50/quintal to clear the dues of the farmers by November 30. the total cane crushed this season was 1,111.90 tonnes. the total cane arrears for 2017-18 now stand at Rs 9,770 crores.",1.2,"[CLS] sugar mills have paid 30 % of their total can ##e due ##s . the government has announced a soft loan of rs 4 , 000 cro ##re plus a subsidy of rs 4 . 50 / quint ##al to clear the due ##s of the farmers by november 30 . the total can ##e crush ##ed this season was 1 , 111 . 90 tonnes . the total can ##e arrears for 2017 - 18 now stand at rs 9 , 770 cro ##res . [SEP]"
,,,,


## Save Selected Samples to Fairness Log

In [7]:
import sys
sys.path.append('../../')
from faid import logging as faidlog

experiment_name = "captum-test"
faidlog.init_log()
ctx = faidlog.FairnessExperimentRecord(name=experiment_name)

[93mModel log file already exists.  Logging will be appended to the existing file.[0m
[93mData log file already exists. Logging will be appended to the existing file.[0m
[93mRisks log file already exists. Logging will be appended to the existing file.[0m
[93mTransparency log file already exists. Logging will be appended to the existing file.[0m


In [8]:
ctx.add_entry(record)

Added captum_records to project metadata under ['model'] and log updated


In [10]:
from faid.report import generate_experiment_overview_report
generate_experiment_overview_report(ctx.to_dict())

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,0 (1.00),"sugar mills have paid 30% of their total cane dues. the government has announced a soft loan of Rs 4,000 crore plus a subsidy of Rs 4.50/quintal to clear the dues of the farmers by November 30. the total cane crushed this season was 1,111.90 tonnes. the total cane arrears for 2017-18 now stand at Rs 9,770 crores.",1.2,"[CLS] sugar mills have paid 30 % of their total can ##e due ##s . the government has announced a soft loan of rs 4 , 000 cro ##re plus a subsidy of rs 4 . 50 / quint ##al to clear the due ##s of the farmers by november 30 . the total can ##e crush ##ed this season was 1 , 111 . 90 tonnes . the total can ##e arrears for 2017 - 18 now stand at rs 9 , 770 cro ##res . [SEP]"
,,,,
