# Load Model For Genception

## SAE Params

In [1]:
HIDDEN_SIZE = 1024
INPUT_SIZE = 64

## Imports

In [2]:
import pickle

import numpy as np
from sklearn.linear_model import LogisticRegression


import torch
from torch import nn 
from torch.utils.data import DataLoader

from helpers.nn import NeuralNetwork
from helpers.sae import SparseAutoencoder
from helpers.edgedataset import EdgeDataset

from helpers.model_helpers import (evaluate_and_gather_activations, get_sublabel_data, 
    get_top_N_features, extract_activations, load_intermediate_labels, seed_worker)

## Set Device to GPU

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"We will be using device: {device}")

We will be using device: cuda


## Load Data

In [4]:
# train data
train_images = load_intermediate_labels("./data/FashionMNIST/parsed/train_images.pkl")
train_labels = load_intermediate_labels("./data/FashionMNIST/parsed/train_labels.pkl")

# test data
test_images = load_intermediate_labels("./data/FashionMNIST/parsed/test_images.pkl")
test_labels = load_intermediate_labels("./data/FashionMNIST/parsed/test_labels.pkl")

# Model Loading

## Model Result Replication

In [5]:
seed = 42
generator = torch.Generator().manual_seed(seed)

NUM_WORKERS = 4
if device.type.lower() == "cpu":
    NUM_WORKERS = 0

# training data
train_dataset = EdgeDataset(train_images, train_labels)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=NUM_WORKERS,
                          worker_init_fn=seed_worker, generator=generator, pin_memory=True)

# test data
test_dataset = EdgeDataset(test_images, test_labels)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

In [6]:
model = NeuralNetwork().to(device)
sae_hidden_one = SparseAutoencoder(input_size=INPUT_SIZE, hidden_size=HIDDEN_SIZE).to(device)
sae_hidden_two = SparseAutoencoder(input_size=INPUT_SIZE, hidden_size=HIDDEN_SIZE).to(device)
sae_hidden_three = SparseAutoencoder(input_size=INPUT_SIZE, hidden_size=HIDDEN_SIZE).to(device)

In [7]:
best_model_path = "./runs/1024-0.75/results/F0/models/25_top/best_model_lf_0.14.pth"

checkpoint = torch.load(best_model_path)
model.load_state_dict(checkpoint['model_state_dict'])
sae_hidden_one.load_state_dict(checkpoint['sae_one_state_dict'])
sae_hidden_two.load_state_dict(checkpoint['sae_two_state_dict'])
sae_hidden_three.load_state_dict(checkpoint['sae_three_state_dict'])

<All keys matched successfully>

## Verify Correct Model

In [8]:
train_results = evaluate_and_gather_activations(model, sae_hidden_one, sae_hidden_two, sae_hidden_three, train_loader, device)
Z_train_one, Z_train_two, y_train = train_results["Z_one"], train_results["Z_two"], train_results["y"]

test_results = evaluate_and_gather_activations(model, sae_hidden_one, sae_hidden_two, sae_hidden_three, test_loader, device)
Z_test_one, Z_test_two, y_test = test_results["Z_one"], test_results["Z_two"], test_results["y"]

print(f"Model acc: {test_results['accuracy']}")

                                                                                                              

Model acc: 86.34




In [9]:
sparsity_one = np.mean(Z_test_one > 1e-5) * Z_test_one.shape[1]
sparsity_two = np.mean(Z_test_two > 1e-5) * Z_test_two.shape[1]
print(f"Average Non-Zero Features per Image (Hidden One): {sparsity_one:.2f}")
print(f"Average Non-Zero Features per Image (Hidden Two): {sparsity_two:.2f}")

Average Non-Zero Features per Image (Hidden One): 85.58
Average Non-Zero Features per Image (Hidden Two): 110.73


In [10]:
print("\n--- Training Linear Probes ---")
clf_one = LogisticRegression(penalty='l2', max_iter=1000, n_jobs=-1)
clf_one.fit(Z_train_one, y_train)
acc_one = clf_one.score(Z_test_one, y_test)
print(f"Linear Probe Accuracy (Hidden One): {acc_one:.2%}")

clf_two = LogisticRegression(penalty='l2', max_iter=1000, n_jobs=-1)
clf_two.fit(Z_train_two, y_train)
acc_two = clf_two.score(Z_test_two, y_test)
print(f"Linear Probe Accuracy (Hidden Two): {acc_two:.2%}")


--- Training Linear Probes ---


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


Linear Probe Accuracy (Hidden One): 87.00%
Linear Probe Accuracy (Hidden Two): 86.78%


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


# Some Geometry

In [11]:
activation_data = extract_activations(
    data_loader=train_loader,
    model=model,
    sae_one=sae_hidden_one,
    sae_two=sae_hidden_two,
    device=device
)

Extracting Activations: 100%|██████████████████████████████████████████████| 782/782 [00:01<00:00, 498.95it/s]


## Comparing Bias w/Recon/H1

Here we are comparing average decoded vector with the average hidden vector and the bias vector from the decoder.

In [12]:
avg_hidden_vector = torch.from_numpy(np.mean(activation_data['hidden_one'], axis=0)).float()
b_d = sae_hidden_one.decoder.bias.cpu()

cos_sim = nn.CosineSimilarity(dim=0)
t1t2_sim = cos_sim(avg_hidden_vector, b_d).mean()

avg_recon_vector = torch.from_numpy(np.mean(activation_data['recon_one'], axis=0)).float()
t1t3_sim = cos_sim(avg_hidden_vector, avg_recon_vector).mean()

print(f"Similarity b/w hidden activations and decoder bias vector: {round(t1t2_sim.item(), 4)}")
print(f"Similarity b/w hidden activations and reconstructed activations: {round(t1t3_sim.item(), 4)}")

Similarity b/w hidden activations and decoder bias vector: 0.8376
Similarity b/w hidden activations and reconstructed activations: 1.0


So, on average, the SAE reconstructs inputs well since the average reconstructed vector and the average hidden direction vector are algined.

Neither aligns with our bias vector from the decoder, making our theory about symmetry with S being the 0 vector moot.

## Comparing Model with Targets

In [13]:
recon_act_path = "./runs/1024-0.75/features/F0/25_top.pkl"
recon_max_sparse_act_one = load_intermediate_labels(recon_act_path)

In [14]:
target_vector = torch.from_numpy(np.mean(np.squeeze(np.array(recon_max_sparse_act_one)), axis=0)).float()

In [15]:
cos_sim(target_vector, avg_hidden_vector).mean()

tensor(0.9439)

In [16]:
cos_sim(target_vector, avg_recon_vector).mean()

tensor(0.9440)

The student has closely aligned with the master.

## Exploring Weight Matrix

### Decoding the Mean Sparse Vector

In [17]:
sae_one_w = sae_hidden_one.decoder.weight.cpu()
sparse_vector_avg = torch.from_numpy(np.mean(activation_data["sparse_one"], axis=0)).float().unsqueeze(1)

W_dS_avg = sae_one_w @ sparse_vector_avg
recon_avg = W_dS_avg + b_d.unsqueeze(1)

sim1 = cos_sim(recon_avg, b_d).mean()
sim2 = cos_sim(recon_avg, avg_hidden_vector).mean()
sim3 = cos_sim(recon_avg, avg_recon_vector).mean()

print(f"Similiarity b/w reconstructed average and decoder bias vector: {round(sim1.item(), 4)}")
print(f"Similiarity b/w reconstructed average and average hidden activations: {round(sim2.item(), 4)}")
print(f"Similiarity b/w reconstructed average and average recon vector: {round(sim3.item(), 4)}")

Similiarity b/w reconstructed average and decoder bias vector: 0.1524
Similiarity b/w reconstructed average and average hidden activations: 0.17
Similiarity b/w reconstructed average and average recon vector: 0.1641


### An Example of One

In [18]:
sparse_vector_ex = torch.from_numpy(activation_data["sparse_one"][0]).float().unsqueeze(1)

W_dS_ex = sae_one_w @ sparse_vector_ex
recon_ex = W_dS_ex + b_d.unsqueeze(1)

sim1_ex = cos_sim(recon_ex, b_d).mean()
sim2_ex = cos_sim(recon_ex, avg_hidden_vector).mean()
sim3_ex = cos_sim(recon_ex, avg_recon_vector).mean()

print(f"Similiarity b/w reconstructed average and decoder bias vector: {round(sim1_ex.item(), 4)}")
print(f"Similiarity b/w reconstructed average and average hidden activations: {round(sim2_ex.item(), 4)}")
print(f"Similiarity b/w reconstructed average and average recon vector: {round(sim3_ex.item(), 4)}")

Similiarity b/w reconstructed average and decoder bias vector: 0.0926
Similiarity b/w reconstructed average and average hidden activations: 0.1033
Similiarity b/w reconstructed average and average recon vector: 0.0998


Further example that our theory about the symmetry b/w 0 and N being off... the bias vector is not even close to any of the samples provided.