In [None]:
import plotly.express as px
import torch.distributions
import transformer_lens
from transformer_lens import HookedTransformer

import connectome as core

%load_ext autoreload
%autoreload 2

In [None]:
model = HookedTransformer.from_pretrained("gpt2-small")

In [None]:
attn_only_4l = HookedTransformer.from_pretrained("attn-only-4l")

# IOI task

In [None]:
threshold = 0.3
prompt = "When Mary and John went to the store, John gave a book to"
corrupt_prompt = "When Tom and Sarah went to the store, Felix gave a book to"
metric = core.logit_diff_metric(model, " Mary", " John")
c = core.connectome(
    model,
    prompt,
    metric,
    core.ZeroPattern(),
    # core.CorruptIntervention(model, prompt, corrupt_prompt),
    # core.CropIntervention(model, prompt),
    core.BasicStrategy(),
    # d.BacktrackBisectStrategy(threshold),
    # d.BacktrackingStrategy(threshold),
    # core.BisectStrategy(threshold),
    # core.SplitStrategy(model, prompt, threshold, delimiters_as_leaves=True),
)
core.plot_attn_connectome(model, prompt, c).show()
# graph = core.plot_graphviz_connectome(model, prompt, c, threshold=threshold).pipe('svg').decode('utf-8')
# SVG(graph)

In [None]:
graph = core.plot_graphviz_connectome(model, prompt, c, depth=2, top_k=15)

In [None]:
sorted_connectome = sorted(c, key=lambda x: abs(x.strength), reverse=True)
thresholds = torch.linspace(0, 1.0, 20)
top_ks = list(range(1, len(c)))
for dampen_weak in [0, 0.2, 0.4, 0.6]:
    strength_kept = [
        core.cut_connectome(
            model,
            prompt,
            metric,
            core.filter_connectome(c, None, top_k=top_k),
            dampen_weak=dampen_weak,
        )
        for top_k in top_ks
        # for threshold in thresholds
    ]
    px.line(
        x=top_ks,
        y=strength_kept,
        title=f"Strength kept when keeping top connections and dampening other by {dampen_weak:.1f}",
        labels={"x": "Top k", "y": "Strength kept"},
        width=800,
    ).show()

In [None]:
px.line(x=top_ks, y=strength_kept)
# px.line(x=thresholds, y=strength_kept)

# Docstring task

In [None]:
model = HookedTransformer.from_pretrained("attn-only-4l")
threshold = 0.3
prompt = '''def port(self, load, size, file, last):
    """oil column piece

    :param load: crime population
    :param size: unit dark
    :param'''
corrupt_prompt = (
    prompt.replace("load", "banana")
    .replace("size", "apple")
    .replace("file", "pear")
    .replace("last", "orange")
)

c = core.connectom(
    model,
    prompt,
    core.logit_diff_metric(model, " file", " self", " load", " size", " last"),
    core.ZeroPattern(),
    # core.CorruptIntervention(model, prompt, corrupt_prompt),
    # core.BasicStrategy(),
    core.SplitStrategy(model, prompt, threshold, delimiters_as_leaves=True),
    # core.BacktrackBisectStrategy(threshold),
    # d.BacktrackingStrategy(threshold),
    # core.BisectStrategy(threshold),
)

In [None]:
graph_threshold = 0.4
core.plot_attn_connectome(model, prompt, c).show()
for depth in [2, 3, 4]:
    graph = core.graphviz_connectome(model, prompt, c, graph_threshold, depth=depth)
    svg = graph.pipe("svg").decode("utf-8")
    display(graph)
    import datetime

    date = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f")
    with open(f"graphs/graph-{date}.svg", "w") as f:
        f.write(svg)

In [None]:
print(prompt)
transformer_lens.utils.test_prompt(prompt, " file", model)

# Code task on Pythia

In [None]:
model = HookedTransformer.from_pretrained("gpt2-medium")

In [None]:
prompt = """from typing import List, Dict
def f(x: List):
    return sum(x[::2])

def g(x: float):
    return x ** 2

def h(x: Dict):
    return sum(x.values())

def i(x: str):
    return len(x)

var1: str = 'abc'
var2: Dict = {'a': 1, 'b': 2}
var3: List = [1, 2, 3]
var4: int = 4

h(var"""

transformer_lens.utils.test_prompt(prompt, "1", model)

# Exploration of grouping techniques

In [None]:
# Finding the log-probs on the prompt tokens
prompt = "When Mary and John went to the store, John gave a book to Mary."
log_probs = model(prompt)[0].log_softmax(-1)  # (seq_len, vocab_size)
tokens = model.to_tokens(prompt)[0]
tokens_str = model.to_str_tokens(tokens)
print(tokens.shape)
print(log_probs.shape)
correct_logprobs = log_probs[torch.arange(len(tokens) - 1), tokens[1:]]
print(correct_logprobs.shape)
for i, (t, n, l) in enumerate(zip(tokens_str, tokens_str[1:], correct_logprobs)):
    print(f"{i:2d} {t!r} {l.item():.2f} -> {n!r}")

import plotly.express as px

px.line(
    x=[
        f"{i} {t!r} -> {n!r}"
        for i, (t, n) in enumerate(zip(tokens_str, tokens_str[1:]))
    ],
    y=correct_logprobs.detach() * 0,
)

In [None]:
prompt = '''def port(self, load, size, file, last):
    """oil column piece

    :param load: crime population
    :param size: unit dark
    :param'''

print(prompt)
s = core.SplitStrategy(
    model,
    prompt,
    0.1,
    (
        "\n\n",
        tuple(".!?"),
        tuple(",:;"),
    ),
)
s.show_tree()

In [None]:
len(core.filter_connectome(c, 1))

In [None]:
len(c)

# Pythia CODE task

In [None]:
prompt = """
from typing import List
from math import pi

class Point:
    def __init__(self, x: float, y: float) -> None:
        self.x = x
        self.y = y

class A:
    def __init__(self, bottom_left: Point, top_right: Point) -> None:
        self.bottom_left = bottom_left
        self.top_right = top_right

class B:
    def __init__(self, center: Point, radius: float) -> None:
        self.center = center
        self.radius = radius

class C:
    def __init__(self, points: List[Point]) -> None:
        self.points = points

def calculate_area(rectangle: A) -> float:
    height = rectangle.top_right.y - rectangle.bottom_left.y
    width = rectangle.top_right.x - rectangle.bottom_left.x
    return height * width

def calculate_center(rectangle: A) -> Point:
    center_x = (rectangle.bottom_left.x + rectangle.top_right.x) / 2
    center_y = (rectangle.bottom_left.y + rectangle.top_right.y) / 2
    return Point(center_x, center_y)

def calculate_distance(point1: Point, point2: Point) -> float:
    return ((point2.x - point1.x) ** 2 + (point2.y - point1.y) ** 2) ** 0.5

def calculate_circumference(circle: B) -> float:
    return 2 * pi * circle.radius

def calculate_circle_area(circle: B) -> float:
    return pi * (circle.radius ** 2)

def calculate_perimeter(polygon: C) -> float:
    perimeter = 0
    points = polygon.points + [polygon.points[0]]  # Add the first point at the end for a closed shape
    for i in range(len(points) - 1):
        perimeter += calculate_distance(points[i], points[i + 1])
    return perimeter

foo = A(Point(2, 3), Point(6, 5))

bar = B(Point(0, 0), 5)

name = C([Point(0, 0), Point(1, 0), Point(0, 1)])

# Calculate circumference
print(calculate_circumference("""

In [None]:
from transformers import AutoTokenizer
from transformers import GPTNeoXTokenizerFast

tokenizer: GPTNeoXTokenizerFast = AutoTokenizer.from_pretrained(
    "EleutherAI/pythia-2.8b"
)

In [None]:
tokens = ["|BOS|"] + tokenizer.batch_decode(tokenizer(prompt)["input_ids"])
labels = [f"{i} {t!r}" for i, t in enumerate(tokens)]

In [None]:
max_attention = torch.load("avg_attention.pt", map_location="cpu")

In [None]:
import plotly.express as px

px.imshow(
    max_attention,
    x=labels,
    y=labels,
    color_continuous_scale="Blues",
    height=6000,
    width=6000,
    title="Max attention matrix for Pythia code task",
    labels=dict(x="Source", y="Target"),
)