In [1]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'svg'

## Import module

In [2]:
import os
import sys
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
import os
from torchsummary import summary
from torchvision import transforms                                                                                                                                        
import matplotlib as mpl
mpl.rcParams['figure.dpi'] = 600
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter
from mpl_toolkits.mplot3d import Axes3D
import pandas as pd
from scipy.interpolate import griddata
import numpy as np
import math
import plotly.graph_objects as go

sys.path.append("../../../src/")
sys.path.append("../../")
import model
from datasets import imagenet
from loss import FileterLoss
import config
from aux.utils import obtain_features_map, load_imgs, zscore, extract_valid
from aux.visualization import visualize_features_map_for_comparision
from utils.function import timethis

# 1. Load Model

In [3]:
resume = "037-0"
model_dir = "/home/lincolnzjx/Desktop/Interpretation/saved/models"
backbone = "vgg16"
# Load model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = model.Network(backbone=backbone, num_classes=1000)
net.to(device)
# resume from model
resume_exp = resume.split("-")[0]
resume_epoch = resume.split("-")[1]
print("Resume from model from exp: {} at epoch {}".format(resume_exp, resume_epoch))
resume_path = os.path.join(model_dir, str(resume_exp), str(resume_epoch))
ckpt = torch.load(resume_path, map_location=device)
net.load_state_dict(ckpt) 

> Use original fc
Resume from model from exp: 037 at epoch 0


<All keys matched successfully>

# Color Point Generated
* RGB Points, 256 * 256 * 256
* Feature Map Points

In [4]:
def fill_AArray_under_resolution(value, width, height):
    aCol = np.repeat(value, height, axis=0)
    array = np.repeat([aCol], width, axis=0)
    return array

def get_rgb_points_by_batch(batch_size=96, width=224, height=224, step=8):
    red_max = 256
    green_max = 256
    blue_max = 256
    counter = 0
    imgs = []
    dims = ((red_max-1)/ step +1) ** 3
    last_iteration = dims
    #last_iteration = red_max * green_max * blue_max -1
    xs = []
    ys = []
    zs = []
    
    for R in range(0, red_max, step):
        for G in range(0, green_max, step):
            for B in range(0, blue_max, step):
                rs = fill_AArray_under_resolution(R, width, height)
                gs = fill_AArray_under_resolution(G, width, height)
                bs = fill_AArray_under_resolution(B, width, height)
                imgs.append(np.dstack((rs, gs, bs)))
                xs.append(R)
                ys.append(G)
                zs.append(B)
                counter += 1
                if counter % batch_size == 0 or counter == last_iteration:
                    imgs = np.array(imgs, dtype=np.float32)
                    yield imgs, xs, ys, zs
                    del imgs, xs, ys, zs
                    imgs = []
                    xs = []
                    ys = []
                    zs = []

In [5]:
def obtain_selected_4D_featureMap(layer_output_indexes=None, 
                                  selected_filter=None, 
                                  batch_size=96, 
                                  step=8,
                                  method="max"):
    """Args:
        method: [max, median, mean]
    """
    rets = []
    if method == "max":
        sel = np.max
    elif method == "median":
        sel = np.median
    elif method == "mean":
        sel = np.mean
    else:
        print("No method")
        sys.exit(-1)
        
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    dataloader = get_rgb_points_by_batch(batch_size=batch_size, step=step)
    all_iterations =  math.ceil((255/ step +1) ** 3 / batch_size)
    # xs => r, gs => g, zs => b 
    xs = []
    ys = []
    zs = []
    ret = []
    for index, (imgs, x, y, z) in enumerate(dataloader):
        if index % 10 == 0 :
            print("[{}/{}]".format(index, all_iterations))
        xs.extend(x)
        ys.extend(y)
        zs.extend(z)
        data = zscore(imgs, mean, std)
        data = torch.tensor(data).to(device)
        layer_output, _ = obtain_features_map(data, 
                                              net.model.features, 
                                              layer_output_indexes=layer_output_indexes) 
        ret = sel(layer_output[0][:, selected_filter], axis=(1, 2))
        rets.extend(ret)
        del layer_output
    del data
    rets = np.array(rets)
    xs = np.array(xs, dtype=np.float32)
    ys = np.array(ys, dtype=np.float32)
    zs = np.array(zs, dtype=np.float32)
    return xs, ys, zs, rets

In [65]:
def plot4DFigure(X=None, Y=None, Z=None, values=None,color="RdBu", 
                 normalize=False, set_one=0, exp_activation=False,
                 title=None, opacity=0.9):
    X = X.copy()
    Y = Y.copy()
    Z = Z.copy()
    values = values.copy()
    if normalize:
        min_value = np.min(values)
        max_value = np.max(values)
        values = (values - min_value) / (max_value-min_value)
    if exp_activation:
        values = np.exp(values)
    if set_one:
        values[values!=0] = 1
#         X = X[values!=0]
#         Y = Y[values!=0]
#         Z = Z[values!=0]
#         values = values[values!=0]
#         values = np.ones_like(values)
#         print(X.shape)
#         print(Y.shape)
#         print(Z.shape)
#         print(values.shape)
#         print(X)
#         print(Y)
#         print(Z)
#         print(values)
    max_value = np.max(values)
    fig = go.Figure(
         data=go.Volume(
             x=X.flatten(),
             y=Y.flatten(),
             z=Z.flatten(),
             value=values.flatten(),
             isomin=0.0,
             isomax=max_value,
             opacity=opacity, # max opacity
             colorscale=color,
             surface_count=100, # needs to be a large number for good volume rendering
    #         caps= dict(x_show=True, y_show=True, z_show=True, x_fill=1),
    caps= dict(x_show=False, y_show=False, z_show=False), # no caps

     ))
    fig.update_layout(
        title=title,
        scene = dict(
            xaxis = dict(
                title='Red'),
            yaxis = dict(
                title='Green'),
            zaxis = dict(
                title='Blue')))
    fig.update_layout(
        scene = dict(
            xaxis = dict(nticks=4, range=[0,255],),
                     yaxis = dict(nticks=4, range=[0,255],),
                     zaxis = dict(nticks=4, range=[0,255],),))
    fig.show()

# Plot 4D Figure

### hyper-parameters setting here!!

In [32]:
layer_output_indexes = [3]
selected_filter = 20
batch_size = 256
step = 15 # [5,15,17,51]

### Max

In [10]:
method="max"
xs, ys, zs, rets = obtain_selected_4D_featureMap(layer_output_indexes=layer_output_indexes,
                                                 selected_filter=selected_filter,
                                                 batch_size=batch_size, 
                                                 step=step,
                                                 method=method)

[0/23]
[10/23]
[20/23]


In [63]:
plot4DFigure(X=xs, Y=ys, Z=zs, values=rets, set_one=False, normalize=False,
             title="RGB-Activation by L1F47 via {} and set one {}".format(method, "False"))

In [67]:
plot4DFigure(X=xs, Y=ys, Z=zs, values=rets, set_one=True, opacity=1,
             title="RGB-Activation by L1F47 via {} and set one {}".format(method, "True"))

### Mean

In [13]:
method = "mean"
xs, ys, zs, rets = obtain_selected_4D_featureMap(layer_output_indexes=layer_output_indexes,
                                                 selected_filter=selected_filter,
                                                 batch_size=batch_size, 
                                                 step=step,
                                                 method=method)

[0/23]
[10/23]
[20/23]


In [14]:
plot4DFigure(X=xs, Y=ys, Z=zs, values=rets, set_one=False,
             title="RGB-Activation by L1F47 via {} and set one {}".format(method, "False"))

In [15]:
plot4DFigure(X=xs, Y=ys, Z=zs, values=rets, set_one=True,
             title="RGB-Activation by L1F47 via {} and set one {}".format(method, "True"))

### Median

In [16]:
method="median"
xs, ys, zs, rets = obtain_selected_4D_featureMap(layer_output_indexes=layer_output_indexes,
                                                 selected_filter=selected_filter,
                                                 batch_size=batch_size, 
                                                 step=step,
                                                 method=method)

[0/23]
[10/23]
[20/23]


In [17]:
plot4DFigure(X=xs, Y=ys, Z=zs, values=rets, set_one=False,
             title="RGB-Activation by L1F47 via {} and set one {}".format(method, "False"))

In [18]:
plot4DFigure(X=xs, Y=ys, Z=zs, values=rets, set_one=True,
             title="RGB-Activation by L1F47 via {} and set one {}".format(method, "True"))