In [None]:
import os
import glob
import torch
from src import data
from overcomplete.sae import TopKSAE
    
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on {device}")


k = 100
d_model = 10_000
print(f"Config: d_model={d_model}, k={k}")

val_directory = "/scratch.global/lee02328/val_data_DINOv2_B"
model_path = "/users/9/lee02328/Ada_Comp/arch_SAE/trained_models/sae_1_SI-SAE_d10000_k100_per_init0.02_state_dict.pth"

# Load validation data
val_shard_files = sorted(glob.glob(os.path.join(val_directory, 'shard_*.pt')))

print(f"Found {len(val_shard_files)} validation shard files")

# Get input dimension from first shard
first_shard = torch.load(val_shard_files[0], map_location='cpu', weights_only=True)
d_brain = first_shard.shape[-1]
print(f"Detected embedding dimension: {d_brain}")

# Setup normalization
mean, std = data.get_dataset_stats(val_directory)
normalizer = data.GPUNormalizer(mean, std).to(device)

# Create validation loader
raw_val_loader = data.create_val_dataloader(
    val_directory,
    total_batch_size = 1000,
    num_workers=1,
    prefetch_factor=2,
    subset_fraction=0.5
)
val_loader = data.DeviceDataLoader(raw_val_loader, device, normalizer)

model_type = "SAE"
# Create and load model
print(f"Loading {model_type} model from {model_path}")


sae = TopKSAE(input_shape=d_brain, 
                nb_concepts=d_model, 
                top_k=k, 
                device=device)

sae.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
sae.eval()

print("Model loaded successfully")

In [None]:
with torch.no_grad():
    for batch in val_loader:
        x = batch.to(device) if isinstance(batch, torch.Tensor) else batch[0].to(device)
        x = x.float()

        # Get sparse codes
        _, z, _ = sae(x) 
        z = z.float()
        break
    print(z[0])