# Sensitivity to hue (HSV) of deep networks

In [None]:
import numpy as np
import torch
from matplotlib import pyplot as plt
from tqdm import tqdm as tqdm
from matplotlib import cm
import matplotlib as mpl
import pickle
import pandas as pd
%matplotlib inline

import seaborn as sns

plt.style.use('seaborn-white', )
plt.rcParams['axes.labelsize'] =  25
plt.rcParams['ytick.labelsize'] = 15.0
plt.rcParams['xtick.labelsize'] = 15.0

import sys
sys.path.insert(0, "..")

from scripts.fisher_calculators import get_fisher_orientations, get_fisher_hues

from torchvision import models, transforms, utils

import colorsys
from colorspacious import cspace_convert

In [None]:
def generator_hsv(hue):
    "in deg"
    
    c = colorsys.hsv_to_rgb(hue/360., 1., .8)

    arr = np.ones((224,224,3)) * c
    
    return torch.from_numpy(arr).permute(2,0,1).to(torch.float)

In [None]:
def generator_hsv_marginalized(hue):
    "in deg"
    
    c = np.array([[colorsys.hsv_to_rgb(hue/360., s, v) 
                       for s in np.linspace(.5,.9, 8)]
                          for v in np.linspace(.5,.9, 8)])
#     print(c.shape)
    arr = c.reshape(64,3,1,1)*np.ones((3,224,224))
    
    return torch.from_numpy(arr).to(torch.float)

In [None]:
sns.palplot([(generator_hsv(hue)[:,0,0]).numpy() for hue in range(0,360,2)])

In [None]:
def simpleaxis(ax):
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.get_xaxis().tick_bottom()
    ax.get_yaxis().set_visible(False)
    ax.xaxis.set_tick_params(size=6)
#     ax.yaxis.set_tick_params(size=6)

In [None]:
def get_fisher(model, title, N=10, n_colors=360, scale=None, savefig=None,generator=generator_hsv_marginalized ):

    fishers = []
    for i in tqdm(range(N)):
        sqrt_fishers_resnet = np.sqrt(torch.stack(get_fisher_hues(model,i,n_colors, delta=.1,
                                                                  generator=generator)).cpu().numpy())
        normed =  sqrt_fishers_resnet/np.sum(sqrt_fishers_resnet)
        fishers.append(normed)
    return fishers


In [None]:
def plot_fisher(model, title, N=10, n_colors=360, scale=None, savefig=None,generator=generator_hsv_marginalized ):
    cs = sns.color_palette('plasma', N)
    cmap = mpl.colors.LinearSegmentedColormap.from_list(
        'Custom cmap', cs, len(cs))

    for i in tqdm(range(N)):
        sqrt_fishers_resnet = np.sqrt(torch.stack(get_fisher_hues(model,i,n_colors, delta=.1,
                                                                  generator=generator)).cpu().numpy())
        normed =  sqrt_fishers_resnet/np.sum(sqrt_fishers_resnet)
        plt.plot(np.linspace(0, 360, n_colors),normed,"-", label = "Layer {}".format(i+1),c=cs[i])
    plt.ylim(bottom=0, top = max(max(normed),2*np.mean(normed)))
    plt.ylabel(r"$\sqrt{J(\theta)}$ (normalized)", fontsize = 15)
    plt.xlabel("Hue (ยบ)", fontsize = 15)

    plt.xticks(np.linspace(0,360,5))
    #     plt.title("Layer {}".format(i), fontsize=15)
    #     plt.show()

    if scale is not None:
        plt.ylim(scale)

    plt.title(title, fontsize=20)
    clb = plt.colorbar(cm.ScalarMappable(norm=mpl.colors.Normalize(vmin=0, vmax=N), cmap=cmap))
    clb.ax.set_title('Layer')
    plt.tight_layout()
    if savefig is not None:
        plt.savefig(savefig)
    plt.show()

In [None]:
def plot_fisher_spectral(model, title, N=10, scale=None, savefig=None,):
    cs = sns.color_palette('plasma', N)
    cmap = mpl.colors.LinearSegmentedColormap.from_list(
        'Custom cmap', cs, len(cs))
    r= range(450,650,5)
    for i in tqdm(range(N)):
        sqrt_fishers_resnet = np.sqrt(torch.stack(get_fisher_hues(model,i, delta=.1, hues = r,
                                                                  generator=generator_spectral)).cpu().numpy())
        normed =  sqrt_fishers_resnet/np.sum(sqrt_fishers_resnet)
        plt.plot(r,normed,"-", label = "Layer {}".format(i+1),c=cs[i])
    plt.ylim(bottom=0, top = max(max(normed),2*np.mean(normed)))
    plt.ylabel(r"$\sqrt{J(\theta)}$ (normalized)", fontsize = 15)
    plt.xlabel("Approx. wavelength (nm)", fontsize = 15)

    plt.xticks(range(450,651,50))
    #     plt.title("Layer {}".format(i), fontsize=15)
    #     plt.show()

    if scale is not None:
        plt.ylim(scale)

    plt.title(title, fontsize=20)
    clb = plt.colorbar(cm.ScalarMappable(norm=mpl.colors.Normalize(vmin=0, vmax=N), cmap=cmap))
    clb.ax.set_title('Layer')
    plt.tight_layout()
    if savefig is not None:
        plt.savefig(savefig)
    plt.show()

## Get distance in CIELAB space

In [None]:
# get array of hsv 
colors = np.ones((360,3))
colors[:,0]=np.linspace(0,1,360)
colors[:,2]=1
# pass to rgb
colors_rgb = [colorsys.hsv_to_rgb(*c) for c in colors]

# then cielab
colors = cspace_convert(colors_rgb, "sRGB1", "CIELab")
#get diff
diff = []
for i in range(1,len(colors)):
    diff.append(np.linalg.norm(colors[i-1]-colors[i]))
plt.plot(diff)
plt.ylabel("Perceptual distance", fontsize = 20)
plt.xlabel("Hue in HSV (ยบ)", fontsize = 20)

plt.xticks(np.linspace(0,360,5))
plt.yticks([])

#     plt.title("Layer {}".format(i), fontsize=15)
#     plt.show()
simpleaxis(plt.gca())


plt.xlim([0,360])
# plt.title(title, fontsize=20)
plt.tight_layout()
plt.savefig("CIELAB_hue.pdf")
plt.show()

# Resnet

In [None]:
model = models.resnet18(pretrained=True).cuda().eval()
fishers = get_fisher(model, "Resnet18", 10, generator=generator_hsv,
            n_colors=360)

In [None]:

def plot_precomputed_fishers(fishers, title, N=10, n_colors=360, scale=None, savefig=None,generator=generator_hsv_marginalized ):
    cs = sns.color_palette('plasma', N)
    cmap = mpl.colors.LinearSegmentedColormap.from_list(
        'Custom cmap', cs, len(cs))

    for i in tqdm(range(N)):
        
        plt.plot(np.linspace(0, 360, n_colors),fishers[i],"-", label = "Layer {}".format(i+1),c=cs[i])
    plt.ylim(bottom=0, top = max(max(fishers[i]),2*np.mean(fishers[i])))
    plt.ylabel(r"$\sqrt{J(\theta)}$ (normalized)", fontsize = 15)
    plt.xlabel("Hue (ยบ)", fontsize = 15)

    plt.xticks(np.linspace(0,360,5))
    #     plt.title("Layer {}".format(i), fontsize=15)
    #     plt.show()

    if scale is not None:
        plt.ylim(scale)
    simpleaxis(plt.gca())
    plt.xlim([0,360])
    plt.title(title, fontsize=20)
    clb = plt.colorbar(cm.ScalarMappable(norm=mpl.colors.Normalize(vmin=0, vmax=N), cmap=cmap))
    clb.ax.set_title('Layer')
    plt.tight_layout()
    if savefig is not None:
        plt.savefig(savefig)
    plt.show()

In [None]:
plot_precomputed_fishers(fishers, None, 10, generator=generator_hsv,
            n_colors=360, savefig=None)

## On rotated hue. 

Requires retraining a network. Look to training script at `train_rotated.py`.

In [None]:
path = '/data/abenjamin/DNN_illusions/resnet18_rotated_hue/checkpoint_epoch_{}.pth.tar'.format(90)
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['state_dict'])
rotated_fishers = get_fisher(model, "Resnet18", 10, generator=generator_hsv_marginalized,
                        n_colors=360)


In [None]:
plot_precomputed_fishers(rotated_fishers, "Resnet18", 10, generator=generator_hsv_marginalized,
                        n_colors=360, savefig="figures/resnet18_hue_rotated_marginalized.pdf")