In [None]:
# %%

import torch as t
import numpy as np
import glob
import os
import re

In [None]:
# %%

def load_tensors_into_dict(directory):
    # Dictionary to store the loaded tensors
    tensor_dict = {}
    
    # Pattern to match the filenames
    pattern = r'layer_(\d+)__feat_(\d+)__num_batches_\d+__batch_size_\d+\.pth'
    
    # Find all .pth files in the directory
    for filepath in glob.glob(os.path.join(directory, '*.pth')):
        # Extract layer_idx and feat_idx from the filename
        match = re.search(pattern, os.path.basename(filepath))
        if match:
            layer_idx, feat_idx = match.groups()
            
            # Create the key
            key = f'layer_{layer_idx}_feat_{feat_idx}'
            
            # Load the tensor
            tensor = t.load(filepath)
            
            # Add to the dictionary
            tensor_dict[key] = tensor
    
    return tensor_dict

In [None]:
# %%

directory = 'artefacts/ablations'
result = load_tensors_into_dict(directory)

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

In [None]:
# %%

import torch as t
import numpy as np
import glob
import os
import re

if t.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if t.cuda.is_available() else "cpu"
print(f"Loaded {device=}")

Loaded device='cpu'


In [None]:
# %%

def load_tensors_into_dict(directory):
    # Dictionary to store the loaded tensors
    tensor_dict = {}
    
    # Pattern to match the filenames
    pattern = r'layer_(\d+)__feat_(\d+)__num_batches_\d+__batch_size_\d+\.pth'
    
    # Find all .pth files in the directory
    for filepath in glob.glob(os.path.join(directory, '*.pth')):
        # Extract layer_idx and feat_idx from the filename
        match = re.search(pattern, os.path.basename(filepath))
        if match:
            layer_idx, feat_idx = match.groups()
            
            # Create the key
            key = f'layer_{layer_idx}_feat_{feat_idx}'
            
            # Load the tensor
            tensor = t.load(filepath, map_location=t.device(device))
            
            # Add to the dictionary
            tensor_dict[key] = tensor
    
    return tensor_dict

In [None]:
# %%

directory = 'artefacts/ablations'
result = load_tensors_into_dict(directory)

KeyboardInterrupt: 

In [None]:
# %%

def load_tensors_into_dict(directory):
    # Dictionary to store the loaded tensors
    tensor_dict = {}
    
    # Pattern to match the filenames
    pattern = r'layer_(\d+)__feat_(\d+)__num_batches_\d+__batch_size_\d+\.pth'
    
    # Find all .pth files in the directory
    for filepath in glob.glob(os.path.join(directory, '*.pth')):
        # Extract layer_idx and feat_idx from the filename
        match = re.search(pattern, os.path.basename(filepath))
        if match:
            layer_idx, feat_idx = match.groups()
            
            # Create the key
            key = f'layer_{layer_idx}_feat_{feat_idx}'
            
            # Load the tensor
            tensor = t.load(filepath, map_location=t.device(device))
            
            # Add to the dictionary
            tensor_dict[key] = tensor
        # TODO: remove
        break
    
    return tensor_dict

In [None]:
# %%

directory = 'artefacts/ablations'
result = load_tensors_into_dict(directory)

In [None]:
result

{'layer_0_feat_4580': {'sae_errors': tensor([[[ 3.6886e-04, -9.5102e-04,  1.5796e-03,  ...,  1.5465e-03,
             3.3407e-04,  1.0646e-03],
           [-5.1186e-03,  1.8308e-02,  5.6459e-03,  ..., -2.9386e-03,
             5.5855e-03,  4.7449e-03],
           [ 1.3691e-03, -3.9159e-03, -2.5745e-04,  ...,  1.3716e-03,
             1.3642e-03,  1.4481e-03],
           ...,
           [ 2.6384e-04, -1.4210e-03,  4.6958e-03,  ...,  2.7325e-03,
             2.6579e-04,  3.6625e-03],
           [ 9.2445e-04,  3.0203e-03,  4.8355e-03,  ...,  1.2169e-04,
             1.1595e-03, -1.1107e-03],
           [-5.7736e-02, -1.0445e-01,  6.6852e-03,  ..., -2.1304e-02,
            -7.3209e-02, -4.2357e-02]],
  
          [[ 3.6886e-04, -9.5102e-04,  1.5796e-03,  ...,  1.5465e-03,
             3.3407e-04,  1.0646e-03],
           [ 1.6750e-03, -2.3648e-03,  1.4567e-03,  ...,  4.4935e-03,
             4.8193e-03, -1.1140e-03],
           [ 1.9542e-03, -2.1624e-03,  1.0924e-02,  ..., -4.6064e-04,
   

In [None]:
result.keys

<function dict.keys>

In [None]:
result.keys()

dict_keys(['layer_0_feat_4580'])

In [None]:
result['layer_0_feat_4580']

{'sae_errors': tensor([[[ 3.6886e-04, -9.5102e-04,  1.5796e-03,  ...,  1.5465e-03,
            3.3407e-04,  1.0646e-03],
          [-5.1186e-03,  1.8308e-02,  5.6459e-03,  ..., -2.9386e-03,
            5.5855e-03,  4.7449e-03],
          [ 1.3691e-03, -3.9159e-03, -2.5745e-04,  ...,  1.3716e-03,
            1.3642e-03,  1.4481e-03],
          ...,
          [ 2.6384e-04, -1.4210e-03,  4.6958e-03,  ...,  2.7325e-03,
            2.6579e-04,  3.6625e-03],
          [ 9.2445e-04,  3.0203e-03,  4.8355e-03,  ...,  1.2169e-04,
            1.1595e-03, -1.1107e-03],
          [-5.7736e-02, -1.0445e-01,  6.6852e-03,  ..., -2.1304e-02,
           -7.3209e-02, -4.2357e-02]],
 
         [[ 3.6886e-04, -9.5102e-04,  1.5796e-03,  ...,  1.5465e-03,
            3.3407e-04,  1.0646e-03],
          [ 1.6750e-03, -2.3648e-03,  1.4567e-03,  ...,  4.4935e-03,
            4.8193e-03, -1.1140e-03],
          [ 1.9542e-03, -2.1624e-03,  1.0924e-02,  ..., -4.6064e-04,
           -9.1030e-04,  1.9664e-03],
     

In [None]:
data = result['layer_0_feat_4580']

In [None]:
data

{'sae_errors': tensor([[[ 3.6886e-04, -9.5102e-04,  1.5796e-03,  ...,  1.5465e-03,
            3.3407e-04,  1.0646e-03],
          [-5.1186e-03,  1.8308e-02,  5.6459e-03,  ..., -2.9386e-03,
            5.5855e-03,  4.7449e-03],
          [ 1.3691e-03, -3.9159e-03, -2.5745e-04,  ...,  1.3716e-03,
            1.3642e-03,  1.4481e-03],
          ...,
          [ 2.6384e-04, -1.4210e-03,  4.6958e-03,  ...,  2.7325e-03,
            2.6579e-04,  3.6625e-03],
          [ 9.2445e-04,  3.0203e-03,  4.8355e-03,  ...,  1.2169e-04,
            1.1595e-03, -1.1107e-03],
          [-5.7736e-02, -1.0445e-01,  6.6852e-03,  ..., -2.1304e-02,
           -7.3209e-02, -4.2357e-02]],
 
         [[ 3.6886e-04, -9.5102e-04,  1.5796e-03,  ...,  1.5465e-03,
            3.3407e-04,  1.0646e-03],
          [ 1.6750e-03, -2.3648e-03,  1.4567e-03,  ...,  4.4935e-03,
            4.8193e-03, -1.1140e-03],
          [ 1.9542e-03, -2.1624e-03,  1.0924e-02,  ..., -4.6064e-04,
           -9.1030e-04,  1.9664e-03],
     

In [None]:
data.masked_mse

AttributeError: 'dict' object has no attribute 'masked_mse'

In [None]:
data['masked_mse']

tensor([5.7281e-12, 1.1155e-11, 2.2949e-11,  ..., 1.8993e-11, 1.9393e-11,
        7.6132e-12])

In [None]:
data['masked_mse'].shape

torch.Size([24576])

In [None]:
data['masked_means'].shape

torch.Size([24576])

In [None]:
data['masked_means']

tensor([ 2.3913e-07, -5.3373e-07,         nan,  ...,         nan,
                nan,         nan])

In [None]:
data['masked_means'].isna()

AttributeError: 'Tensor' object has no attribute 'isna'

In [None]:
t.isnan(data['masked_means']).sum().item()

20235

In [None]:
# %%

pearson_corr_filename = f"artefacts/similarity_measures/pearson_correlation/res_jb_sae_feature_similarity_pearson_correlation_1M_0.0_0.1.npz"
with open(pearson_corr_filename, 'rb') as data:
    interaction_data = np.load(data)['arr_0']

In [None]:
# %%

import torch as t
import numpy as np
import glob
import os
import re


if t.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if t.cuda.is_available() else "cpu"
d_sae = 24576
print(f"Loaded {device=}, {d_sae=}")
# %%

Loaded device='cpu', d_sae=24576


In [None]:
# %%

top_corr_flat_idx = np.load("artefacts/sampled_interaction_measures/pearson_correlation/count_1000.npz")['arr_0']
# for all high-value pearson correlations, loading the first feature from each pair in a (num_layers, num_top_features) matrix
corr_prev_layer_feat_idxes, corr_next_layer_feat_idxes = (
    np.array([
        np.unravel_index(top_corr_flat_idx[layer_idx], shape=(d_sae, d_sae))[ordering_idx]
        for layer_idx in range(top_corr_flat_idx.shape[0])
    ])
    for ordering_idx in range(2)
)

In [None]:
corr_prev_layer_feat_idxes

array([[14525,  9441,  8718, ...,  6073, 10024,  5580],
       [22140, 23465, 17663, ...,  2064, 18409,  6703],
       [  737, 22060,  6861, ..., 11177,   938,  5063],
       ...,
       [ 6955,  8016, 14557, ..., 14316, 17437,   221],
       [ 4612, 12311, 11774, ..., 22370, 24023,  2251],
       [23726, 15936, 13239, ..., 11804,  5984,  7942]])

In [None]:
corr_prev_layer_feat_idxes.shape

(11, 1000)

In [None]:
corr_prev_layer_feat_idxes[0]

array([14525,  9441,  8718,  7609,  1555,  2248, 24067,  6076, 21572,
       18712, 23330, 10256, 10314, 15015,  8406, 18743,   978, 24361,
        1591,  5701,  4270, 20527,  7447,  8949, 20692,  2529,  9724,
        4202, 13226, 18420, 17972, 19029, 13489,  1724, 19665, 15420,
       21046, 15678, 23551,  7464, 12288,  3898, 14319,  4580,  8023,
       14440, 22626, 17819, 23623, 24491, 21745,  4051, 15506,   544,
        4734, 23716, 23969, 14599,   132, 19532, 11155, 10614, 16989,
       12105,  2282,  5904,  8104, 17174, 20413, 12309, 12984,  3707,
       18133,  3996,  4460,  7923, 24475, 24353,  6500, 13216, 21872,
       12327,  7558,  9848, 19677, 14927, 17974, 10953, 16991, 16652,
       10427, 23496,  3725,  8018,  1505, 23190,   715, 22597,  2469,
       24362, 12984, 10297, 23190,  2439,  1162, 14469, 16687,  7011,
        4248, 11045,  6040,  3078, 23224,  1737, 16408,  3317, 16092,
       21442, 20155, 12240, 16030, 22216,  1839, 13862, 13677,  6872,
        5880, 10254,

In [None]:
corr_prev_layer_feat_idxes[0][0]

14525

In [None]:
corr_next_layer_feat_idxes[0][0]

11914

In [None]:
interaction_data[0][14525][11914]

1.0000046

In [None]:
result

{'layer_0_feat_4580': {'sae_errors': tensor([[[ 3.6886e-04, -9.5102e-04,  1.5796e-03,  ...,  1.5465e-03,
             3.3407e-04,  1.0646e-03],
           [-5.1186e-03,  1.8308e-02,  5.6459e-03,  ..., -2.9386e-03,
             5.5855e-03,  4.7449e-03],
           [ 1.3691e-03, -3.9159e-03, -2.5745e-04,  ...,  1.3716e-03,
             1.3642e-03,  1.4481e-03],
           ...,
           [ 2.6384e-04, -1.4210e-03,  4.6958e-03,  ...,  2.7325e-03,
             2.6579e-04,  3.6625e-03],
           [ 9.2445e-04,  3.0203e-03,  4.8355e-03,  ...,  1.2169e-04,
             1.1595e-03, -1.1107e-03],
           [-5.7736e-02, -1.0445e-01,  6.6852e-03,  ..., -2.1304e-02,
            -7.3209e-02, -4.2357e-02]],
  
          [[ 3.6886e-04, -9.5102e-04,  1.5796e-03,  ...,  1.5465e-03,
             3.3407e-04,  1.0646e-03],
           [ 1.6750e-03, -2.3648e-03,  1.4567e-03,  ...,  4.4935e-03,
             4.8193e-03, -1.1140e-03],
           [ 1.9542e-03, -2.1624e-03,  1.0924e-02,  ..., -4.6064e-04,
   

In [None]:
result['masked_stdevs']

KeyError: 'masked_stdevs'

In [None]:
result.keys()

dict_keys(['layer_0_feat_4580'])