# **Setup**

In [None]:
#@title #Setup: Imports and dependencies
#@markdown Install and import Python dependencies.
location = 'vsc' #@param['colab', 'vsc']

package = 'inseq' #@param['inseq', 'kayo', 'both']

!pip install jsonlines
import jsonlines
import json

from transformers import GPT2Tokenizer, GPT2LMHeadModel
import copy
import sys
import torch
from IPython.display import clear_output
import numpy as np
import gc
import argparse
from collections import defaultdict
import pandas as pd
import matplotlib as plt

import random
random.seed(42)

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")

if package == 'inseq' or package == 'both':
  print('Installing dependencies inseq')
  !pip install inseq
  import inseq


if package == 'kayo' or package == 'both':
  print('Installing dependencies kayo')
  !pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
  !git clone https://github.com/JanAthmer/Baseline_clusters.git &> /dev/null
#   !git clone https://github.com/kayoyin/interpret-lm.git &> /dev/null
  sys.path.append('./Baseline_clusters')
#   sys.path.append('./interpret-lm')

  from lm_saliency import *

clear_output()

print('Done!')

In [None]:
#@title #Setup: load Drive
if location == 'colab':
  from google.colab import drive
  drive.mount('/content/drive')

In [None]:
#@title #Setup: Utils

def average(lst):
  return sum(lst)/len(lst)


def reciprocal_rank(predictions, targets):
    # Combine predictions and targets into pairs
    pairs = list(zip(predictions, targets))

    # Sort the pairs based on the prediction values (in descending order)
    sorted_pairs = sorted(pairs, key=lambda x: x[0], reverse=True)

    # Find the rank of the first correct prediction
    rank = next((i + 1 for i, (pred, target) in enumerate(sorted_pairs) if target), 0)

    # Calculate reciprocal rank
    reciprocal_rank = 1 / rank if rank > 0 else 0

    return reciprocal_rank

def get_target(sentence):
    b = sentence.split(' ')
    prefix_l = 0
    postix_l = 0
    target_l = 0
    check = False
    for i,word in enumerate(b):
        if i != 0:
            word = " "+word

        if check == True:
            postix_l += len(tokenizer(word)['input_ids'])
        if "n't" in word:
            target_l = len(tokenizer(word)['input_ids'])
            check = True
        if check == False:
            prefix_l += len(tokenizer(word)['input_ids'])
    return([False]*prefix_l+[True]*target_l+[False]*postix_l)

import random

def generate_one_hot_list(length):
    # Check if the length is valid
    if length <= 0:
        raise ValueError("Length must be greater than 0")

    # Choose a random index for the "hot" element
    hot_index = random.randint(0, length - 1)

    # Create the one-hot list
    one_hot_list = [0] * length
    one_hot_list[hot_index] = 1

    return one_hot_list

def get_last_true_index(one_hot_vector):
    """
    Get the index of the last True value in the one-hot vector.

    Parameters:
    one_hot_vector (list of bool): One-hot vector with possibly multiple True values.

    Returns:
    int: The index of the last True value in the one-hot vector, or -1 if no True values are found.
    """
    last_true_index = -1
    for i, value in enumerate(one_hot_vector):
        if value:
            last_true_index = i
    return last_true_index

def normalize_vector(vector):
    """
    Normalize a vector between -1 and 1 by dividing each element by the absolute maximum value.

    Args:
    vector (numpy array): The input vector to be normalized.

    Returns:
    numpy array: The normalized vector.
    """
    abs_max = max(abs(np.max(vector)), abs(np.min(vector)))
    normalized_vector = vector / abs_max
    return normalized_vector


def normalize_and_truncate(vector_a, vector_b):
    """
    Normalize two input vectors between -1 and 1 and truncate the longer vector to match the length of the shorter one.

    Args:
    vector_a (numpy array): The first input vector to be normalized and truncated.
    vector_b (numpy array): The second input vector to be normalized and truncated.

    Returns:
    numpy array, numpy array: The normalized and truncated vectors A and B.
    """
    # Normalize vector A
    abs_max_a = max(abs(np.max(vector_a)), abs(np.min(vector_a)))
    normalized_vector_a = vector_a / abs_max_a

    # Normalize vector B
    abs_max_b = max(abs(np.max(vector_b)), abs(np.min(vector_b)))
    normalized_vector_b = vector_b / abs_max_b

    # Determine which vector is longer and truncate it
    if len(normalized_vector_a) > len(normalized_vector_b):
        truncated_vector_a = normalized_vector_a[:len(normalized_vector_b)]
        truncated_vector_b = normalized_vector_b
    else:
        truncated_vector_a = normalized_vector_a
        truncated_vector_b = normalized_vector_b[:len(normalized_vector_a)]

    return truncated_vector_a, truncated_vector_b

In [None]:
#@title #Setup: Load Data
with jsonlines.open("npi_present_1.jsonl", 'r') as f:
  npi = pd.DataFrame(f)

with jsonlines.open("determiner_noun_agreement_1.jsonl", 'r') as f:
  dna = pd.DataFrame(f)

sentences = []
values = []
with open("prefix+value.tsv", 'r', encoding='utf-8') as ifh:
  for line in ifh:
      sentence, value = line.strip().split('\t')
      sentences.append(sentence)
      values.append(value)
  any = pd.DataFrame({'one_prefix_prefix': sentences})

with open("vocab_tagged.json", 'r') as file:
    vocabulary = json.load(file)


# **Main function and loop**

In [None]:
#@title #Main Function
def rr_attribution(model_inseq, attributed_fn, sentence, target, targets, contrastive = True, foil = None, baseline = True, rnd = None):

    if attributed_fn == 'logit':
        contrast_attributed_fn = "contrast_logits_diff"
    elif attributed_fn == 'probability':
        contrast_attributed_fn = "contrast_prob_diff"

    if baseline:
        base = model_inseq.attribute(
                  sentence,
                  sentence + " " + target,
                  attributed_fn= attributed_fn,
                )

        base_att = base[0].target_attributions
        if explanation != 'occlusion':
            base_att = base_att.sum(axis = 2)
        base_att = torch.flatten(base_att[~torch.any(base_att.isnan(),dim=1)]).numpy()
        base_att_rev = base_att*-1

        base_rank = reciprocal_rank(base_att, targets)
        base_rank_rev = reciprocal_rank(base_att_rev, targets)

    if contrastive:
        con = model_inseq.attribute(
          sentence,
          sentence + " " + target,
          attributed_fn= contrast_attributed_fn,
          contrast_targets= sentence + " " + foil,
          step_scores=[contrast_attributed_fn]
        )
        con_att = con[0].target_attributions
        if explanation != 'occlusion':
            con_att = con_att.sum(axis = 2)
        con_att = torch.flatten(con_att[~torch.any(con_att.isnan(),dim=1)]).numpy()
        con_att_rev = con_att*-1

        con_rank = reciprocal_rank(con_att, targets)
        con_rank_rev = reciprocal_rank(con_att_rev, targets)

    if rnd != None:
        con_rnd = model_inseq.attribute(
          sentence,
          sentence + " " + target,
          attributed_fn= contrast_attributed_fn,
          contrast_targets= sentence + " " + rnd,
          step_scores=[contrast_attributed_fn]
        )
        con_rnd_att = con_rnd[0].target_attributions
        if explanation != 'occlusion':
            con_rnd_att = con_rnd_att.sum(axis = 2)
        con_rnd_att = torch.flatten(con_rnd_att[~torch.any(con_rnd_att.isnan(),dim=1)]).numpy()
        con_rnd_att_rev = con_rnd_att*-1

        con_rnd_rank = reciprocal_rank(con_rnd_att, targets)
        con_rnd_rank_rev = reciprocal_rank(con_rnd_att_rev, targets)

    clear_output()
    if baseline and contrastive and (rnd!=None):
        return base_rank,base_rank_rev,con_rank,con_rank_rev,con_rnd_rank,con_rnd_rank_rev
    elif baseline and contrastive:
        return base_rank,base_rank_rev,con_rank,con_rank_rev
    elif baseline:
        return base_rank,base_rank_rev
    elif contrastive:
        return con_rank,con_rank_rev
    elif (rnd!=None):
        return con_rnd_rank,con_rnd_rank_rev

In [None]:
#@title #Main Loop
results = []

explanations    = ["input_x_gradient",  "lime", "occlusion"]
datasets        = ["npi", "dna", "any"]
attributed_fns  = ["logit","probability"]
methods         = ["base", "contrastive", "contrastive_rnd"]
reversed        = ["Original", "Reversed"]

for explanation in explanations:
  model_inseq = inseq.load_model("gpt2", explanation)

  for data in datasets:

    for attributed_fn in attributed_fns:

      for index, row in data.iterrows():

            if data == 'any':
              sentence = row["one_prefix_prefix"]
              target = "any"
              foil = "some"
              targets = get_target(sentence)
            elif data == 'npi':
              sentence = row["one_prefix_prefix"]
              target = row["one_prefix_word_good"]
              foil = row["one_prefix_word_bad"]
              sent_len = len(sentence.split(' '))
              targets = (sent_len-1)*[False] + [True]
            elif data == 'dna':
              sentence = row["one_prefix_prefix"]
              target = row["one_prefix_word_good"]
              foil = row["one_prefix_word_bad"]
              sent_len = len(sentence.split(' '))
              targets =  [True]+(sent_len-1)*[False]

            foil_rnd = random.sample(vocabulary,1)[0][0][0]

            base_rr, base_rev_rr,con_rr, con_rev_rr, con_rnd_rr, con_rnd_rev_rr = rr_attribution(
                model_inseq,attributed_fn,sentence, target,targets, rnd = foil_rnd)

            results.append(base_rr)
            results.append(base_rev_rr)

            results.append(con_rr)
            results.append(con_rev_rr)

            results.append(con_rnd_rr)
            results.append(con_rnd_rev_rr)

# Create a MultiIndex from the dimensions
multi_index = pd.MultiIndex.from_product([explanations, datasets, attributed_fns,methods,reversed], names=['explanations', 'datasets', 'attributed_fns', 'methods', 'reversed'])

# Create the DataFrame using the MultiIndex and data
df = pd.DataFrame(results, index=multi_index, columns=['Value'])
df.to_csv('full_results.csv')