# Iterative summarization interpretability notebook

## Goal

* Interpret a couple of summarizer/exemplifier networks.

* Follow the stream of information during a discussion between the two networks.

## Notes

* The bellow code is inspired (sometimes copied) from Neel Nanda's [demo notebook](https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/main/Main_Demo.ipynb) and his library [TransformerLens](https://github.com/neelnanda-io/TransformerLens).

## Imports

### Pip installs

In [None]:
!pip install git+https://github.com/neelnanda-io/TransformerLens.git
!pip install circuitsvis

### Classic libraries imports

In [None]:
import os
import json
import torch

import plotly.express as px
import plotly.io as pio
pio.renderers.default = "colab"

### External toolboxes

In [None]:
import circuitsvis as cv
import transformer_lens.utils as utils
from transformer_lens import HookedTransformer

## Model loading

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.cuda.empty_cache()
sum_model = HookedTransformer.from_pretrained("gpt2-small", device=device)
exp_model = HookedTransformer.from_pretrained("gpt2-small", device=device)

In [None]:
sum_weight_file = './sum_weights.pt'
exp_weight_file = './exp_weights.pt'

!wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=FILEID' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1tPU5mHCXcAxZJJHvv9XyT-MzgoeSWUk9" -O $sum_weight_file  && rm -rf /tmp/cookies.txt
!wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=FILEID' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1--qi-Rzhff4OcAtknrSrzfDzgNC9z5mQ" -O $exp_weight_file  && rm -rf /tmp/cookies.txt

if not os.path.exists(sum_weight_file):
    raise FileNotFoundError
else:
    sum_model.load_state_dict(torch.load(sum_weight_file))
    sum_model.eval()
if not os.path.exists(exp_weight_file):
    raise FileNotFoundError
else:
    exp_model.load_state_dict(torch.load(exp_weight_file))
    exp_model.eval()

## Iterative summarization test

*  The main idea is to plug to complementary networks to create a discussion.



In [None]:
def discuss(exp_model, sum_model, starting_review, discussion_length=5, sum_max_new_tokens=20, exp_max_new_tokens=200):
    to_pred = f"[review]: {starting_review}\n[summary]: "
    discussion = to_pred
    for _ in range(discussion_length):
        sum_out = sum_model.generate(to_pred, max_new_tokens=sum_max_new_tokens)
        sum_gen = sum_out.split('[summary]: ')[1]
        sum_gen = sum_gen.split('<|endoftext|>')[0]
        print(sum_gen)
        to_pred = f"[summary]: {sum_gen}\n[review]: "
        discussion += sum_gen
        exp_out = exp_model.generate(to_pred, max_new_tokens=exp_max_new_tokens)
        exp_gen = exp_out.split('[review]: ')[1]
        exp_gen = exp_gen.split('<|endoftext|>')[0]
        print(exp_gen)
        to_pred = f"[review]: {exp_gen}\n[summary]: "
        discussion += "\n" + to_pred
    return discussion


In [None]:
train_review = "I love this adaptation of the classic tale.  Henry Winkler is, of course, one of my favorite actors.  This is a different slant from the original but it gets the message across none the less.  Well worth the viewing."
train_summary = "Good Adaptation of the Classic Tale"
eval_review = "The counter scene when Allen's character says he is robbing the bank and has a \"gub\".  That is hilarious!! Many more humorous scenes!\nOne of America's best comics ever!!"
eval_summary = "Very Funny"

In [None]:
discussion_length = 3
train_discussion = discuss(exp_model, sum_model, train_review, discussion_length=discussion_length)
print(train_discussion)
eval_discussion = discuss(exp_model, sum_model, eval_review, discussion_length=discussion_length)
print(eval_discussion)

## Plot helpers

In [None]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [None]:
def attention_patterns(model, text, layer=0):
    print(f"Text:\n{text}")
    tokens = sum_model.to_tokens(gen_text)
    logits, cache = sum_model.run_with_cache(tokens, remove_batch_dim=True)
    attention_pattern = cache["pattern", layer, "attn"]
    str_tokens = sum_model.to_str_tokens(text)
    return cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern)

## Summarizer plots

In [None]:
gen_text = sum_model.generate(f"[review]: {train_review}\n[summary]: ", max_new_tokens=20)

### Attention patterns for generated text

In [None]:
layer = 0
print(f"Layer {layer} Head Attention Patterns:")
attention_patterns(sum_model, gen_text, layer=layer)

In [None]:
layer = 5
print(f"Layer {layer} Head Attention Patterns:")
attention_patterns(sum_model, gen_text, layer=layer)

In [None]:
layer = 10
print(f"Layer {layer} Head Attention Patterns:")
attention_patterns(sum_model, gen_text, layer=layer)

In [None]:
train_text = f"[review]: {train_review}\n[summary]: {train_summary}"
eval_text = f"[review]: {eval_review}\n[summary]: {eval_summary}"

### Attention patterns for an eval sample

In [None]:
layer = 0
print(f"Layer {layer} Head Attention Patterns:")
attention_patterns(sum_model, eval_text, layer=layer)

In [None]:
layer = 5
print(f"Layer {layer} Head Attention Patterns:")
attention_patterns(sum_model, eval_text, layer=layer)

In [None]:
layer = 11
print(f"Layer {layer} Head Attention Patterns:")
attention_patterns(sum_model, eval_text, layer=layer)

### Copying score

In [None]:
OV_circuit_all_heads = sum_model.OV
OV_circuit_all_heads_eigenvalues = OV_circuit_all_heads.eigenvalues 
OV_copying_score = OV_circuit_all_heads_eigenvalues.sum(dim=-1).real / OV_circuit_all_heads_eigenvalues.abs().sum(dim=-1)
imshow(utils.to_numpy(OV_copying_score), xaxis="Head", yaxis="Layer", title="OV Copying Score for each head", zmax=1.0, zmin=-1.0)

## Exemplifier plots

In [None]:
gen_text = exp_model.generate(f"[summary]: {train_summary}\n[review]: ", max_new_tokens=200)

### Attention patterns for generated text

In [None]:
layer = 0
print(f"Layer {layer} Head Attention Patterns:")
attention_patterns(exp_model, gen_text, layer=layer)

In [None]:
layer = 5
print(f"Layer {layer} Head Attention Patterns:")
attention_patterns(exp_model, gen_text, layer=layer)

In [None]:
layer = 10
print(f"Layer {layer} Head Attention Patterns:")
attention_patterns(exp_model, gen_text, layer=layer)

In [None]:
train_text = f"[summary]: {train_summary}\n[review]: {train_review}"
eval_text = f"[summary]: {eval_summary}\n[review]: {eval_review}"

### Attention patterns for an eval sample

In [None]:
layer = 0
print(f"Layer {layer} Head Attention Patterns:")
attention_patterns(exp_model, eval_text, layer=layer)

In [None]:
layer = 5
print(f"Layer {layer} Head Attention Patterns:")
attention_patterns(exp_model, eval_text, layer=layer)

In [None]:
layer = 11
print(f"Layer {layer} Head Attention Patterns:")
attention_patterns(exp_model, eval_text, layer=layer)

### Copying score

In [None]:
OV_circuit_all_heads = exp_model.OV
OV_circuit_all_heads_eigenvalues = OV_circuit_all_heads.eigenvalues 
OV_copying_score = OV_circuit_all_heads_eigenvalues.sum(dim=-1).real / OV_circuit_all_heads_eigenvalues.abs().sum(dim=-1)
imshow(utils.to_numpy(OV_copying_score), xaxis="Head", yaxis="Layer", title="OV Copying Score for each head", zmax=1.0, zmin=-1.0)