In [1]:
import torch
from sparse_tensor import SparseTensor
from sparsify.sparsify import Sae
import einops

In [2]:
sae = Sae.load_from_hub("EleutherAI/sae-pythia-410m-65k", hookpoint="layers.22.mlp")
sae = sae.to('cuda:0')

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Dropping extra args {'signed': False}


In [3]:
ip = torch.rand((512,1024), device='cuda:0')

In [4]:
sae.cfg

SparseCoderConfig(activation='topk', expansion_factor=32, normalize_decoder=True, num_latents=65536, k=32, multi_topk=False, skip_connection=False, transcode=False)

In [5]:
with torch.no_grad():
    result = sae.forward(ip)

In [6]:
latent_acts = einops.rearrange(result.latent_acts, "(b s) e -> b s e", b = 16)
latent_indices = einops.rearrange(result.latent_indices, "(b s) e -> b s e", b = 16)
print(latent_indices.shape)
st = SparseTensor(latent_acts, latent_indices, (latent_acts.shape[0],latent_acts.shape[1],sae.num_latents), 2)

torch.Size([16, 32, 32])


In [7]:
full = SparseTensor.decompress(st)

In [8]:
full.shape

torch.Size([16, 32, 65536])

In [9]:
check = (full != 0)
print(check.sum(dim=-1).max())
print(check.sum(dim=-1).min())

tensor(32, device='cuda:0')
tensor(32, device='cuda:0')


## Testing Pythia SAE

In [10]:
import torch
from pythia_sae import SAE

In [11]:
sae_handle = SAE(device='cuda:0', sae_layer_template="layers.<layer>.mlp")
sae_handle.load_many("EleutherAI/sae-pythia-410m-65k", [22])

Loading SAEs


Fetching 49 files:   0%|          | 0/49 [00:00<?, ?it/s]

.gitattributes:   0%|          | 0.00/1.52k [00:00<?, ?B/s]

sae.safetensors:   0%|          | 0.00/537M [00:00<?, ?B/s]

sae.safetensors:   0%|          | 0.00/537M [00:00<?, ?B/s]

sae.safetensors:   0%|          | 0.00/537M [00:00<?, ?B/s]

sae.safetensors:   0%|          | 0.00/537M [00:00<?, ?B/s]

sae.safetensors:   0%|          | 0.00/537M [00:00<?, ?B/s]

sae.safetensors:   0%|          | 0.00/537M [00:00<?, ?B/s]

sae.safetensors:   0%|          | 0.00/537M [00:00<?, ?B/s]

sae.safetensors:   0%|          | 0.00/537M [00:00<?, ?B/s]

sae.safetensors:   0%|          | 0.00/537M [00:00<?, ?B/s]

sae.safetensors:   0%|          | 0.00/537M [00:00<?, ?B/s]

sae.safetensors:   0%|          | 0.00/537M [00:00<?, ?B/s]

sae.safetensors:   0%|          | 0.00/537M [00:00<?, ?B/s]

sae.safetensors:   0%|          | 0.00/537M [00:00<?, ?B/s]

sae.safetensors:   0%|          | 0.00/537M [00:00<?, ?B/s]

sae.safetensors:   0%|          | 0.00/537M [00:00<?, ?B/s]

sae.safetensors:   0%|          | 0.00/537M [00:00<?, ?B/s]

sae.safetensors:   0%|          | 0.00/537M [00:00<?, ?B/s]

sae.safetensors:   0%|          | 0.00/537M [00:00<?, ?B/s]

sae.safetensors:   0%|          | 0.00/537M [00:00<?, ?B/s]

sae.safetensors:   0%|          | 0.00/537M [00:00<?, ?B/s]

sae.safetensors:   0%|          | 0.00/537M [00:00<?, ?B/s]

sae.safetensors:   0%|          | 0.00/537M [00:00<?, ?B/s]

sae.safetensors:   0%|          | 0.00/537M [00:00<?, ?B/s]

Dropping extra args {'signed': False}
Dropping extra args {'signed': False}
Dropping extra args {'signed': False}
Dropping extra args {'signed': False}
Dropping extra args {'signed': False}
Dropping extra args {'signed': False}
Dropping extra args {'signed': False}
Dropping extra args {'signed': False}
Dropping extra args {'signed': False}
Dropping extra args {'signed': False}
Dropping extra args {'signed': False}
Dropping extra args {'signed': False}
Dropping extra args {'signed': False}
Dropping extra args {'signed': False}
Dropping extra args {'signed': False}
Dropping extra args {'signed': False}
Dropping extra args {'signed': False}
Dropping extra args {'signed': False}
Dropping extra args {'signed': False}
Dropping extra args {'signed': False}
Dropping extra args {'signed': False}
Dropping extra args {'signed': False}
Dropping extra args {'signed': False}
Dropping extra args {'signed': False}


Finished Loading SAEs
Layer 22 laoded


In [12]:
sae_out, comp_sae_acts = sae_handle.compute_activations(torch.rand((16,512,1024),device='cuda:0'),layer=22)

In [13]:
sae_out.shape

torch.Size([16, 512, 1024])

In [16]:
full_t = comp_sae_acts.decompress()

In [17]:
full_t.shape

torch.Size([16, 512, 65536])

In [19]:
torch.save([comp_sae_acts.serialize()],'test.pt')

In [23]:
state = torch.load('test.pt',map_location=torch.device('cuda:0'),weights_only=False)

In [25]:
state[0].keys()

dict_keys(['index_tensor', 'activation_tensor', 'uncompressed_dim_idx', 'uncompressed_shape'])

In [28]:
full_t_prime = SparseTensor.deserialize(state[0]).decompress()

In [30]:
full_t_prime.shape

torch.Size([16, 512, 65536])

In [32]:
#print(full_t == full_t_prime)
print(torch.allclose(full_t, full_t_prime))

True


In [33]:
sae_out.shape

torch.Size([16, 512, 1024])

## PyTorch sparse tensor

In [35]:
c = full_t_prime.to_sparse()

In [36]:
c.shape

torch.Size([16, 512, 65536])

In [37]:
c

tensor(indices=tensor([[    0,     0,     0,  ...,    15,    15,    15],
                       [    0,     0,     0,  ...,   511,   511,   511],
                       [ 4353,  6585, 10602,  ..., 52517, 57425, 58835]]),
       values=tensor([8.6875, 4.2500, 5.3438,  ..., 5.7188, 5.0000, 4.6875]),
       device='cuda:0', size=(16, 512, 65536), nnz=262144, dtype=torch.bfloat16,
       layout=torch.sparse_coo)

In [38]:
full_t_prime.shape

torch.Size([16, 512, 65536])

In [39]:
torch.save(full_t_prime,"test1.pt")

In [40]:
torch.save(c,"test2.pt")

In [41]:
a = torch.load('test1.pt')

  a = torch.load('test1.pt')


In [42]:
b = torch.load('test2.pt')

  b = torch.load('test2.pt')


In [43]:
b

tensor(indices=tensor([[    0,     0,     0,  ...,    15,    15,    15],
                       [    0,     0,     0,  ...,   511,   511,   511],
                       [ 4353,  6585, 10602,  ..., 52517, 57425, 58835]]),
       values=tensor([8.6875, 4.2500, 5.3438,  ..., 5.7188, 5.0000, 4.6875]),
       device='cuda:0', size=(16, 512, 65536), nnz=262144, dtype=torch.bfloat16,
       layout=torch.sparse_coo)

In [44]:
a

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0., 

In [45]:
b.shape

torch.Size([16, 512, 65536])

In [46]:
a.shape

torch.Size([16, 512, 65536])

In [47]:
torch.allclose(a,b)

NotImplementedError: Could not run 'aten::eq.Tensor' with arguments from the 'SparseCUDA' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::eq.Tensor' is only available for these backends: [CPU, CUDA, HIP, MPS, IPU, XPU, HPU, VE, MTIA, PrivateUse1, PrivateUse2, PrivateUse3, Meta, FPGA, MAIA, Vulkan, Metal, QuantizedCPU, QuantizedCUDA, QuantizedHIP, QuantizedMPS, QuantizedIPU, QuantizedXPU, QuantizedHPU, QuantizedVE, QuantizedMTIA, QuantizedPrivateUse1, QuantizedPrivateUse2, QuantizedPrivateUse3, QuantizedMeta, CustomRNGKeyId, MkldnnCPU, SparseCsrCPU, SparseCsrCUDA, SparseCsrHIP, SparseCsrMPS, SparseCsrIPU, SparseCsrXPU, SparseCsrHPU, SparseCsrVE, SparseCsrMTIA, SparseCsrPrivateUse1, SparseCsrPrivateUse2, SparseCsrPrivateUse3, SparseCsrMeta, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMTIA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradMeta, AutogradNestedTensor, Tracer, AutocastCPU, AutocastXPU, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

Undefined: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
CPU: registered at aten/src/ATen/RegisterCPU.cpp:30455 [kernel]
CUDA: registered at aten/src/ATen/RegisterCUDA.cpp:44681 [kernel]
HIP: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
MPS: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
IPU: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
XPU: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
HPU: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
VE: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
MTIA: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
PrivateUse1: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
PrivateUse2: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
PrivateUse3: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
Meta: registered at /dev/null:241 [kernel]
FPGA: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
MAIA: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
Vulkan: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
Metal: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
QuantizedCPU: registered at aten/src/ATen/RegisterQuantizedCPU.cpp:951 [kernel]
QuantizedCUDA: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
QuantizedHIP: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
QuantizedMPS: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
QuantizedIPU: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
QuantizedXPU: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
QuantizedHPU: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
QuantizedVE: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
QuantizedMTIA: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
QuantizedPrivateUse1: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
QuantizedPrivateUse2: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
QuantizedPrivateUse3: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
QuantizedMeta: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
CustomRNGKeyId: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
MkldnnCPU: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
SparseCsrCPU: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
SparseCsrCUDA: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
SparseCsrHIP: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
SparseCsrMPS: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
SparseCsrIPU: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
SparseCsrXPU: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
SparseCsrHPU: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
SparseCsrVE: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
SparseCsrMTIA: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
SparseCsrPrivateUse1: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
SparseCsrPrivateUse2: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
SparseCsrPrivateUse3: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
SparseCsrMeta: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
BackendSelect: fallthrough registered at ../aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:153 [backend fallback]
FuncTorchDynamicLayerBackMode: registered at ../aten/src/ATen/functorch/DynamicLayer.cpp:497 [backend fallback]
Functionalize: registered at ../aten/src/ATen/FunctionalizeFallbackKernel.cpp:349 [backend fallback]
Named: fallthrough registered at ../aten/src/ATen/core/NamedRegistrations.cpp:11 [kernel]
Conjugate: registered at ../aten/src/ATen/ConjugateFallback.cpp:17 [backend fallback]
Negative: registered at ../aten/src/ATen/native/NegateFallback.cpp:18 [backend fallback]
ZeroTensor: registered at ../aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
ADInplaceOrView: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:86 [backend fallback]
AutogradOther: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:18032 [autograd kernel]
AutogradCPU: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:18032 [autograd kernel]
AutogradCUDA: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:18032 [autograd kernel]
AutogradHIP: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:18032 [autograd kernel]
AutogradXLA: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:18032 [autograd kernel]
AutogradMPS: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:18032 [autograd kernel]
AutogradIPU: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:18032 [autograd kernel]
AutogradXPU: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:18032 [autograd kernel]
AutogradHPU: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:18032 [autograd kernel]
AutogradVE: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:18032 [autograd kernel]
AutogradLazy: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:18032 [autograd kernel]
AutogradMTIA: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:18032 [autograd kernel]
AutogradPrivateUse1: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:18032 [autograd kernel]
AutogradPrivateUse2: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:18032 [autograd kernel]
AutogradPrivateUse3: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:18032 [autograd kernel]
AutogradMeta: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:18032 [autograd kernel]
AutogradNestedTensor: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:18032 [autograd kernel]
Tracer: registered at ../torch/csrc/autograd/generated/TraceType_0.cpp:17004 [kernel]
AutocastCPU: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:209 [backend fallback]
AutocastXPU: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:351 [backend fallback]
AutocastCUDA: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:165 [backend fallback]
FuncTorchBatched: registered at ../aten/src/ATen/functorch/BatchRulesBinaryOps.cpp:320 [kernel]
BatchedNestedTensor: registered at ../aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:758 [backend fallback]
FuncTorchVmapMode: fallthrough registered at ../aten/src/ATen/functorch/VmapModeRegistrations.cpp:27 [backend fallback]
Batched: registered at ../aten/src/ATen/LegacyBatchingRegistrations.cpp:1079 [kernel]
VmapMode: fallthrough registered at ../aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
FuncTorchGradWrapper: registered at ../aten/src/ATen/functorch/TensorWrapper.cpp:207 [backend fallback]
PythonTLSSnapshot: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:161 [backend fallback]
FuncTorchDynamicLayerFrontMode: registered at ../aten/src/ATen/functorch/DynamicLayer.cpp:493 [backend fallback]
PreDispatch: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:165 [backend fallback]
PythonDispatcher: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:157 [backend fallback]


In [49]:
b = b.to_dense()

In [50]:
torch.allclose(a,b)

True

In [62]:
a = torch.rand((16,512,66000), device='cuda:1')
b = a * (a > 0.5).float()

In [63]:
a.shape, b.shape

(torch.Size([16, 512, 66000]), torch.Size([16, 512, 66000]))

In [57]:
a = a.to_sparse()
b = b.to_sparse()
torch.save(a, 'test1.pt')
torch.save(b, 'test2.pt')

In [64]:
a = a.to_sparse()

In [59]:
a.shape

torch.Size([16, 512, 66000])

In [65]:
b = b.to_sparse()

In [66]:
b

tensor(indices=tensor([[    0,     0,     0,  ...,    15,    15,    15],
                       [    0,     0,     0,  ...,   511,   511,   511],
                       [    0,     1,     4,  ..., 65986, 65997, 65999]]),
       values=tensor([0.6686, 0.8125, 0.9768,  ..., 0.9957, 0.7157, 0.7359]),
       device='cuda:1', size=(16, 512, 66000), nnz=270342151,
       layout=torch.sparse_coo)