In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import time
import os
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from collections import defaultdict
import pickle

from color_regions import *
from network import *
from visualizations import *
from utils import *
from hooks import *
from config_objects import *
from training import *

# set up autoreloading of shared code
%load_ext autoreload
%autoreload 1
%aimport color_regions,network,visualizations,utils,hooks,config_objects,training
%aimport

torch.backends.cudnn.benchmark = True
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
train_indices = (0, 250_000) # size of training set
valid_indices = (1_250_000, 1_275_000)
test_indices = (3_260_000, 3_560_000)

critical_color_values = list(range(0,241,30))

dset_config = ColorDatasetConfig(task_difficulty="hard",
                                 noise_size=(1,9),
                                 num_classes=3,
                                 num_objects=1,  # => permuted
                                 radius=(1/8., 1/7.),
                                 device=device,
                                 batch_size=128)

# copies the config each time
train_set = ColorDatasetGenerator(train_indices, dset_config)
valid_set = ColorDatasetGenerator(valid_indices, dset_config)
test_set = ColorDatasetGenerator(test_indices, dset_config)
# train_set.cfg.infinite = True

In [None]:
# the "hard" task
color_probe = np.linspace(0, 255, 255)
color_class = [color_classifier(x) for x in color_probe]
plt.plot(color_probe, color_class)
plt.xlabel("Color")
plt.yticks([0, 1, 2])
plt.ylabel("Class")

In [None]:
tiny_config = ExperimentConfig(layer_sizes=[[2, 3, 4], [6, 3, 4]], 
                                    learn_rate=0.01, weight_decay=2e-03, 
                                    gain=0.05, epochs=50)
tiny_net = ResNet("tiny_net_small_circles.dict", tiny_config, dset_config).to(device)
loss_func = nn.CrossEntropyLoss()
tiny_optim = torch.optim.Adam(tiny_net.parameters())
print(tiny_net.num_params())
tiny_net.load_model_state_dict(optim=tiny_optim)

In [None]:
unstrided_config = ExperimentConfig(layer_sizes=[[16, 3, 1], [32, 3, 1]], 
                                    learn_rate=0.01, weight_decay=2e-03, 
                                    gain=0.05, epochs=50)

# unstrided_net = ResNet("models/corrected_unstrided_small_circles.dict", unstrided_config, dset_config).to(device)
# loss_func = nn.CrossEntropyLoss()
# unstrided_optim = torch.optim.Adam(unstrided_net.parameters())
# print(unstrided_net.num_params())
# unstrided_net.load_model_state_dict(optim=unstrided_optim)

In [None]:
unstrided_net = ResNet("small_circles_unpermuted/0.0001_8.858667904100833e-07_True_medium_size_0.1_25.dict",
                      unstrided_config, dset_config)
unstrided_net.load_model_state_dict()
unstrided_net.to(dset_config.device)

In [None]:
evaluate(tiny_net, nn.CrossEntropyLoss(), test_set)

In [None]:
evaluate(unstrided_net, nn.CrossEntropyLoss(), test_set)

In [None]:
import copy
old_config = copy.copy(unstrided_config)
old_config.global_avg_pooling = False
old_unstrided_net = ResNet("models/unstrided_small_circles.dict",
                      unstrided_config, dset_config)
print(old_unstrided_net)
old_unstrided_net.load_model_state_dict()
old_unstrided_net.to(dset_config.device)

In [None]:
print(unstrided_net)

In [None]:
results = train(unstrided_net, unstrided_optim, loss_func, 40, train_loader, valid_loader, device=device,)

In [None]:
evaluate(unstrided_net, loss_func, valid_loader, device=device)

In [None]:
evaluate(tiny_net, loss_func, valid_loader, device=device)

In [None]:
results = train(tiny_net, tiny_optim, loss_func, 1000, train_loader, valid_loader, device=device)

In [None]:
interp_net = AllActivations(tiny_net)

In [None]:
np.random.seed(5_123_456)
test_img, lbl, color, *size, pos  = valid_set.generate_one()
print(color)
plt.imshow(test_img, cmap="gray")
tensor_test_img = tensorize(test_img, device=device)
plt.xticks([])
plt.yticks([])

In [None]:
interp_net.eval()
interp_net(tensor_test_img)

In [None]:
first_conv = interp_net._features["conv_blocks.0.conv1"].detach().cpu().numpy().squeeze()
first_conv_weights = dict(tiny_net.named_modules())["conv_blocks.0.conv1"].weight.detach().cpu().numpy().squeeze()
print(dict(tiny_net.named_modules())["conv_blocks.0.conv1"].bias)
fig = plt.figure(figsize=(4*2, 5*2))
plt.subplot(3,2,1)
imshow_centered_colorbar(test_img, "bwr", "original_image")
plt.subplot(3,2,3)
imshow_centered_colorbar(first_conv[0], "bwr", "output conv1_0.0")
plt.subplot(3,2,4)
imshow_centered_colorbar(first_conv[1], "bwr", "output conv1_0.1")
plt.subplot(3,2,5)
imshow_centered_colorbar(first_conv_weights[0], "bwr", "weights of conv1_0.0")
plt.subplot(3,2,6)
imshow_centered_colorbar(first_conv_weights[1], "bwr", "weights of conv1_0.1")
# => first layer basically just computes a compressed version of original, twice

In [None]:
bn1_params = dict(tiny_net.named_modules())["conv_blocks.0.batch_norm1"]
print(bn1_params.weight, bn1_params.bias)
print(bn1_params.running_mean, bn1_params.running_var)
first_batchnorms = interp_net._features["conv_blocks.0.batch_norm1"].detach().cpu().numpy().squeeze()
fig = plt.figure(figsize=(6, 5))
plt.subplot(2,2,1)
imshow_centered_colorbar(first_conv[0], "bwr", "output conv1_0.0")
plt.subplot(2,2,3)  # conv{1,2}_{layer_num}.{channel_index}
imshow_centered_colorbar(first_conv[1], "bwr", "output of conv1_0.1")

plt.subplot(2,2,2)
imshow_centered_colorbar(first_batchnorms[0], "bwr", "output batchnorm1_0.0")
plt.subplot(2,2,4)
imshow_centered_colorbar(first_batchnorms[1], "bwr", "output batchnorm1_0.1")
# conv of circle (which we just preserve its shape with our conv1) must exceed
# the bias else it gets zero-ed out => gives us 1 boundary on the color. 
# eg. look at channel 1. we multiply the raw value by 7.5, and then subtract 270
# (note that the bias on channel 1 is basically 0), and divide by 553
# then multiply by 1, and subtract 0.8176 => any color value above -23 will be > 0
# for channel 0, it turns out any color value above +25 will be > 0 => already
# separating on that first non-linearity

In [None]:
second_conv = interp_net._features["conv_blocks.0.conv2"].detach().cpu().numpy().squeeze()
second_conv_weights = dict(tiny_net.named_modules())["conv_blocks.0.conv2"].weight.detach().cpu().numpy().squeeze()
fig = plt.figure(figsize=(6, 7))

for m in range(2):
    plt.subplot(3,2,m+1)
    imshow_centered_colorbar(second_conv[m], "bwr", f"out conv2_1.{m}")
    plt.subplot(3,2,m+3)
    imshow_centered_colorbar(second_conv_weights[m][0], "bwr", f"w 2_0.0->{m}")
    plt.subplot(3,2,m+5)
    imshow_centered_colorbar(second_conv_weights[m][1], "bwr", f"w 2_0.1->{m}")
# both paths have a "just recompute/compress the image (identity mapping learned?)", though
# 1 shifts it up a bit (not sure how relevant this is, but you can actually see it in the image)
# very curve detector-like filters as well in both paths
# so channel 0 is upper-right curves, unsure what the bright pixel in lower left of w_2_0.0 is
# but the other path doesn't have it, so maybe not important?

In [None]:
c = 0
block = 1

#uniform_inpt = torch.full((1,16,32,32), 100.0).to(device)
#plt.imshow(tiny_net.conv_blocks[0].conv2.weight[c, in_c].detach().cpu().numpy(), cmap="bwr")
conv_maps = tiny_net.conv_blocks[block].conv2.weight[c, :]
#imshow_centered_colorbar(conv_maps[7].detach().cpu().numpy(), cmap="bwr")
conv_scale = conv_maps.max(axis=-1).values.max(axis=-1).values
conv_shift = tiny_net.conv_blocks[block].conv2.bias[c]
bn_scale = tiny_net.conv_blocks[block].batch_norm2.weight[c]
bn_shift = tiny_net.conv_blocks[block].batch_norm2.bias[c]
bn_var = tiny_net.conv_blocks[block].batch_norm2.running_var[c]
bn_mean = tiny_net.conv_blocks[block].batch_norm2.running_mean[c]
print(conv_shift, bn_scale, bn_shift, bn_var, bn_mean)
#(c*conv_scale + conv_shift - bn_mean) / torch.sqrt(bn_var) * bn_scale + bn_shift
slope = (conv_scale/torch.sqrt(bn_var)*bn_scale).detach().cpu().numpy()
bias = ((conv_shift - bn_mean)/torch.sqrt(bn_var)*bn_scale + bn_shift).detach().cpu().numpy()

lines = np.asarray([profile_plots[f"conv_blocks.{block}.act_func1_{x}"][0] for x in range(6)])

uniform_scaling = slope.dot(lines) + bias


In [None]:
unstrided_net.eval()
profile_plots,_ = activation_color_profile(AllActivations(unstrided_net), valid_set)

In [None]:
interp_net.eval()
profile_plots,_ = activation_color_profile(AllActivations(tiny_net), valid_set)

In [None]:
show_conv_weights(tiny_net, "conv_blocks.0.act_func1", color_profile=profile_plots)

In [None]:
show_conv_weights(interp_net, "conv_blocks.0.act_func1", color_profile=profile_plots)

In [None]:
show_conv_weights(interp_net, "conv_blocks.0.act_func2", color_profile=profile_plots)

In [None]:
show_conv_weights(interp_net, "conv_blocks.1.act_func1", color_profile=profile_plots)

In [None]:
plt.rcParams.update({'font.size': 12})
show_profile_plots(profile_plots, "conv_blocks.1.act_func2", size_mul=0.7, 
                   fixed_height=False, rm_border=False, hide_ticks=False)
plt.savefig("mid_conv_intensity.png")

In [None]:
show_conv_weights(interp_net, "conv_blocks.1.act_func2", color_profile=profile_plots, size_mul=2.67, fixed_height=True)

In [None]:
show_conv_layer(interp_net, "conv_blocks.1.act_func2")

In [None]:
second_conv = interp_net._features["conv_blocks.1.conv1"].detach().cpu().numpy().squeeze()
second_conv_weights = dict(tiny_net.named_modules())["conv_blocks.1.conv1"].weight.detach().cpu().numpy().squeeze()
fig = plt.figure(figsize=(14, 7))

for m in range(6):
    plt.subplot(3,6,m+1)
    imshow_centered_colorbar(second_conv[m], "bwr", f"out conv1_1.{m}")
    plt.subplot(3,6,m+7)
    imshow_centered_colorbar(second_conv_weights[m][0], "bwr", f"w 1_1.0->{m}")
    plt.subplot(3,6,m+13)
    imshow_centered_colorbar(second_conv_weights[m][1], "bwr", f"w 1_1.1->{m}")


In [None]:
bn3_params = dict(tiny_net.named_modules())["conv_blocks.1.batch_norm1"]
print(bn3_params.weight, bn3_params.bias)
print(bn3_params.running_mean, bn3_params.running_var)
third_batchnorms = interp_net._features["conv_blocks.1.batch_norm1"].detach().cpu().numpy().squeeze()
fig = plt.figure(figsize=(16, 5))
for m in range(6):
    plt.subplot(2,6,m+1)
    imshow_centered_colorbar(second_conv[m], "bwr", f"output conv1_1.{m}")
    plt.subplot(2,6,m+7)
    imshow_centered_colorbar(third_batchnorms[m], "bwr", f"output batchnorm1_1.{m}")
# it appears the only important channel at this point is 2. channels 0,1 looks like it was
# close to being important, but failed some color check. channel 5 I don't really
# understand since it appears to have picked up some signal that wasnt there before?
# (I suppose the mean is negative, and the scale is larger than 1 so it would expand any
# slight differences that existed but weren't visible?). Channel 4 I think is also trying
# to be a circle finder (upper right?), but failed color check as well. Channel 3 is also
# looking like it just barely failed the color check. Actually, looking at channel 1 again, 
# its output after a ReLU I expect would look exactly like channel 4 right now, so
# channel 4 is definitely a "failed color check"

In [None]:
second_conv = interp_net._features["conv_blocks.1.conv2"].detach().cpu().numpy().squeeze()
second_conv_weights = dict(tiny_net.named_modules())["conv_blocks.1.conv2"].weight.detach().cpu().numpy().squeeze()
fig = plt.figure(figsize=(15, 10))

for m in range(6):
    plt.subplot(7,6,m+1)
    imshow_centered_colorbar(second_conv[m], "bwr", f"out conv2_1.{m}")
    for k in range(6):
        plt.subplot(7,6,m+1+(k+1)*6)
        imshow_centered_colorbar(second_conv_weights[m][k], "bwr", f"w 2_1.{k}->{m}",
                                colorbar=True)
    # Channel 0 is basically saying "cancel out everything except for channel 4 in prev layer"
    # So it should basically copy its value (which it does). Channel 2 is similar, though it 
    # appears to copy from channel 1, and 4 a bit. At the end of it, channel 4 ends
    # up being the most active, since it has that strong positive edge detector with 
    # channel 2 in the previous layer. Channel 1 also does decently well, but its circle
    # has been thoroughly zeroed out, and only an "artifact-like" row of brightness 
    # remains at the top edge

In [None]:
bn4_params = dict(tiny_net.named_modules())["conv_blocks.1.batch_norm2"]
print(bn4_params.weight, bn4_params.bias)
print(bn4_params.running_mean, bn4_params.running_var)
fourth_batchnorms = interp_net._features["conv_blocks.1.batch_norm2"].detach().cpu().numpy().squeeze()
fig = plt.figure(figsize=(16, 5))
for m in range(6):
    plt.subplot(2,6,m+1)
    imshow_centered_colorbar(second_conv[m], "bwr", f"output conv2_1.{m}")
    plt.subplot(2,6,m+7)
    imshow_centered_colorbar(fourth_batchnorms[m], "bwr", f"output batchnorm2_1.{m}")
# so again, we see somewhat of a "direction reversal" in channel5 (pretty much because of 
# the positive bias (compared to the other biases, which are all negative), but channel 4 mostly
# seems to be the winner here. The "artifact" bright top row of channel 1 is mostly negated, 
# (we actually see those weird rows in multiple conv maps here, could be an artifact
# of the padding/striding method maybe?)

In [None]:
fc_weights = dict(tiny_net.named_modules())["fully_connected.0.fully_connected"].weight.detach().cpu().numpy().squeeze()
fig = plt.figure(figsize=(14, 7))

for m in range(6):
    plt.subplot(4,6,m+1)
    imshow_centered_colorbar(fourth_batchnorms[m], "bwr", f"out batchnorm2_1.{m}") # no ReLU
    for fc_m in range(3):
        fc_shaped = fc_weights[fc_m].reshape(6,8,8)[m]
        result = (np.where(fourth_batchnorms[m]>0, fourth_batchnorms[m], 0)*fc_shaped).sum()
        plt.subplot(4,6,m+1+(fc_m+1)*6)
        imshow_centered_colorbar(fc_shaped, "bwr", f"{result:.2f}")
    # the stupid edge lines actually seem to be getting used somehow (see bottom row, which
    # is used to predict class 2). Some of these maps are just "find a circle-ish thing in 
    # the center". Probably makes sense that the "best" place to put your circle checker is 
    # right in the middle, because most circles are at least overlapping the middle, due
    # to the data generation process. Some of these maps appear to do nothing, eg.
    # the map for predicting class 1 ignores channel 4. Although maybe there is some
    # "antipodal" symmetry between class 1 channel 4 and class 0 channel 4 => channel 4 
    # gives a lot of info for class 1??, though im not sure why you would only highlight
    # one pixel inside them (we see the same pattern used in channel 2, and actually in a lot of
    # the channels) => channel 0 is like "positive evidence for class 1, negative evidence for
    # class 0"
    
    # note that its actually the same classes that are in superpositon:
    # eg. for class 0,1 we have superposition in channel 0, channel 4
    #     for class 0,2 we have superpositon in channel 1,2,3,5
    
    # also we arguably have a "1-map" type thing occuring in many of the channels. For example,
    # in channel 4, it sort of looks like that for class 0 and class 2, 

In [None]:
show_fc_conv(interp_net, color_profile=profile_plots, fixed_height=True, full_gridspec=True)

In [None]:
tiny_net.fully_connected[0].fully_connected.bias

In [None]:
plt.rcParams.update({'font.size': 12})
plt.figure(figsize=(8,8))
show_fc_conv(interp_net, color_profile=profile_plots, fixed_height=True, full_gridspec=True)
plt.savefig("intensity_profile.pdf")

In [None]:
%matplotlib notebook
feature_gram, projected_weights = visualizations.fc_conv_feature_angles(tiny_net, 
                            "fully_connected.0.act_func", num_embed=3, normalize=True)

# Which pixels do we care about more?

In [None]:
%matplotlib inline
def patch_pixel_response():
    img, lbl, *_, (noise_size, noise_clr, noise_loc) = valid_set.generate_one()
    mask = np.where((noise_clr > 5) & (noise_size > 1))
    offsets = np.random.randint(0, noise_size[mask,None], 
                                (len(noise_size[mask]),2))
    selection = noise_loc[mask] + offsets
    unif = 2*valid_set.radius[1], valid_set.size-2*valid_set.radius[1]
    interior_mask = np.where((unif[0] <= selection[:,0]) & (selection[:,0] <= unif[1]) & 
                             (unif[0] <= selection[:,1]) & (selection[:,1] <= unif[1]))
    selection = selection[interior_mask]

    num_locs = len(selection)
    stacked_img = np.repeat(img[None,...], num_locs, axis=0)
    tensor_img = tensorize(stacked_img, device=device)
    #colors = np.arange(255)
    maxs = torch.full((num_locs,), -torch.inf).to(device)
    mins = torch.full((num_locs,), torch.inf).to(device)
    zeros = torch.zeros((num_locs,1), requires_grad=True).to(device)
    #for clr in colors:
    tensor_img[np.arange(num_locs), :, selection[:,0], selection[:,1]] += zeros
    response = tiny_net(tensor_img, logits=True)[:,lbl.argmax()].sum()
    grads = torch.autograd.grad(response, zeros)[0]
    sizes = noise_size[mask][interior_mask]
    return sizes, abs(grads)

def avg_patch_response(runs=1000):
    tiny_net.eval()
    samples = [[] for _ in range(valid_set.noise_size[1] - valid_set.noise_size[0] - 1)]
    
    for _ in tqdm(range(runs)):
        sizes,diffs = patch_pixel_response()
        for sz,diff in zip(sizes, diffs.cpu().numpy()):
            samples[sz-2].append(diff)  # -2 because we ignore sizes of 0 and 1
    return samples
result4 = avg_patch_response(runs=1_000_000)

In [None]:
for i, x in enumerate(result):
    x.extend(result3[i])

In [None]:
std_error_mean = [np.std(x)/np.sqrt(len(x)) for x in result]
means = [np.mean(x) for x in result]
plt.errorbar(list(range(2,9)), means, std_error_mean, capsize=2)
plt.xlabel("Region Size")
plt.ylabel("Average network response diff")

In [None]:
std_error_mean = [np.std(x)/np.sqrt(len(x)) for x in result4]
means = [np.mean(x) for x in result4]
plt.errorbar(list(range(2,9)), means, std_error_mean, capsize=2)
plt.xlabel("Region Size")
plt.ylabel("Average absolute gradient")

In [None]:
both_pixels_response(tiny_net, valid_set, 1, 5,# img_id=5125,
                     one_class=True, outer=False, device=device)

In [None]:
result2 = visualizations.region_importance(tiny_net, valid_set, batch_size=512, device=device)

In [None]:
result2 = visualizations.region_importance(tiny_net, valid_set, batch_size=512, 
                                           device=device, runs=10_000)

In [None]:
visualizations.both_pixels_response(tiny_net, valid_set, 1, 3, one_class=True, 
                                    img_id=987_650, outer=True, device=device, batch_size=526)

In [None]:
visualizations.both_pixels_response(tiny_net, valid_set, 650, 10, one_class=True, 
                                    img_id=987_650, outer=False, device=device, batch_size=526)

In [None]:
visualizations.plot_region_importance(*result2)

# PCA Saliency

In [None]:
default_scales = [3,5,7,9,13,15]
if 1: 
    %store -r small_pca_directions_1_stride small_pca_directions_s_stride
else:
    small_pca_directions_1_stride = find_pca_directions(valid_set, 16384, default_scales, 1)
    small_pca_directions_s_stride = find_pca_directions(valid_set, 16384, default_scales, default_scales)
    %store small_pca_directions_1_stride small_pca_directions_s_stride

In [None]:
seeds = [1_2123, 1_40_124, 1_508_559, 1_5_019_258, 1_2_429_852, 9032, 5832, 12, 5014, 92, 42, 52, 
         52_934, 935_152, 1_000_000, 1_000_001, 27, 24, 512, 999_105]  # 20 

In [None]:
seeds = [1_2123, 1_40_124, 9032, 1_5_019_258]

In [None]:
pca_map_s_strides, _, grad_maps, explain_imgs = generate_many_pca(tiny_net, seeds, 
                small_pca_directions_1_stride, default_scales, valid_set, component=0, 
                batch_size=512, strides=3, skip_1_stride=True, device=device)

In [None]:
plt_grid_figure([explain_imgs, pca_map_s_strides, grad_maps], transpose=True, titles=["Image", "Strides=3", "Gradient"])

In [None]:
pca_map_s_strides, _, grad_maps, explain_imgs = generate_many_pca(unstrided_net, seeds, 
                small_pca_directions_1_stride, default_scales, valid_set, component=0, 
                batch_size=128, strides=3, skip_1_stride=True, device=device)

In [None]:
plt_grid_figure([explain_imgs, pca_map_s_strides, grad_maps], transpose=True, titles=["Image", "Strides=3", "Gradient"])

In [None]:
guided_net = GuidedBackprop(unstrided_net)
guided_pca_map_s_strides, _, guided_grad_maps, explain_imgs = generate_many_pca(guided_net, seeds, 
                small_pca_directions_1_stride, default_scales, valid_set, component=0, 
                batch_size=128, strides=3, skip_1_stride=True, device=device)

In [None]:
plt_grid_figure([explain_imgs, guided_pca_map_s_strides, guided_grad_maps], transpose=True, titles=["Image", "PCA Strides=3", "Guided Backprop"])

In [None]:
# cherry_picked = [0, 1, 5, 3]
# cexplain_imgs = [explain_imgs[c] for c in cherry_picked]
# cguided_pca_map_s_strides = [guided_pca_map_s_strides[c] for c in cherry_picked]
# cguided_grad_maps = [guided_grad_maps[c] for c in cherry_picked]
plt_grid_figure([explain_imgs, guided_pca_map_s_strides, guided_grad_maps], transpose=True, colorbar=False)
#plt.rcParams.update({'font.size': 25})
plt.savefig("saliency.png")

In [None]:
plt_grid_figure([[explain_imgs[0]], [guided_pca_map_s_strides[0]]], transpose=True, colorbar=False)
#plt.rcParams.update({'font.size': 25})
plt.savefig("saliency.pdf")