In [1]:
import torch
import torch.nn as nn
import torchvision.models as models   
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
from numpy import asarray, percentile, tile
import json
import requests
import math
import random
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter
import warnings
warnings.filterwarnings('ignore')
%matplotlib inline

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# device = torch.device("cpu")
print(f'Using {device} for inference')

resnet50 = models.resnet50(pretrained = True)
utils = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_convnets_processing_utils')

resnet50.eval().to(device)

alexnet = models.alexnet(weights='IMAGENET1K_V1')
alexnet.eval().to(device)

# Normalize images for final displaying
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
denormalize = transforms.Normalize(mean = [-0.485/0.229, -0.456/0.224, -0.406/0.225], std = [1/0.229, 1/0.224, 1/0.225] )
def image_converter(im):
    im_copy = im.cpu()
    
    im_copy = denormalize(im_copy.clone().detach()).numpy()
    im_copy = im_copy.transpose(1,2,0)
    im_copy = im_copy.clip(0, 1) 
    return im_copy

Using cuda for inference


Using cache found in C:\Users\phili/.cache\torch\hub\NVIDIA_DeepLearningExamples_torchhub


In [2]:
# Compute jitter loss
def jitter(model, img, jitter_t, layer_activation, layer_name, unit):
    sum = 0
    for i in range(10):
        temp_img = img
        tao = random.randint(0, jitter_t)
        temp_img = torch.add(temp_img, tao)
        model(temp_img)
        layer_out = layer_activation[layer_name]
        sum = torch.add(sum, layer_out[0][unit])
    jitter_loss = sum / 10
    jitter_loss.requires_grad_(True)
    jitter_loss.retain_grad() 
    jitter_loss.backward(retain_graph=True)
    jitter_grad = jitter_loss.grad.detach()
    return jitter_loss.detach(), jitter_grad
# Compute TV loss
def tv(img, img_grad):
    bs_img, c_img, h_img, w_img = img.size()
    w = torch.sum(torch.pow(img[:,:,:,:-1] - img[:,:,:,1:], 2))
    h = torch.sum(torch.pow(img[:,:,:-1,:] - img[:,:,1:,:], 2))
    tv_loss = (1/(torch.norm(img_grad) * h_img * w_img) * (h + w))
    tv_loss.requires_grad_(True)
    tv_loss.retain_grad() 
    tv_loss.backward(retain_graph=True)
    tv_grad = tv_loss.grad.detach()
    return tv_loss.detach(), tv_grad

In [3]:
# Maximize activation of a single input image with regularization
def act_max(model, 
    inp_img, 
    layer_activation, 
    layer_name, 
    unit, 
    steps=100, 
    alpha=torch.tensor(1),
    TV = False,
    Jitter = False,
    Regular = False,
    jitter_t = 20,
    jitter_alpha = 0.05,
    tv_alpha = 0.05,
    show_img = False
    ):

    best_activation = -float('inf')
    min_loss = float('inf')
    best_img = inp_img
    for k in range(steps):
        inp_img.requires_grad_(True)
        inp_img.retain_grad() 
        inp_img = inp_img.to(device)
        old_norm = torch.norm(inp_img)
        # Propagate image
        model(inp_img)
        layer_out = layer_activation[layer_name]
        # Compute gradients
        layer_out[0][unit].backward(retain_graph=True)
        img_grad = inp_img.grad
            
        # Gradient Step
        inp_img = torch.add(inp_img, torch.mul(img_grad.detach(), alpha))

        act_loss = layer_out[0][unit]
        #Jitter
        jitter_loss = torch.tensor(0)
        jitter_grad = 0
        if Jitter and k % 10 == 0:
            jitter_loss, jitter_grad = jitter(model, inp_img, jitter_t, layer_activation, layer_name, unit)
        if TV:
            tv_loss, tv_grad = tv(inp_img, img_grad)

        # Keep highest activation
        loss = -1 * act_loss
        if Jitter:
            loss -= jitter_alpha * jitter_loss
        if TV:
            loss += tv_alpha * tv_loss
        if Regular:
            if Jitter:
              inp_img = torch.add(inp_img, torch.mul(jitter_grad, alpha*jitter_alpha))
            if TV:
              inp_img = torch.add(inp_img, torch.mul(tv_grad, -alpha*tv_alpha))
        
        new_norm = torch.norm(inp_img)
        inp_img = torch.mul(inp_img, old_norm/new_norm)
        if loss < min_loss:
            if not Jitter or k % 10 != 10:
                jitter_loss, jitter_grad = jitter(model, inp_img, jitter_t, layer_activation, layer_name, unit)
            if not TV:
                tv_loss, tv_grad = tv(inp_img, img_grad)
            best_activation = act_loss, jitter_loss, tv_loss, loss
            min_loss = loss
            best_img = inp_img

        if show_img and k == steps-1:
            final_image = image_converter(inp_img.squeeze(0))
            plt.imshow(final_image)
            plt.show()        
            print('step: ', k, 'activation: ', layer_out[0][unit])
        
    return (best_activation, best_img)

In [4]:
# Needed for getting gradients
def layer_hook(act_dict, layer_name):
    def hook(module, input, output):
        act_dict[layer_name] = output
    return hook
# Random image initialization
def reset_img():
  inp = torch.rand((1, 3, 227, 227))
  inp.requires_grad_(True)
  return inp.to(device)
# Convert (1, 3, 227, 227) Torch tensor into 227*227 element numpy array, averaging across RGB channels
def np_data(img):
    img = denormalize(img.squeeze().detach().cpu())
    img = torch.mean(img, 0)
    img = torch.flatten(img)
    img = img.numpy()
    return img
# Displaying data in a histogram to visualize activation distribution
def get_hist(arr, title):
    hist, bin = np.histogram(arr)
    plt.hist(arr, bins=bin)
    plt.title(title)
    plt.show()

In [5]:
unit = 130
steps = 200
alpha = torch.tensor(1.5)
# Main method: activation maximization over 'trials' times, computing all 4 kinds of 
# losses (activation score, jitter loss, TV loss, and total (weighted) loss)
def experiment(model, TV, Jitter, jitter_t=0, jitter_alpha=0, tv_alpha=0, trials=10):
    # In order: activation, jitter, tv, and total losses
    losses = [], [], [], []
    for t in range(trials):
        # starting image
        orig_img = reset_img()
        inp = orig_img
        # outputs of image through both neural nets
        results = model(inp)
        value = results.detach().cpu().numpy()
        # max outputs of image through both neural nets
        k = max(value[0])
        act_dict = {}
        layer_name = 'classifier_final'
        list(model.children())[-1].register_forward_hook(layer_hook(act_dict, layer_name))
        
        activation, output = act_max(model=model,
                    inp_img=inp,
                    layer_activation=act_dict,
                    layer_name=layer_name,
                    unit=unit,
                    steps=steps,
                    alpha=alpha,
                    TV=TV,
                    Jitter=Jitter,
                    Regular=True,
                    jitter_t=jitter_t,
                    jitter_alpha=jitter_alpha,
                    tv_alpha=tv_alpha,
                    show_img=False,
                    )
        for i in range(4):
            if isinstance(activation[i], int):
                print(i)
            losses[i].append(activation[i].detach().cpu().numpy().item())
        out = np_data(output)
        torch.cuda.empty_cache()
    names = ["Activations:", "Jitter losses:", "TV losses:", "Total losses:"], ["Average activation:", "Average jitter loss:", "Average TV loss:", "Average total loss:"]
    # for i in range(4):
    #     # print(names[0][i], str(losses[i]))
    #     # print(names[1][i], sum(losses[i])/trials)
    data = [losses[0][0] ,losses[1][0],losses[2][0],losses[3][0]]
    return data

#### Example results from an experiment
resnet50, TV=False, Jitter=True, jitter_t=10, jitter_alpha=0.1, tv_alpha=0, trials=10  
 - Activations: [163.57022094726562, 171.38491821289062, 177.36441040039062, 184.45797729492188, 174.5690460205078, 162.1881103515625, 173.3818359375, 167.202392578125, 176.84043884277344, 178.08633422851562]
 - Average activation: 172.90456848144532
 - Jitter losses: [22.103370666503906, 34.036346435546875, 8.708541870117188, 19.32915496826172, 38.63886260986328, 27.92671775817871, 14.243896484375, 23.848962783813477, 12.266249656677246, 39.02633285522461]
 - Average jitter loss: 24.0128436088562
 - TV losses: [0.02318560890853405, 0.027121976017951965, 0.028241293504834175, 0.04551585391163826, 0.025176187977194786, 0.026284033432602882, 0.026428097859025, 0.02534600719809532, 0.028803151100873947, 0.029583383351564407]
 - Average TV loss: 0.02856855932623148
 - Total losses: [-167.22222900390625, -171.38491821289062, -180.71128845214844, -184.45797729492188, -174.5690460205078, -164.54434204101562, -178.4657440185547, -170.62872314453125, -180.0749053955078, -181.38262939453125]
 - Average total loss: -175.34418029785155

In [11]:
import csv
# Experiment we ran for AlexNet
with open('alexnet_reg.csv', 'w', encoding='UTF8', newline='') as f:
    writer = csv.writer(f)
    
    jitter_t_values = [5, 8, 10, 15]
    jitter_alpha_values = [0.005,0.01,0.05,0.1]
    print("printing for alexnet")
    for i in range(4):
        for j in range(4):
            for k in range(4):
                data = experiment(alexnet, TV=True, Jitter=True, jitter_t=10, jitter_alpha=jitter_t_values[j], tv_alpha=jitter_alpha_values[k], trials=1)
                writer.writerow(data)
                print(str(data[0])+","+str(data[1])+","+str(data[2])+","+str(data[3]))

printing for alexnet
209.92312622070312,-0.07909908145666122,0.01727479323744774,-963.8392944335938
254.4646759033203,2.6975669860839844,0.022933119907975197,-1035.287109375
507.0416564941406,4.948776721954346,0.12290405482053757,-2624.91796875
836.6007080078125,14.898063659667969,0.32265886664390564,-4679.0009765625
194.5734405517578,-0.2849178910255432,0.016341088339686394,-1260.2451171875
218.1263427734375,-1.0432029962539673,0.021373311057686806,-1501.546875
397.50030517578125,6.869326114654541,0.10230516642332077,-3092.896484375
726.5560913085938,5.282922267913818,0.335868239402771,-6080.73828125
115.45935821533203,-1.5717108249664307,0.008543641306459904,-700.4129028320312
122.19690704345703,-1.8676187992095947,0.010542846284806728,-1001.8298950195312
378.70477294921875,1.9961864948272705,0.08805258572101593,-3241.064208984375
624.8535766601562,5.756803035736084,0.3534456491470337,-6178.48095703125
67.59841918945312,-1.7852413654327393,0.005208959802985191,-623.195556640625
55.31

In [12]:
# Experiment we ran for ResNet
print("printing for resnet50")
with open('resnet50_reg.csv', 'w', encoding='UTF8', newline='') as f:
    writer = csv.writer(f)
    
    jitter_t_values = [5, 8, 10, 15]
    jitter_alpha_values = [0.005,0.01,0.05,0.1]
    print("printing for resnet50")
    for i in range(4):
        for j in range(4):
            for k in range(4):
                data = experiment(resnet50, TV=True, Jitter=True, jitter_t=10, jitter_alpha=jitter_t_values[j], tv_alpha=jitter_alpha_values[k], trials=1)
                writer.writerow(data)
                print(str(data[0])+","+str(data[1])+","+str(data[2])+","+str(data[3]))

printing for resnet50
printing for resnet50
52.32490158081055,-6.062312602996826,0.02914789691567421,-121.46125030517578
60.06840515136719,-4.39596700668335,0.03160114958882332,-155.46815490722656
75.02816772460938,-1.404200792312622,0.2588936388492584,-159.3423614501953
67.95729064941406,-0.014199132099747658,0.6074260473251343,-212.98597717285156
39.4609375,-4.094005107879639,0.015002685599029064,-177.2740478515625
62.68140411376953,-5.944693565368652,0.021997181698679924,-271.08978271484375
59.43431854248047,-3.8177971839904785,0.2237296849489212,-233.4839630126953
58.95433044433594,0.6091572642326355,0.6811813712120056,-269.3022155761719
37.40721130371094,-5.779452323913574,0.01546125765889883,-219.0736846923828
58.025535583496094,-4.8475260734558105,0.01859930530190468,-289.9904479980469
60.6517219543457,-5.749935150146484,0.15406280755996704,-277.5993347167969
56.377899169921875,-0.10409469902515411,0.6986702084541321,-335.70648193359375
41.665061950683594,-4.39353084564209,0.015