# Install requirements / Clone repository

In [None]:
from IPython.display import clear_output

! git clone "https://github.com/mohsenfayyaz/DecompX"
! pip install datasets
! pip install transformers==4.18.0
# ! pip install nvidia-ml-py3


clear_output()
print("Done!")

In [None]:
import torch
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import matplotlib
from IPython.display import display, HTML
from transformers import AutoTokenizer
from DecompX.src.decompx_utils import DecompXConfig
from DecompX.src.modeling_bert import BertForSequenceClassification
from DecompX.src.modeling_roberta import RobertaForSequenceClassification
import seaborn as sns
import matplotlib.colors as mcolors

# Utils

In [None]:
CONFIGS = {
    "DecompX":
        DecompXConfig(
            include_biases=True,
            bias_decomp_type="absdot",
            include_LN1=True,
            include_FFN=True,
            FFN_approx_type="GeLU_ZO",
            include_LN2=True,
            aggregation="vector",
            include_classifier_w_pooler=True,
            tanh_approx_type="ZO",
            output_all_layers=True,
            output_attention=None,
            output_res1=None,
            output_LN1=None,
            output_FFN=None,
            output_res2=None,
            output_encoder=None,
            output_aggregated="norm",
            output_pooler="norm",
            output_classifier=True,
        ),
}

In [None]:
import time
import psutil
# import pynvml  # GPU monitoring library

def reporter(prompt):
    def decorator(func):
        def wrapper(*args, **kwargs):
            # Start time and memory (RAM) usage monitoring
            start_time = time.time()
            process = psutil.Process()
            start_memory = process.memory_info().rss / (1024 ** 2)  # Convert to MB

            # GPU usage monitoring
            # pynvml.nvmlInit()
            # gpu_handle = pynvml.nvmlDeviceGetHandleByIndex(0)  # Assuming one GPU is available
            # start_gpu_memory = pynvml.nvmlDeviceGetMemoryInfo(gpu_handle).used / (1024 ** 2)  # Convert to MB

            result = func(*args, **kwargs)

            # End time and memory (RAM) usage monitoring
            end_time = time.time()
            end_memory = process.memory_info().rss / (1024 ** 2)  # Convert to MB

            # GPU usage monitoring
            # end_gpu_memory = pynvml.nvmlDeviceGetMemoryInfo(gpu_handle).used / (1024 ** 2)  # Convert to MB

            execution_time_ms = (end_time - start_time) * 1000
            max_memory_mb = end_memory - start_memory
            # gpu_memory_usage_mb = end_gpu_memory - start_gpu_memory

            print(f"{prompt} report:")
            print(f"- execution time: {execution_time_ms:.2f} (ms)")
            print(f"- CPU memory used: {max_memory_mb:.2f} (MB)")
            # print(f"- GPU memory used: {gpu_memory_usage_mb:.2f} (MB)")

            # Clean up GPU monitoring
            # pynvml.nvmlShutdown()

            return result
        return wrapper
    return decorator


In [None]:
@reporter(prompt="Model initializer")
def get_model_and_tokenizer(model="WillHeld/roberta-base-sst2"):
  tokenizer = AutoTokenizer.from_pretrained(model)
  # tokenized_sentence = tokenizer(SENTENCES, return_tensors="pt", padding=True)
  # batch_lengths = tokenized_sentence['attention_mask'].sum(dim=-1)
  if "roberta" in model:
      _model = RobertaForSequenceClassification.from_pretrained(model)
  elif "bert" in model:
      _model = BertForSequenceClassification.from_pretrained(model)
  else:
      raise Exception(f"Not implemented model: {model}")
  # _model = _model.to("cuda" if torch.cuda.is_available() else "cpu")
  return _model, tokenizer


In [None]:
@torch.no_grad()
def evaluate(model, tokenizer, premise, hypothesis, configs):
  tokenized_sentence = tokenizer(premise, hypothesis, return_tensors="pt", padding=True)
  batch_lengths = tokenized_sentence['attention_mask'].sum(dim=-1)
  model.eval()
  logits, hidden_states, decompx_last_layer_outputs, decompx_all_layers_outputs = model(
      **tokenized_sentence,
      output_attentions=False,
      return_dict=False,
      output_hidden_states=True,
      decompx_config=configs["DecompX"]
  )
  return (logits, hidden_states,
          decompx_last_layer_outputs,
          decompx_all_layers_outputs,
          batch_lengths, tokenized_sentence)

In [None]:
@reporter(prompt="DecompX calculator")
def compute_decompX(model, tokenizer, premise, hypothesis, df=True):
  model = model.to("cpu")
  (logits, hidden_states,
   decompx_last_layer_outputs,
   decompx_all_layers_outputs,
   batch_lengths,
   tokenized_sentence) = evaluate(model, tokenizer, premise, hypothesis, CONFIGS)
  decompx_outputs = {
    "tokens":
     [tokenizer.convert_ids_to_tokens(tokenized_sentence["input_ids"][i][:batch_lengths[i]]) for i in range(len(premise))],
    "logits":
    logits.cpu().detach().numpy().tolist(),  # (batch, classes)
    "cls":
    hidden_states[-1][:, 0, :].cpu().detach().numpy().tolist() # Last layer & only CLS -> (batch, emb_dim)
  }

  # logits ~ (8, 2)
  # hidden_states ~ (13, 8, 55, 768)
  # decompx_last_layer_outputs.aggregated ~ (1, 8, 55, 55)
  # decompx_last_layer_outputs.pooler ~ (1, 8, 55)
  # decompx_last_layer_outputs.classifier ~ (8, 55, 2)
  # decompx_all_layers_outputs.aggregated ~ (12, 8, 55, 55)

  ### decompx_last_layer_outputs.aggregated ~ (1, 8, 55, 55) ###
  importance = np.array(
      [g.squeeze().cpu().detach().numpy() for g in decompx_last_layer_outputs.aggregated]
      ).squeeze()  # (batch, seq_len, seq_len)
  importance = [
      importance[j][:batch_lengths[j],:batch_lengths[j]] for j in range(len(importance))
      ]
  decompx_outputs["importance_last_layer_aggregated"] = importance

  ### decompx_last_layer_outputs.pooler ~ (1, 8, 55) ###
  importance = np.array(
      [g.squeeze().cpu().detach().numpy() for g in decompx_last_layer_outputs.pooler]
      ).squeeze()  # (batch, seq_len)
  importance = [
      importance[j][:batch_lengths[j]] for j in range(len(importance))
      ]
  decompx_outputs["importance_last_layer_pooler"] = importance

  ### decompx_last_layer_outputs.classifier ~ (8, 55, 2) ###
  importance = np.array(
      [g.squeeze().cpu().detach().numpy() for g in decompx_last_layer_outputs.classifier]
      ).squeeze()  # (batch, seq_len, classes)
  importance = [
      importance[j][:batch_lengths[j], :] for j in range(len(importance))
      ]
  decompx_outputs["importance_last_layer_classifier"] = importance

  ### decompx_all_layers_outputs.aggregated ~ (12, 8, 55, 55) ###
  importance = np.array(
      [g.squeeze().cpu().detach().numpy() for g in decompx_all_layers_outputs.aggregated])  # (layers, batch, seq_len, seq_len)
  importance = np.einsum('lbij->blij', importance)  # (batch, layers, seq_len, seq_len)
  importance = [
      importance[j][:, :batch_lengths[j], :batch_lengths[j]] for j in range(len(importance))
      ]
  decompx_outputs["importance_all_layers_aggregated"] = importance

  if df:
    return pd.DataFrame(decompx_outputs)
  return decompx_outputs

In [None]:
def arrays_to_batch(batch_size, *arrays):
    # combined_arrays = list(zip(*arrays))
    for i in range(0, len(arrays[0]), batch_size):
      batch = []
      for arr in arrays:
        batch.append(arr[i:i + batch_size])
      yield batch

## Visualization

In [None]:
def print_importance(importance, tokenized_text, discrete=False, prefix="", no_cls_sep=False):
    """
    importance: (sent_len)
    """
    if no_cls_sep:
        importance = importance[1:-1]
        tokenized_text = tokenized_text[1:-1]
    importance = importance / np.abs(importance).max() / 1.5  # Normalize
    if discrete:
        importance = np.argsort(np.argsort(importance)) / len(importance) / 1.6


    prefix = (f"<span style='"
                 "color: rgba(255, 255, 255, 1.0); "
                 "padding: 3px;"
                 "'>"
                f"{prefix}"
              "</span> ")
    html = "<pre style='color:black; padding: 3px;'>"+prefix
    for i in range(len(tokenized_text)):
        if importance[i] >= 0:
            rgba = matplotlib.colormaps.get_cmap('Greens')(importance[i])   # Wistia
        else:
            rgba = matplotlib.colormaps.get_cmap('Reds')(np.abs(importance[i]))   # Wistia
        text_color = "color: rgba(255, 255, 255, 1.0); " if np.abs(importance[i]) > 0.9 else ""
        color = f"background-color: rgba({rgba[0]*255}, {rgba[1]*255}, {rgba[2]*255}, {rgba[3]}); " + text_color
        html += (f"<span style='"
                 f"{color}"
                 f"border-radius: 5px; padding: 3px;"
                 f"font-weight: {int(800)};"
                 "'>")
        html += tokenized_text[i].replace('<', "[").replace(">", "]")
        html += "</span> "
    display(HTML(html))
#     print(html)
    return html

def print_preview(idx=0, discrete=False):
    NO_CLS_SEP = False
    df = decompx_outputs_df
    for col in [
        "importance_last_layer_aggregated",
        "importance_last_layer_classifier",
        # "importance_all_layers_aggregated",
        "importance_last_layer_pooler"]:
        if col in df and df[col][idx] is not None:
            if "aggregated" in col:
                sentence_importance = df[col].iloc[idx][0, :]
            if "classifier" in col:
                for label in range(df[col].iloc[idx].shape[-1]):
                    sentence_importance = df[col].iloc[idx][:, label]
                    print_importance(
                        sentence_importance,
                        df["tokens"].iloc[idx],
                        prefix=f"{col.split('_')[-1]} Label{label}:".ljust(20),
                        no_cls_sep=NO_CLS_SEP,
                        discrete=False
                    )
                continue
                sentence_importance = df[col].iloc[idx][:, df["label"].iloc[idx]]
            if "pooler" in col:
                sentence_importance = df[col].iloc[idx]

            # print(col, sentence_importance)
            print_importance(
                sentence_importance,
                df["tokens"].iloc[idx],
                prefix=f"{col.split('_')[-1]}:".ljust(20),
                no_cls_sep=NO_CLS_SEP,
                discrete=discrete
            )
    print("------------------------------------")
    return df

# for i in range(len(SENTENCES)):
#     print_preview(idx=i)

In [None]:
@reporter(prompt="Heatmap generator")
def generate_heatmap(df, data_column, annot=True, fmt='.2f', figsize=(10, 8), **kwargs):
    """
    Generate a heatmap-like plot using Seaborn's sns.heatmap().

    Parameters:
        df (pd.DataFrame): The DataFrame containing the data.
        data_column (str): The column name containing the nested arrays for the heatmap.
        annot (bool, optional): If True, display the data values in each cell. Default is True.
        fmt (str, optional): Format string for annotating cells. Default is '.2f'.
        figsize (tuple, optional): Figure size (width, height) in inches. Default is (10, 8).
        **kwargs: Additional keyword arguments to be passed to sns.heatmap().

    Returns:
        None (displays the heatmap plot)
    """
    # Convert the nested arrays to a 2D list
    data = df[data_column][1][:,0,:].tolist()

    # Create the heatmap using sns.heatmap()
    plt.figure(figsize=figsize)
    sns.heatmap(data, cmap=cmap, annot=annot, fmt=fmt, linewidths=0.5, linecolor='white', cbar=True, **kwargs)

    # Additional customization
    plt.title("Heatmap-like Plot")
    plt.xlabel("X-axis Labels")
    plt.ylabel("Y-axis Labels")
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)

    # Show the plot
    plt.show()


In [None]:
@reporter(prompt="Heatmap plotter")
def plot_heatmap_classifier_and_logits(dataset, use_abs=False, save_path=None):
    """
      Generate a multi-plot visualization for each row in the dataset, including a heatmap of feature importance values
      from all layers, a bar chart showing the importance values from the last layer classifier, and a bar chart for the
      logits of the classifier.

      Parameters:
          dataset (pandas.DataFrame): A DataFrame containing the data to be visualized. It should have the following columns:
              - 'importance_all_layers_aggregated': A list of importance values for each layer in the neural network.
              - 'importance_last_layer_classifier': A one-dimensional array containing the importance values for the last layer.
              - 'logits': A two-element array representing the logits of the classifier.
              - 'tokens': A list of tokens (labels) for the corresponding importance values.

          save_path (str, optional): If provided, the plot will be saved as an image to the specified file path. If not provided,
              the plot will be displayed interactively.

      Returns:
          None
    """
    num_rows = len(dataset)
    num_cols = 3  # Number of plots for each row (heatmap, classifier, and logits)

    fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, 4*num_rows))

    for row_idx, row_data in dataset.iterrows():
        # Get the data for the current row
        importance_all_layers = row_data["importance_all_layers_aggregated"]
        importance_last_layer_classifier = row_data["importance_last_layer_classifier"]
        logits = row_data["logits"]
        tokens = row_data["tokens"]

        print("Tokens:", tokens, "\n--", len(tokens), '\n--', importance_all_layers.shape)

        label = np.argmax(logits) # TODO: check if works

        # Plot heatmap
        temp = importance_all_layers[:, 0, :]
        temp = np.concatenate([temp, [importance_last_layer_classifier[:, label]]], axis=0)
        max_r = np.max(np.abs(temp), axis=1, keepdims=True)
        temp = temp / max_r

        if use_abs:
          temp = np.abs(temp)

        # colors = ["lightblue", "white", "lightcoral", "darkred"]
        # cmap = mcolors.LinearSegmentedColormap.from_list('cmap', colors, N=1024)
        sns.heatmap(temp, ax=axes[row_idx, 0], linewidths=0.5, cmap="Reds")  # Reverse the heatmap rows
        axes[row_idx, 0].set_title("Heatmap (importance all layers)")
        axes[row_idx, 0].set_xticks(np.arange(len(tokens))+0.5)
        axes[row_idx, 0].set_xticklabels(tokens, rotation=90, ha='center')
        axes[row_idx, 0].invert_yaxis()

        # Plot importance_last_layer_classifier
        Y = importance_last_layer_classifier[:, label]
        barplot = sns.barplot(x=tokens, y=Y, ax=axes[row_idx, 1])
        axes[row_idx, 1].set_title(f"Importance Last Layer Classifier (instance {row_idx})")
        coef = max(abs(np.max(Y)), abs(np.min(Y))) / 2
        # Manually adjust color intensity based on the value
        for i, value in enumerate(Y):
            alpha = min(abs(value) / coef, 1)
            # axes[row_idx, 1].text(i, value, tokens[i], ha='center', va='center', rotation=90, fontsize=4)
            if value > 0:
                barplot.get_children()[i].set_color((0.0, 1.0, 0.0, alpha))  # Green color with adjusted alpha
            else:
                barplot.get_children()[i].set_color((1.0, 0.0, 0.0, alpha))  # Red color with adjusted alpha

        # Set ticks and labels for the bar chart
        axes[row_idx, 1].set_xticklabels(tokens, rotation=90, ha='center')

        # Plot logits
        colors = ['red', 'green']
        sns.barplot(x=np.arange(len(logits)), y=logits, ax=axes[row_idx, 2], palette=colors)
        axes[row_idx, 2].set_title(f"Logits (label: {label})")
        axes[row_idx, 2].set_xticks(np.arange(len(logits)))
        axes[row_idx, 2].set_xticklabels(["Logit {}".format(i) for i in range(len(logits))])

    plt.tight_layout()
    if save_path is not None:
        plt.savefig(save_path)
    plt.show()

In [None]:
def calculate_correlation_and_label(importance_all_layers, importance_last_layer_classifier, logits):
    predicted_label = np.argmax(logits)

    temp = np.concatenate(
        [importance_all_layers[:, 0, :], [importance_last_layer_classifier[:, predicted_label]]],
        axis=0,
    )

    # Calculate the correlation between each consecu/tive pair of rows
    correlations = np.array([np.corrcoef(temp[i], temp[i+1])[0, 1] for i in range(temp.shape[0] - 1)])


    return correlations


In [None]:
def correlate_decompx_output(one_ins):
  return calculate_correlation_and_label(
    one_ins["importance_all_layers_aggregated"],
    one_ins["importance_last_layer_classifier"],
    one_ins["logits"],
  )


In [None]:
@reporter(prompt="Correlation plotter")
def plot_correlation_boxchart(correlations, ax, title):
    corrs = [np.array(corr) for corr in correlations]
    flat_correlations = np.array(corrs)

    sns.boxplot(data=flat_correlations, ax=ax, color='black')
    ax.set_title(title)
    ax.set_ylabel("Correlation")
    ax.set_xlabel("Layer(i, i+1)")


# Sample usage

# Experiment 2

Token importance

## Initiating

In [None]:
import os

In [None]:
### Experiment configuration ###
MODEL, tokenizer_name = [
    ("bert-base-uncased", "bert-base-uncased"),
    ("drive/MyDrive/DecompX/bert/mnli/model/checkpoint-12272", "bert-base-uncased"),
    ("drive/MyDrive/DecompX/bert/mnli/model/checkpoint-24544", "bert-base-uncased"),
    ("drive/MyDrive/DecompX/bert/mnli/model/checkpoint-36816", "bert-base-uncased"),
    ("drive/MyDrive/DecompX/bert/mnli/model/checkpoint-49088", "bert-base-uncased"),
    ][0]
DATASET_NAME = "mnli"
DATASET_SPLIT = "validation_matched"
BATCH_SIZE = 3

BASE_PATH = "drive/MyDrive/DecompX/"

RESULT_DIR_PATH = f"{os.path.join(BASE_PATH, DATASET_NAME)}"

model_split_str = f"{DATASET_SPLIT}_{MODEL.replace('/', '_')}"

RESULT_FILE = f"{os.path.join(RESULT_DIR_PATH, model_split_str)}.npy"

DEMO = False


In [None]:
print(f"{MODEL=}")
print(f"{DATASET_NAME=}")
print(f"{DATASET_SPLIT=}")
print(f"{BATCH_SIZE=}")
print(f"{BASE_PATH=}")
print(f"{RESULT_DIR_PATH=}")
print(f"{RESULT_FILE=}")
print(f"{DEMO=}")

In [None]:
if not DEMO:
  from google.colab import drive

  # This will prompt you to authorize the Colab notebook to access your Drive
  drive.mount('/content/drive/')

In [None]:
if not DEMO:
  if not os.path.exists(RESULT_DIR_PATH):
      os.makedirs(RESULT_DIR_PATH)
      print(f"Directory '{RESULT_DIR_PATH}' created.")
  else:
      print(f"Directory '{RESULT_DIR_PATH}' already exists.")

In [None]:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

if "roberta" in MODEL:
    model = RobertaForSequenceClassification.from_pretrained(MODEL)
elif "bert" in MODEL:
    model = BertForSequenceClassification.from_pretrained(MODEL)

In [None]:
from datasets import load_dataset

dataset = load_dataset('glue', DATASET_NAME, split=DATASET_SPLIT)

print(f"Number of samples: {len(dataset)}")

In [None]:
dataset = dataset[::3]

In [None]:
# ! pip install --upgrade datasets


# import datasets
# print(datasets.__version__)


In [None]:
# import gc

# gc.collect()

## Evaluation

In [None]:
report = {
        "idx": [],
        "tokens": [],
        "importance": [],
        "importances": [],
        "logits": [],
        "predicted_label": [],
        "true_label": [],
    }

In [None]:
if not DEMO and os.path.exists(RESULT_FILE):
  _report = np.load(RESULT_FILE, allow_pickle=True).item()
  if "importances" in _report:
    report = _report


In [None]:
from tqdm.auto import tqdm

dataset_iter = arrays_to_batch(
    BATCH_SIZE,
    list(dataset['idx']),
    list(dataset['premise']),
    list(dataset['hypothesis']),
    list(dataset['label']),
    )

for idx_s, premise, hypothesis, labels in tqdm(dataset_iter):
  if idx_s[0] in report['idx']:
    continue

  decompx_outputs = compute_decompX(
    model=model,
    tokenizer=tokenizer,
    premise=premise,
    hypothesis=premise,
    df=False)

  _iter = zip(
      idx_s,
      decompx_outputs['tokens'],
      decompx_outputs['importance_last_layer_classifier'],
      decompx_outputs['importance_last_layer_aggregated'],
      decompx_outputs['logits'],
      labels
  )
  for idx, tokens, importance, importances, logits, label in _iter:
    report["idx"].append(idx)
    report["tokens"].append(tokens)
    report["importance"].append(importance)
    report["importances"].append(importances)
    report["logits"].append(logits)
    report["predicted_label"].append(np.argmax(logits))
    report["true_label"].append(label)

  if not DEMO:
    np.save(RESULT_FILE, report)

print("Done")
