# Load Model For Genception

## SAE Params

In [1]:
HIDDEN_SIZE = 256
L1_PENALTY = 0.01

## Imports

In [2]:
import random
import copy
import pickle

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression


import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm

In [3]:
from models_and_data.nn import NeuralNetwork
from models_and_data.sae import SparseAutoencoder
from models_and_data.edgedataset import EdgeDataset

from models_and_data.model_helpers import (evaluate_and_gather_activations, get_sublabel_data, 
    get_top_N_features, extract_activations, load_intermediate_labels, seed_worker)

## Set Device to GPU

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"We will be using device: {device}")

We will be using device: cuda


## Load Data

In [5]:
# train data
train_images = load_intermediate_labels("./intermediate-labels/first_layer/train_images.pkl")
train_labels = load_intermediate_labels("./intermediate-labels/first_layer/train_labels.pkl")

# val data
val_images = load_intermediate_labels("./intermediate-labels/first_layer/val_images.pkl")
val_labels = load_intermediate_labels("./intermediate-labels/first_layer/val_labels.pkl")

# test data
test_images = load_intermediate_labels("./intermediate-labels/first_layer/test_images.pkl")
test_labels = load_intermediate_labels("./intermediate-labels/first_layer/test_labels.pkl")

# intermediate labels
N = 25
sparse_type = "top"  # mask or top

# Genception

## Model Result Replication

In [6]:
seed = 42
generator = torch.Generator().manual_seed(seed)

NUM_WORKERS = 4
if device.type.lower() == "cpu":
    NUM_WORKERS = 0

# training data
train_dataset = EdgeDataset(train_images, train_labels)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=NUM_WORKERS,
                          worker_init_fn=seed_worker, generator=generator, pin_memory=True)

# validation data
val_dataset = EdgeDataset(val_images, val_labels)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)  # larger batch size for faster validation

# test data
test_dataset = EdgeDataset(test_images, test_labels)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

In [7]:
model = NeuralNetwork().to(device)
sae_hidden_one = SparseAutoencoder(input_size=16, hidden_size=HIDDEN_SIZE).to(device)
sae_hidden_two = SparseAutoencoder(input_size=16, hidden_size=HIDDEN_SIZE).to(device)

In [8]:
best_model_path = "./intermediate-labels-new/first_layer_results/F0/models/256_top/best_model_lf_0.23.pth"
# best_model_path = "./intermediate-labels-new/first_layer_results/F0/models/25_top/best_model_lf_0.02.pth"

# best_model_path = "./intermediate-labels-new/first_layer_results/classifier_results.pth"

checkpoint = torch.load(best_model_path)
model.load_state_dict(checkpoint['model_state_dict'])
sae_hidden_one.load_state_dict(checkpoint['sae_one_state_dict'])
sae_hidden_two.load_state_dict(checkpoint['sae_two_state_dict'])

<All keys matched successfully>

In [9]:
train_results = evaluate_and_gather_activations(model, sae_hidden_one, sae_hidden_two, train_loader, device)
Z_train_one, Z_train_two, y_train = train_results["Z_one"], train_results["Z_two"], train_results["y"]

test_results = evaluate_and_gather_activations(model, sae_hidden_one, sae_hidden_two, test_loader, device)
Z_test_one, Z_test_two, y_test = test_results["Z_one"], test_results["Z_two"], test_results["y"]

print(f"Model acc: {test_results['accuracy']}")

                                                                                                                        

Model acc: 93.55




In [49]:
sparsity_one = np.mean(Z_test_one > 1e-5) * Z_test_one.shape[1]
sparsity_two = np.mean(Z_test_two > 1e-5) * Z_test_two.shape[1]
print(f"Average Non-Zero Features per Image (Hidden One): {sparsity_one:.2f}")
print(f"Average Non-Zero Features per Image (Hidden Two): {sparsity_two:.2f}")

Average Non-Zero Features per Image (Hidden One): 137.46
Average Non-Zero Features per Image (Hidden Two): 140.25


In [10]:
print("\n--- Training Linear Probes ---")
clf_one = LogisticRegression(penalty='l2', max_iter=1000, n_jobs=-1)
clf_one.fit(Z_train_one, y_train)
acc_one = clf_one.score(Z_test_one, y_test)
print(f"Linear Probe Accuracy (Hidden One): {acc_one:.2%}")

clf_two = LogisticRegression(penalty='l2', max_iter=1000, n_jobs=-1)
clf_two.fit(Z_train_two, y_train)
acc_two = clf_two.score(Z_test_two, y_test)
print(f"Linear Probe Accuracy (Hidden Two): {acc_two:.2%}")


--- Training Linear Probes ---


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


Linear Probe Accuracy (Hidden One): 93.82%
Linear Probe Accuracy (Hidden Two): 93.76%


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


## Target Reconstruction Generation

In [11]:
activation_data = extract_activations(
    data_loader=train_loader,
    model=model,
    sae_one=sae_hidden_one,
    sae_two=sae_hidden_two,
    device=device
)

Extracting Activations: 100%|████████████████████████████████████████████████████████| 782/782 [00:01<00:00, 625.95it/s]


In [12]:
sparse_vector_sizes = [25, 256]
for N_recon in sparse_vector_sizes:
    labels = activation_data["labels"]
    sparse_act_one = activation_data["sparse_one"]
    avg_digit_encoding, top_n_features = get_top_N_features(N_recon, sparse_act_one, labels)
    
    feature_indices_dict = {}
    for digit in range(0, 10):
        feature_indices_dict[digit] = top_n_features[digit]['indices']
    
    print("Features used:")
    print(len(feature_indices_dict[0]))
    
    recon_max_sparse_training, recon_max_sparse_ablated_training = get_sublabel_data(train_labels,
                                                                                     train_images,
                                                                                     feature_indices_dict,
                                                                                     sparse_act_one,
                                                                                     sae_hidden_one,
                                                                                     device,
                                                                                     HIDDEN_SIZE
                                                                                    )
    
    print("Size of datasets:")
    print(len(train_images), len(val_images), len(test_images), len(recon_max_sparse_training))
    
    file_path = f"./{N_recon}_top.pkl"
    with open(file_path, "wb") as f:
        pickle.dump(recon_max_sparse_training, f)
    
    file_path = f"./{N_recon}_mask.pkl"
    with open(file_path, "wb") as f:
        pickle.dump(recon_max_sparse_ablated_training, f)

Features used:
25
Size of datasets:
50000 10000 10000 50000
Features used:
256
Size of datasets:
50000 10000 10000 50000


# Some Geometry

In [44]:
tensor1 = torch.from_numpy(np.mean(activation_data['hidden_one'], axis=0)).float()
tensor2 = sae_hidden_one.decoder.bias.cpu()

cos_sim = nn.CosineSimilarity(dim=0)
t1t2_sim = cos_sim(tensor1, tensor2).mean()

tensor3 = torch.from_numpy(np.mean(activation_data['recon_one'], axis=0)).float()
t1t3_sim = cos_sim(tensor1, tensor3).mean()

print(f"Similarity b/w hidden activations and decoder bias vector: {round(t1t2_sim.item(), 4)}")
print(f"Similarity b/w hidden activations and reconstructed activations: {round(t1t3_sim.item(), 4)}")

Similarity b/w hidden activations and decoder bias vector: -0.2551
Similarity b/w hidden activations and reconstructed activations: 1.0


In [45]:
sae_one_w = sae_hidden_one.decoder.weight.cpu()
sae_one_w.shape

torch.Size([16, 256])