# TextNum Transformer Explainability 
-----------------------------------

Explanability analysis of textnum transformer output using Layer-wise Relevance Propagation (LRP) : https://arxiv.org/abs/2012.09838

https://colab.research.google.com/github/hila-chefer/Transformer-Explainability/blob/main/BERT_explainability.ipynb#scrollTo=4-XGl_Zw6Aht


In [None]:
# !git clone https://github.com/hila-chefer/Transformer-Explainability.git

# # import os
# # os.chdir(f'./Transformer-Explainability')
# import sys
# sys.path.append('./Transformer-Explainability')

# # !pip install -r requirements.txt
# # !pip install captum

# Load the model 

In [None]:
%load_ext autoreload
from researchpkg.industry_classification.config import *
from researchpkg.industry_classification.utils.experiment_utils import (
    ExperimentUtils,
)

EXPERIMENT_NAME = "lrp_transformer_encoder_finlangl0_row933_h8_l4_e64_f256_d0.0_proj768___sic1_max_tag_depth_5"
EXPERIMENT_DIR =  os.path.join(LOGS_DIR,"experiments_count30_sic1agg_including_is_2023",EXPERIMENT_NAME)                               
assert os.path.exists(EXPERIMENT_DIR), "Provided experiment dir does not exists"

experiment_config = ExperimentUtils.load_experiment_data(EXPERIMENT_DIR)
model_config = experiment_config["model_config"]
model_config

In [None]:
%autoreload
from researchpkg.industry_classification.models.transformers.textnum_transformer import TextNumTransformerForClassification
model = TextNumTransformerForClassification(
        n_accounts= model_config["n_accounts"],
        pretrained_model=model_config["pretrained_model"],
        n_head=model_config["n_head"],
        n_layers=model_config["n_layers"],
        n_classes=model_config["n_classes"],
        emb_dim=model_config["emb_dim"],
        ffw_dim=model_config["ffw_dim"],
        learning_rate=model_config["learning_rate"],
        class_names=model_config["class_names"],
        dropout_rate=model_config["dropout_rate"],
        trf_trainable_layers=model_config["trf_trainable_layers"],
        use_lrp_modules = model_config["use_lrp_modules"],
    ).cuda()


#Load the best model
print("Loading the best model")
ckpt_file = ExperimentUtils.get_best_model( os.path.basename(EXPERIMENT_DIR),os.path.dirname(EXPERIMENT_DIR),)["path"]
ckpt_file = os.path.join(ROOT_DIR,ckpt_file)
assert os.path.exists(ckpt_file), f"{ckpt_file} do not exists"
model  = model.load_from_checkpoint(ckpt_file)

In [None]:
model

In [None]:
 ## Load dataset

In [None]:
%load_ext autoreload
%autoreload
from researchpkg.industry_classification.dataset.sec_textnum_transformer_datamodule import SecTrfClassificationDataModule


dataset_config = ExperimentUtils.load_experiment_data(EXPERIMENT_DIR)[
    "dataset_config"
]
dataset_dir = os.path.join(SEC_ROOT_DATA_DIR,"count30_sic1agg_including_is_2023")d
# dataset_dir ="/home/test/servlilleg5k//researchpkg/industry_classification/data/sec_data_v2/count30_sic1agg_including_is_2023"

datamodule = SecTrfClassificationDataModule(
    dataset_dir,
    batch_size=32,
    num_workers=2,
    sic_digits=1,
    tokenizer=model.tokenizer,
    use_change=True,
    load_in_memory=False,
    max_desc_len=32,
    max_tags=dataset_config["max_nb_tag_per_sheet"],
    balance_sampling=False,
    max_tag_depth=5
)

# LRP Explainer

## Prepare the textnum transformer to use LRP


In [None]:
## Monkey patch the generate LRP method to support textnum transformer interface.
%autoreload
from researchpkg.industry_classification.models.transformers.bert_explainability_modules.BERT_explainability.modules.BERT.ExplanationGenerator import Generator, compute_rollout_attention
import numpy as np, torch 
import types
explanations = Generator(model)

def generate_LRP_for_textnum(self, textnum_enc, input_attn_mask,
                    index=None, start_layer=1):
    output = self.model.forward_only_textnum(textnum_enc, input_attn_mask)
    kwargs = {"alpha": 1}

    if index == None:
        index = np.argmax(output.cpu().data.numpy(), axis=-1)

    one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
    one_hot[0, index] = 1
    one_hot_vector = one_hot
    one_hot = torch.from_numpy(one_hot).requires_grad_(True)
    one_hot = torch.sum(one_hot.cuda() * output)

    self.model.zero_grad()
    one_hot.backward(retain_graph=True)

    self.model.relprop(torch.tensor(one_hot_vector).to(textnum_enc.device), **kwargs)

    cams = []
    blocks = self.model.text_num_transformer.layer
    for blk in blocks:
        grad = blk.attention.self.get_attn_gradients()
        cam = blk.attention.self.get_attn_cam()
        cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
        grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
        cam = grad * cam
        cam = cam.clamp(min=0).mean(dim=0)
        cams.append(cam.unsqueeze(0))
    rollout = compute_rollout_attention(cams, start_layer=start_layer)
    rollout[:, 0, 0] = rollout[:, 0].min()
    return rollout[:, 0]

explanations.generate_LRP = types.MethodType(generate_LRP_for_textnum, explanations)

## Explanation generation function

In [None]:
from captum.attr import visualization

In [None]:
filename_to_idx = {os.path.basename(datamodule.test_dataset.data_files[i]):i for i in range(len(datamodule.test_dataset.data_files))}

In [None]:
filename_to_idx["0000010254-23-000058_2023_q1.csv"]

In [None]:
datamodule.test_dataset.data_files[0]

In [None]:
from typing import Iterable
from IPython.display import display, HTML

import re
def split_by_capital_letter(tag) -> str:
        """
        Separate a tag by capital letters(Adding a space)
        :param tag: The tag to separate
        :return: The separated tag
        """
        return re.sub(r"([A-Z])", r" \1", tag).strip()

def visualize_text_multirow(
    account_datarecord: visualization.VisualizationDataRecord,
    amount_datarecord: visualization.VisualizationDataRecord,
    legend: bool = True
) -> "HTML":  # In quotes because this type doesn't exist in standalone mode

    def _get_row_color(attr):
        # clip values to prevent CSS errors (Values should be from [-1,1])
        attr = max(-1, min(1, attr))
        if attr >= 0:
            hue = 120
            sat = 75
            lig = 100 - int(60 * attr)
        else:
            raise ValueError("Only positive values are supported")
        return "hsl({}, {}%, {}%)".format(hue, sat, lig)


     # Add ground truth label, pred label attributions
     
    dom = ["<table style=' border: 1px solid black; border-collapse: ;'>"]
    rows = [
        "<tr>",
        "<th style='border: 1px solid black; padding: 8px;'>Ground Truth</th>",
        "<th style='border: 1px solid black; padding: 8px;'>Predicted</th>",
        "<th style='border: 1px solid black; padding: 8px;'>Attribution Score</th>",
        "</tr>"
         f"<tr style='background-color: white;'>",
            f"<td style='border: 1px solid black; padding: 2px;'>{visualization.format_classname(account_datarecord.true_class)}</td>",
            f"<td style='border: 1px solid black; padding: 2px;'>{account_datarecord.pred_class} ({account_datarecord.pred_prob:.2f})</td>",
            f"<td style='border: 1px solid black; padding: 2px;'>{account_datarecord.attr_score:.2f}</td>",
            "</tr>"
    ]

    
    # Closing the HTML table
    dom.append("".join(rows))
    dom.append("</table>")
            
    
    # Begin the HTML for the table (on a 2 column )
    dom.append("<table style='table-layout: fixed; width: 50%; border: 3px solid black; border-collapse: collapse;'>")
        

    # Add table headers
    rows = [
        "<tr>",
        "<th style='border: 1px solid black; padding: 8px; font-size:19px; font-weight:bold ' >Tag</th>",
        "<th  style='border: 1px; solid black; padding: 8px; font-size:19px; font-weight:bold'>Amount</th>",
        "</tr>"
    ]

    # Iterate over the provided data records and populate the table
    for i in range(len(account_datarecord.word_attributions)):
        account_word = account_datarecord.raw_input_ids[i]
        account_word = split_by_capital_letter(account_word)
        
        account_amount = amount_datarecord.raw_input_ids[i]
        account_amount = "${:,.0f}".format(account_amount)
        score = amount_datarecord.word_attributions[i]
        background_color = _get_row_color(score)
        
        rows.append(
            "".join(
                [
                    f"<tr style='background-color: {background_color};'>",
                    f"<td style='word-wrap: break-word; border: 1px solid black; padding:8px;font-size:18px; height:30px'>{account_word}</td>",
                    f"<td style='border: 1px solid black;  padding:4px; font-size:18px; '>{account_amount}</td>",
                    "</tr>"
                ]
            )
        )

    # Add legend if necessary
    if legend:
        dom.append(
            '<div style="border-top: 1px solid; margin-top: 5px; \
            padding-top: 5px; display: inline-block">'
        )
        dom.append("<b>Legend: </b>")

        for value, label in zip([ 0, 1], [ "Neutral", "Positive"]):
            color = _get_row_color(value)
            dom.append(
                f'<span style="display: inline-block; width: 10px; height: 10px; \
                border: 1px solid; background-color: {color};"></span> {label}  '
            )
        dom.append("</div>")

    # Close the HTML table
    dom.append("".join(rows))
    dom.append("</table>")

    # Convert list to a single HTML string
    html = HTML("".join(dom))

    # Display the HTML in IPython
    display(html)
    return html




In [None]:
class_names = model_config["class_names"]


def explain_sample(sample):
    # Splitting and sorting tags
    tags = sample["tags"].split(";")
    sorted_indices = sorted(range(len(tags)), key=lambda i: tags[i])
    # sorted_indices = list(range(len(tags)))
    # Sorting and aligning other inputs based on the sorted indices
    tags = [tags[i] for i in sorted_indices]
    input_desc = sample["input_desc"][sorted_indices].to(model.device).unsqueeze(0)
    input_net_change = sample["input_net_change"][sorted_indices].to(model.device).unsqueeze(0)
    input_attn_mask = sample["input_attn_mask"][sorted_indices].to(model.device).unsqueeze(0)
    
    input_desc = sample["input_desc"].to(model.device).unsqueeze(0)
    input_net_change = sample["input_net_change"].to(model.device).unsqueeze(0)
    input_attn_mask = sample["input_attn_mask"].to(model.device).unsqueeze(0)

    
    y_true = sample["target"].item()
    textnum_enc = model.forward(
        input_desc=input_desc,
        input_net_change=input_net_change,
        input_attn_mask=input_attn_mask,
        return_text_num_enc=True
    )

    # Generating explanations
    expl = explanations.generate_LRP(
        textnum_enc=textnum_enc, input_attn_mask=input_attn_mask, start_layer=3
    ).squeeze()
    expl = expl[:len(tags)]

    # Normalize scores
    expl = (expl - expl.min()) / (expl.max() - expl.min())
    
    output = torch.nn.functional.softmax(
        model.forward(
            input_desc=input_desc,
            input_net_change=input_net_change,
            input_attn_mask=input_attn_mask
        ), dim=-1 
    )
    y_pred = output.argmax(dim=-1).item()

    # Formatting functions
    net_changes = sample["input_net_change"].cpu().squeeze()[sorted_indices]
    net_changes = datamodule.test_dataset.revert_log_scaling_transform(net_changes)

    textnum_pseudo_tokens = [
        tags[i] + " : " + str(net_changes[i].item()) for i in range(len(tags))
    ]

    return expl, output.squeeze(), y_true, y_pred, textnum_pseudo_tokens, tags, net_changes


def visualize_explained_sample(i):
    expl, output, y_true, y_pred, textnum_pseudo_tokens,tags, net_changes = explain_sample(datamodule.test_dataset[i])

    account_record =  visualization.VisualizationDataRecord(
                                    expl,
                                    output[y_pred],
                                    class_names[y_pred],
                                    class_names[y_true],
                                    y_pred,
                                    1,
                                    tags,
                                    1)
    amount_record = visualization.VisualizationDataRecord(
                                    expl,
                                    output[y_pred],
                                    class_names[y_pred],
                                    class_names[y_true],
                                    y_pred,
                                    1,#Sum to 1  textnum_pseudo_tokens
                                    [net_changes[i].item() for i in range(len(tags))],
                                    1)
    print("y_pred:",class_names[y_pred], "\ny_true:", class_names[y_true])

    visualize_text_multirow(account_record, amount_record, legend=True)
    






In [None]:
filename = "0001410578-23-000236_2023_q1.csv"
sample = filename_to_idx[filename]
visualize_explained_sample(sample)