### Activation Maximization

In [None]:
# These two lines ensure that we always import the latest version of a package, in case it has been modified.
%load_ext autoreload
%autoreload 2

import timm
import torch
import detectors
from torchvision import transforms
from utils import vis as vis

In [2]:
import os
from tqdm import tqdm

from data_utils.data_stats import *
from models.networks import get_model
from data_utils.dataloader import get_loader
from data_utils.dataset_to_beton import get_dataset

In [3]:
# define important parameters

dataset         = 'cifar10'               # One of cifar10, cifar100, stl10, imagenet or imagenet21
num_classes     = CLASS_DICT[dataset]
data_path       = '/scratch/ffcv'
model_path      = '/tmp/zooming_in_on_mlps/'
eval_batch_size = 32
crop_resolution = 32
data_resolution = 32 
checkpoint      = None
device          = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_type      = 'mlp'                   
checkpoint      = 'in21k_cifar10'       
architecture    = 'B_12-Wi_1024'        
crop_resolution = 64   

In [4]:
def get_models_full(model_type, 
                    architecture, 
                    resolution  = crop_resolution, 
                    num_classes = CLASS_DICT[dataset], 
                    checkpoint  = checkpoint, 
                    model_path   = model_path):
    if model_type == 'mlp':
        model = get_model(architecture=architecture, resolution = resolution, 
                          num_classes=num_classes,checkpoint= checkpoint)
        model = torch.nn.Sequential(vis.Reshape(64), model)
    elif model_type == 'cnn':
        model = timm.create_model(architecture, pretrained=True)
    elif model_type == 'vit':
        model = torch.load(os.path.join(model_path, architecture))
        model = torch.nn.Sequential(vis.Reshape(224), model)
    return model

In [None]:

model_type      = 'mlp'                   
checkpoint      = 'in21k_cifar10'       
architecture    = 'B_12-Wi_1024'        
crop_resolution = 32 

#model_type      = 'vit'                  
#architecture    = 'vit_small_patch16_224_' + dataset + '_7.pth'        
#crop_resolution = 32 

#model_type      = 'cnn'               
#architecture    = 'resnet18_' + dataset                      
#crop_resolution = 32 


# load the models
torch.backends.cuda.matmul.allow_tf32 = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = get_models_full(model_type, architecture)
model.to(device)
with torch.no_grad():
    model.eval()

loader = get_loader(
        dataset,
        bs=eval_batch_size,
        mode="test",
        augment=False,
        dev=device,
        mixup=0.0,
        data_path=data_path,
        data_resolution=data_resolution,
        crop_resolution=crop_resolution,
    )

# for activation maximization. Start from total random noise or use
# an initial image instead
use_init_image = True

# generate image using activation maximization
dataiter    = iter(loader)
ims, labels = next(dataiter)
img         = ims[1].unsqueeze(0)
label       = labels[1]

if use_init_image == False:
    img   = None
    label = 0
else:
    label = label.item()

init_image1, synthetic_image1 = vis.generate_image(model        = model, 
                                     target_class = label,
                                     epochs       = 250, 
                                     min_prob     = 0.9, 
                                     lr           = .01, 
                                     weight_decay = 5e-2, 
                                     step_size    = 100, 
                                     gamma        = 0.9,
                                     noise_size   = crop_resolution,
                                     model_type   = model_type,
                                     img          = img,
                                     dataset      = dataset)

### Feature Inversion

In [None]:
# These two lines ensure that we always import the latest version of a package, in case it has been modified.
%load_ext autoreload
%autoreload 2

import timm
import torch
import detectors
from torchvision import transforms
from utils import vis as vis

In [None]:
import os
from tqdm import tqdm

from data_utils.data_stats import *
from models.networks import get_model
from data_utils.dataloader import get_loader
from data_utils.dataset_to_beton import get_dataset

In [3]:
# define important parameters

dataset         = 'cifar10'               # One of cifar10, cifar100, stl10, imagenet or imagenet21
num_classes     = CLASS_DICT[dataset]
data_path       = '/scratch/ffcv'
model_path      = '/scratch/zooming_in_on_mlps'
eval_batch_size = 32
crop_resolution = 32
data_resolution = 32 
checkpoint      = None
device          = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_type      = 'mlp'                   
checkpoint      = 'in21k_cifar10'       
architecture    = 'B_12-Wi_1024'        
crop_resolution = 64  

In [4]:
def get_models_full(model_type, 
                    architecture, 
                    resolution  = crop_resolution, 
                    num_classes = CLASS_DICT[dataset], 
                    checkpoint  = checkpoint, 
                    model_path   = model_path):
    if model_type == 'mlp':
        model = get_model(architecture=architecture, resolution = resolution, 
                          num_classes=num_classes,checkpoint= checkpoint)
        model = torch.nn.Sequential(vis.Reshape(64), model)
    elif model_type == 'cnn':
        model = timm.create_model(architecture, pretrained=True)
    elif model_type == 'vit':
        model = torch.load(os.path.join(model_path, architecture))
        model = torch.nn.Sequential(vis.Reshape(224), model)
    return model

In [None]:
#model_type      = 'vit'                  
#architecture    = 'vit_small_patch16_224_' + dataset + '_7.pth'        
#crop_resolution = 32 

#model_type      = 'cnn'               
#architecture    = 'resnet18_' + dataset                      
#crop_resolution = 32 

model_type      = 'mlp'                   
checkpoint      = 'in21k_cifar10'       
architecture    = 'B_12-Wi_1024'        
crop_resolution = 32 

# load the model
torch.backends.cuda.matmul.allow_tf32 = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = get_models_full(model_type, architecture)
model.to(device)
with torch.no_grad():
    model.eval()

# initialize loader
loader = get_loader(
        dataset,
        bs=eval_batch_size,
        mode="test",
        augment=False,
        dev=device,
        mixup=0.0,
        data_path=data_path,
        data_resolution=data_resolution,
        crop_resolution=crop_resolution,
    )

In [6]:
if model_type == 'cnn':
    #modules = ['conv1', 'layer1.0.conv1']
    modules = ['layer1.0.conv1']
    
if model_type == 'vit':
    #modules = ['1.blocks.0.mlp.fc1', '1.blocks.1.mlp.fc1']
    modules = ['1.blocks.1.mlp.fc1']
    
if model_type == 'mlp':
    #modules = ['1.blocks.0.block.0', '1.blocks.2.block.0']
    modules = ['1.blocks.0.block.0']

In [7]:
dataiter    = iter(loader)
ims, labels = next(dataiter)
img         = ims[1].unsqueeze(0)
label       = labels[1]

In [None]:
vis.feature_inversion(model         = model, 
                      modules_names = modules, 
                      img           = img,
                      noise_size    = crop_resolution,
                      epochs        = 350, 
                      lr            = 1500,
                      step_size     = 100,
                      gamma         = 0.6,
                      mu            = 1e-1,
                      device        = device,
                      mode          = 1,
                      model_type    = model_type)