# Research sprint summary: Sam and Adam

## Setup

In [36]:
RESULTS_DATA_DIR = "../results"
HIER_EQAL_PREFIX = "hier_eqal"
PAREN_BAL_PREFIX = "paren_bal"

In [67]:
import json
import os

import numpy as np

import plotly.express as px
import plotly.graph_objects as go
import plotly.colors as colors

from rich.table import Table
from rich.console import Console

In [38]:
import plotly.io as pio
pio.renderers.default = "colab+vscode"

In [39]:
console = Console()

## Hierarchical equality 

## Experimental results

In [40]:
data_he = []
for filename in os.listdir(RESULTS_DATA_DIR):
    filepath = os.path.join(RESULTS_DATA_DIR, filename)
    if os.path.isfile(filepath) and filename.startswith(HIER_EQAL_PREFIX):
        with open(filepath) as fp:
            data_he.append(json.load(fp))

In [50]:
intervene_hook = [
    "hook_mid1",
    "hook_mid2",
    "hook_mid3",
]
subspace_size = [1, 2, 4, 8]

train_accuracies_he = np.zeros((len(intervene_hook), len(subspace_size), len(data_he[0]["train_accuracies"])))
test_accuracies_he = np.zeros((len(intervene_hook), len(subspace_size)))
for datum in data_he:
    i = intervene_hook.index(datum["parameters"]["intervene_hook"])
    j = subspace_size.index(datum["parameters"]["subspace_sizes"][0])
    test_accuracies_he[i, j] = datum["test_accuracy"]
    train_accuracies_he[i, j] = np.array(datum["train_accuracies"])

In [51]:
table = Table(
    title="Test accuracies for interventions on b1 and b2", show_header=True, header_style="bold"
)
table.add_column("Hook", style="bold")
table.add_column("Subspace Sizes 1")
table.add_column("Subspace Sizes 2")
table.add_column("Subspace Sizes 4")
table.add_column("Subspace Sizes 8")
for i, hook in enumerate(intervene_hook[:2]):
    table.add_row(
        hook,
        f"{test_accuracies_he[i, 0]:.4f}",
        f"{test_accuracies_he[i, 1]:.4f}",
        f"{test_accuracies_he[i, 2]:.4f}",
        f"{test_accuracies_he[i, 3]:.4f}",
    )
console.print(table)

In [72]:
import plotly.graph_objects as go

traces = []

# Iterate over the first two elements of intervene_hook
for i, hook in enumerate(intervene_hook[:2]):
    # Iterate over subspace_size
    for j, size in enumerate(subspace_size):
        # Get train accuracies for the current hook and size
        train_accuracies = train_accuracies_he[i, j]
        
        # Create a scatter trace for the train accuracies
        trace = go.Scatter(
            x=list(range(len(train_accuracies))),
            y=train_accuracies,
            name=f"{hook}, subspace size: {size}",
            mode="lines",
            visible=bool(1-i),
        )
        
        # Add the trace to the list of traces
        traces.append(trace)

# Create layout
layout = go.Layout(
    title="Hierarchical equality rotation matrix train accuracies",
    xaxis_title="Epoch",
    yaxis_title="Train Accuracy",
    updatemenus=[
        dict(
            buttons=[
                dict(
                    label="hook_mid1",
                    method="update",
                    args=[{"visible": [True] * 4 + [False] *4}]
                ),
                dict(
                    label="hook_mid2",
                    method="update",
                    args=[{"visible": [False] * 4 + [True] *4}]
                )
            ],
            active=0,
            showactive=True
        )
    ]
)

# Create Figure object
fig = go.Figure(data=traces, layout=layout)

# Display the figure
fig.show()

## Parenthesis balancing

### Experimental results

In [74]:
data_pb = []
for filename in os.listdir(RESULTS_DATA_DIR):
    filepath = os.path.join(RESULTS_DATA_DIR, filename)
    if os.path.isfile(filepath) and filename.startswith(PAREN_BAL_PREFIX):
        with open(filepath) as fp:
            data_pb.append(json.load(fp))

In [82]:
intervene_node = ["v", "s"]
intervene_hook = [
    "blocks.0.hook_resid_pre",
    "blocks.0.hook_resid_mid",
    "blocks.0.hook_resid_post",
]
subspace_size = [64, 256]
train_lr = [1, 100, 10000]

test_accuracies_pb = np.zeros(
    (len(intervene_node), len(intervene_hook), len(subspace_size))
)
train_accuracies_pb = np.zeros(
    (
        len(intervene_node),
        len(intervene_hook),
        len(subspace_size),
        len(data_pb[0]["train_accuracies"]),
    )
)

# Select the best runs
for datum in data_pb:
    i = intervene_node.index(datum["parameters"]["intervene_node"])
    j = intervene_hook.index(datum["parameters"]["intervene_hook"])
    k = subspace_size.index(datum["parameters"]["subspace_size"])
    if datum["test_accuracy"] > test_accuracies_pb[i, j, k]:
        test_accuracies_pb[i, j, k] = datum["test_accuracy"]
        train_accuracies_pb[i, j, k] = np.array(datum["train_accuracies"])

In [83]:
table_v = Table(
    title="Node 'v'", show_header=True, header_style="bold"
)
table_v.add_column("Hook", style="bold")
table_v.add_column("Subspace Size 64")
table_v.add_column("Subspace Size 256")
for j, hook in enumerate(intervene_hook):
    table_v.add_row(
        hook,
        f"{test_accuracies_pb[0, j, 0]:.4f}",
        f"{test_accuracies_pb[0, j, 1]:.4f}",
    )

table_s = Table(
    title="Node 's'", show_header=True, header_style="bold"
)

table_s.add_column("Hook", style="bold")
table_s.add_column("Subspace Size 64")
table_s.add_column("Subspace Size 256")
for j, hook in enumerate(intervene_hook):
    table_s.add_row(
        hook,
        f"{test_accuracies_pb[1, j, 0]:.4f}",
        f"{test_accuracies_pb[1, j, 1]:.4f}",
    )

console.print(table_v, table_s, justify="center")

In [90]:
traces = []

for i, node in enumerate(intervene_node):
    for j, hook in enumerate(intervene_hook):
        for k, size in enumerate(subspace_size):
    
            train_accuracies = train_accuracies_pb[i, j, k]
            
            # Create a scatter trace for the train accuracies
            trace = go.Scatter(
                x=list(range(len(train_accuracies))),
                y=train_accuracies,
                name=f"node: {node}, {hook}, subspace size: {size}",
                mode="lines",
                visible=bool(1-i),
            )
            
            # Add the trace to the list of traces
            traces.append(trace)

# Create layout
layout = go.Layout(
    title="Parenthesis balancing rotation matrix train losses",
    xaxis_title="Epoch",
    yaxis_title="Loss",
    updatemenus=[
        dict(
            buttons=[
                dict(
                    label="Node v",
                    method="update",
                    args=[{"visible": [True] * 6 + [False] *6}]
                ),
                dict(
                    label="Node s",
                    method="update",
                    args=[{"visible": [False] * 6 + [True] *6}]
                )
            ],
            active=0,
            showactive=True
        )
    ]
)

fig = go.Figure(data=traces, layout=layout)
fig.show()