In [1]:
import os
import numpy as np
from tqdm import tqdm

import torch
from torchvision import transforms as T
from torchvision.datasets import ImageFolder
from models.resnet_imagenet import *
from models.vgg_imagenet import *
from zca_conv0 import get_conv0_weights
import random

import utils

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#### Seed everything
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

#### Variables
mode = "iid"
extract_mode = "gap" # flatten, channel, gap
network_name = "vgg11" # resnet18 , vgg11
num_classes = 50
save_dir = './MMD_values'
ZCA_conv0 = True
addgray = True
lmscone = True
data_path = '/data/datasets/ImageNet2012/'
indices_path = f'./indices/{num_classes}_classes'

conv0_outchannels = 3
if addgray:
    conv0_outchannels = 4

#### Create save folder
save_dir = os.path.join(save_dir, f'{num_classes}_classes/{network_name}')
save_dir += f'_randinit'
if lmscone:
    save_dir += '_lmscone'
if ZCA_conv0:
    save_dir += '_ZCAconv0'
    if addgray:
        save_dir += '_addgray'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

#### Load network
if ZCA_conv0:
    model = eval(network_name)(num_classes=num_classes, conv0_flag=True, conv0_outchannels=conv0_outchannels)
else:
    model = eval(network_name)(num_classes=num_classes)

#### Load complete dataset and transform function
data_transform = [
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor(),
        ]
if lmscone: # LMS cone values for whole imagenet
    mean=[0.5910, 0.5758, 0.5298]
    std=[0.2657, 0.2710, 0.2806]
    data_transform.append(utils.ConeTransform()) # add cone transformation
else: # RGB values for whole imagenet
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
if ZCA_conv0: # when doing ZCA, make std 1s
    std = [1.0, 1.0, 1.0]
# add normalization
data_transform.append(T.Normalize(mean=mean, std=std))
# compose transforms
data_transform = T.Compose(data_transform)
# load ImageNet dataset (val set)
dataset = ImageFolder(root=os.path.join(data_path,'val'), transform=data_transform)

#### Load task datasets and dataloaders
if mode == "iid":
    indices_A_train = np.load(f'{indices_path}/IID_indices_A_train.npy')
    indices_B_train = np.load(f'{indices_path}/IID_indices_B_train.npy')
    indices_A_val = np.load(f'{indices_path}/IID_indices_A_val.npy')
    indices_B_val = np.load(f'{indices_path}/IID_indices_B_val.npy')
elif mode == "non_iid":
    indices_A_train = np.load(f'{indices_path}/Non_IID_indices_A_train.npy')
    indices_B_train = np.load(f'{indices_path}/Non_IID_indices_B_train.npy')
    indices_A_val = np.load(f'{indices_path}/Non_IID_indices_A_val.npy')
    indices_B_val = np.load(f'{indices_path}/Non_IID_indices_B_val.npy')
# get datasets A and B with subset and create loaders
data_A = torch.utils.data.Subset(dataset, indices_A_val)
data_B = torch.utils.data.Subset(dataset, indices_B_val)
loader_A = torch.utils.data.DataLoader(data_A, batch_size=256, shuffle=False)
loader_B = torch.utils.data.DataLoader(data_B, batch_size=256, shuffle=False)


In [3]:
import pandas as pd

def create_dataframe_for_subset(subset, dataset, subset_name):
    # Extract indices and class names
    indices = [idx for idx in subset.indices]
    class_names = [dataset.classes[dataset.targets[idx]] for idx in indices]

    # Create a DataFrame
    df = pd.DataFrame({'Index': indices, 'Class Name': class_names})
    df['Subset'] = subset_name  # Add a column to identify the subset

    return df

# Assuming you have subset1, subset2, subset3, subset4 and the original imagenet_dataset
df_subset1 = create_dataframe_for_subset(data_A, dataset, 'Subset 1')
df_subset2 = create_dataframe_for_subset(data_B, dataset, 'Subset 2')

In [4]:
layers = {}
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) or isinstance(module, CosineLinear):
        if "downsample" in name:
            continue
        layers[name]=module

In [6]:
for name, module in layers.items():
    print(name)

conv0
features.0
features.2
features.5
features.7
features.10
features.12
features.15
features.17
features.20
features.22
classifier
