In [1]:
!pip install -q --upgrade typing_extensions pydantic huggingface_hub aiohttp multiprocess xxhash better-abc sentencepiece typeguard

In [2]:
from training import train_sae, train_sae_group
from sae import VanillaSAE, TopKSAE, BatchTopKSAE, JumpReLUSAE
from sae import TopAFASAE
from activation_store import ActivationsStore
from config import get_default_cfg, post_init_cfg
from transformer_lens import HookedTransformer

In [3]:
import torch

if torch.cuda.is_available():
    device_id = torch.cuda.current_device()
    print(f"GPU Device ID: {device_id}")
    print(f"GPU Name: {torch.cuda.get_device_name(device_id)}")
    print(f"Memory Allocated: {torch.cuda.memory_allocated(device_id) / 1024**2:.2f} MB")
    print(f"Memory Cached: {torch.cuda.memory_reserved(device_id) / 1024**2:.2f} MB")
    print(f"Total Memory: {torch.cuda.get_device_properties(device_id).total_memory / 1024**2:.2f} MB")
    print(f"Compute Capability: {torch.cuda.get_device_capability(device_id)}")
    print(f"Multiprocessors: {torch.cuda.get_device_properties(device_id).multi_processor_count}")
    print(f"CUDA Version: {torch.version.cuda}")
else:
    print("CUDA is not available")


GPU Device ID: 0
GPU Name: NVIDIA A100-SXM4-80GB
Memory Allocated: 0.00 MB
Memory Cached: 0.00 MB
Total Memory: 81037.75 MB
Compute Capability: (8, 0)
Multiprocessors: 108
CUDA Version: 12.1


In [4]:
project_name = '25-03-19-1852-20k'

num_tokens = int((20*1000+1)*4096)

In [5]:
experiments = [
    ([6,7,8], 'topafa', [768*16], [0], [1/128, 1/64, 1/32, 1/24, 1/16]),
    ([6,7,8], 'batchtopk', [768*16], [2**i for i in [13, 12, 11, 10, 9, 8, 7, 6, 5]], [0]), 
    ([6,7,8], 'topk', [768*16], [2**i for i in [13, 12, 11, 10, 9, 8, 7, 6, 5]], [0]),
]

In [None]:
for layers, sae_type, dict_sizes, ks, afa_coeffs in experiments:
    for layer in layers:
        for dict_size in dict_sizes:
            for k in ks:
                for afa_coeff in afa_coeffs:
                    cfg = get_default_cfg()
                
                    cfg["model_name"] = "gpt2-small"
                    cfg["dict_size"] = dict_size
                    cfg["layer"] = layer
                    
                    cfg["sae_type"] = sae_type
                    cfg["top_k"] = k
                    cfg["dataset_path"] = "Skylion007/openwebtext"
                    cfg["num_tokens"] = num_tokens
                    cfg['wandb_project'] = project_name
                    
                    if cfg["sae_type"] == "topk":
                        sae = TopKSAE(cfg)
                    elif cfg["sae_type"] == "batchtopk":
                        sae = BatchTopKSAE(cfg)
                    elif cfg["sae_type"] == 'topafa':
                        cfg['afa_coeff'] = afa_coeff
                        sae = TopAFASAE(cfg)
                
                    cfg = post_init_cfg(cfg)
                    print(cfg)
                                
                    model = HookedTransformer.from_pretrained(cfg["model_name"]).to(cfg["dtype"]).to(cfg["device"])
                    activations_store = ActivationsStore(model, cfg)
                    train_sae(sae, activations_store, model, cfg)
                    
import wandb
wandb.finish()