In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import time
import os
import matplotlib.pyplot as plt
from collections import defaultdict

import color_regions, network, visualizations, utils
from color_regions import *
from network import *
from visualizations import *
from utils import *

torch.backends.cudnn.benchmark = True
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
# set up autoreloading of shared code
%load_ext autoreload
%autoreload 1
%aimport color_regions,network,visualizations,utils
%aimport

In [None]:
import sys
prev_time = 0
gamma = 0.99
stats = {}  # tracks ewma running average
def benchmark(point=None, profile=True, verbose=True, cuda=True): # not thread safe at all
    global prev_time
    if not profile:
        return
    if cuda:
        torch.cuda.synchronize()
    time_now = time.perf_counter()
    if point is not None:
        point = f"{sys._getframe().f_back.f_code.co_name}-{point}"
        time_taken = time_now - prev_time
        if point not in stats:
            stats[point] = time_taken
        stats[point] = stats[point]*gamma + time_taken*(1-gamma)
        if verbose:
            print(f"took {time_taken} to reach {point}, ewma={stats[point]}")
    prev_time = time_now

In [None]:
transform = transforms.Compose(
     [transforms.ToTensor()])#,
    #transforms.Normalize((0.5), (0.5))])

batch_size = 32  # seems to be the fastest batch size
train_indices = (0, 200_000) # size of training set
valid_indices = (1_250_000, 1_260_000)
test_indices = (260_000, 460_000)

def color_classifier(color):
    if color <= 100:  # medium difficulty (width = 75)
        return 0
    if 100 < color <= 150:  # hard difficulty (width = 50)
        return 1
    if 150 < color <= 200:  # hard difficulty (width = 50)
        return 2
    if 200 < color:  # hard difficulty (width = 50)
        return 1
critical_color_values = [100, 150, 200]

def set_loader_helper(indices):
    data_set = ColorDatasetGenerator(color_classifier=color_classifier,
                                    image_indices=indices,
                                    transform=transform,
                                    color_range=(25, 250),
                                    noise_size=(1,7),
                                    num_classes=3,
                                    size=128,
                                    radius=(128//6, 128//3))
    loader = torch.utils.data.DataLoader(data_set, batch_size=batch_size, 
                                          shuffle=True, num_workers=4, pin_memory=True)
    return data_set, loader
train_set, train_loader = set_loader_helper(train_indices)
valid_set, valid_loader = set_loader_helper(valid_indices)
test_set, test_loader = set_loader_helper(test_indices)

In [None]:
# the "medium" task
color_probe = np.linspace(0, 255, 255)
color_class = [color_classifier(x) for x in color_probe]
plt.plot(color_probe, color_class)

In [None]:
small_net = ResNet([[16, 3, 2],  # num_channels (input and output), kernel_size, stride
                    [32, 3, 2],
                    [64, 3, 2]], 3, [128, 128, 1], 
                   "small_net_noise_medium_grey.dict", fc_layers=[32]).to(device)
loss_func = nn.CrossEntropyLoss()
small_optim = torch.optim.Adam(small_net.parameters())
print(small_net.num_params())
small_net.load_model_state_dict(optim=small_optim)

In [None]:
unstrided_net = ResNet([[16, 3, 1],  # num_channels (input and output), kernel_size, stride
                        [32, 3, 1],
                        [64, 3, 2]], 3, [128, 128, 1], 
                   "unstrided_net_noise_medium_grey.dict", fc_layers=[32]).to(device)
loss_func = nn.CrossEntropyLoss()
unstrided_optim = torch.optim.Adam(unstrided_net.parameters())
print(unstrided_net.num_params())
unstrided_net.load_model_state_dict(optim=unstrided_optim)

In [None]:
results = train(unstrided_net, unstrided_optim, loss_func, 200)

In [None]:
results = train(small_net, small_optim, loss_func, 200)

In [None]:
rate_distribution(small_net, valid_loader, valid_set,
                  critical_values=critical_color_values, device=device)

In [None]:
response_graph(small_net, valid_set, device=device)

In [None]:
# to further test the "using 1 image => bad batchnorm estimates" lets do the same test
# but instead we will average over a sample of responses
small_net.eval() # very important!
stack_size = 32
sampled_indices = 1_250_000 + np.random.choice(1000, stack_size, replace=False)
total_images = stack_size * 255
correct_num = 0
with torch.no_grad():
    counterfactual_color_values = np.linspace(0, 255, 255)
    responses = []
    for color in tqdm(counterfactual_color_values):
        stacked_generated_img = []
        for sampled_index in sampled_indices:
            np.random.seed(sampled_index)
            generated_img, lbl, *__ = valid_set.generate_one(set_color=color)
            stacked_generated_img.append(generated_img)
        stacked_generated_img = np.array(stacked_generated_img).transpose(0, 3, 1, 2)
        generated_img = torch.tensor(stacked_generated_img).to(device).float()
        response = small_net(generated_img, logits=True)
        stacked_lbl = torch.tensor(np.repeat(np.expand_dims(lbl, 0), stack_size, axis=0)).to(device)
        correct_num += correct(response, stacked_lbl).sum()
        responses.append(np.squeeze(response.cpu().numpy()).mean(axis=0))
print(correct_num/total_images, "total accuracy")

In [None]:
responses = np.arcsinh(np.array(responses))  # this graph is quite robust to changes in batch size
for output_logit in range(responses.shape[1]): # if we do .eval(), but varies if we do .train()
    plt.plot(counterfactual_color_values, responses[:, output_logit], label=f"class {output_logit}")
plt.legend()
plt.xlabel("Color value")
plt.ylabel("Network output logit")
plt.vlines([100, 150], np.min(responses), np.max(responses), linewidth=0.8,
           colors="r", label="decision boundary", # probably because .train() in this case actually gives biased estimates because all the colors are the same
           linestyles="dashed")  # logit graphs look bad if doing .train(), and accuracy is lower?

In [None]:
res_net.eval()   # ---> without this line, it fails, especially with small colors
with torch.no_grad():  # => batchnorm updates are very inaccurate if just one image
    idx = 1_250_026   # => network expects batchnorm updates to basically be exactly in the "middle" 127
    print(valid_set[idx])  # => fix the logit response graph by requiring it to be in eval mode
    print(torch.softmax(res_net(torch.unsqueeze(valid_set[idx]["image"], 0).to(device).float()), 1))

In [None]:
# to test the above hypothesis, if we just stack the same image a bunch of times, and do .train()
# the estimates should still be bad because batchnorm estimates would be just as bad as with 
# a single image in the batch
res_net.train()
idx = 1_250_024  # => hypothesis seems to be confirmed
test_image = valid_set[idx]["image"].numpy()
stacked_test = np.repeat(np.expand_dims(test_image, 0), 32, axis=0)
print(valid_set[idx])
print(torch.softmax(res_net(torch.tensor(stacked_test).to(device).float()), 1)[0])

In [None]:
plt.imshow(np.squeeze(generated_img.cpu().numpy()), cmap="gray")

In [None]:
res_net.train()
with torch.no_grad():
    for i, sample in enumerate(valid_loader):
        imgs = sample["image"].to(device).float()
        labels = sample["label"].to(device).float()
        print(sample["color"])
        outputs = res_net(imgs)
        print(loss_func(outputs, labels).item())
        print(torch.argmax(labels,dim=1), torch.argmax(outputs, dim=1))
        print(correct(outputs, labels).sum().item())
        break

In [None]:
np.random.seed(500_001)
explain_img, lbl, *_ = valid_set.generate_one()
heat_map = finite_differences_map(small_net, valid_set, lbl.argmax(), explain_img, device=device)

In [None]:
plt.subplot(1, 2, 1)
plt.imshow(explain_img, cmap="gray")
plt.subplot(1, 2, 2)
plt.imshow(heat_map, cmap="bwr", interpolation="bilinear")
plt.colorbar()
# generated with strides = 2 everywhere

In [None]:
plt_grid_figure([explain_img, heat_map], transpose=False, colorbar=True)
# generated with strides = 1, strides = 8 for lats layer

In [None]:
image_ids = [20_000, 25_000, 30_000, 600_000, 600_001]
heat_maps = []
explain_imgs = []
for i, image_id in enumerate(image_ids):
    np.random.seed(image_id)
    explain_img_i, target_i, *__ = valid_set.generate_one()
    heat_map_i = finite_differences_map(small_net, valid_set, target_i.argmax(), explain_img_i)
    heat_maps.append(heat_map_i)
    explain_imgs.append(explain_img_i)

In [None]:
plt_grid_figure([explain_imgs, heat_maps], transpose=True, colorbar=True)
# generated with strides=2 everywhere

In [None]:
plt_grid_figure([explain_imgs, heat_maps], transpose=True, colorbar=True)
# generated with strides = 1, strides = 8 for lats layer

What if we run the same experiment, but cheat with a prior on pixel values that we know *should* be informative to the output logit, namely values closest to the decision boundary?

In [None]:
unfair_prior = np.array([90, 110, 140, 160])  #  close to the critical values of 100, 150
unfair_heat_maps = []
plt.figure(figsize=(12, 5*len(image_ids)))
for i, image_id in enumerate(image_ids):
    np.random.seed(image_id)
    explain_img_i, target_i, __ = valid_set.generate_one()
    unfair_map_i = finite_differences_map(res_net, valid_set, target_i.argmax(), explain_img_i, unfairness="unfair", values_prior=unfair_prior)
    unfair_heat_maps.append(unfair_map_i)
    plt.subplot(len(image_ids), 3, 3*i+1)
    plt.imshow(explain_img_i, cmap="gray")
    plt.subplot(len(image_ids), 3, 3*i+2)
    heat_max = np.max(abs(unfair_map_i))
    plt.imshow(unfair_map_i, cmap="bwr", vmax=heat_max, vmin=-heat_max)
    plt.colorbar(shrink=0.5)
    plt.subplot(len(image_ids), 3, 3*i+3)
    heat_max = np.max(abs(heat_maps[i]))
    plt.imshow(heat_maps[i], cmap="bwr", vmax=heat_max, vmin=-heat_max)
    plt.colorbar(shrink=0.5)
plt.show()  # => very similar results, but with a 4x speedup

In [None]:
# regenerate unfair FD maps
unfair_prior = np.array([90, 110, 140, 160])  #  close to the critical values of 100, 150
image_ids = [20_000, 25_000, 30_000, 600_000, 600_001, 227_662, 998_102, 106_758]
unfair_heat_maps = []
#plt.figure(figsize=(12, 5*len(image_ids)))
for i, image_id in enumerate(image_ids):
    np.random.seed(image_id)
    explain_img_i, target_i, __ = valid_set.generate_one()
    unfair_map_i = finite_differences_map(res_net, valid_set, target_i.argmax(), explain_img_i, unfairness="unfair", values_prior=unfair_prior)
    unfair_heat_maps.append(unfair_map_i)

We can do even better by taking the "closest value in a different class" for our prior

In [None]:
very_unfair_heat_maps = []
plt.figure(figsize=(20, 5*len(image_ids)))
for i, image_id in enumerate(image_ids):
    np.random.seed(image_id)
    explain_img_i, target_i, color_i = valid_set.generate_one()
    very_unfair_map_i = finite_differences_map(res_net, valid_set, target_i.argmax(), explain_img_i, unfairness="very unfair", values_prior=[100, 150])
    very_unfair_heat_maps.append(very_unfair_map_i)
    
    plt.subplot(len(image_ids), 5, 5*i+1)
    if i == 0:
        plt.title("Image")
    plt.imshow(explain_img_i, cmap="gray")
    
    plt.subplot(len(image_ids), 5, 5*i+2)
    if i == 0:
        plt.title("Very unfair FD map")
    heat_max = np.max(abs(very_unfair_map_i))
    plt.imshow(very_unfair_map_i, cmap="bwr", vmax=heat_max, vmin=-heat_max)
    plt.colorbar(shrink=0.5)
    
    plt.subplot(len(image_ids), 5, 5*i+3)
    if i == 0:
        plt.title("Unfair FD map")
    heat_max = np.max(abs(unfair_heat_maps[i]))
    plt.imshow(unfair_heat_maps[i], cmap="bwr", vmax=heat_max, vmin=-heat_max)
    plt.colorbar(shrink=0.5)
    
    plt.subplot(len(image_ids), 5, 5*i+4)
    if i == 0:
        plt.title("FD map")
    heat_max = np.max(abs(heat_maps[i]))
    plt.imshow(heat_maps[i], cmap="bwr", vmax=heat_max, vmin=-heat_max)
    plt.colorbar(shrink=0.5)
    
    plt.subplot(len(image_ids), 5, 5*i+5)
    if i == 0:
        plt.title("Location in color space")
    plt.plot(color_probe, color_class) 
    plt.vlines([color_i], 0, valid_set.num_classes-1, linewidth=0.8,
           colors="r", label="color value",
           linestyles="dashed")
plt.show()  # => somewhat similar results (see image 2), but with an overall ~11x speedup

In [None]:
grad_heat_maps = []
plt.figure(figsize=(20, 5*len(image_ids)))
for i, image_id in tqdm(enumerate(image_ids)):
    np.random.seed(image_id)
    explain_img_i, target_i, color_i = valid_set.generate_one()
    batched_explain_img_i = torch.tensor(np.expand_dims(explain_img_i, 0).transpose(0, 3, 1, 2), requires_grad=True).to(device).float()
    output_logit_i = res_net(batched_explain_img_i)[0, target_i.argmax()]
    
    img_grad_i = torch.autograd.grad(output_logit_i, batched_explain_img_i)[0].squeeze().cpu().numpy()
    grad_times_input_i = img_grad_i * np.squeeze(explain_img_i)
    
    plt.subplot(len(image_ids), 5, 5*i+1)
    if i == 0:
        plt.title("Image")
    plt.imshow(explain_img_i, cmap="gray")
    
    plt.subplot(len(image_ids), 5, 5*i+2)
    if i == 0:
        plt.title("Input*Gradient explanation")
    heat_max = np.max(abs(grad_times_input_i))
    plt.imshow(grad_times_input_i, cmap="bwr", vmax=heat_max, vmin=-heat_max)
    plt.colorbar(shrink=0.5)
    
    plt.subplot(len(image_ids), 5, 5*i+3)
    if i == 0:
        plt.title("Gradient explanation")
    heat_max = np.max(abs(img_grad_i))
    plt.imshow(img_grad_i, cmap="bwr", vmax=heat_max, vmin=-heat_max)
    plt.colorbar(shrink=0.5)
    
    plt.subplot(len(image_ids), 5, 5*i+4)
    if i == 0:
        plt.title("FD explanation (unfair)")
    heat_max = np.max(abs(unfair_heat_maps[i]))
    plt.imshow(unfair_heat_maps[i], cmap="bwr", vmax=heat_max, vmin=-heat_max)
    plt.colorbar(shrink=0.5)
    
    plt.subplot(len(image_ids), 5, 5*i+5)
    if i == 0:
        plt.title("Location in color space")
    plt.plot(color_probe, color_class) 
    plt.vlines([color_i], 0, valid_set.num_classes-1, linewidth=0.8,
           colors="r", label="color value",
           linestyles="dashed")
plt.show()
# gradient should be zero, so double check computations, fix scale on cmap

# PCA Direction Tests

In [None]:
class DummyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(2,2)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.linear(x)
        print("\tBefore ReLUs", x)
        return self.relu(x)
dummy_net = DummyNet()
dummy_net.linear._parameters["weight"].data = torch.nn.Parameter(torch.tensor([[1., 0], [0, 1]]))
dummy_net.linear._parameters["bias"].data = torch.nn.Parameter(torch.tensor([200., 200]))
print("Network parameters", dummy_net.linear._parameters)
print("WITHOUT GUIDED BACKPROP")
inpt = torch.tensor([1., -1.], requires_grad=True)
result = dummy_net(inpt)
print("\tNetwork output", result)
print("\t'Loss'", -result[0]+result[1])
print("\tResulting gradients", torch.autograd.grad(-result[0]+result[1], inpt))
print("WITH GUIDED BACKPROP")
guided_dummy = GuidedBackprop(dummy_net)
result = guided_dummy(inpt, preserve_hooks=False)
print("\tNetwork output", result)
print("\t'Loss'", -result[0]+result[1])
print("\tResulting gradients", torch.autograd.grad(-result[0]+result[1], inpt))
print("GUIDED BACKPROP AGAIN (should auto-clean now)")
inpt = torch.tensor([1., -1.], requires_grad=True)
result = dummy_net(inpt)
print("\tNetwork output", result)
print("\t'Loss'", -result[0]+result[1])
print("\tResulting gradients", torch.autograd.grad(-result[0]+result[1], inpt))

In [None]:
default_scales = [3,5,7,9,13,15]
pca_directions_1_stride = find_pca_directions(valid_set, 1024, default_scales, 1)
pca_directions_s_stride = find_pca_directions(valid_set, 1024, default_scales, default_scales)

In [None]:
plt.figure(figsize=(6*4, 12))
for i, res in enumerate(pca_directions_s_stride):
    compressed_results = np.concatenate(np.concatenate(res, 1), 1)
    plt.subplot(1,len(pca_directions_s_stride),i+1)
    if i == 0:
        plt.title("Strided windows")
    plt.imshow(compressed_results, cmap="gray")

In [None]:
plt.figure(figsize=(6*4, 12))
for i, res in enumerate(pca_directions_1_stride):
    compressed_results = np.concatenate(np.concatenate(res, 1), 1)
    plt.subplot(1,len(pca_directions_1_stride),i+1)
    if i == 0:
        plt.title("Stride=1")
    plt.imshow(compressed_results, cmap="gray")

In [None]:
np.random.seed(200_010)
generated_img, label, *__ = valid_set.generate_one()
pca_map_strided = pca_direction_grids(small_net, valid_set, label.argmax(), generated_img, 
                                      pca_direction_grids=pca_directions_s_stride)
pca_map_1_stride = pca_direction_grids(small_net, valid_set, label.argmax(), generated_img, 
                                      pca_direction_grids=pca_directions_1_stride, strides=1)

In [None]:
plt_grid_figure([generated_img, result])  # => with strides == scales
# I believe it was generated on 
# small_net = ResNet([[16, 3, 2],  # num_channels (input and output), kernel_size, stride
#                     [32, 3, 2],
#                     [64, 3, 2]], 3, [128, 128, 1], 
#                    "small_net_noise_hard_grey.dict",

In [None]:
plt_grid_figure([generated_img, result])  # => with strides == 1

In [None]:
seeds = [1_2123, 1_40_124, 1_508_559, 1_5_019_258, 1_2_429_852, 9032, 5832, 12, 5014, 92, 42, 52, 
         52_934, 935_152, 1_000_000, 1_000_001, 27, 24, 512, 999_105]  # 20 
def generate_many_pca(net):
    _pca_map_s_strides = []
    _pca_map_1_strides = []
    _grad_maps = []
    _explain_imgs = []
    for seed in seeds:
        np.random.seed(seed)
        generated_img, label, *__ = valid_set.generate_one()
        tensored_img = torch.tensor(generated_img.transpose(2,0,1), requires_grad=True).unsqueeze(0).float().to(device)
        grad_map = torch.autograd.grad(net(tensored_img)[0,label.argmax()], tensored_img)[0]
        pca_map_strided = pca_direction_grids(net, valid_set, label.argmax(), generated_img, 
                                              pca_direction_grids=pca_directions_s_stride)
        pca_map_1_stride = pca_direction_grids(net, valid_set, label.argmax(), generated_img, 
                                          pca_direction_grids=pca_directions_1_stride, strides=1)
        _explain_imgs.append(generated_img)
        _grad_maps.append(grad_map.detach().cpu().squeeze(0).numpy().transpose(1,2,0))
        _pca_map_s_strides.append(pca_map_strided.copy())
        _pca_map_1_strides.append(pca_map_1_stride.copy())
    return _pca_map_s_strides, _pca_map_1_strides, _grad_maps, _explain_imgs

In [None]:
pca_map_s_strides, pca_map_1_strides, grad_maps, explain_imgs = generate_many_pca(unstrided_net)

In [None]:
guided_net = GuidedBackprop(unstrided_net)
guided_pca_map_s_strides, guided_pca_map_1_strides, guided_grad_maps, explain_imgs = generate_many_pca(guided_net)

In [None]:
plt_grid_figure([explain_imgs, guided_pca_map_s_strides, guided_pca_map_1_strides, guided_grad_maps, pca_map_s_strides, pca_map_1_strides, grad_maps], transpose=True, titles=["Image", "Guided Strides=scale", "Guided strides=1", "Guided Gradient", "Strides=scale", "strides=1", "Gradient"])
# on the strides=2 final layer network

In [None]:
last_layer_weight = unstrided_net._modules["fully_connected"][-1].fully_connected.weight.detach().cpu().numpy()
imshow_centered_colorbar(last_layer_weight, cmap="bwr", title="Last Layer FC weights", colorbar=False)
plt.vlines([1, 4, 17], ymin=0, ymax=2.5)

last_layer_bias = unstrided_net._modules["fully_connected"][-1].fully_connected.bias.detach().cpu().numpy()
print(last_layer_bias)
# in column 17, all weights are negative, so guided backprop means we immediately zero everything out
# would also happend for columns 3(?), 10, 13, 16, 17, 18, 20, 30(?), 31
# basically 9/32 = 28% of all images will be completely zero, if the assumption of
# "only 1 non-zero logit in the final hidden layer" holds true
# relevant columns have been highlighted

# column 1 mostly means class 0, very strongly not class 2
# column 4 mostly means class 1, equally strongly not class 0 and 2
# column 17 means not class 0, not class 1, barely class 2 => problematic

# since class 2 is "default class" (largest bias), the negative weights in column 17 are fine

In [None]:
plt.figure(figsize=(6*4, 6*8))
first_fc = unstrided_net._modules["fully_connected"][0].fully_connected.weight.detach().cpu().numpy()
# only bother visualizing outputs 1, 4, and 17 (and add 13,12,15,0,9,27 to compare)
relevant_outputs = range(32)#[1, 4, 17, 13, 12, 15, 0, 9, 27]
for i, output_col in enumerate(relevant_outputs):
    fc_weights = np.concatenate(np.concatenate(first_fc[output_col].reshape(8, 8, 63, 63),1),1)
    plt.subplot(8,4,i+1)
    imshow_centered_colorbar(fc_weights, cmap="bwr", title=f"FC weights of {output_col}")
# seemingly "empty" maps in the actual useful columns is just very low norm compared to the
# other weights (~0.2, whereas other weight maps range up to ~2)

In [None]:
plt.figure(figsize=(6*4, 12))
final_conv_map = debug_net._features["conv_blocks.2.batch_norm2"].detach().cpu().numpy()[0].reshape(8,8,63,63)
conv_max = abs(final_conv_map).max()

compressed_results = np.concatenate(np.concatenate(final_conv_map, 1), 1)
plt.imshow(compressed_results, cmap="bwr", vmin=-conv_max, vmax=conv_max)
plt.colorbar()

In [None]:
plt.figure(figsize=(6*4, 12))
final_relu = debug_net._features["conv_blocks.2.act_func2"].detach().cpu().numpy()[0].reshape(8,8,63,63)
relu_max = abs(final_relu).max()

compressed_results = np.tanh(np.concatenate(np.concatenate(final_relu, 1), 1))
plt.imshow(compressed_results, cmap="bwr", vmin=-1, vmax=1)
plt.colorbar()
# with tanh

In [None]:
np.random.seed(seeds[9])
working_img, label, color, *_____ = valid_set.generate_one()
print(color, label.argmax())
tensored_img = torch.tensor(working_img.transpose(2,0,1), requires_grad=True).unsqueeze(0).float().to(device)
guided_net = GuidedBackprop(unstrided_net)
grad_map = torch.autograd.grad(guided_net(tensored_img)[0,label.argmax()], tensored_img)[0]
debug_net = AllActivations(unstrided_net)
debug_net(tensored_img)

In [None]:
@torch.no_grad()
def final_activation_histogram(net):
    net.eval()
    nonzero_histogram = torch.zeros(32).to(device)
    pattern_counts = defaultdict(lambda: defaultdict(int))
    debug_net = AllActivations(net)
    for i, sample in tqdm(enumerate(test_loader)):
        imgs = sample["image"].to(device).float()
        colors = sample["color"]
        debug_net(imgs)
        final_layer = debug_net._features["fully_connected.0.act_func"]
        nonzero = torch.where(final_layer > 0, 1, 0).detach().cpu().numpy()
        for row, color in zip(nonzero, colors):
            pattern = str(row)
            pattern_counts[int(color)][pattern] += 1
    return pattern_counts
color_distrib = final_activation_histogram(unstrided_net)

In [None]:
uniq = set()
for x in color_distrib.values():
    uniq = uniq.union(set(x.keys()))
pattern_to_names = {}
for pattern in uniq:
    pattern_to_names[pattern] = pattern[1+2*1] + pattern[1+2*4] + pattern[1+2*17]
pattern_to_names

In [None]:
pattern_totals = defaultdict(int)
for color, values in color_distrib.items():
    for pattern, count in values.items():
        pattern_totals[pattern] += count
pattern_totals

In [None]:
for pattern in uniq:
    amounts = np.zeros((255,))
    for color, distrib in color_distrib.items():
        amounts[color] = distrib[pattern]
    plt.plot(amounts, label=pattern_to_names[pattern])
color_probe = np.linspace(0, 255, 255)
color_class = [color_classifier(x)*400 for x in color_probe]
plt.plot(color_probe, color_class, label="classes", linestyle="dotted")
plt.legend()

In [None]:
def adversarial_generate(img, lbl, net, alpha, lr, runs):
    # alpha is maximum norm that the adversarial can be
    adversarial_direction = np.random.uniform(-alpha, alpha, size=(1, valid_set.size, valid_set.size))
    adversarial_direction = adversarial_direction/np.linalg.norm(adversarial_direction)*alpha
    
    # pick arbitrary target
    bad_lbl = (lbl.argmax() + 1) % valid_set.num_classes
    target = torch.tensor(np.zeros_like(lbl)).to(device).unsqueeze(0).float()
    target[0,2] = 1.
    print(target, lbl)
    
    tensor_img = torch.tensor(img.transpose(2,0,1)).unsqueeze(0).to(device).float()
    tensor_adv_dir = torch.tensor(adversarial_direction, requires_grad=True).unsqueeze(0).to(device).float()
    
    loss_func = nn.CrossEntropyLoss()
    for i in range(runs):
        curr_img = tensor_img + tensor_adv_dir
        curr_net_out = net(curr_img)
        curr_loss = loss_func(curr_net_out, target)# + 5e-10*torch.linalg.norm(tensor_adv_dir)
        grad_dir = torch.autograd.grad(curr_loss, tensor_adv_dir)[0]
        tensor_adv_dir -= lr*grad_dir
        tensor_adv_dir = torch.clamp(tensor_adv_dir, min=-alpha, max=alpha)

        if i % (runs//5) == (runs//5-1):
            print(curr_loss.item(), curr_net_out)
    return tensor_img + tensor_adv_dir
np.random.seed(58)
generated_img, gen_label, color, *_ = valid_set.generate_one()
print(color)
adv_example = adversarial_generate(generated_img, gen_label, unstrided_net, 6.4, 1e2, 500)

In [None]:
# 230 color -> alpha very close to 5, works on 1e2, 500 (seed 55)
# 136 color -> alpha of 3 (seed 54)
# 50 color -> alpha of 9-10 (seed 53)
# 181 color (in class 2 already) -> adv_dir is basically random noise, weird cyclic structure to it (seed 52)
# 82 color -> alpha of 3.8 (high) (seed 51)
# 201 color -> alpha of 0-1 (high) (seed 50)
# 110 color -> alpha of 2-3 (low) (seed 56)
# 232 color -> alpha close to 5 (seed 57)
# 60 color -> alpha 6-7 (mid) (seed 58)

In [None]:
plt.subplot(1,3,1)
plt.imshow(generated_img, cmap="gray")
plt.subplot(1,3,2)
np_adv = adv_example.detach().cpu().numpy().squeeze()
plt.imshow(np_adv, cmap="gray")
plt.subplot(1,3,3)
imshow_centered_colorbar(generated_img.squeeze()-np_adv, cmap="bwr", title="Adv direction")

In [None]:
debug_net = AllActivations(unstrided_net)
debug_net(adv_example)

In [None]:
debug_net._features["fully_connected.0.act_func"]

In [None]:
plt_grid_figure([explain_imgs, pca_map_s_strides, pca_map_1_strides], transpose=True, titles=["Image", "Strides=scale", "strides=1"])
# add comparison to regular gradient
# smaller circles = bad?
# manually test if the gradient changes make sense
# model lerans weird stuff about the noise
# interpertation of its algo is interesting

# make texture dataset and test the methods on it
# texture generation: emerging conv?
# heuristic = lots of code
# dataset
# test on natual images eventually

# show it works when edeges important too (guided backprop first)
# could do saliency checks

# fourier transform could work for texture, if its the whole image (to get window size)
# do fft on quadrants of image to guess at scale, look at max fourier coeff =>
# should give rough idea of window size (top k coeffs?)

In [None]:
plt_grid_figure([explain_imgs, pca_map_s_strides, pca_map_1_strides, grad_maps], transpose=True, titles=["Image", "Strides=scale", "strides=1", "Gradient"])
# on the strides=8 final layer network

In [None]:
plt_grid_figure([explain_imgs, pca_map_s_strides, pca_map_1_strides, grad_maps], transpose=True, titles=["Image", "Strides=scale", "strides=1", "Gradient"])
# on the strides=2 final layer network

In [None]:
# surely this is the best way to do this :)
x = np.arange(400).reshape(5,5,4,4)
from itertools import permutations
for transp_1 in permutations([0,1,2,3]):
    for axis1 in range(3):
        for transp_2 in permutations([0,1,2]):
            for axis2 in range(2):
                for transp_3 in permutations([0,1]):
                    try:
                        t1 = x.transpose(*transp_1)
                        t2 = np.concatenate(t1, axis1)
                        t3 = t2.transpose(*transp_2)
                        t4 = np.concatenate(t3, axis2)
                        t5 = t4.transpose(*transp_3)
                        if t5.shape == (20,20):
                            if all(t5[0,:4] == np.arange(4)) and t5[0,4] == 16:
                                pass
                                #print(transp_1, axis1, transp_2, axis2, transp_3)
                    except:
                        continue
np.concatenate(np.concatenate(x,1),1)

So the question then becomes, how do we search for useful reference images/pixel values in general? We want the distance to be close to the image (small denominator), but also lead to large differences in output logits. This is dangerously close to finding adversarial directions, so we need to make sure we stay in the data manifold => need to establish some sort of distance metric, and potentially a way of detecting whether we are in manifold or not, so we can project into manifold space if needed. This also allows us to switch to a DeepLIFT style type thing.

# Model Optimization Stuff

In [None]:
res_net.save_model_state_dict(optim=optim)

In [None]:
generated_img, _, __ = valid_set.generate_one()
generated_img = torch.tensor(generated_img.transpose(2,0,1)).to(device).unsqueeze(0).float()
for _ in range(1000):
    small_net.forward(generated_img, profile=True)

In [None]:
total = sum(stats.values())  # --> gave 3x speed! (Fast and Accurate Model scaling?)
for k,v in stats.items():    # --> the 3x speedup caused underfitting though, so switched to 2x
    print(k,(100.*v/total))
