In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import time
import os
import matplotlib.pyplot as plt
from collections import defaultdict

import color_regions, network, visualizations, utils
from color_regions import *
from network import *
from visualizations import *
from utils import *
from hooks import *

torch.backends.cudnn.benchmark = True
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
# set up autoreloading of shared code
%load_ext autoreload
%autoreload 1
%aimport color_regions,network,visualizations,utils,hooks
%aimport

In [None]:
transform = transforms.Compose(
     [transforms.ToTensor()])#,
    #transforms.Normalize((0.5), (0.5))])

batch_size = 512 # seems to be the fastest batch size
train_indices = (0, 250_000) # size of training set
valid_indices = (1_250_000, 1_270_000)
test_indices = (2_260_000, 2_270_000)

def color_classifier(color):  
    if color <= 30:  # => 3 classes
        return 0
    if 30 < color <= 60:  # => 90/255 is 0, 90/255 is 1, 75/255 is 2
        return 1
    if 60 < color <= 90:
        return 2
    if 90 < color <= 120:
        return 1
    if 120 < color <= 150:
        return 0
    if 150 < color <= 180:
        return 1
    if 180 < color <= 210:
        return 2
    if 210 < color <= 240:
        return 0
    if 240 < color:
        return 2
critical_color_values = list(range(0,241,30))

def set_loader_helper(indices):
    data_set = ColorDatasetGenerator(color_classifier=color_classifier,
                                    image_indices=indices,
                                    transform=transform,
                                    color_range=(5, 255),
                                    noise_size=(1,9),
                                    num_classes=3,
                                    size=128,
                                    radius=(128//6, 128//3))
    loader = torch.utils.data.DataLoader(data_set, batch_size=batch_size, 
                                          shuffle=True, num_workers=6, pin_memory=True)
    return data_set, loader
train_set, train_loader = set_loader_helper(train_indices)
valid_set, valid_loader = set_loader_helper(valid_indices)
test_set, test_loader = set_loader_helper(test_indices)

In [None]:
# the "hard" task
color_probe = np.linspace(0, 255, 255)
color_class = [color_classifier(x) for x in color_probe]
plt.plot(color_probe, color_class)
plt.xlabel("Color")
plt.yticks([0, 1, 2])
plt.ylabel("Class")

In [None]:
tiny_net = ResNet([[2, 3, 4],  # num_channels (input and output), kernel_size, stride
                   #[4, 3, 2],
                   [6, 3, 4]], 3, [128, 128, 1], 
                   "tiny_net_noise_hard_grey.dict", fc_layers=[]).to(device)
loss_func = nn.CrossEntropyLoss()
tiny_optim = torch.optim.Adam(tiny_net.parameters())
print(tiny_net.num_params())
tiny_net.load_model_state_dict(optim=tiny_optim)

In [None]:
results = train(tiny_net, tiny_optim, loss_func, 1000)

In [None]:
interp_net = AllActivations(tiny_net)

In [None]:
np.random.seed(5_123_456)
test_img, lbl, color, size, pos  = valid_set.generate_one()
print(color)
plt.imshow(test_img, cmap="gray")
tensor_test_img = tensorize(test_img, device=device)

In [None]:
interp_net.eval()
interp_net(tensor_test_img)

In [None]:
first_conv = interp_net._features["conv_layers1_0"].detach().cpu().numpy().squeeze()
first_conv_weights = dict(tiny_net.named_modules())["conv_layers1.0"].weight.detach().cpu().numpy().squeeze()
print(dict(tiny_net.named_modules())["conv_layers1.0"].bias)
fig = plt.figure(figsize=(4*2, 5*2))
plt.subplot(3,2,1)
imshow_centered_colorbar(test_img, "bwr", "original_image")
plt.subplot(3,2,3)
imshow_centered_colorbar(first_conv[0], "bwr", "output conv1_0.0")
plt.subplot(3,2,4)
imshow_centered_colorbar(first_conv[1], "bwr", "output conv1_0.1")
plt.subplot(3,2,5)
imshow_centered_colorbar(first_conv_weights[0], "bwr", "weights of conv1_0.0")
plt.subplot(3,2,6)
imshow_centered_colorbar(first_conv_weights[1], "bwr", "weights of conv1_0.1")
# => first layer basically just computes a compressed version of original, twice

In [None]:
bn1_params.running_mean, bn1_params.running_var

In [None]:
bn1_params = dict(tiny_net.named_modules())["batch_norms1.0"]
print(bn1_params.weight, bn1_params.bias)
first_batchnorms = interp_net._features["batch_norms1_0"].detach().cpu().numpy().squeeze()
fig = plt.figure(figsize=(6, 5))
plt.subplot(2,2,1)
imshow_centered_colorbar(first_conv[0], "bwr", "output conv1_0.0")
plt.subplot(2,2,3)  # conv{1,2}_{layer_num}.{channel_index}
imshow_centered_colorbar(first_conv[1], "bwr", "output of conv1_0.1")

plt.subplot(2,2,2)
imshow_centered_colorbar(first_batchnorms[0], "bwr", "output batchnorm1_0.0")
plt.subplot(2,2,4)
imshow_centered_colorbar(first_batchnorms[1], "bwr", "output batchnorm1_0.1")
# conv of circle (which we just preserve its shape with our conv1) must exceed
# the bias else it gets zero-ed out => gives us 1 boundary on the color. 
# eg. look at channel 1. we multiply the raw value by 7.5, and then subtract 270
# (note that the bias on channel 1 is basically 0), and divide by 553
# then multiply by 1, and subtract 0.8176 => any color value above -23 will be > 0
# for channel 0, it turns out any color value above +25 will be > 0 => already
# separating on that first non-linearity

In [None]:
second_conv = interp_net._features["conv_layers2_0"].detach().cpu().numpy().squeeze()
second_conv_weights = dict(tiny_net.named_modules())["conv_layers2.0"].weight.detach().cpu().numpy().squeeze()
fig = plt.figure(figsize=(6, 7))

for m in range(2):
    plt.subplot(3,2,m+1)
    imshow_centered_colorbar(second_conv[m], "bwr", f"out conv2_1.{m}")
    plt.subplot(3,2,m+3)
    imshow_centered_colorbar(second_conv_weights[m][0], "bwr", f"w 2_0.0->{m}")
    plt.subplot(3,2,m+5)
    imshow_centered_colorbar(second_conv_weights[m][1], "bwr", f"w 2_0.1->{m}")
# both paths have a "just recompute/compress the image (identity mapping learned?)", though
# 1 shifts it up a bit (not sure how relevant this is, but you can actually see it in the image)
# very curve detector-like filters as well in both paths
# so channel 0 is upper-right curves, unsure what the bright pixel in lower left of w_2_0.0 is
# but the other path doesn't have it, so maybe not important?

In [None]:
batchnorms_2 = dict(tiny_net.named_modules())["batch_norms2.0"]
print(batchnorms_2.weight, batchnorms_2.bias)
print(batchnorms_2.running_mean, batchnorms_2.running_var) 
# take a look at channel 0 (which separated at +25 before), the equation is now
# ([(x*6.9-bn1.bias0)/bn1.var0*bn1.scale0+bn1.shift0]*4-1.2290)/sqrt(5.97)*0.7737-0.7180 +
# ([(x*7.5-bn1.bias1)/bn1.var0*bn2.scale1+bn1.sfiht1]*1-1.1275)/sqrt(3.9812)*0.8932+1.1430 = 0
# after rearranging, 0.028507*x-1.9267244 => has its 0 at 67, (so any color > 67)
# will leave channel 0 here with activation > 0 (post-ReLU), which isn't particularly
# close to any critical value, but I guess it just approximates the boundaries with a
# bunch of piecewise linear functions like this, so you get the idea 

In [None]:
tiny_net.conv_blocks[block].conv2.weight[c, :].max(axis=-1).values.max(axis=-1).values

In [None]:
c = 0
block = 1

#uniform_inpt = torch.full((1,16,32,32), 100.0).to(device)
#plt.imshow(tiny_net.conv_blocks[0].conv2.weight[c, in_c].detach().cpu().numpy(), cmap="bwr")
conv_maps = tiny_net.conv_blocks[block].conv2.weight[c, :]
#imshow_centered_colorbar(conv_maps[7].detach().cpu().numpy(), cmap="bwr")
conv_scale = conv_maps.max(axis=-1).values.max(axis=-1).values
conv_shift = tiny_net.conv_blocks[block].conv2.bias[c]
bn_scale = tiny_net.conv_blocks[block].batch_norm2.weight[c]
bn_shift = tiny_net.conv_blocks[block].batch_norm2.bias[c]
bn_var = tiny_net.conv_blocks[block].batch_norm2.running_var[c]
bn_mean = tiny_net.conv_blocks[block].batch_norm2.running_mean[c]
print(conv_shift, bn_scale, bn_shift, bn_var, bn_mean)
#(c*conv_scale + conv_shift - bn_mean) / torch.sqrt(bn_var) * bn_scale + bn_shift
slope = (conv_scale/torch.sqrt(bn_var)*bn_scale).detach().cpu().numpy()
bias = ((conv_shift - bn_mean)/torch.sqrt(bn_var)*bn_scale + bn_shift).detach().cpu().numpy()

lines = np.asarray([profile_plots[f"conv_blocks.{block}.act_func1_{x}"][0] for x in range(6)])

uniform_scaling = slope.dot(lines) + bias


In [None]:
plt.plot(np.maximum(uniform_scaling, 0))

In [None]:
tiny_net.eval()
profile_plots,_ = activation_color_profile(AllActivations(tiny_net), valid_loader, valid_set, device=device)

In [None]:
show_profile_plot(profile_plots["conv_blocks.1.act_func2_4"])

In [None]:
# lets attempt to somewhat automate this process
def fetch_layer_params(layer_idx, one_or_two):
    conv_param = dict(tiny_net.named_modules())[f"conv_layers{one_or_two}.{layer_idx}"]
    batchnorm_param = dict(tiny_net.named_modules())[f"batch_norms{one_or_two}.{layer_idx}"]
    return conv_param, batchnorm_param

def recurse_build_func():
    conv, bn = fetch_layer_params(layer_idx, one_or_two)
    for conv_map in conv.weight[channel]:
        sorted_map = torch.sort(conv_map)
        first_diff = sorted_map[0] - sorted_map[1]
        if first_diff > 0.8:
            last_diff = None
            for i,j in zip(range(1,9), range(2,9)):
                diff = sorted_map[i] - sorted_map[j]
                if diff < first_diff and (last_diff is None or abs(diff - last_diff) < 0.1):
                    last_diff = diff
                else:
                    break
    else:  # TODO: finish this (done with color_profile_plots instead)
        pass

In [None]:
show_conv_weights(interp_net, "conv_blocks.1.act_func1", color_profile=profile_plots)

In [None]:
show_conv_weights(interp_net, "conv_blocks.1.act_func2", color_profile=profile_plots)

In [None]:
show_conv_layer(interp_net, "conv_blocks.1.act_func2")

In [None]:
second_conv = interp_net._features["conv_layers1_1"].detach().cpu().numpy().squeeze()
second_conv_weights = dict(tiny_net.named_modules())["conv_layers1.1"].weight.detach().cpu().numpy().squeeze()
fig = plt.figure(figsize=(14, 7))

for m in range(6):
    plt.subplot(3,6,m+1)
    imshow_centered_colorbar(second_conv[m], "bwr", f"out conv1_1.{m}")
    plt.subplot(3,6,m+7)
    imshow_centered_colorbar(second_conv_weights[m][0], "bwr", f"w 1_1.0->{m}")
    plt.subplot(3,6,m+13)
    imshow_centered_colorbar(second_conv_weights[m][1], "bwr", f"w 1_1.1->{m}")


In [None]:
bn3_params = dict(tiny_net.named_modules())["batch_norms1.1"]
print(bn3_params.weight, bn3_params.bias)
print(bn3_params.running_mean, bn3_params.running_var)
third_batchnorms = interp_net._features["batch_norms1_1"].detach().cpu().numpy().squeeze()
fig = plt.figure(figsize=(16, 5))
for m in range(6):
    plt.subplot(2,6,m+1)
    imshow_centered_colorbar(second_conv[m], "bwr", f"output conv1_1.{m}")
    plt.subplot(2,6,m+7)
    imshow_centered_colorbar(third_batchnorms[m], "bwr", f"output batchnorm1_1.{m}")
# it appears the only important channel at this point is 2. channels 0,1 looks like it was
# close to being important, but failed some color check. channel 5 I don't really
# understand since it appears to have picked up some signal that wasnt there before?
# (I suppose the mean is negative, and the scale is larger than 1 so it would expand any
# slight differences that existed but weren't visible?). Channel 4 I think is also trying
# to be a circle finder (upper right?), but failed color check as well. Channel 3 is also
# looking like it just barely failed the color check. Actually, looking at channel 1 again, 
# its output after a ReLU I expect would look exactly like channel 4 right now, so
# channel 4 is definitely a "failed color check"

In [None]:
second_conv = interp_net._features["conv_layers2_1"].detach().cpu().numpy().squeeze()
second_conv_weights = dict(tiny_net.named_modules())["conv_layers2.1"].weight.detach().cpu().numpy().squeeze()
fig = plt.figure(figsize=(15, 10))

for m in range(6):
    plt.subplot(7,6,m+1)
    imshow_centered_colorbar(second_conv[m], "bwr", f"out conv2_1.{m}")
    for k in range(6):
        plt.subplot(7,6,m+1+(k+1)*6)
        imshow_centered_colorbar(second_conv_weights[m][k], "bwr", f"w 2_1.{k}->{m}",
                                colorbar=True)
    # Channel 0 is basically saying "cancel out everything except for channel 4 in prev layer"
    # So it should basically copy its value (which it does). Channel 2 is similar, though it 
    # appears to copy from channel 1, and 4 a bit. At the end of it, channel 4 ends
    # up being the most active, since it has that strong positive edge detector with 
    # channel 2 in the previous layer. Channel 1 also does decently well, but its circle
    # has been thoroughly zeroed out, and only an "artifact-like" row of brightness 
    # remains at the top edge

In [None]:
bn4_params = dict(tiny_net.named_modules())["batch_norms2.1"]
print(bn4_params.weight, bn4_params.bias)
print(bn4_params.running_mean, bn4_params.running_var)
fourth_batchnorms = interp_net._features["batch_norms2_1"].detach().cpu().numpy().squeeze()
fig = plt.figure(figsize=(16, 5))
for m in range(6):
    plt.subplot(2,6,m+1)
    imshow_centered_colorbar(second_conv[m], "bwr", f"output conv2_1.{m}")
    plt.subplot(2,6,m+7)
    imshow_centered_colorbar(fourth_batchnorms[m], "bwr", f"output batchnorm2_1.{m}")
# so again, we see somewhat of a "direction reversal" in channel5 (pretty much because of 
# the positive bias (compared to the other biases, which are all negative), but channel 4 mostly
# seems to be the winner here. The "artifact" bright top row of channel 1 is mostly negated, 
# (we actually see those weird rows in multiple conv maps here, could be an artifact
# of the padding/striding method maybe?)

In [None]:
fc_weights = dict(tiny_net.named_modules())["fully_connected.0"].weight.detach().cpu().numpy().squeeze()
fig = plt.figure(figsize=(14, 7))

for m in range(6):
    plt.subplot(4,6,m+1)
    imshow_centered_colorbar(fourth_batchnorms[m], "bwr", f"out batchnorm2_1.{m}") # no ReLU
    for fc_m in range(3):
        fc_shaped = fc_weights[fc_m].reshape(6,8,8)[m]
        result = (np.where(fourth_batchnorms[m]>0, fourth_batchnorms[m], 0)*fc_shaped).sum()
        plt.subplot(4,6,m+1+(fc_m+1)*6)
        imshow_centered_colorbar(fc_shaped, "bwr", f"{result:.2f}")
    # the stupid edge lines actually seem to be getting used somehow (see bottom row, which
    # is used to predict class 2). Some of these maps are just "find a circle-ish thing in 
    # the center". Probably makes sense that the "best" place to put your circle checker is 
    # right in the middle, because most circles are at least overlapping the middle, due
    # to the data generation process. Some of these maps appear to do nothing, eg.
    # the map for predicting class 1 ignores channel 4. Although maybe there is some
    # "antipodal" symmetry between class 1 channel 4 and class 0 channel 4 => channel 4 
    # gives a lot of info for class 1??, though im not sure why you would only highlight
    # one pixel inside them (we see the same pattern used in channel 2, and actually in a lot of
    # the channels) => channel 0 is like "positive evidence for class 1, negative evidence for
    # class 0"
    
    # note that its actually the same classes that are in superpositon:
    # eg. for class 0,1 we have superposition in channel 0, channel 4
    #     for class 0,2 we have superpositon in channel 1,2,3,5
    
    # also we arguably have a "1-map" type thing occuring in many of the channels. For example,
    # in channel 4, it sort of looks like that for class 0 and class 2, 

In [None]:
show_fc_conv(interp_net, color_profile=profile_plots, fixed_height=True, full_gridspec=True)

In [None]:
show_fc_conv(interp_net, color_profile=profile_plots, fixed_height=False, full_gridspec=False)

In [None]:
tiny_net.fully_connected[0].fully_connected.bias

In [None]:
mlab.init_notebook()

In [None]:
mlab.clf()

In [None]:
mlab_fc_conv_feature_angles(tiny_net, 
                            "fully_connected.0.act_func", num_embed=3, normalize=True)

In [None]:
%matplotlib inline  
%matplotlib notebook
feature_gram, projected_weights = visualizations.fc_conv_feature_angles(tiny_net, 
                            "fully_connected.0.act_func", num_embed=3, normalize=True)

In [None]:
mlab_fc_conv_feature_angles(tiny_net, "fully_connected.0.act_func", num_embed=3, normalize=True)

In [None]:
feature_gram # normalized version

In [None]:
s = mlab.test_contour3d()
s

In [None]:
_

In [None]:
feature_gram  # unnormalized version
# in terms of visualization, should compress the number down to d=n/2
# d = num dimensions to embed into, n = num features or classes
# since num_equations is num in strict upper right triangle of gram matrix
# ie. (n-1)(n-2)/2
# and num unknowns is (d-1)*(n-1), since d-1 angles to choose per point, and get to pick
# angles for n-1 points (the system is rotation invariant so position of first one doesn't add
# any dof)
# thus d-1=(n-2)/2 => d = n/2 => 2 should be fine for this case????????
# but this feels wrong because first dot limits us to 2 locations for second feature
# and 2 locations for 3rd feature, but 2nd feature 3rd feature dot is unlikely
# to overlap with these 2 locations for the 2????????

In [None]:
show_fc(interp_net, "fully_connected.0.act_func", size_mul=(8,25), color_profile=profile_plots)
# this was subsampled by taking (arbitrarily) the first 32 weights of each 384 weight vector

In [None]:
show_conv_layer(interp_net, "conv_blocks.1.act_func2")