<a href="https://colab.research.google.com/github/SIDN-IAP/nl-explanations/blob/master/nl_explanations_exercise.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Set up the environment (this will take a minute):

In [0]:
%%bash
!(stat -t /usr/local/lib/*/dist-packages/google/colab > /dev/null 2>&1) && exit 
pip install ninja
wget http://lingo.csail.mit.edu/demos/sidn-iap-2020/nl-explanations.tar.gz
tar xzf nl-explanations.tar.gz

In [0]:
import sys
sys.path.append("/content/nl-explanations")
import pickle
import random
import collections

import torch
import torch.nn.functional as F
import numpy as np

import lab
from lab import ImageCaptioner, ImageClassifier, TextClassifier, ImageRanker
import torchdec

from IPython.display import display, HTML

DEVICE = torch.device("cpu")
lab.DEVICE = DEVICE

Load models and data:

In [0]:
with open("/content/nl-explanations/vocab.json") as reader:
  vocab = torchdec.Vocab()
  vocab.load(reader)

with open("/content/nl-explanations/CUB_test_features.p", "rb") as reader:
  test_features = pickle.load(reader)

# Keep 1000 only
test_features = {k: test_features[k] for i, k in zip(range(1000), test_features)}
test_names, test_images = zip(*test_features.items())
test_images = torch.tensor(test_images)

test_classes = collections.defaultdict(list)
for name in test_names:
  cls = name.split('.')[0]
  test_classes[int(cls)].append(name)

classifier = ImageClassifier()
classifier.load_state_dict(torch.load("/content/nl-explanations/image_classifier-1.0.m", map_location=DEVICE))
classifier.eval()

captioner = ImageCaptioner(vocab)
captioner.load_state_dict(torch.load("/content/nl-explanations/image_captioner-1.0.m", map_location=DEVICE))
captioner.eval()

text_classifier = TextClassifier(vocab)
text_classifier.load_state_dict(torch.load("/content/nl-explanations/text_classifier.m", map_location=DEVICE))
text_classifier.eval()

image_ranker = ImageRanker(vocab)
image_ranker.load_state_dict(torch.load("/content/nl-explanations/image_ranker.m", map_location=DEVICE))
image_ranker.eval()

Function for visualizing model predictions

In [0]:
def url(key):
  return f'http://lingo.csail.mit.edu/demos/sidn-iap-2020/CUB_200_2011/images/{key}'

def show_classifier_predictions(classifier, explainers={}, n=10, with_class=None):
  if not isinstance(classifier, dict):
    classifier = {"model": classifier}
  keys = list(enumerate(test_names[:n]))
  for i, key in keys:
    true_label, name = key.split("/")[0].split(".")
    true_label = int(true_label)
    if with_class is not None and true_label not in with_class:
      continue
    name = name.replace("_", " ").lower()
    features = torch.tensor([test_features[key]], device=DEVICE)
    
    display(HTML(f"<img src='{url(key)}' width=200/>"))
    display(HTML(f"<p><b>True label</b> {true_label} ({name})</p>"))
    for classifier_name in classifier:
      scores = classifier[classifier_name](features)
      pred_label = torch.argmax(scores, dim=1).item()
      display(HTML(f"<p><b>Predicted label ({classifier_name})</b> {pred_label}</p>"))

    for name, explainer in explainers.items():
      explanation = explainer(i, features, pred_label)
      display(HTML(f"<p><b>({name}) because...</b>{explanation}</p>"))


## Fine-grained classification

As a running example in this lab, we'll look at a model trained for a fine-grained bird species classification task. This is the [Caltech-UCSD Birds Dataset](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html), a common benchmark for fine-grained and few-shot visual classification. It contains 200 species of birds, each with 40-60 images.

We'd like this classifier to be able to _explain_ its predictions by describing the features of the input birds that are relevant to its classification decision.

As a first step, let's look at some predictions from this classifier:

In [0]:
show_classifier_predictions(classifier)

## Generating natural language explanations

There are a number of fine-grained distinctions between classes. Different features of birds can be discriminative for different reasons (e.g. the color of a wing might be important in one context, and the wing's shape in a different context.) We would like a tool for generating explanations that can make these fine-grained distinctions; the main idea in this lab is that language might provide such a tool.

The Caltech Birds dataset is also accompanied by [natural language descriptions of all the birds](https://github.com/reedscot/cvpr2016). This language data will serve as the basis for the explanation techniques we develop.

## Image captions as explanations

The simplest way we can imagine using the bird descriptions is just to train a model to generate descriptions, and show these descriptions alongside explanations. We'll let the captioning model condition _on the classifier's representation of the image_, rather than raw pixels, so classifier behavior is already partly tied to model behavior.

(For more information about image captioning models, check out https://arxiv.org/abs/1411.4389.)

In [0]:
def naive_explanation(i, features, pred_label):
  (caption,), _ = captioner.decode(features, greedy=True)
  return " ".join(vocab.decode(caption))

show_classifier_predictions(classifier, {"naive explanation": naive_explanation})

**Question**: Are these captions explanations at all? (What's the relationship between captions and the classifier they're supposed to explain?)

To answer this, we can start by exploring a related question: Can we distinguish good classifiers from bad classifiers on the basis of the quality of their descriptions?

**Exercise 1**: Below we've trained two extra image captioning models. One is slightly worse than the model used above, and one is significantly worse. We've also trained models to generate bird descriptions from the representations computed by each of these models. Can you tell which model has the highest-quality representations based on the quality of the associated explanations?

In [0]:
ok_classifier = ImageClassifier()
ok_classifier.load_state_dict(torch.load("/content/nl-explanations/image_classifier-0.03.m", map_location=DEVICE))
ok_captioner = ImageCaptioner(vocab)
ok_captioner.load_state_dict(torch.load("/content/nl-explanations/image_captioner-0.03.m", map_location=DEVICE))

bad_classifier = ImageClassifier()
bad_classifier.load_state_dict(torch.load("/content/nl-explanations/image_classifier-0.01.m", map_location=DEVICE))
bad_captioner = ImageCaptioner(vocab)
bad_captioner.load_state_dict(torch.load("/content/nl-explanations/image_captioner-0.01.m", map_location=DEVICE))

def naive_explanation_for_captioner(captioner):
  def fn(i, features, pred_label):
    (caption,), _ = captioner.decode(features, greedy=False)
    return " ".join(vocab.decode(caption))
  return fn

show_classifier_predictions(
  {"good": classifier, "ok": ok_classifier, "bad": bad_classifier},
  {"good": naive_explanation, "ok": naive_explanation_for_captioner(ok_captioner), "bad": naive_explanation_for_captioner(bad_captioner)},
  n=10
)

So we can already see that captions provide a proxy signal for model quality (even though the "good" and "ok" models get roughly the same accuracy on the test set we've been looking at!)


However: suppose we want to understand _why_ a given class label has been assigned. Is the current approach to textual explanations good enough? Consider following examples:

In [0]:
show_classifier_predictions(classifier, {"naive explanation": naive_explanation}, n=200, with_class={88, 157})

Both descriptions apply reasonably well to both birds!

## Class-discriminative captions

We should pick captions that not only apply to the given image, but that are *informative about the class label*.

In the code below, we've trained a model (`text_classifier`) that assigns a probability to each bird class based on _description text alone_. 

Make sure you understand what the code below does. How does adding a text classifier term change the kind of explanations that get generated?

In [0]:
def discriminative_explanation(n_samples=20):
  def fn(i, features, pred_label):
    feature_batch = torch.cat([features] * n_samples)
    captions, caption_scores = captioner.decode(feature_batch, greedy=False)
    caption_batch = torchdec.batch_seqs(captions)
    classifier_scores = text_classifier(caption_batch)[:, pred_label]

    combined_scores = classifier_scores + torch.tensor(caption_scores)
    best_caption_index = combined_scores.argmax()
    return " ".join(vocab.decode(captions[best_caption_index]))
  return fn

show_classifier_predictions(
    classifier,
    {
        "naive": naive_explanation,
        "discriminative": discriminative_explanation(),
    }
)

**Exercise 2**: how does changing the number of samples used in the reranking step affect explanation quality?

**Exercise 3**: Add a parameter that trades off the weight of the image captioner and text classifier terms in the objective above. How does this weight affect explanation quality?

## Image-discriminative captions

The explanations above are discriminative of the class of the bird, but not necessarily the actual contents of the image itself. We can circumvent this by generating descriptions that are not just discriminative of the birds class, but are discriminative of images as well.

In [0]:
def image_discriminative_explanation(class_weight, image_weight, n_samples=20, n_distractors=10):
  def fn(i, features, pred_label):
    feature_batch = torch.cat([features] * n_samples)
    captions, caption_scores = captioner.decode(feature_batch, greedy=False)
    caption_batch = torchdec.batch_seqs(captions)
    text_classifier_scores = text_classifier(caption_batch)[:, pred_label]

    # Sample some distractor images and score them
    distractors = random.sample(list(test_features.values()), n_distractors)
    image_scores = torch.zeros((n_samples, n_distractors + 1))
    image_scores[:, 0] = image_ranker.score((caption_batch, feature_batch))
    for i, distractor in enumerate(distractors, start=1):
        distractor_t = torch.tensor([distractor], device=DEVICE)
        distractor_batch = torch.cat([distractor_t] * n_samples)
        image_scores[:, i] =  image_ranker.score((caption_batch, distractor_batch))
        
    # Normalize over distractors
    image_scores = F.log_softmax(image_scores, dim=1)[:, 0]

    combined_scores = image_weight * image_scores + class_weight * text_classifier_scores + torch.tensor(caption_scores)
    best_caption_index = combined_scores.argmax()
    return " ".join(vocab.decode(captions[best_caption_index]))
  return fn

In [0]:
show_classifier_predictions(
    classifier,
    {
        "naive": naive_explanation,
        "class-discriminative-1.0": discriminative_explanation(),
        "image-discriminative-1.0": image_discriminative_explanation(1.0, 2.0),
    }
)

**Exercise 4**: can you get even higher-quality explanations by changing the way these terms are combined?

**Exercise 5**: how would you use these models to produce explanations for counterfactual classes? What about counterfactual inputs?

## Feature-discriminative captions

There's one final issue with the approach described above: an explanation might
discriminate both the class and the input image, by talking about features
that the model doesn't actually use to make decisions.

To address this issue, we'll introduce a final explanation technique that identifies a set of images whose _model representations_ are similar to the input image, and then generates an explanation summarizing what all those images have in common.

We'll start by computing similarity between images in representation space. We'll call an image with representation $x$ **similar** to image with representation $y$ if the Euclidean distance $d(x, y)$ is within the top 5% closest images in our dataset.

In [0]:
from scipy.spatial.distance import squareform, pdist

# Get representations
all_test_reps = captioner.proj(test_images).detach().cpu().numpy()

# Compute pairwise similarity
dists = pdist(all_test_reps)
dists = squareform(dists)

In [0]:
def similar_to_image(i, alpha=0.05):
  imgs = [(j, dist) for j, dist in enumerate(dists[i]) if j != i]
  img_dists = [i[1] for i in imgs]
  threshold = np.quantile(img_dists, alpha)
  most_sim = [i for i in imgs if i[1] < threshold]
  return sorted(most_sim, key=lambda x: x[1])
  
def show_similar_to_image(i, n=5):
  x = test_names[i]
  ys = similar_to_image(i, alpha=0.05)[:n]
  ys = [(test_names[j], dist) for j, dist in ys]
  display(HTML(f"<p><b>Most similar images to {x}</b></p>"))
  display(HTML(f"<img src='{url(x)}'' width=200/>"))
  for y, y_dist in ys:
    display(HTML(f"<p><b>{y}</b> ({y_dist:.3f})"))
    display(HTML(f"<img src='{url(y)}' width=200/>"))

In [0]:
show_similar_to_image(2)

Similarly, let's make a function that, given a textual description, computes the top 5% of matching images:

In [0]:
def similar_to_text(caption, alpha=0.05):
  caption_batch = caption.unsqueeze(0).expand(test_images.shape[0], caption.shape[0])
  caption_batch = torchdec.batch_seqs(caption_batch)
  image_scores = image_ranker.score((caption_batch, test_images))
  image_scores = image_scores.detach().cpu().numpy()
  # Higher is better
  threshold = np.quantile(image_scores, 1 - alpha)
  most_sim = [(j, score) for j, score in enumerate(image_scores) if score > threshold]
  return sorted(most_sim, key=lambda x: -x[1])

def show_similar_to_text(caption, n=5):
  cap = torch.tensor(vocab.encode(caption.split()))
  ys = similar_to_text(cap, alpha=0.05)[:n]
  ys = [(test_names[j], dist) for j, dist in ys]
  display(HTML(f"<p><b>Most similar images to</b> <span style='font-family: monospace'>{caption}</span></p>"))
  for y, y_dist in ys:
    display(HTML(f"<p><b>{y}</b> ({y_dist:.3f})"))
    display(HTML(f"<img src='{url(y)}' width=200/>"))
  
show_similar_to_text('this bird is yellow')

We claim that a good natural language explanation for a representation $x$ is the string $w$ that describes common features of all images with representations similar to $x$.

If representations $x$ are associated with a ''neighborhood'' $N_{\text{img}}(x)$ and descriptions $w$ with a neighborhood $N_{\text{text}}(w)$, we explain $x$
by finding $w$ so the two neighborhoods are similar (e.g. using [IoU](https://en.wikipedia.org/wiki/Jaccard_index)):
$$
w = \text{argmax}_{w'}\, \text{IoU} \left( N_{\text{img}}(x), N_{\text{text}}(w') \right)
$$

(How is this similar to the network dissection lab from earlier in the course?)

In [0]:
from scipy.spatial.distance import jaccard
from tqdm import tqdm

def representation_explanation(class_weight, iou_weight, n_samples=20, n_distractors=10, alpha=0.10):
  def fn(i, features, pred_label):
    feature_batch = torch.cat([features] * n_samples)
    captions, caption_scores = captioner.decode(feature_batch, greedy=False)
    caption_batch = torchdec.batch_seqs(captions)
    text_classifier_scores = text_classifier(caption_batch)[:, pred_label]

    # Compute similarity to other representations
    i_dists = dists[i]
    threshold = np.quantile(i_dists, alpha)
    img_hits = i_dists < threshold

    ious = torch.zeros(n_samples)
    # For each caption, compute similar images + iou
    for c, caption in enumerate(tqdm(captions, desc='Computing caption IoU')):
      caption = [caption] * len(test_images)
      caption_batch = torchdec.batch_seqs(caption)
      scores = image_ranker.score((caption_batch, test_images))
      scores = scores.detach().cpu().numpy()
      threshold = np.quantile(scores, 1 - alpha)
      cap_hits = scores > threshold
      iou = 1 - jaccard(img_hits, cap_hits)
      ious[c] = iou
    print(f"Iou mean: {ious.mean().item()}")
    print(f"Iou min: {ious.min().item()}")
    print(f"Iou max: {ious.max().item()}")
    
    combined_scores = iou_weight * ious + class_weight * text_classifier_scores + torch.tensor(caption_scores)
    best_caption_index = combined_scores.argmax()
    return " ".join(vocab.decode(captions[best_caption_index]))
  return fn

In [0]:
show_classifier_predictions(
    classifier,
    {
        "naive": naive_explanation,
        "representation-discriminative-1.0": representation_explanation(1.0, 100.0),
    }
)

**Exercise 6**: Again, play with the weights for different terms and see how they affect explanations.