# Load Model For Genception

## SAE Params

In [1]:
HIDDEN_SIZE = 256

## Imports

In [2]:
import pickle

import numpy as np
from sklearn.linear_model import LogisticRegression


import torch
from torch import nn 
from torch.utils.data import DataLoader

from models_and_data.nn import NeuralNetwork
from models_and_data.sae import SparseAutoencoder
from models_and_data.edgedataset import EdgeDataset

from models_and_data.model_helpers import (evaluate_and_gather_activations, get_sublabel_data, 
    get_top_N_features, extract_activations, load_intermediate_labels, seed_worker)

## Set Device to GPU

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"We will be using device: {device}")

We will be using device: cuda


## Load Data

In [4]:
# train data
train_images = load_intermediate_labels("./intermediate-labels/first_layer/train_images.pkl")
train_labels = load_intermediate_labels("./intermediate-labels/first_layer/train_labels.pkl")

# test data
test_images = load_intermediate_labels("./intermediate-labels/first_layer/test_images.pkl")
test_labels = load_intermediate_labels("./intermediate-labels/first_layer/test_labels.pkl")

# Model Loading

## Model Result Replication

In [5]:
seed = 42
generator = torch.Generator().manual_seed(seed)

NUM_WORKERS = 4
if device.type.lower() == "cpu":
    NUM_WORKERS = 0

# training data
train_dataset = EdgeDataset(train_images, train_labels)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=NUM_WORKERS,
                          worker_init_fn=seed_worker, generator=generator, pin_memory=True)

# test data
test_dataset = EdgeDataset(test_images, test_labels)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

In [6]:
model_baseline = NeuralNetwork().to(device)
model_F0 = NeuralNetwork().to(device)

sae_hidden_one_baseline = SparseAutoencoder(input_size=16, hidden_size=HIDDEN_SIZE).to(device)
sae_hidden_one_F0 = SparseAutoencoder(input_size=16, hidden_size=HIDDEN_SIZE).to(device)

In [7]:
best_model_path = "./SAE-Results/256-0.75/results/baseline/model_state_dict.pth"
checkpoint = torch.load(best_model_path)
model_baseline.load_state_dict(checkpoint['model_state_dict'])
sae_hidden_one_baseline.load_state_dict(checkpoint['sae_one_state_dict'])

best_model_path = "./SAE-Results/256-0.75/results/F1/models/256_mask_0.29/256_mask/best_model_lf_0.02.pth"
checkpoint = torch.load(best_model_path)
model_F0.load_state_dict(checkpoint['model_state_dict'])
sae_hidden_one_F0.load_state_dict(checkpoint['sae_one_state_dict'])

<All keys matched successfully>

# Some Geometry

In [8]:
activation_data_baseline = extract_activations(
    data_loader=train_loader,
    model=model_baseline,
    sae_one=sae_hidden_one_baseline,
    sae_two=sae_hidden_one_baseline,
    device=device
)

Extracting Activations: 100%|████████████████████████████████████████| 782/782 [00:01<00:00, 601.62it/s]


In [9]:
activation_data_baseline = extract_activations(
    data_loader=train_loader,
    model=model_F0,
    sae_one=sae_hidden_one_F0,
    sae_two=sae_hidden_one_F0,
    device=device
)

Extracting Activations: 100%|████████████████████████████████████████| 782/782 [00:01<00:00, 638.85it/s]


## Comparing Bias w/Recon/H1

Here we are comparing average decoded vector with the average hidden vector and the bias vector from the decoder.

In [10]:
b_d_baseline = sae_hidden_one_baseline.decoder.bias.cpu()
b_d_F0 = sae_hidden_one_F0.decoder.bias.cpu()

cos_sim = nn.CosineSimilarity(dim=0)
t1t2_sim = cos_sim(b_d_baseline, b_d_F0).mean()

print(f"Similarity b/w decoder bias vector baseline and decoder bias vector F0: {round(t1t2_sim.item(), 4)}")

Similarity b/w decoder bias vector baseline and decoder bias vector F0: 0.9333


## Comparing Model with Targets

In [11]:
recon_act_path = "./SAE-Results/256-0.75/features/F0/256_mask.pkl"
recon_max_sparse_act_one = load_intermediate_labels(recon_act_path)

In [12]:
target_vector = torch.from_numpy(np.mean(np.squeeze(np.array(recon_max_sparse_act_one)), axis=0)).float()

In [13]:
cos_sim(target_vector, b_d_baseline).mean()

tensor(1.0000, grad_fn=<MeanBackward0>)

In [14]:
cos_sim(target_vector, b_d_F0).mean()

tensor(0.9333, grad_fn=<MeanBackward0>)

In [20]:
recon_act_path = "./SAE-Results/256-0.75/features/F0/256_mask.pkl"
recon_max_sparse_act_one = load_intermediate_labels(recon_act_path)
target_vector_f0 = torch.from_numpy(np.mean(np.squeeze(np.array(recon_max_sparse_act_one)), axis=0)).float()

recon_act_path_two = "./SAE-Results/256-0.75/features/F2/256_mask_0.29_256_mask_0.02/256_mask.pkl"
recon_max_sparse_act_one_two = load_intermediate_labels(recon_act_path_two)
target_vector_f1 = torch.from_numpy(np.mean(np.squeeze(np.array(recon_max_sparse_act_one_two)), axis=0)).float()

In [21]:
cos_sim(target_vector_f0, target_vector_f1).mean()

tensor(0.9332)

## Exploring Weight Matrix

### Decoding the Mean Sparse Vector

In [None]:
sae_one_w = sae_hidden_one.decoder.weight.cpu()
sparse_vector_avg = torch.from_numpy(np.mean(activation_data["sparse_one"], axis=0)).float().unsqueeze(1)

W_dS_avg = sae_one_w @ sparse_vector_avg
recon_avg = W_dS_avg + b_d.unsqueeze(1)

sim1 = cos_sim(recon_avg, b_d).mean()
sim2 = cos_sim(recon_avg, avg_hidden_vector).mean()
sim3 = cos_sim(recon_avg, avg_recon_vector).mean()

print(f"Similiarity b/w reconstructed average and decoder bias vector: {round(sim1.item(), 4)}")
print(f"Similiarity b/w reconstructed average and average hidden activations: {round(sim2.item(), 4)}")
print(f"Similiarity b/w reconstructed average and average recon vector: {round(sim3.item(), 4)}")

### An Example of One

In [None]:
sparse_vector_ex = torch.from_numpy(activation_data["sparse_one"][0]).float().unsqueeze(1)

W_dS_ex = sae_one_w @ sparse_vector_ex
recon_ex = W_dS_ex + b_d.unsqueeze(1)

sim1_ex = cos_sim(recon_ex, b_d).mean()
sim2_ex = cos_sim(recon_ex, avg_hidden_vector).mean()
sim3_ex = cos_sim(recon_ex, avg_recon_vector).mean()

print(f"Similiarity b/w reconstructed average and decoder bias vector: {round(sim1_ex.item(), 4)}")
print(f"Similiarity b/w reconstructed average and average hidden activations: {round(sim2_ex.item(), 4)}")
print(f"Similiarity b/w reconstructed average and average recon vector: {round(sim3_ex.item(), 4)}")