# Interpreting Language Models with Contrastive Explanations
### ___Kayo Yin and Graham Neubig___
This Colab notebook lets you visualize contrastive explanations for language model decisions, based on [Yin and Neubig (2022)](https://arxiv.org/abs/2202.10419).





In [None]:
#@markdown #**Setup Environment**
#@markdown Install and import Python dependencies.
from IPython.display import clear_output
print('Installing dependencies...')
!pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
!pip install transformers
!pip install scikit-learn

print('Importing libraries...')
import sys
sys.path.append('./interpret-lm')
from lm_saliency import *
from sklearn.cluster import KMeans
import numpy as np
import random
import hashlib
import re
import sys
import tarfile
from collections import Counter, defaultdict
from pathlib import Path
from collections import Counter
import matplotlib.pyplot as plt
import requests
import nltk
from nltk.tokenize import word_tokenize
nltk.download('all')
clear_output()
print('Done!')

if not torch.cuda.is_available():
  print("Please restart runtime with GPU")

Done!


In [None]:
 #@markdown #**Open pre-made vocab**
  #@markdown Open pre-made vocab made from wikitext-103. The vocab is tokenized and tagged using nltk. It is done beforehand because of memory issues.
with open('/home/scur0635/Baseline_clusters/vocab_tagged.json', 'r') as f:
  vocab_tagged = json.load(f)

In [None]:
 #@markdown #**Open pre-made sentence list**
 #@markdown Open pre-made sentence list made from wikitext-103. It is done using regex which is a slow process so done beforehand.
with open('/home/scur0635/Baseline_clusters/sentences_standard.json', 'r') as f:
  sentences_list = json.load(f)

In [None]:
 #@markdown #**Only take f.i. verbs from the vocab**
verbs = [item for item in vocab_tagged if item[0][1] in ['VBD', 'VB']]

# Sort verbs based on frequency in descending order
sorted_verbs = sorted(verbs, key=lambda x: x[1], reverse=True)

# Print the sorted verbs
top_words = sorted_verbs[:1000]
top_words = [x[0][0] for x in sorted_verbs]

In [None]:
y_t ="go" #@param ["go", "he" , "man", "black"]
targets = [y_t]

In [None]:
 #@markdown #**Sample set X**
sample_size = 100 #@param [100, 200, 300, 400, 500]
target_list = [x for x in sentences_list for trgt in [" "+y_t+" ", " "+y_t+","," "+y_t+"."," "+y_t+"?"," "+y_t+"!"] if trgt in x]
samples = random.sample(target_list, sample_size)
samples = [x.split(y_t)[0] for x in samples if re.search(r'.+ .+ '+y_t, x)]

In [None]:
#@markdown #**Load Language Model**
model_name = "gpt-2" #@param ["gpt-2", "gpt-neo"]

if model_name == "gpt-2":
  tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
  model = GPT2LMHeadModel.from_pretrained("gpt2")
  vocab = list(tokenizer.encoder.keys())
elif model_name == "gpt-neo":
  model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
  tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
assert device.type == "cuda"
model.to(device)
clear_output()

In [None]:
#@markdown #**Retrieve contrastive explanation function**
def expl_vec(target, foil, input, explanation):

  input = input.strip() + " "
  input_tokens = tokenizer(input)['input_ids']
  attention_ids = tokenizer(input)['attention_mask']
  CORRECT_ID = tokenizer(" "+ target)['input_ids'][0]
  FOIL_ID = tokenizer(" "+ foil)['input_ids'][0]


  if explanation == "erasure":
    contra_explanation = erasure_scores(model, input_tokens, attention_ids, correct=CORRECT_ID, foil=FOIL_ID, normalize=True)
  else:
    saliency_matrix, embd_matrix = saliency(model, input_tokens, attention_ids, foil=FOIL_ID)
    if explanation == "input x gradient":
      contra_explanation = input_x_gradient(saliency_matrix, embd_matrix, normalize=False)
    elif explanation == "gradient norm":
      contra_explanation = l1_grad_norm(saliency_matrix, normalize=True)
  return contra_explanation

In [None]:
#@markdown #**Loop over targets, foils and samples**
results = []
for target in targets:
  for foil in top_words:
    vec_list = []
    for input in samples:
      e_x_y_t = expl_vec(target, foil, input, "input x gradient")
      vec_list.append(e_x_y_t)
    e_y_t = np.concatenate(vec_list)
    results.append(e_y_t)



KeyboardInterrupt: ignored

In [None]:
with open('verb_results.npy', 'wb') as f:
    np.save(f, results)

In [None]:
from google.colab import files
files.download('verb_results.npy')

In [None]:
# with open('test.npy', 'rb') as f:
#     results = np.load(f)

In [None]:
# #@markdown #**Fit KMeans cluster**
# #@markdown Fit KMeans cluster on what I think are 3 clusters? Those are the catogaries they use in the paper from what I understand.
# kmeans = KMeans(n_clusters = 8, random_state = 0, n_init='auto')
# kmeans.fit(results)

In [None]:
from sklearn.cluster import AgglomerativeClustering
from scipy.cluster.hierarchy import dendrogram, linkage
# model = AgglomerativeClustering(linkage="ward", n_clusters=2)
# model.fit(results)

In [None]:
def fancy_dendrogram(*args, **kwargs):
    max_d = kwargs.pop('max_d', None)
    if max_d and 'color_threshold' not in kwargs:
        kwargs['color_threshold'] = max_d
    annotate_above = kwargs.pop('annotate_above', 0)

    ddata = dendrogram(*args, **kwargs)

    if not kwargs.get('no_plot', False):
        plt.title('Hierarchical Clustering Dendrogram (truncated)')
        plt.xlabel('sample index or (cluster size)')
        plt.ylabel('distance')
        for i, d, c in zip(ddata['icoord'], ddata['dcoord'], ddata['color_list']):
            x = 0.5 * sum(i[1:3])
            y = d[1]
            if y > annotate_above:
                plt.plot(x, y, 'o', c=c)
                plt.annotate("%.3g" % y, (x, y), xytext=(0, -5),
                             textcoords='offset points',
                             va='top', ha='center')
        if max_d:
            plt.axhline(y=max_d, c='k')
    return ddata

In [None]:
X = np.array(results)
# distance, weight = get_distances(X,model)
# linkage_matrix = np.column_stack([model.children_, distance, weight]).astype(float)
plt.figure(figsize=(100,200))
linkage_matrix = linkage(X, 'ward')
dendrogram(linkage_matrix, orientation="right")
plt.show()

In [None]:
X = np.array(results)
linkage_matrix = linkage(X, 'ward')

fancy_dendrogram(
    linkage_matrix,
    truncate_mode='lastp',
    p=12,
    leaf_rotation=90.,
    leaf_font_size=12.,
    show_contracted=True,
    annotate_above=10,  # useful in small plots so annotations don't overlap
)
plt.show()

In [None]:
#@markdown #**Function for finding indexes of a certain cluster**
def ClusterIndicesNumpy(clustNum, labels_array): #numpy
    return np.where(labels_array == clustNum)[0]

In [None]:
#@markdown #**Function for retrieving words from indexes**
def retrieve_words(cluster, words, kmeans):
  idx = ClusterIndicesNumpy(cluster, kmeans.labels_)
  results = []
  for word in words:
    if words.index(word) in idx:
      results.append(word)
  return results

In [None]:
Foil_cluster = retrieve_words(7, top_words, kmeans)
print(Foil_cluster)

NameError: ignored

In [None]:
for i in range(8):
  cluster = retrieve_words(i, top_words, model)
  print('Cluster',i)
  print('------------------------------')
  print(cluster[:10])
  print('------------------------------')