# Sensitivity to orientation of deep networks

This notebook reproduces subpanels of Fig 2, and also the supplementary material showing controls.

All trained networks are downloaded from Torchvision's distribution. In the controls, some networks are trained on a modified (rotated) version of ImageNet. Training scripts can be found in `train_rotated.py`. 

This script can be easily modified to test the orientation sensitivity of any Torchvision network. Simply change the 'model' in the relevant cell.

In [None]:
import numpy as np
import torch
from matplotlib import pyplot as plt
import pickle
import pandas as pd
%matplotlib inline

from tqdm import tqdm as tqdm
from matplotlib import cm
import matplotlib as mpl

import seaborn as sns

plt.style.use('seaborn-white', )
plt.rcParams['axes.labelsize'] =  25
plt.rcParams['ytick.labelsize'] = 15.0
plt.rcParams['xtick.labelsize'] = 15.0
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42

import sys
sys.path.insert(0, "..")


from scripts.fisher_calculators import get_fisher_orientations
from scripts.orientation_stim import broadband_noise, gabor, grating, circular_mask
from torchvision import models, transforms, utils


##### Define some generators to get Fisher for.

In [None]:
a = np.pi/4
generator_gabor = lambda a: torch.from_numpy(gabor(size=224, pixelsPerDegree=100, spatial_freq=2, spatial_phase=.1,
          orientation=-a-np.pi/2, contrast=1, sigma=.5, spatial_aspect_ratio=1)).expand(3,224,224)
plt.imshow(generator_gabor(a)[0])
plt.axis('off')

plt.show()


mask = circular_mask(224, 100, radius = 1, polarity_out=0,polarity_in=1,
                     if_filtered=True, filter_size=(50, 50), filter_width=5)
generator_grating = lambda a: torch.from_numpy(np.multiply(
                            grating(size=224, pixelsPerDegree=100, spatial_freq=3, spatial_phase=0,
            orientation=a, contrast=1),mask)).expand(3,224,224)
plt.imshow(generator_grating(a)[0])
plt.axis('off')
plt.show()


$+\Delta\theta$

In [None]:
def generator_gabor(a, spatial_freq=2, spatial_phase=0):
    return torch.from_numpy(gabor(size=224, pixelsPerDegree=100, spatial_freq=spatial_freq,
                                                   spatial_phase=spatial_phase,
              orientation=-a-np.pi/2, contrast=1, sigma=.5, spatial_aspect_ratio=1)).expand(3,224,224)

In [None]:


def plot_fisher(model, title, N=10, n_angles=40, n_phases=1, generator=None, scale=None, savefig=None):
    cs = sns.color_palette('plasma', N)
    cmap = mpl.colors.LinearSegmentedColormap.from_list(
        'Custom cmap', cs, len(cs))

    for i in tqdm(range(N)):
        sqrt_fishers_resnet = np.sqrt(torch.stack(get_fisher_orientations(model,i,n_angles, n_images=n_phases,
                                                              generator=generator_gabor, delta=1e-2)).cpu().numpy())
        normed =  sqrt_fishers_resnet/np.sum(sqrt_fishers_resnet)
        plt.plot(np.linspace(0, 180, n_angles),normed,"-", label = "Layer {}".format(i+1),c=cs[i])
    plt.ylim(bottom=0, top = max(max(normed),2*np.mean(normed)))
    plt.ylabel(r"$\sqrt{J(\theta)}$ (normalized)", fontsize = 15)
    plt.xlabel("Angle (º)", fontsize = 15)

    plt.xticks(np.linspace(0,180,5))
    #     plt.title("Layer {}".format(i), fontsize=15)
    #     plt.show()

    if scale is not None:
        plt.ylim(scale)

    plt.title(title, fontsize=20)
    clb = plt.colorbar(cm.ScalarMappable(norm=mpl.colors.Normalize(vmin=0, vmax=N), cmap=cmap))
    clb.ax.set_title('Layer')
    plt.tight_layout()
    if savefig is not None:
        plt.savefig(savefig)
    plt.show()
    
def shuffle_all_weights(model):
    """Note: this destroys the model. In-place"""
    for m in model.parameters():
        
        #look at Conv2d
        if len(m.data.size()) > 3:
            n_channels = m.data.size()[0]
            
            for row in range(m.data.size()[2]):
                for col in range(m.data.size()[3]):
                    
                    # make it so each 
                    idx = torch.randperm(n_channels)
                    m.data[:,:,row,col] = m.data[idx,:,row,col]
                    
        elif len(m.data.size()) == 2:
            m.data = m.data[torch.randperm(m.data.size(0))]
            m.data = m.data[:, torch.randperm(m.data.size(1))]
        elif len(m.data.size()) == 1:
            m.data = m.data[torch.randperm(m.data.size(0))]
        elif len(m.data.size()) == 3:
            m.data = m.data[torch.randperm(m.data.size(0))]
            m.data = m.data[:, torch.randperm(m.data.size(1))]
            m.data = m.data[:, :, torch.randperm(m.data.size(2))]
                    
        
                        
    return model
            


## On Resnet

In [None]:
# initialization
model = models.resnet18(pretrained=False).cuda().eval()
plot_fisher(model, "Resnet18, initialization", 10, n_angles=180, n_phases=10)

In [None]:
model = models.resnet18(pretrained=True).cuda().eval()
plot_fisher(model, "Resnet18", 10, n_angles=180, n_phases=10, 
            savefig=None
           )



In [None]:
# shuffled weights
model = models.resnet18(pretrained=True).cuda().eval()
shuffle_all_weights(model)
plot_fisher(model, "Shuffled Resnet18", 10, n_angles=180, n_phases=10)


## Resnet34

In [None]:
# cs = sns.color_palette('plasma', 10)
# model = models.resnet34(pretrained=True)
# # checkpoint = torch.load(which_alexnet)
# # model.load_state_dict(checkpoint['state_dict'])

# model.cuda().eval()
# plot_fisher(model, "Resnet34", 10, n_angles=180, n_phases=1, savefig="figures/resnet34.pdf")


## Resnet50

In [None]:
# cs = sns.color_palette('plasma', 10)
# model = models.resnet50(pretrained=True)
# # checkpoint = torch.load(which_alexnet)
# # model.load_state_dict(checkpoint['state_dict'])

# model.cuda().eval()
# plot_fisher(model, "resnet50", 10, n_angles=180, n_phases=1, savefig="figures/resnet50.pdf")


## Now on VGG

In [None]:
# cs = sns.color_palette('plasma', 21)

# model = models.vgg11(pretrained=True)
# # checkpoint = torch.load(which_alexnet)
# # model.load_state_dict(checkpoint['state_dict'])

# model.cuda().eval()
# plot_fisher(model, "vgg11", 20, n_angles=180, n_phases=1, savefig="figures/vgg11.pdf")


In [None]:
# cs = sns.color_palette('plasma', 29)

# model = models.vgg11_bn(pretrained=True)
# # checkpoint = torch.load(which_alexnet)
# # model.load_state_dict(checkpoint['state_dict'])
# model.cuda().eval()
# plot_fisher(model, "vgg11_bn", 20, n_angles=180, n_phases=1, savefig="figures/vgg11_bn.pdf")


In [None]:
# cs = sns.color_palette('plasma', 35)

# model = models.vgg13_bn(pretrained=True)
# # checkpoint = torch.load(which_alexnet)
# # model.load_state_dict(checkpoint['state_dict'])

# model.cuda().eval()
# plot_fisher(model, "vgg13_bn", 35, n_angles=180, n_phases=1, savefig="figures/vgg13bn.pdf")


In [None]:

model = models.vgg16_bn(pretrained=False)
# checkpoint = torch.load(which_alexnet)
# model.load_state_dict(checkpoint['state_dict'])

model.cuda().eval()
plot_fisher(model, "vgg16_bn", 44, n_angles=180, n_phases=10)


In [None]:

model = models.vgg16_bn(pretrained=True)
# checkpoint = torch.load(which_alexnet)
# model.load_state_dict(checkpoint['state_dict'])

model.cuda().eval()
plot_fisher(model, "vgg16_bn", 44, n_angles=180, n_phases=10)


In [None]:

model = models.vgg16_bn(pretrained=True)
shuffle_all_weights(model)


model.cuda().eval()
plot_fisher(model, "vgg16_bn_shuff", 44, n_angles=180, n_phases=10, savefig="figures)


# Controls

These subpanels appear in the supplementary material. 

Note: Rotation controls require retraining a network on ImageNet, which can be expensive. The script for training a network on a modified ImageNet can be found as `train_rotated.py`. After training, snapshots of the model can be pointed to by this script.

## Alexnet

In [None]:
cs = sns.color_palette('plasma', 12)

model = models.alexnet(pretrained=True)
shuffle_all_weights(model)


model.cuda().eval()
plot_fisher(model, "alexnet_shuff", 12, n_angles=180, n_phases=10)

In [None]:
cs = sns.color_palette('plasma', 12)

model = models.alexnet(pretrained=True)
shuffle_all_weights(model)

model.cuda().eval()
plot_fisher(model, "alexnet", 12, n_angles=180, n_phases=10)

In [None]:
cs = sns.color_palette('plasma', 12)

model = models.alexnet(pretrained=False).cuda().eval()
plot_fisher(model, "alexnet_init", 12, n_angles=180, n_phases=10, scale=[.004,.008])

## Alexnet with no pooling.

What happens for Alexnet with no pooling? These require loading models defined in `Alexnet_nooverlap.py`.

In [None]:
from Alexnet_nooverlap import AlexNet_nooverlap, AlexNet_nopool

In [None]:
cs = sns.color_palette('plasma', 12)

model = AlexNet_nooverlap().cuda().eval()
# shuffle_all_weights(model)

plot_fisher(model, "alexnet_init_nooverlap", 12, 
            n_angles=180, n_phases=10, scale=[.004,.008],savefig="figures/alexnet_init_nooverlap2.pdf")

In [None]:
plot_fisher(model, "alexnet_init_nooverlap", 12, n_angles=180, n_phases=10)#, scale=[.004,.008],savefig="figures/alexnet_init_nooverlap.pdf")

In [None]:
cs = sns.color_palette('plasma', 12)

model = AlexNet_nopool().cuda().eval()

plot_fisher(model, "alexnet_init_nopool", 12, n_angles=180, n_phases=10, savefig="figures/alexnet_init_nopool.pdf")

In [None]:

model = AlexNet_nooverlap().cuda().eval()
state_dict = models.alexnet(pretrained=True).state_dict()
model.load_state_dict(state_dict)
# shuffle_all_weights(model)


plot_fisher(model, "alexnet_nooverlap_shuff", 12, n_angles=180, n_phases=10, savefig="figures/alexnet_nooverlap_loaded.pdf")


In [None]:

model = AlexNet_nooverlap().cuda().eval()
state_dict = models.alexnet(pretrained=True).state_dict()
model.load_state_dict(state_dict)
shuffle_all_weights(model)


plot_fisher(model, "alexnet_nooverlap_shuff", 12, n_angles=180, n_phases=10, savefig="figures/alexnet_nooverlap_shuff.pdf")


## Rotated

In [None]:

model = models.vgg16_bn(pretrained=False).cuda().eval()
checkpoint = torch.load('/data/abenjamin/DNN_illusions/vgg16_bn_rotated/checkpoint_epoch_90.pth.tar')
model.load_state_dict(checkpoint['state_dict'])

plot_fisher(model, "rotated vgg", 44, n_angles=180, n_phases=10, savefig="figures/vgg16_bn_rotated.pdf")


In [None]:

model = models.resnet18(pretrained=False).cuda().eval()
checkpoint = torch.load('/data/abenjamin/DNN_illusions/resnet18_rotated/checkpoint_epoch_90.pth.tar')
model.load_state_dict(checkpoint['state_dict'])

plot_fisher(model, "rotated Resnet18", 10, n_angles=180, n_phases=10, savefig="figures/resnet18_rotated.pdf")


In [None]:
def plot_rotated_difference(model, model_rotated, N=10, n_angles=40, n_phases=1, generator=None, savefig=None):
    cs = sns.color_palette('plasma', N)
    cmap = mpl.colors.LinearSegmentedColormap.from_list(
        'Custom cmap', cs, len(cs))
    
    for i in tqdm(range(N)):
        sqrt_fishers_resnet = np.sqrt(torch.stack(get_fisher_orientations(model,i,n_angles, n_images=n_phases,
                                                              generator=generator_gabor, delta=1e-2)).cpu().numpy())
        normed =  sqrt_fishers_resnet/np.sum(sqrt_fishers_resnet)
        
        sqrt_fishers_resnet_rotated = np.sqrt(torch.stack(get_fisher_orientations(model_rotated,i,n_angles, n_images=n_phases,
                                                              generator=generator_gabor, delta=1e-2)).cpu().numpy())
        normed_rot =  sqrt_fishers_resnet_rotated/np.sum(sqrt_fishers_resnet_rotated)
        
        diff = normed_rot-normed
        
        plt.plot(np.linspace(0, 180, n_angles),diff,"-", label = "Layer {}".format(i+1),c=cs[i])
        
    plt.ylim(bottom=-max(max(diff),-min(diff)), top = max(max(diff),-min(diff)))
    plt.plot([0,180],[0,0],'k-')
    plt.ylabel(r"$\Delta\sqrt{J(\theta)}$ (normalized)", fontsize = 15)
    plt.xlabel("Angle (º)", fontsize = 15)
    plt.xticks(np.linspace(0,180,5))

    clb = plt.colorbar(cm.ScalarMappable(norm=mpl.colors.Normalize(vmin=0, vmax=N), cmap=cmap))
    clb.ax.set_title('Layer')
    plt.tight_layout()
    if savefig is not None:
        plt.savefig(savefig)
    plt.show()

In [None]:
model = models.resnet18(pretrained=True).cuda().eval()

model_rotated = models.resnet18(pretrained=False).cuda().eval()
checkpoint = torch.load('/data/abenjamin/DNN_illusions/resnet18_rotated/checkpoint_epoch_90.pth.tar')
model_rotated.load_state_dict(checkpoint['state_dict'])

plot_rotated_difference(model, model_rotated, 10, n_angles=180, n_phases=10, savefig="figures/resnet18_rotated_diff.pdf")


In [None]:
model = models.vgg16_bn(pretrained=True).cuda().eval()


model_rotated = models.vgg16_bn(pretrained=False).cuda().eval()
checkpoint = torch.load('/data/abenjamin/DNN_illusions/vgg16_bn_rotated/checkpoint_epoch_90.pth.tar')
model_rotated.load_state_dict(checkpoint['state_dict'])

plot_rotated_difference(model, model_rotated, 44, n_angles=180, n_phases=10, savefig="figures/vgg16_bn_rotated_diff.pdf")
