In [170]:
%load_ext autoreload
%autoreload 2

import torch
import wandb
from epsilon_transformers.analysis.wandb import fetch_artifacts_for_run
from epsilon_transformers.analysis.wandb import fetch_run_config
from epsilon_transformers.analysis.wandb import load_model_artifact
from epsilon_transformers.analysis.wandb import download_artifacts

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

user_or_org = "adamimos"
project_name = "transformer-MSPs"
run_id = "vfs4q106"  # rrxor 

wandb.init(id=run_id, resume='must')
arts = fetch_artifacts_for_run(user_or_org, project_name, run_id)
print(f"Found {len(arts)} artifacts")


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Using device: cpu


VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))



VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011118893977578006, max=1.0…

Found 31788 artifacts


In [12]:

download_dir = "./downloaded_artifacts"  # Specify your directory here
download_artifacts(arts, download_dir)

config = fetch_run_config(user_or_org, project_name, run_id)


Found 31787 existing .pt files in ./downloaded_artifacts


Downloading artifacts: 100%|██████████| 1/1 [01:41<00:00, 101.12s/it, Artifact: run-vfs4q106-history:v10]


In [21]:
import os
# get the number in each filename in the download_dir
# and make it a list of ints, then sort it and see if any are missing
filenames = [int(x.split("_")[-1].split(".")[0]) for x in os.listdir(download_dir)]
filenames.sort()
print(f"Missing: {set(range(filenames[0], filenames[-1])) - set(filenames)}")


Missing: {31783, 31784, 31785, 31786, 31787, 31788, 31789, 31790, 31791, 31792, 31795, 31796, 31797, 31798, 31799, 31800, 31801, 31802, 31803, 31804, 31805, 31806, 31807, 31808, 31809, 31811, 31812, 31813, 31814, 31815, 31816, 31817, 31818, 31819, 31820, 31821, 31822, 31823, 31824, 31825, 31826, 31827, 31828, 31829, 31830, 31831, 31832, 31833, 31834}


In [339]:
from epsilon_transformers.process.processes import RRXOR
import numpy as np
from epsilon_transformers.training.configs import RawModelConfig

model = model_config = RawModelConfig(
            d_vocab=config["d_vocab"],
            d_model=config["d_model"],
            n_ctx=config["n_ctx"],
            d_head=config["d_head"],
            n_head=config["n_heads"],
            d_mlp=config["d_model"] * 4,
            n_layers=config["n_layers"],
        ).to_hooked_transformer(seed=1337, device=device)


save_point = 20000
model.load_state_dict(torch.load(f"{download_dir}/model_epoch_{save_point}.pt", map_location=device))

process = RRXOR()
msp = process.derive_mixed_state_presentation(depth=config["n_ctx"] + 1)
msp_paths_and_beliefs = msp.paths_and_belief_states

msp_paths = [np.array(x[0]) for x in msp_paths_and_beliefs]
msp_beliefs = [x[1] for x in msp_paths_and_beliefs]
msp_beliefs = [tuple(round(b, 5) for b in belief) for belief in msp_beliefs]
msp_unique_beliefs = set(msp_beliefs)
print(f"Unique beliefs: {len(msp_unique_beliefs)} out of {len(msp_beliefs)}")

pca = PCA(n_components=3)
pca.fit(msp_beliefs)
msp_beliefs_pca = pca.transform(msp_beliefs)

msp_belief_index = {b: i for i, b in enumerate(msp_unique_beliefs)}
msp_beliefs_index = [msp_belief_index[tuple(round(b, 5) for b in belief)] for belief in msp_beliefs]

X = [x for x in msp.paths if len(x) == config["n_ctx"]]
X = torch.tensor(X, dtype=torch.int).to(device) # (batch, n_ctx)

X_beliefs = torch.zeros(X.shape[0], X.shape[1], 5).to(device)
X_belief_indices = torch.zeros(X.shape[0], X.shape[1], 1).to(device)


for i in range(X.shape[0]):
    for j in range(X.shape[1]):
        p = X[i, :j+1].cpu().numpy()
        # find which index of msp_paths is p
        path_index = int(np.where([np.array_equal(p, x) for x in msp_paths])[0])
        msp_belief_state = msp_beliefs[path_index]
        X_beliefs[i, j] = torch.tensor(msp_belief_state, dtype=torch.float32)
        X_belief_indices[i, j] = torch.tensor(msp_belief_index[msp_belief_state], dtype=torch.int)



_, cache = model.run_with_cache(X, names_filter=lambda x: 'ln1.hook_normalized' in x)

acts = torch.cat([cache[f"blocks.{i}.ln1.hook_normalized"] for i in range(4)], dim=-1) # (batch, n_ctx, 4 * d_model)

acts_flattened = acts.view(-1, acts.shape[-1]).cpu().numpy()
X_beliefs_flattened = X_beliefs.view(-1, X_beliefs.shape[-1]).cpu().numpy()

regression = LinearRegression()
regression.fit(acts_flattened, X_beliefs_flattened)

result = regression.predict(acts_flattened)

result_pca = pca.transform(result)

# plot result pca in 3d using plotly
import plotly.graph_objects as go



import plotly.graph_objects as go

X_belief_indices_flattened = X_belief_indices.view(-1).cpu().numpy()

colors = px.colors.qualitative.Light24 + px.colors.qualitative.Dark24 + px.colors.qualitative.Plotly
fig = go.Figure()

from plotly.subplots import make_subplots

# Create a subplot with 1 row and 2 columns
fig = make_subplots(rows=1, cols=2, specs=[[{'type': 'scatter3d'}, {'type': 'scatter3d'}]])

# Plot msp_beliefs_pca on the left (first subplot)
msp_beliefs_pca_data = msp_beliefs_pca
for b in range(len(msp_unique_beliefs)):
    relevant_indices = msp_beliefs_index == b
    relevant_data = msp_beliefs_pca_data[relevant_indices]
    fig.add_trace(go.Scatter3d(x=relevant_data[:, 0],
                           y=relevant_data[:, 1],
                           z=relevant_data[:, 2],
                           mode='markers',
                           name=f'Belief {b}',
                           marker=dict(size=5, color=colors[b], opacity=0.5)),
              row=1, col=1)

# Plot result_pca on the right (second subplot)
for b in range(len(msp_unique_beliefs)):
    relevant_indices = np.where(X_belief_indices_flattened == b)[0]
    relevant_data = result_pca[relevant_indices]
    fig.add_trace(go.Scatter3d(x=relevant_data[:, 0],
                               y=relevant_data[:, 1],
                               z=relevant_data[:, 2],
                               mode='markers',
                               name=f'Belief {b}',
                               marker=dict(size=2, color=colors[b]), opacity=0.1),
                  row=1, col=2)

fig.update_layout(title='3D PCA Projection of Beliefs',
                  scene=dict(xaxis_title='PCA 1', yaxis_title='PCA 2', zaxis_title='PCA 3'),
                  scene2=dict(xaxis_title='PCA 1', yaxis_title='PCA 2', zaxis_title='PCA 3'))
fig.show()


Unique beliefs: 36 out of 1723



Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)



In [354]:
from epsilon_transformers.process.processes import RRXOR
import numpy as np
from epsilon_transformers.training.configs import RawModelConfig
import torch
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px

model_config = RawModelConfig(
    d_vocab=config["d_vocab"],
    d_model=config["d_model"],
    n_ctx=config["n_ctx"],
    d_head=config["d_head"],
    n_head=config["n_heads"],
    d_mlp=config["d_model"] * 4,
    n_layers=config["n_layers"],
).to_hooked_transformer(seed=1337, device=device)

save_point = 0
model.load_state_dict(torch.load(f"{download_dir}/model_epoch_{save_point}.pt", map_location=device))

process = RRXOR()
msp = process.derive_mixed_state_presentation(depth=config["n_ctx"] + 1)
msp_paths_and_beliefs = msp.paths_and_belief_states

msp_paths = [np.array(x[0]) for x in msp_paths_and_beliefs]
msp_beliefs = [x[1] for x in msp_paths_and_beliefs]
msp_beliefs = [tuple(round(b, 5) for b in belief) for belief in msp_beliefs]
msp_unique_beliefs = list(set(msp_beliefs))  # Convert to list for indexing
print(f"Unique beliefs: {len(msp_unique_beliefs)} out of {len(msp_beliefs)}")


pca = PCA(n_components=3)
pca.fit(msp_unique_beliefs)
msp_beliefs_pca = pca.transform(msp_unique_beliefs)

msp_belief_index = {b: i for i, b in enumerate(msp_unique_beliefs)}
msp_beliefs_index = [msp_belief_index[tuple(round(b, 5) for b in belief)] for belief in msp_unique_beliefs]

X = [x for x in msp.paths if len(x) == config["n_ctx"]]
X = torch.tensor(X, dtype=torch.int).to(device)  # (batch, n_ctx)

X_beliefs = torch.zeros(X.shape[0], X.shape[1], 5).to(device)
X_belief_indices = torch.zeros(X.shape[0], X.shape[1], 1, dtype=torch.int).to(device)  # Specify dtype

for i in range(X.shape[0]):
    for j in range(X.shape[1]):
        p = X[i, :j+1].cpu().numpy()
        path_index = np.where([np.array_equal(p, x) for x in msp_paths])[0][0]  # Get the first index
        msp_belief_state = msp_beliefs[path_index]
        X_beliefs[i, j] = torch.tensor(msp_belief_state, dtype=torch.float32)
        X_belief_indices[i, j] = msp_belief_index[msp_belief_state]  # Assign directly

_, cache = model.run_with_cache(X, names_filter=lambda x: 'ln1.hook_normalized' in x)

acts = torch.cat([cache[f"blocks.{i}.ln1.hook_normalized"] for i in range(4)], dim=-1)  # (batch, n_ctx, 4 * d_model)

acts_flattened = acts.view(-1, acts.shape[-1]).cpu().numpy()
X_beliefs_flattened = X_beliefs.view(-1, X_beliefs.shape[-1]).cpu().numpy()

regression = LinearRegression()
regression.fit(acts_flattened, X_beliefs_flattened)

result = regression.predict(acts_flattened)

result_pca = pca.transform(result)

X_belief_indices_flattened = X_belief_indices.view(-1).cpu().numpy()

colors = px.colors.qualitative.Light24 + px.colors.qualitative.Dark24 + px.colors.qualitative.Plotly

fig = make_subplots(rows=1, cols=2, specs=[[{'type': 'scatter3d'}, {'type': 'scatter3d'}]])

msp_beliefs_pca_data = msp_beliefs_pca
for b in range(len(msp_unique_beliefs)):
    relevant_indices = [i for i, x in enumerate(msp_beliefs_index) if x == b]  # Use list comprehension
    relevant_data = msp_beliefs_pca_data[relevant_indices]
    fig.add_trace(go.Scatter3d(x=relevant_data[:, 0],
                               y=relevant_data[:, 1],
                               z=relevant_data[:, 2],
                               mode='markers',
                               name=f'Belief {b}',
                               marker=dict(size=5, color=colors[b], opacity=1.0)),
                  row=1, col=1)

for b in range(len(msp_unique_beliefs)):
    relevant_indices = np.where(X_belief_indices_flattened == b)[0]
    relevant_data = result_pca[relevant_indices]
    centers_of_mass = np.mean(relevant_data, axis=0)
    fig.add_trace(go.Scatter3d(x=[centers_of_mass[0]],
                               y=[centers_of_mass[1]],
                               z=[centers_of_mass[2]],
                               mode='markers',
                               name=f'Belief {b}',
                               marker=dict(size=5, color=colors[b], opacity=1)),
                  row=1, col=2)
    fig.add_trace(go.Scatter3d(x=relevant_data[:, 0],
                               y=relevant_data[:, 1],
                               z=relevant_data[:, 2],
                               mode='markers',
                               name=f'Belief {b}',
                               marker=dict(size=1.5, color=colors[b], opacity=0.1)),
                  row=1, col=2)

fig.update_layout(title='3D PCA Projection of Beliefs',
                  scene=dict(xaxis_title='PCA 1', yaxis_title='PCA 2', zaxis_title='PCA 3'),
                  scene2=dict(xaxis_title='PCA 1', yaxis_title='PCA 2', zaxis_title='PCA 3'))
fig.show()

Unique beliefs: 36 out of 1723



Mean of empty slice.


invalid value encountered in divide



wandb: Network error (ConnectionError), entering retry loop.


In [303]:
def summarize_paths_and_beliefs(msp):
    """
    Summarizes the paths and beliefs from a mixed state presentation (msp).

    This function takes a mixed state presentation object, extracts the paths and belief states,
    and prints the total number of paths and beliefs. It then rounds the belief states to 5 decimal
    points and identifies unique beliefs, aggregating paths that lead to the same rounded belief state.
    Additionally, it returns two dictionaries: one mapping unique rounded belief states (as tuples) to lists of paths that lead to them,
    and another mapping individual paths (as strings) to their beliefs.

    Parameters:
    - msp (MixedStatePresentation): The mixed state presentation object containing paths and belief states.

    Returns:
    - tuple of dicts: A tuple containing two dictionaries. The first dictionary maps unique rounded belief states (as tuples) to lists of paths that lead to them.
                      The second dictionary maps individual paths (as strings) to their beliefs.
    """

    paths_and_beliefs = msp.paths_and_belief_states
    paths = [x[0] for x in paths_and_beliefs]
    beliefs = [x[1] for x in paths_and_beliefs]
    print(f"Paths: {len(paths)}, Beliefs: {len(beliefs)}")

    unique_beliefs_with_paths = {}
    path_to_belief = {}

    for path, belief in zip(paths, beliefs):
        rounded_belief = tuple(round(b, 5) for b in belief)
        path_str = ''.join([str(x) for x in path])
        path_to_belief[path_str] = rounded_belief  # Map path string to belief

        if rounded_belief not in unique_beliefs_with_paths:
            unique_beliefs_with_paths[rounded_belief] = [path_str]
        else:
            unique_beliefs_with_paths[rounded_belief].append(path_str)

    print(f"Unique beliefs with rounding to 5 decimal points: {len(unique_beliefs_with_paths)}")
    # for belief, paths in unique_beliefs_with_paths.items():
    #    print(f"Belief: {belief}, Paths: {len(paths)}")

    return unique_beliefs_with_paths, path_to_belief

msp_belief_paths_dict, msp_path_to_belief_dict = summarize_paths_and_beliefs(msp)

# now let's index each belief

msp_index_to_belief = {i: k for i, k in enumerate(msp_belief_paths_dict.keys())}
msp_belief_to_index = {k: i for i, k in enumerate(msp_belief_paths_dict.keys())}


_, cache = model.run_with_cache(X, names_filter=lambda x: 'ln1.hook_normalized' in x)

acts = torch.cat([cache[f"blocks.{i}.ln1.hook_normalized"] for i in range(4)], dim=-1) # (batch, n_ctx, 4 * d_model)



Paths: 1723, Beliefs: 1723
Unique beliefs with rounding to 5 decimal points: 36


In [304]:
acts.shape

torch.Size([436, 10, 512])

In [253]:
from epsilon_transformers.training.configs import RawModelConfig

model = model_config = RawModelConfig(
            d_vocab=config["d_vocab"],
            d_model=config["d_model"],
            n_ctx=config["n_ctx"],
            d_head=config["d_head"],
            n_head=config["n_heads"],
            d_mlp=config["d_model"] * 4,
            n_layers=config["n_layers"],
        ).to_hooked_transformer(seed=1337, device=device)

"""art = arts[1000]
art_name = art.name
print(f"Artifact: {art_name}")

# art.name is "name:version" so we split it to get the name and version
artifact_name, artifact_version = art_name.split(":")
epoch_num = int(artifact_name.split("_")[-1])

artifact = load_model_artifact(
    user_or_org,
    project_name,
    artifact_name,
    art.type,
    art.version,
    device=device,
)"""

save_point = 30000
model.load_state_dict(torch.load(f"{download_dir}/model_epoch_{save_point}.pt", map_location=device))


<All keys matched successfully>

In [254]:
from typing import Dict
import torch
from jaxtyping import Float, Int
from torch import Tensor

# generate inputs
X = [x for x in msp.paths if len(x) == config["n_ctx"]]
X: Float[Tensor, 'batch_size n_ctx'] = torch.tensor(X, dtype=torch.int).to(device)

_, cache = model.run_with_cache(X, names_filter=lambda x: 'ln1.hook_normalized' in x)

acts: Float[Tensor, 'batch_size n_ctx (4 * d_model)'] = torch.cat([cache[f"blocks.{i}.ln1.hook_normalized"] for i in range(4)], dim=-1)


In [255]:
gt_beliefs = np.array(list(msp_belief_paths_dict.keys())) # (n_unique_beliefs, d_beliefs)

# get 36 unique colors
# Generate a list of 36 unique colors
colors = (
    px.colors.qualitative.Light24
    + px.colors.qualitative.Dark24
    + px.colors.qualitative.Plotly
)

# run pca on beliefs
pca = PCA(n_components=3)
pca.fit(gt_beliefs)

gt_beliefs_pca = pca.transform(gt_beliefs)

# plot in 3d using plotly with colors
import plotly.graph_objects as go

fig = go.Figure()

for i, belief in enumerate(gt_beliefs_pca):
    fig.add_trace(go.Scatter3d(
        x=[belief[0]], 
        y=[belief[1]], 
        z=[belief[2]],
        marker=dict(color=colors[i % len(colors)]),
        mode='markers'
    ))

fig.show()


In [256]:
all_paths = []
gt_belief_label_indices = []
for full_ctx_input in X:
    paths = []
    indices = []
    for i in range(len(full_ctx_input)):
        path = list(full_ctx_input[:i+1].cpu().numpy())
        path_str = "".join([str(x) for x in path])
        paths.append(msp_path_to_belief_dict[path_str])
        indices.append(msp_belief_to_index[msp_path_to_belief_dict[path_str]])
    all_paths.append(paths)
    gt_belief_label_indices.append(indices)

gt_belief_labels = np.array(all_paths) # (batch, n_ctx, d_beliefs)
gt_belief_label_indices = np.array(gt_belief_label_indices)



In [263]:
from sklearn.linear_model import LinearRegression
import numpy as np
from sklearn.decomposition import PCA
import plotly.express as px


# Reshape beliefs_array to (batch * n_ctx, d_beliefs) where d_beliefs is 5 as per notebook_cell_output_1
gt_belief_labels_flatted = gt_belief_labels.reshape(-1, gt_belief_labels.shape[-1])
gt_belief_label_indices_flatted = gt_belief_label_indices.reshape(-1)

# Reshape acts to (batch * n_ctx, 4 * d_model)
acts_reshaped = acts.view(-1, acts.shape[-1]).cpu().numpy()

# Instantiate the Linear Regression model
model = LinearRegression()

# Fit the model
model.fit(acts_reshaped, gt_belief_labels_flatted)

# Now the model is trained to predict the 5d beliefs_array from the 4*d_model dimension acts
results = model.predict(acts_reshaped)

# Apply PCA to reduce the dimensionality for visualization
pca_results = pca.fit_transform(results) # (batch * n_ctx, 3)
import plotly.graph_objects as go
from plotly.subplots import make_subplots

fig = make_subplots(
    rows=1, cols=2,
    specs=[[{'type': 'scatter3d'}, {'type': 'scatter3d'}]]
)

# Adding gt_beliefs_pca to the left column (1st column)
for i, belief in enumerate(gt_beliefs):
    belief_index = msp_belief_to_index[tuple(belief)]
    fig.add_trace(go.Scatter3d(
        x=[gt_beliefs_pca[i][0]], 
        y=[gt_beliefs_pca[i][1]], 
        z=[gt_beliefs_pca[i][2]],
        marker=dict(color=colors[belief_index]),
        mode='markers'
    ), row=1, col=1)

    

# Adding PCA results to the right column (2nd column)
for belief_index in np.unique(gt_belief_label_indices_flatted):
    relevant_data = pca_results[gt_belief_label_indices_flatted == belief_index]
    print(relevant_data.shape)
    com = np.mean(relevant_data, axis=0)
    color = colors[belief_index]
    fig.add_trace(go.Scatter3d(x=[com[0]], y=[com[1]], z=[com[2]],
                               mode='markers', marker=dict(color=color, size=5),
                               name=f'Belief {belief_index}'), row=1, col=2)
    fig.add_trace(go.Scatter3d(x=relevant_data[:,0], y=relevant_data[:,1], z=relevant_data[:,2],
                               mode='markers', marker=dict(color=color, size=2, opacity=0.1),
                               name=f'Belief {belief_index}'), row=1, col=2)

fig.show()




(320, 3)
(444, 3)
(53, 3)
(64, 3)
(250, 3)
(320, 3)
(111, 3)
(250, 3)
(138, 3)
(111, 3)
(80, 3)
(82, 3)
(136, 3)
(138, 3)
(82, 3)
(53, 3)
(136, 3)
(111, 3)
(45, 3)
(111, 3)
(80, 3)
(45, 3)
(64, 3)
(218, 3)
(36, 3)
(53, 3)
(53, 3)
(80, 3)
(136, 3)
(45, 3)
(36, 3)
(45, 3)
(136, 3)
(80, 3)
(218, 3)


In [262]:
[round(x, 5) for x in belief]


[-0.0, -0.29068, 0.70711]

In [141]:
for belief in beliefs_array_reshaped:
    # deal with .4999999 etc. = 5 type issue
    belief = [round(x, 5) for x in belief]
    belief_to_color[tuple(belief)]



In [146]:
len(colors_for_data)

36