In [1]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, TensorDataset, ConcatDataset
import numpy as np
import matplotlib.pyplot as plt
import yaml
from tqdm import tqdm
import torchvision.transforms.v2 as v2
from copy import deepcopy
import sys
sys.path.append("/n/home11/sambt/phlab-neurips25")

device = 'cuda' if torch.cuda.is_available() else 'cpu'

from models.litmodels import SimCLRModel
from models.networks import CustomResNet, MLP
from models.losses import MMDLoss, RBF
from data.datasets import CIFAR10Dataset
from data.cifar import CIFAR5MDataset
import data.data_utils as dutils

from sklearn.metrics import roc_auc_score, top_k_accuracy_score
from utils.plotting import make_corner
import os

  from pandas.core.computation.check import NUMEXPR_INSTALLED


In [3]:
N_cifar5m_load = 100_000
cifar = CIFAR10Dataset("resnet50",num_workers=2,batch_size=1024,exclude_classes=[])
cifar_train_dataset = cifar.train_dataset
cifar_test_dataset = cifar.test_dataset
cifar5m_full = CIFAR5MDataset("resnet50",[0],[(None,N_cifar5m_load)],grayscale=False,exclude_classes=[])

cifar_train_loader = DataLoader(cifar_train_dataset,batch_size=512,shuffle=True)
cifar_test_loader = DataLoader(cifar_test_dataset,batch_size=512,shuffle=True)
cifar5m_loader = DataLoader(cifar5m_full,batch_size=512,shuffle=True)

classes = np.arange(10)
#checkpoints = {
#    0:"/n/home11/sambt/phlab-neurips25/runs/cifar10_simCLR_ResNet50_T0.1_dim4_exclude0/lightning_logs/2xpzevgs/checkpoints/epoch=3-step=352.ckpt",
#    1:"/n/home11/sambt/phlab-neurips25/runs/cifar10_simCLR_ResNet50_T0.1_dim4_exclude1/lightning_logs/phformo7/checkpoints/epoch=3-step=352.ckpt",
#    2:"/n/home11/sambt/phlab-neurips25/runs/cifar10_simCLR_ResNet50_T0.1_dim4_exclude2/lightning_logs/7qzpzkw2/checkpoints/epoch=4-step=440.ckpt",
#    3:"/n/home11/sambt/phlab-neurips25/runs/cifar10_simCLR_ResNet50_T0.1_dim4_exclude3/lightning_logs/1tg0w3bd/checkpoints/epoch=4-step=440.ckpt",
#    4:"/n/home11/sambt/phlab-neurips25/runs/cifar10_simCLR_ResNet50_T0.1_dim4_exclude4/lightning_logs/q7lmidvn/checkpoints/epoch=4-step=440.ckpt",
#    5:"/n/home11/sambt/phlab-neurips25/runs/cifar10_simCLR_ResNet50_T0.1_dim4_exclude5/lightning_logs/fykocd4q/checkpoints/epoch=8-step=792.ckpt",
#    6:"/n/home11/sambt/phlab-neurips25/runs/cifar10_simCLR_ResNet50_T0.1_dim4_exclude6/lightning_logs/n5eq4qyw/checkpoints/epoch=4-step=440.ckpt",
#    7:"/n/home11/sambt/phlab-neurips25/runs/cifar10_simCLR_ResNet50_T0.1_dim4_exclude7/lightning_logs/516e1a20/checkpoints/epoch=4-step=440.ckpt",
#    8:"/n/home11/sambt/phlab-neurips25/runs/cifar10_simCLR_ResNet50_T0.1_dim4_exclude8/lightning_logs/r74quvbx/checkpoints/epoch=4-step=440.ckpt",
#    9:"/n/home11/sambt/phlab-neurips25/runs/cifar10_simCLR_ResNet50_T0.1_dim4_exclude9/lightning_logs/uxa5naeo/checkpoints/epoch=6-step=616.ckpt"
#}
checkpoints = {
    1:"/n/home11/sambt/phlab-neurips25/runs/cifar10_simCLR_ResNet50_T0.1_withClassifier_dim4_exclude1/lightning_logs/a9v4e8zh/checkpoints/epoch=4-step=440.ckpt",
    2:"/n/home11/sambt/phlab-neurips25/runs/cifar10_simCLR_ResNet50_T0.1_withClassifier_dim4_exclude2/lightning_logs/rlzxkflo/checkpoints/epoch=5-step=528.ckpt",
    3:"/n/home11/sambt/phlab-neurips25/runs/cifar10_simCLR_ResNet50_T0.1_withClassifier_dim4_exclude3/lightning_logs/wi6idr9q/checkpoints/epoch=5-step=528.ckpt",
    4:"/n/home11/sambt/phlab-neurips25/runs/cifar10_simCLR_ResNet50_T0.1_withClassifier_dim4_exclude4/lightning_logs/nfmxrtyd/checkpoints/epoch=3-step=352.ckpt",
    5:"/n/home11/sambt/phlab-neurips25/runs/cifar10_simCLR_ResNet50_T0.1_withClassifier_dim4_exclude5/lightning_logs/lhxchpq6/checkpoints/epoch=6-step=616.ckpt",
    6:"/n/home11/sambt/phlab-neurips25/runs/cifar10_simCLR_ResNet50_T0.1_withClassifier_dim4_exclude6/lightning_logs/pvt7snek/checkpoints/epoch=2-step=264.ckpt"
}
classes = sorted(list(checkpoints.keys()))

models = {l:SimCLRModel.load_from_checkpoint(checkpoints[l]).to(device) for l in classes}
outdirs = {}
for label in classes:
    outdir = f"embeddings_noShift_spaceWithClassifier/cifar_excludeClass{label}/"
    os.makedirs(outdir,exist_ok=True)
    outdirs[label] = outdir

/n/holystore01/LABS/iaifi_lab/Users/sambt/mamba/envs/torch_gpu/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'encoder' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['encoder'])`.
/n/holystore01/LABS/iaifi_lab/Users/sambt/mamba/envs/torch_gpu/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'projector' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['projector'])`.
/n/holystore01/LABS/iaifi_lab/Users/sambt/mamba/envs/torch_gpu/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'classifier' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['classifier'])`.


In [4]:
embeds = {l:[] for l in classes}
labels = {l:[] for l in classes}
for batch in tqdm(cifar_train_loader):
    x,labs = batch
    for l in classes:
        with torch.no_grad():
            embeds[l].append(models[l].encoder(x.to(device)).cpu().numpy())
        labels[l].append(labs.numpy())
for l in classes:
    np.savez(f"{outdirs[l]}/cifar10_train.npz",
             data=np.concatenate(embeds[l]),
             labels=np.concatenate(labels[l]))
del embeds, labels
torch.cuda.empty_cache()

embeds = {l:[] for l in classes}
labels = {l:[] for l in classes}
for batch in tqdm(cifar_test_loader):
    x,labs = batch
    for l in classes:
        with torch.no_grad():
            embeds[l].append(models[l].encoder(x.to(device)).cpu().numpy())
        labels[l].append(labs.numpy())
for l in classes:
    np.savez(f"{outdirs[l]}/cifar10_test.npz",
             data=np.concatenate(embeds[l]),
             labels=np.concatenate(labels[l]))
del embeds, labels
torch.cuda.empty_cache()

embeds = {l:[] for l in classes}
labels = {l:[] for l in classes}
for batch in tqdm(cifar5m_loader):
    x,labs = batch
    for l in classes:
        with torch.no_grad():
            embeds[l].append(models[l].encoder(x.to(device)).cpu().numpy())
        labels[l].append(labs.numpy())
for l in classes:
    np.savez(f"{outdirs[l]}/cifar5m_N{N_cifar5m_load}.npz",
             data=np.concatenate(embeds[l]),
             labels=np.concatenate(labels[l]))
del embeds, labels
torch.cuda.empty_cache()

100%|██████████| 98/98 [03:37<00:00,  2.22s/it]
100%|██████████| 20/20 [00:43<00:00,  2.17s/it]
100%|██████████| 196/196 [06:59<00:00,  2.14s/it]


# eval on cifar 10.1 

In [5]:
#classes = np.arange(10)
#checkpoints = {
#    0:"/n/home11/sambt/phlab-neurips25/runs/cifar10_simCLR_ResNet50_T0.1_dim4_exclude0/lightning_logs/2xpzevgs/checkpoints/epoch=3-step=352.ckpt",
#    1:"/n/home11/sambt/phlab-neurips25/runs/cifar10_simCLR_ResNet50_T0.1_dim4_exclude1/lightning_logs/phformo7/checkpoints/epoch=3-step=352.ckpt",
#    2:"/n/home11/sambt/phlab-neurips25/runs/cifar10_simCLR_ResNet50_T0.1_dim4_exclude2/lightning_logs/7qzpzkw2/checkpoints/epoch=4-step=440.ckpt",
#    3:"/n/home11/sambt/phlab-neurips25/runs/cifar10_simCLR_ResNet50_T0.1_dim4_exclude3/lightning_logs/1tg0w3bd/checkpoints/epoch=4-step=440.ckpt",
#    4:"/n/home11/sambt/phlab-neurips25/runs/cifar10_simCLR_ResNet50_T0.1_dim4_exclude4/lightning_logs/q7lmidvn/checkpoints/epoch=4-step=440.ckpt",
#    5:"/n/home11/sambt/phlab-neurips25/runs/cifar10_simCLR_ResNet50_T0.1_dim4_exclude5/lightning_logs/fykocd4q/checkpoints/epoch=8-step=792.ckpt",
#    6:"/n/home11/sambt/phlab-neurips25/runs/cifar10_simCLR_ResNet50_T0.1_dim4_exclude6/lightning_logs/n5eq4qyw/checkpoints/epoch=4-step=440.ckpt",
#    7:"/n/home11/sambt/phlab-neurips25/runs/cifar10_simCLR_ResNet50_T0.1_dim4_exclude7/lightning_logs/516e1a20/checkpoints/epoch=4-step=440.ckpt",
#    8:"/n/home11/sambt/phlab-neurips25/runs/cifar10_simCLR_ResNet50_T0.1_dim4_exclude8/lightning_logs/r74quvbx/checkpoints/epoch=4-step=440.ckpt",
#    9:"/n/home11/sambt/phlab-neurips25/runs/cifar10_simCLR_ResNet50_T0.1_dim4_exclude9/lightning_logs/uxa5naeo/checkpoints/epoch=6-step=616.ckpt"
#}

checkpoints = {
    1:"/n/home11/sambt/phlab-neurips25/runs/cifar10_simCLR_ResNet50_T0.1_withClassifier_dim4_exclude1/lightning_logs/a9v4e8zh/checkpoints/epoch=4-step=440.ckpt",
    2:"/n/home11/sambt/phlab-neurips25/runs/cifar10_simCLR_ResNet50_T0.1_withClassifier_dim4_exclude2/lightning_logs/rlzxkflo/checkpoints/epoch=5-step=528.ckpt",
    3:"/n/home11/sambt/phlab-neurips25/runs/cifar10_simCLR_ResNet50_T0.1_withClassifier_dim4_exclude3/lightning_logs/wi6idr9q/checkpoints/epoch=5-step=528.ckpt",
    4:"/n/home11/sambt/phlab-neurips25/runs/cifar10_simCLR_ResNet50_T0.1_withClassifier_dim4_exclude4/lightning_logs/nfmxrtyd/checkpoints/epoch=3-step=352.ckpt",
    5:"/n/home11/sambt/phlab-neurips25/runs/cifar10_simCLR_ResNet50_T0.1_withClassifier_dim4_exclude5/lightning_logs/lhxchpq6/checkpoints/epoch=6-step=616.ckpt",
    6:"/n/home11/sambt/phlab-neurips25/runs/cifar10_simCLR_ResNet50_T0.1_withClassifier_dim4_exclude6/lightning_logs/pvt7snek/checkpoints/epoch=2-step=264.ckpt"
}
classes = sorted(list(checkpoints.keys()))

models = {l:SimCLRModel.load_from_checkpoint(checkpoints[l]).to(device) for l in classes}
outdirs = {}
for label in classes:
    outdir = f"embeddings_noShift_spaceWithClassifier/cifar_excludeClass{label}/"
    os.makedirs(outdir,exist_ok=True)
    outdirs[label] = outdir

In [6]:
cifar10p1_data = np.load("cifar10.1_data/cifar10.1_v6_data.npy").transpose(0,3,1,2)
cifar10p1_labels = np.load("cifar10.1_data/cifar10.1_v6_labels.npy")
shuf = np.random.permutation(len(cifar10p1_labels))
cifar10p1_data = torch.tensor(cifar10p1_data[shuf])
cifar10p1_labels = torch.tensor(cifar10p1_labels[shuf])
transform = dutils.ResNet50Transform(resnet_type='resnet50',grayscale=False,from_pil=False)
dataset = dutils.TransformDataset(transform,cifar10p1_data,cifar10p1_labels)
loader = DataLoader(dataset,batch_size=512,shuffle=False)

In [7]:
embeds = {l:[] for l in classes}
labels = {l:[] for l in classes}
for batch in tqdm(loader):
    x,labs = batch
    for l in classes:
        with torch.no_grad():
            embeds[l].append(models[l].encoder(x.to(device)).cpu().numpy())
        labels[l].append(labs.numpy())
for l in classes:
    np.savez(f"{outdirs[l]}/cifar10.1.npz",
             data=np.concatenate(embeds[l]),
             labels=np.concatenate(labels[l]))
del embeds, labels
torch.cuda.empty_cache()

100%|██████████| 4/4 [00:08<00:00,  2.16s/it]
