In [None]:
from circuits.utils import (
    get_feature,
    get_ae_bundle,
    AutoEncoderBundle,
    get_first_n_dataset_rows,
    collect_activations_batch,
)
from tqdm import tqdm
import pickle
import torch
from jaxtyping import Int, Float, jaxtyped
from beartype import beartype
from torch import Tensor
import einops

import importlib
import circuits.chess_utils as chess_utils
importlib.reload(chess_utils)
from circuits.chess_utils import config_lookup, get_num_classes

In [None]:

autoencoder_path = "../autoencoders/group0/ef=4_lr=1e-03_l1=1e-01_layer=5/"
batch_size = 25
n_inputs = 100
device = "cuda"
model_path = "../models/"

with open("data.pkl", "rb") as f:
    data = pickle.load(f)

for key in data:
    if key != "pgn_strings":
        data[key] = data[key].to(device)

In [None]:
ae_bundle = get_ae_bundle(autoencoder_path, device, data, batch_size, model_path)
pgn_strings = data["pgn_strings"]

features = torch.arange(0, ae_bundle.dictionary_size, device=device)
num_features = len(features)

assert len(pgn_strings) >= n_inputs
assert n_inputs % batch_size == 0

n_iters = n_inputs // batch_size
results = {}

custom_functions = [chess_utils.board_to_piece_state, chess_utils.board_to_pin_state]
thresholds = [0.0, 0.5]

In [None]:


# # Example setup (assuming activations_FBL is already defined)
# n_features, batch_size, context_length = activations_FBL.shape
# n_thresholds = len(thresholds)
# thresholds_tensor = torch.tensor(thresholds, device=device).view(1, 1, 1, -1)  # Reshape for broadcasting

# # Expand activations to match the thresholds tensor for broadcasting
# activations_expanded = repeat(activations_FBL, 'F B L -> F B L T', T=n_thresholds)

# # Vectorized thresholding
# active_indices_FBLT = activations_expanded > thresholds_tensor

# # Compute active counts for all thresholds using einops
# active_counts_FBT = reduce(active_indices_FBLT, 'F B L T -> F T', 'sum')
# off_counts_FBT = reduce(~active_indices_FBLT, 'F B L T -> F T', 'sum')

# # Now you have the counts of active and inactive indices for each feature at each threshold
# print(active_counts_FBT.shape)  # Shape: (n_features, n_thresholds)
# print(off_counts_FBT.shape)     # Shape: (n_features, n_thresholds)


In [None]:
# start = 0
# end = 25
# for custom_function in custom_functions:
#     # on_tracker_FTRRC = results[custom_function.__name__]['on']
#     # off_tracker_FTRRC = results[custom_function.__name__]['off']

#     boards_BLRRC = data[custom_function.__name__][start:end]
#     print(boards_BLRRC.shape)

#     # Force CUDA synchronization before measuring memory usage
#     torch.cuda.synchronize()
#     memory_before = torch.cuda.memory_allocated()

#     boards_TBLRRC = einops.repeat(boards_BLRRC, 'B L R1 R2 C -> T B L R1 R2 C', T=1000)
#     boards_TBLRRC += 0.0001  # Minor operation to force physical instantiation
#     print(boards_TBLRRC.shape)

#     # Force CUDA synchronization again to ensure all operations are complete
#     torch.cuda.synchronize()
#     memory_after = torch.cuda.memory_allocated()

#     print(f"Memory usage before: {memory_before} bytes")
#     print(f"Memory usage after: {memory_after} bytes")
#     print(f"Increase in memory: {memory_after - memory_before} bytes")

In [None]:
thresholds_T = torch.tensor(thresholds, device=device).view(-1, 1, 1,)  # Reshape for broadcasting
print(thresholds_T.shape)

In [None]:
# Dimension key (from https://medium.com/@NoamShazeer/shape-suffixes-good-coding-style-f836e72e24fd):
# F  = features
# M = feature minibatch size
# B = batch_size
# L = seq length (context length)
# T = thresholds
# R = rows (or cols)
# C = classes for one hot encoding

thresholds_T = torch.tensor(thresholds, device=device).view(-1, 1, 1)  # Reshape for broadcasting

feature_batch_size = 10
num_feature_iters = num_features // feature_batch_size

for custom_function in custom_functions:
    results[custom_function.__name__] = {}
    config = config_lookup[custom_function.__name__]
    num_classes = get_num_classes(config)

    results[custom_function.__name__] = {}
    on_tracker_FTRRC = torch.zeros(num_features, len(thresholds), config.num_rows, config.num_cols, num_classes)
    results[custom_function.__name__]['on'] = on_tracker_FTRRC
    results[custom_function.__name__]['off'] = on_tracker_FTRRC.clone()

    on_counter_FT = torch.zeros(num_features, len(thresholds))
    results[custom_function.__name__]['on_count'] = on_counter_FT
    results[custom_function.__name__]['off_count'] = on_counter_FT.clone()

for i in tqdm(range(n_iters)):
    start = i * batch_size
    end = (i + 1) * batch_size
    inputs_BL = data['pgn_strings'][start:end]

    activations_FBL, encoded_inputs = collect_activations_batch(
        ae_bundle.model, ae_bundle.submodule, ae_bundle.context_length, inputs_BL, ae_bundle.ae, features
    ) # activations: (features, batch_size, context_length)

    # Iterating over each feature to reduce memory usage
    for feature in range(num_features):
        f_start = feature * feature_batch_size
        f_end = min((feature + 1) * feature_batch_size, num_features)

        activations_BL = activations_FBL[feature] #NOTE: Now F == feature_batch_size
        # Maybe that's stupid and inconsistent and I should use a new letter for annotations
        # I'll roll with it for now


        # Expand activations to match the thresholds tensor for broadcasting
        active_indices_TBL = activations_BL > thresholds_T

        active_counts_T = einops.reduce(active_indices_TBL, 'T B L -> T', 'sum')
        off_counts_T = einops.reduce(~active_indices_TBL, 'T B L -> T', 'sum')

        for custom_function in custom_functions:
            on_tracker_FTRRC = results[custom_function.__name__]['on']
            off_tracker_FTRRC = results[custom_function.__name__]['off']

            boards_BLRRC = data[custom_function.__name__][start:end]
            boards_TBLRRC = einops.repeat(boards_BLRRC, 'B L R1 R2 C -> T B L R1 R2 C', T=len(thresholds))


            
            active_boards_sum = einops.reduce(boards_TBLRRC * active_indices_TBL[:, :, :, None, None, None],
                                'T B L R1 R2 C -> T R1 R2 C', 'sum')


    if i >= 10:
        break