In [None]:
import json
import matplotlib.pyplot as plt
log_name = "group1_results.json"

with open(log_name, "r") as f:
    data = json.load(f)

In results logged from the `sae_stats_collection.py` script, we have a hyperparameters key, and in this case, 32 separate autoencoders that we evaluated.

In [None]:
print(f"Total results: {len(data)}")
print()

example_key = ""
for key in data.keys():
    # print(key)
    if "syntax" in data[key].keys():
        example_key = key


print(f"Example key: {example_key}")
print("\nTop level keys:")
for key in data[example_key].keys():
    print(key)

print("\nSyntax keys:")
for key in data[example_key]["syntax"].keys():
    print(key)

print("\nBoard keys:")
for key in data[example_key]["board"].keys():
    print(key)

print("\nExample statistics:")
print(data[example_key]["syntax"]["find_num_indices"])
print(data[example_key]["board"]["board_to_piece_state"])

print("\nEval results:")
print(data[example_key]["eval_results"])

We have 3 syntax filters: find_num_indices, find_spaces_indices, and find_dots_indices. We also have four board state filters: board_to_piece_state, board_to_threat_state, board_to_check_state, and board_to_pin_state. I added the below cell to find the average board state filter match count and average syntax filter match count. A simple average probably isn't the best way of doing this, this is just the first thing I tried.

In [None]:
syntax_keys = {}
for func in data[example_key]["syntax"].keys():
    for var in data[example_key]["syntax"][func].keys():
        syntax_keys[var] = 0.0
    break

board_keys = {}

for func in data[example_key]["board"].keys():
    for var in data[example_key]["board"][func].keys():

        if type(data[example_key]["board"][func][var]) == list:
            continue
        if type(data[example_key]["board"][func][var]) == dict:
            continue

        board_keys[var] = 0.0
    break

print("Syntax keys: ", syntax_keys)
print("Board keys: ", board_keys)

for key in data.keys():
    if "syntax" not in data[key].keys():
        continue
    
    syntax_dict = syntax_keys.copy()
    for func in data[key]["syntax"].keys():
        for var in data[key]["syntax"][func].keys():
            if var not in syntax_dict.keys():
                continue
            syntax_dict[var] += data[key]["syntax"][func][var]
    for var in syntax_dict.keys():
        syntax_dict[var] /= len(data[key]["syntax"].keys())

    board_dict = board_keys.copy()
    for func in data[key]["board"].keys():
        for var in data[key]["board"][func].keys():
            if var not in board_dict.keys():
                continue
            board_dict[var] += data[key]["board"][func][var]
    for var in board_dict.keys():
        board_dict[var] /= len(data[key]["board"].keys())

    data[key]["syntax"]["syntax_average"] = syntax_dict
    data[key]["board"]["board_average"] = board_dict

        

All 32 of the sparse autoencoders share the same keys and nested key structure. For example, there's ['syntax']['find_dots_indices']['syntax_match_idx_count']. This below cell creates `vars`, where we basically collapse the nesting, so it would now be `find_dots_indices_syntax_match_idx_count`. This is so we can easily look for correlations.

In [None]:
vars = {}

for key in data.keys():
    if "syntax" not in data[key].keys():
        continue

    vars[key] = {}

    for func in data[key]["syntax"].keys():
        for var in data[key]["syntax"][func].keys():
            vars[key][f"{func + '_' + var}"] = data[key]["syntax"][func][var]
    for func in data[key]["board"].keys():
        for var in data[key]["board"][func].keys():
            if type(data[key]["board"][func][var]) == int or type(data[key]["board"][func][var]) == float:
                vars[key][f"{func + '_' + var}"] = data[key]["board"][func][var]
    for var in data[key]["eval_results"].keys():
        vars[key][f"{var}"] = data[key]["eval_results"][var]

In [None]:
print(f"We have {len(vars[example_key])} variables for each example\n")

for key in vars[example_key].keys():
    print(key)

We have 43 unique variables per SAE. This makes plots crowded. We make `simple_vars`, which filters out most of the above keys.

In [None]:
simple_vars = {}

for key in vars.keys():
    simple_vars[key] = {}
    for var in vars[key].keys():
        if "syntax" in var or "board" in var:

            if "nonzero" in var or "dim_count" in var:
                continue

            if "syntax_match" not in var and "pattern_match" not in var:
                continue

            if "syntax_average" in var or "board_average" in var:
                simple_vars[key][var] = vars[key][var]
        elif "find" in var:
            continue
        elif "loss_original" in var:
            continue
        else:
            simple_vars[key][var] = vars[key][var]

In [None]:
print(f"We now have {len(simple_vars[example_key])} variables for each example\n")

for key, value in simple_vars[example_key].items():
    print(key, value)

The next two cells find the SAE in a variety of configurations, like expansion factor 4 or layer 0, with the maximum and minimum scores for a key of interest.

In [None]:
max_matching_dict = {}

max_matching_dict["_"] = {"name": "", "count": 0} # any
max_matching_dict["ef16"] = {"name": "", "count": 0}
max_matching_dict["ef8"] = {"name": "", "count": 0}
max_matching_dict["ef4"] = {"name": "", "count": 0}
max_matching_dict["layer=0"] = {"name": "", "count": 0}
max_matching_dict["layer=5"] = {"name": "", "count": 0}

key_of_interest = "board_average_pattern_match_count"
# key_of_interest = "syntax_average_syntax_match_idx_count"

for key in simple_vars.keys():

    for name in max_matching_dict.keys():
        if name in key:
            if simple_vars[key][key_of_interest] > max_matching_dict[name]['count']:
                max_matching_dict[name]['count'] = simple_vars[key][key_of_interest]
                max_matching_dict[name]['name'] = key
    
for name in max_matching_dict.keys():
    print(name, max_matching_dict[name])

In [None]:
min_matching_dict = {}

min_matching_dict["_"] = {"name": "", "count": 1e6} # any
min_matching_dict["ef16"] = {"name": "", "count": 1e6}
min_matching_dict["ef8"] = {"name": "", "count": 1e6}
min_matching_dict["ef4"] = {"name": "", "count": 1e6}
min_matching_dict["layer=0"] = {"name": "", "count": 1e6}
min_matching_dict["layer=5"] = {"name": "", "count": 1e6}

key_of_interest = "board_average_pattern_match_count"
# key_of_interest = "syntax_average_syntax_match_idx_count"

for key in simple_vars.keys():

    for name in min_matching_dict.keys():
        if name in key:
            if simple_vars[key][key_of_interest] < min_matching_dict[name]['count']:
                min_matching_dict[name]['count'] = simple_vars[key][key_of_interest]
                min_matching_dict[name]['name'] = key
    
for name in min_matching_dict.keys():
    print(name, min_matching_dict[name])

Here we print out stats of any autoencoder we are interested in.

In [None]:
ae_of_interest = 'autoencoders/ef4_20k_resample/ef=4_lr=1e-03_l1=1e-01_layer=5/'
# ae_of_interest = 'autoencoders/ef8/ef=8_lr=1e-04_l1=1e-04_layer=5/'
key_of_interest = "pattern_match_count"

print("Average board match count: ", data[ae_of_interest]['board']['board_average'][key_of_interest])
print("Average syntax match count: ", data[ae_of_interest]['syntax']['syntax_average']['syntax_match_idx_count'])

for key, value in data[ae_of_interest]['eval_results'].items():
    print(key, value)

And now we just have a bunch of plots of our data.

In [None]:
import pandas as pd

# Step 1: Create DataFrame from vars dictionary
df = pd.DataFrame.from_dict(simple_vars, orient='index')

# Step 2: Calculate the correlation matrix
correlation_matrix = df.corr()

# Display the correlation matrix
print(correlation_matrix)

# Optional Step 3: Visualize the correlation matrix using seaborn
import seaborn as sns
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 8))
sns.heatmap(correlation_matrix, annot=False, fmt=".2f", cmap='coolwarm')
plt.title('Correlation Matrix of Reported Results')
plt.show()


In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# Step 1: Create DataFrame from vars dictionary
df = pd.DataFrame.from_dict(simple_vars, orient='index')

# Step 2: Calculate the correlation matrix
correlation_matrix = df.corr()

# Selecting specific variables for the x-axis
selected_vars = ['syntax_average_syntax_match_idx_count', 'board_average_pattern_match_count']

# Filter the correlation matrix
# We use `.loc` to specify that we want all rows (all variables) but only columns for the two selected variables
filtered_correlation_matrix = correlation_matrix.loc[:, selected_vars]

# Step 3: Visualize the correlation matrix
plt.figure(figsize=(5, 10))  # Adjusted the figure size for better visualization
sns.heatmap(filtered_correlation_matrix, annot=True, fmt=".2f", cmap='coolwarm', yticklabels=correlation_matrix.index)
plt.title('Correlation Matrix for Selected Variables')
plt.show()


In [None]:
x = []
y = []

key1 = "l2_loss"
key2 = "pattern_match_count"

for key in data:
    if "syntax" not in data[key].keys():
        continue
    x.append(data[key]["eval_results"][key1])
    y.append(data[key]["board"]["board_to_piece_state"][key2])

plt.scatter(x, y)
plt.title(f'{key1} vs. {key2}')
plt.xlabel(key1)
plt.ylabel(key2)
plt.grid(True)
plt.show()

In [None]:
x = []
y = []

key1 = "loss_reconstructed"
key2 = "pattern_match_count"

for key in data:
    if "syntax" not in data[key].keys():
        continue
    x.append(data[key]["eval_results"][key1])
    y.append(data[key]["board"]["board_to_piece_state"][key2])

plt.scatter(x, y)
plt.title(f'{key1} vs. {key2}')
plt.xlabel(key1)
plt.ylabel(key2)
plt.grid(True)
plt.show()

In [None]:
x = []
y = []

key1 = "syntax_match_idx_count"
key2 = "pattern_match_count"

for key in data:
    if "syntax" not in data[key].keys():
        continue
    x.append(data[key]["syntax"]['find_num_indices'][key1])
    y.append(data[key]["board"]["board_to_piece_state"][key2])

plt.scatter(x, y)
plt.title(f'{key1} vs. {key2}')
plt.xlabel(key1)
plt.ylabel(key2)
plt.grid(True)
plt.show()

In [None]:
x = []
y = []

key1 = "l0"
key2 = "pattern_match_count"

for key in data:
    if "syntax" not in data[key].keys():
        continue
    x.append(data[key]["eval_results"][key1])
    y.append(data[key]["board"]["board_to_piece_state"][key2])

plt.scatter(x, y)
plt.title(f'{key1} vs. {key2}')
plt.xlabel(key1)
plt.ylabel(key2)
plt.grid(True)
plt.show()

In [None]:
x = []
y = []

key1 = "l0"
key2 = "syntax_match_idx_count"

for key in data:
    if "syntax" not in data[key].keys():
        continue
    x.append(data[key]["eval_results"][key1])
    y.append(data[key]["syntax"]["find_num_indices"][key2])

plt.scatter(x, y)
plt.title(f'{key1} vs. {key2}')
plt.xlabel(key1)
plt.ylabel(key2)
plt.grid(True)
plt.show()