In [1]:
from sklearn.metrics import mutual_info_score

import pandas as pd
import numpy as np
import torch
from torchvision import datasets, transforms
from torchvision.models import vit_h_14, ViT_H_14_Weights
from torch.utils.data import DataLoader
from sklearn.metrics import adjusted_rand_score, silhouette_score, davies_bouldin_score, calinski_harabasz_score
from sklearn.manifold import TSNE

from helpers.sae import SparseAutoencoder
from helpers.helpers import extract_activations, SNE_plot_2d, load_intermediate_labels

EMBEDS_PATH = './embeds/pos_embed_edge_384_99.56.pth'
# VIT_PATH = './classifiers/baseline/vit_h_99.56.pth'
# SAE_PATH = './sae_models/baseline-99.56/last_layer/sae_last_layer_l1_0.0002.pth'

# VIT_PATH = './classifiers/F0/best_model_lf_0.01.pth'
# SAE_PATH = './sae_models/F0/sae_last_layer_l1_0.0002.pth'

VIT_PATH = './classifiers/F1/best_model_lf_0.01.pth'
SAE_PATH = './sae_models/F1/sae_last_layer_l1_0.0002.pth'

# VIT_PATH = './classifiers/F2/best_model_lf_0.3.pth'
# SAE_PATH = './sae_models/F2/sae_last_layer_l1_0.0002.pth'

IMG_RES = 384
# FEATURE_DIM = 1280
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# print(f"We will be using device: {device}")

# eval_transform = transforms.Compose([
#     transforms.Resize((IMG_RES, IMG_RES)),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
# ])
# test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=eval_transform)
# test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False,
#                 pin_memory=True, num_workers=4)

# print("Loading ViT")
# weights = ViT_H_14_Weights.IMAGENET1K_SWAG_LINEAR_V1
# model = vit_h_14(weights=weights) 
# model.image_size = IMG_RES  # Update the expected image size

# model.encoder.pos_embedding = torch.nn.Parameter(torch.load(EMBEDS_PATH))

# num_ftrs = model.heads.head.in_features
# model.heads.head = torch.nn.Linear(num_ftrs, 10)
# model.load_state_dict(torch.load(VIT_PATH))
# model.to(device)

# print("Loading SAE")
# sae = SparseAutoencoder(input_dim=FEATURE_DIM)
# sae.load_state_dict(torch.load(SAE_PATH))
# sae.to(device)

# print("Extracting activations")
# activation_data = extract_activations(
#     data_loader=test_loader,
#     model=model,
#     sae=sae,
#     device=device
# )
seed = 42
generator = torch.Generator().manual_seed(seed)

train_transform = transforms.Compose([
    transforms.Resize((IMG_RES, IMG_RES)),
    transforms.TrivialAugmentWide(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
eval_transform = transforms.Compose([
    transforms.Resize((IMG_RES, IMG_RES)),  # Ensure 384x384 for validation
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

full_train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
val_dataset_for_split = datasets.CIFAR10(root='./data', train=True, download=True, transform=eval_transform)

num_train = len(full_train_dataset)
split = int(0.9 * num_train)
indices = torch.randperm(num_train, generator=generator).tolist()

train_subset = torch.utils.data.Subset(full_train_dataset, indices[:split])


RECON_ACT_BASE_PATH = "./features/classifier-99.56/baseline"
N = 25
sparse_type = "top"
recon_act_path = f"{RECON_ACT_BASE_PATH}/{N}_{sparse_type}.pkl"
recon_act_raw = load_intermediate_labels(recon_act_path)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
labels = []
for img, label in train_subset:
    labels.append(label)

In [7]:
type(labels)

list

In [11]:
recon_act_raw[0].shape

torch.Size([1, 1280])

In [13]:
import pickle 
output_path = "./labels.pkl"

# 2. Open the file in "write binary" mode and save the object
with open(output_path, 'wb') as f:
    pickle.dump(labels, f)