In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import time
import os
import matplotlib.pyplot as plt
from collections import defaultdict

import color_regions, network, visualizations, utils
from color_regions import *
from network import *
from visualizations import *
from utils import *
from hooks import *

torch.backends.cudnn.benchmark = True
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
# set up autoreloading of shared code
%load_ext autoreload
%autoreload 1
%aimport color_regions,network,visualizations,utils,hooks
%aimport

In [None]:
transform = transforms.Compose(
     [transforms.ToTensor()])#,
    #transforms.Normalize((0.5), (0.5))])

batch_size = 128  # seems to be the fastest batch size
train_indices = (0, 200_000) # size of training set
valid_indices = (1_250_000, 1_255_000)
test_indices = (260_000, 310_000)

def color_classifier(color):
    if color <= 100:  # medium difficulty (width = 75)
        return 0
    if 100 < color <= 150:  # hard difficulty (width = 50)
        return 1
    if 150 < color <= 200:  # hard difficulty (width = 50)
        return 2
    if 200 < color:  # hard difficulty (width = 50)
        return 1
critical_color_values = [100, 150, 200]

def set_loader_helper(indices):
    data_set = ColorDatasetGenerator(color_classifier=color_classifier,
                                    image_indices=indices,
                                    transform=transform,
                                    color_range=(25, 250),
                                    noise_size=(1,7),
                                    num_classes=3,
                                    size=128,
                                    radius=(128//6, 128//3))
    loader = torch.utils.data.DataLoader(data_set, batch_size=batch_size, 
                                          shuffle=True, num_workers=4, pin_memory=True)
    return data_set, loader
train_set, train_loader = set_loader_helper(train_indices)
valid_set, valid_loader = set_loader_helper(valid_indices)
test_set, test_loader = set_loader_helper(test_indices)

In [None]:
# the "medium" task
color_probe = np.linspace(0, 255, 255)
color_class = [color_classifier(x) for x in color_probe]
plt.plot(color_probe, color_class)
plt.xlabel("Color")
plt.yticks([0, 1, 2])
plt.ylabel("Class")

In [None]:
small_net = ResNet([[16, 3, 2],  # num_channels (input and output), kernel_size, stride
                    [32, 3, 2],
                    [64, 3, 2]], 3, [128, 128, 1], 
                   "small_net_noise_medium_grey.dict", fc_layers=[32]).to(device)
loss_func = nn.CrossEntropyLoss()
small_optim = torch.optim.Adam(small_net.parameters())
print(small_net.num_params())
small_net.load_model_state_dict(optim=small_optim)

In [None]:
small_net2 = ResNet([[16, 3, 2],  # num_channels (input and output), kernel_size, stride
                    [32, 3, 2],
                    [64, 3, 2]], 3, [128, 128, 1], 
                   "small_net_noise_medium_grey_2.dict", fc_layers=[32]).to(device)
# if 0:
#     small_net2.save_model_state_dict("small_net_init_noise_medium_grey.dict")
loss_func = nn.CrossEntropyLoss()
small_optim2 = torch.optim.Adam(small_net2.parameters())
print(small_net2.num_params())
small_net2.load_model_state_dict(optim=small_optim2)

In [None]:
tiny_net_med = ResNet([[2, 3, 4],  # num_channels (input and output), kernel_size, stride
                   [6, 3, 4]], 3, [128, 128, 1], 
                   "tiny_net_noise_medium_grey.dict", fc_layers=[32]).to(device)
# if 0:
#     tiny_net_med.save_model_state_dict("tiny_net_noise_init_medium_grey.dict")
loss_func = nn.CrossEntropyLoss()
tiny_optim_med = torch.optim.Adam(tiny_net_med.parameters())
print(tiny_net_med.num_params())
tiny_net_med.load_model_state_dict(optim=tiny_optim_med)

In [None]:
tiny_net_med2 = ResNet([[2, 3, 4],  # num_channels (input and output), kernel_size, stride
                   [6, 3, 4]], 3, [128, 128, 1], 
                   "tiny_net_noise_medium_grey_2.dict", fc_layers=[32]).to(device)
tiny_net_med2.load_model_state_dict("tiny_net_noise_init_medium_grey.dict")
loss_func = nn.CrossEntropyLoss()
tiny_optim_med2 = torch.optim.Adam(tiny_net_med2.parameters())
print(tiny_net_med2.num_params())
tiny_net_med2.load_model_state_dict(optim=tiny_optim_med2)
# didnt actually train this one for some reason

In [None]:
tiny_net_med3 = ResNet([[2, 3, 4],  # num_channels (input and output), kernel_size, stride
                   [6, 3, 4]], 3, [128, 128, 1], 
                   "tiny_net_noise_medium_grey_3.dict", fc_layers=[32]).to(device)
tiny_net_med3.load_model_state_dict("tiny_net_noise_init_medium_grey.dict")
loss_func = nn.CrossEntropyLoss()
tiny_optim_med3 = torch.optim.Adam(tiny_net_med3.parameters())
print(tiny_net_med3.num_params())
tiny_net_med3.load_model_state_dict(optim=tiny_optim_med3)

In [None]:
unstrided_net = ResNet([[16, 3, 1],  # num_channels (input and output), kernel_size, stride
                        [32, 3, 1],
                        [64, 3, 2]], 3, [128, 128, 1], 
                   "unstrided_net_noise_medium_grey.dict", fc_layers=[32]).to(device)
loss_func = nn.CrossEntropyLoss()
unstrided_optim = torch.optim.Adam(unstrided_net.parameters())
print(unstrided_net.num_params())
unstrided_net.load_model_state_dict(optim=unstrided_optim)

In [None]:
feature_gram = visualizations.fc_conv_feature_angles(unstrided_net, 
                            "fully_connected.0.act_func", num_embed=50, normalize=False)

In [None]:
visualizations.display_fc_conv_grams(feature_gram, selection=[1,4,17])

In [None]:
feature_gram.shape

In [None]:
imshow_centered_colorbar(np.arcsinh(feature_gram[5]), cmap="bwr")

In [None]:
random_net = ResNet([[16, 3, 1],  # num_channels (input and output), kernel_size, stride
                        [32, 3, 1],
                        [64, 3, 2]], 3, [128, 128, 1], 
                   "random_init_net_noise_medium_grey.dict", fc_layers=[32]).to(device)
# if 0:  # changed right away, so its good
#     random_net.save_model_state_dict()
random_net.path = "random_net_noise_medium_grey.dict"
loss_func = nn.CrossEntropyLoss()
random_optim = torch.optim.Adam(random_net.parameters())
print(random_net.num_params())
random_net.load_model_state_dict(optim=random_optim)

In [None]:
random_net2 = ResNet([[16, 3, 1],  # num_channels (input and output), kernel_size, stride
                        [32, 3, 1],
                        [64, 3, 2]], 3, [128, 128, 1], 
                   "random_init_net_noise_medium_grey.dict", fc_layers=[32]).to(device)
random_net2.load_model_state_dict()  # start from same initialization
random_net2.path = "random_net_noise_medium_grey_2.dict"
loss_func = nn.CrossEntropyLoss()
random_optim2 = torch.optim.Adam(random_net2.parameters())
print(random_net2.num_params())
random_net2.load_model_state_dict(optim=random_optim2)

In [None]:
random_net3 = ResNet([[16, 3, 1],  # num_channels (input and output), kernel_size, stride
                      [32, 3, 1],
                      [64, 3, 2]], 3, [128, 128, 1], 
                   "random_init_net_noise_medium_grey.dict", fc_layers=[32]).to(device)
random_net3.load_model_state_dict()  # start from same initialization
random_net3.path = "random_net_noise_medium_grey_3.dict"
loss_func = nn.CrossEntropyLoss()
random_optim3 = torch.optim.Adam(random_net3.parameters())
print(random_net3.num_params())
random_net3.load_model_state_dict(optim=random_optim3)

In [None]:
random_net4 = ResNet([[16, 3, 1],  # num_channels (input and output), kernel_size, stride
                      [32, 3, 1],
                      [64, 3, 2]], 3, [128, 128, 1], 
                   "random_init_net_noise_medium_grey.dict", fc_layers=[32]).to(device)
random_net4.load_model_state_dict()  # start from same initialization
random_net4.path = "random_net_noise_medium_grey_4.dict"
loss_func = nn.CrossEntropyLoss()
random_optim4 = torch.optim.Adam(random_net4.parameters())
print(random_net4.num_params())
random_net4.load_model_state_dict(optim=random_optim4)

In [None]:
random_net5 = ResNet([[16, 3, 1],  # num_channels (input and output), kernel_size, stride
                      [32, 3, 1],
                      [64, 3, 2]], 3, [128, 128, 1], 
                   "random_net_noise_medium_grey_5.dict", fc_layers=[32]).to(device)
random_net5.load_model_state_dict("random_init_net_noise_medium_grey.dict")  # start from same initialization
loss_func = nn.CrossEntropyLoss()
random_optim5 = torch.optim.Adam(random_net5.parameters())
print(random_net5.num_params())
random_net5.load_model_state_dict(optim=random_optim5)

In [None]:
random_net6 = ResNet([[16, 3, 1],  # num_channels (input and output), kernel_size, stride
                      [32, 3, 1],
                      [64, 3, 2]], 3, [128, 128, 1], 
                   "random_net_noise_medium_grey_6.dict", fc_layers=[32]).to(device)
random_net6.load_model_state_dict("random_init_net_noise_medium_grey.dict")  # start from same initialization
loss_func = nn.CrossEntropyLoss()
random_optim6 = torch.optim.Adam(random_net6.parameters())
print(random_net6.num_params())
random_net6.load_model_state_dict(optim=random_optim6)

In [None]:
random_net7 = ResNet([[16, 3, 1],  # num_channels (input and output), kernel_size, stride
                      [32, 3, 1],
                      [64, 3, 2]], 3, [128, 128, 1], 
                   "random_net_noise_medium_grey_7.dict", fc_layers=[32]).to(device)
loss_func = nn.CrossEntropyLoss()  # dont start from same initialization
random_optim7 = torch.optim.Adam(random_net7.parameters())
print(random_net7.num_params())
random_net7.load_model_state_dict(optim=random_optim7)

In [None]:
results = train(random_net7, random_optim7, loss_func, 100, train_loader, valid_loader, device=device,
                track_stat=final_activation_tracker, summarize=count_logit_usage, 
                log_file="logs/random_net7.pkl", test_loader=test_loader)

In [None]:
plot_results(results, "RandomNet7")

In [None]:
results = train(random_net6, random_optim6, loss_func, 100, train_loader, valid_loader, device=device,
                track_stat=final_activation_tracker, summarize=count_logit_usage, 
                log_file="logs/random_net6.pkl", test_loader=test_loader)

# double check the correlation computation

In [None]:
plot_results(results, "RandomNet6", size=0.2, alpha=1)

In [None]:
results = train(random_net5, random_optim5, loss_func, 100, train_loader, valid_loader, device=device,
                track_stat=final_activation_tracker, summarize=count_logit_usage, 
                log_file="logs/random_net5.pkl")

In [None]:
plot_results(results, "RandomNet5")

In [None]:
plot_corr_grad_logit_info(results[-2], "RandomNet5")

In [None]:
results = train(tiny_net_med3, tiny_optim_med3, loss_func, 100, train_loader, valid_loader, device=device,
                track_stat=final_activation_tracker, summarize=count_logit_usage)

In [None]:
plot_results(results, "TinyNetMedium3") 

In [None]:
network.plot_corr_grad_logit_info(results[-2], "TinyNetMedium3")

In [None]:
tiny_results = train(tiny_net_med, tiny_optim_med, loss_func, 100, train_loader, valid_loader, device=device,
                track_stat=final_activation_histogram, summarize=count_logit_usage)

In [None]:
plot_results(tiny_results, "TinyNetMedium") # initial is 76998

In [None]:
results = train(small_net2, small_optim2, loss_func, 100, train_loader, valid_loader, device=device,
                track_stat=final_activation_histogram, summarize=count_logit_usage)

In [None]:
plot_results(results, "SmallNet2") # initial is 36031

In [None]:
results = train(random_net4, random_optim4, loss_func, 100, train_loader, valid_loader, device=device,
                track_stat=final_activation_histogram, summarize=count_logit_usage)

In [None]:
plot_results(results, "RandomNet4") # starts to increase num_logits used when tr_loss drops below 3
                   # initial num_logits 61498
                   # doesnt seem to directly correspond to generalization capacity
                   # (overfitting, EMC curve)
# human loss curve = concave
# AI should be that too

# bad hidden units could be same across inputs or amplifying noise in input
# already converged and happy (low gradient)
# not converged but still ok (high grad)
# completely useless, gradient is 0 again

# correlation with outputs to check these cases and magnitude of gradient
# do correlation before ReLU so there is at least some signal
# guess: start with useful units, 
# try more times

# somewhat related to grokking?
# fourier components, need small data to get first components, need more to get later ones
# => physicist!
# also inverse graphics (learning model from image joshua b tenenbaum, then do physics on it)

In [None]:
results = train(random_net3, random_optim3, loss_func, 200, train_loader, valid_loader, device=device)

In [None]:
results = train(random_net2, random_optim2, loss_func, 200, train_loader, valid_loader, device=device)

In [None]:
results = train(random_net, random_optim, loss_func, 200, train_loader, valid_loader, device=device)

In [None]:
results = train(unstrided_net, unstrided_optim, loss_func, 200)
# never actually reached the tr_loss < 3 regime, so wouldnt have started to increase 
# its num_logits_used!

In [None]:
results = train(small_net, small_optim, loss_func, 200)

In [None]:
rate_distribution(unstrided_net, valid_loader, valid_set,
                  critical_values=critical_color_values, device=device)

In [None]:
response_graph(unstrided_net, valid_set, device=device)

In [None]:
unstrided_net.eval()
for sample in valid_loader:
    #explain_img, explain_target_logit, *__ = valid_set.generate_one()
    imgs = sample["image"].to(device).float()
    labels = sample["label"].to(device).float()
    model_outpt = unstrided_net(imgs)
    print(model_outpt.argmax(dim=1))
    print(labels.argmax(dim=1))
    print(correct(model_outpt, labels))
    break
    #plt.imshow(explain_img)

In [None]:
# to further test the "using 1 image => bad batchnorm estimates" lets do the same test
# but instead we will average over a sample of responses
small_net.eval() # very important!
stack_size = 32
sampled_indices = 1_250_000 + np.random.choice(1000, stack_size, replace=False)
total_images = stack_size * 255
correct_num = 0
with torch.no_grad():
    counterfactual_color_values = np.linspace(0, 255, 255)
    responses = []
    for color in tqdm(counterfactual_color_values):
        stacked_generated_img = []
        for sampled_index in sampled_indices:
            np.random.seed(sampled_index)
            generated_img, lbl, *__ = valid_set.generate_one(set_color=color)
            stacked_generated_img.append(generated_img)
        stacked_generated_img = np.array(stacked_generated_img).transpose(0, 3, 1, 2)
        generated_img = torch.tensor(stacked_generated_img).to(device).float()
        response = small_net(generated_img, logits=True)
        stacked_lbl = torch.tensor(np.repeat(np.expand_dims(lbl, 0), stack_size, axis=0)).to(device)
        correct_num += correct(response, stacked_lbl).sum()
        responses.append(np.squeeze(response.cpu().numpy()).mean(axis=0))
print(correct_num/total_images, "total accuracy")

In [None]:
responses = np.arcsinh(np.array(responses))  # this graph is quite robust to changes in batch size
for output_logit in range(responses.shape[1]): # if we do .eval(), but varies if we do .train()
    plt.plot(counterfactual_color_values, responses[:, output_logit], label=f"class {output_logit}")
plt.legend()
plt.xlabel("Color value")
plt.ylabel("Network output logit")
plt.vlines([100, 150], np.min(responses), np.max(responses), linewidth=0.8,
           colors="r", label="decision boundary", # probably because .train() in this case actually gives biased estimates because all the colors are the same
           linestyles="dashed")  # logit graphs look bad if doing .train(), and accuracy is lower?

In [None]:
res_net.eval()   # ---> without this line, it fails, especially with small colors
with torch.no_grad():  # => batchnorm updates are very inaccurate if just one image
    idx = 1_250_026   # => network expects batchnorm updates to basically be exactly in the "middle" 127
    print(valid_set[idx])  # => fix the logit response graph by requiring it to be in eval mode
    print(torch.softmax(res_net(torch.unsqueeze(valid_set[idx]["image"], 0).to(device).float()), 1))

In [None]:
# to test the above hypothesis, if we just stack the same image a bunch of times, and do .train()
# the estimates should still be bad because batchnorm estimates would be just as bad as with 
# a single image in the batch
res_net.train()
idx = 1_250_024  # => hypothesis seems to be confirmed
test_image = valid_set[idx]["image"].numpy()
stacked_test = np.repeat(np.expand_dims(test_image, 0), 32, axis=0)
print(valid_set[idx])
print(torch.softmax(res_net(torch.tensor(stacked_test).to(device).float()), 1)[0])

In [None]:
plt.imshow(np.squeeze(generated_img.cpu().numpy()), cmap="gray")

In [None]:
res_net.train()
with torch.no_grad():
    for i, sample in enumerate(valid_loader):
        imgs = sample["image"].to(device).float()
        labels = sample["label"].to(device).float()
        print(sample["color"])
        outputs = res_net(imgs)
        print(loss_func(outputs, labels).item())
        print(torch.argmax(labels,dim=1), torch.argmax(outputs, dim=1))
        print(correct(outputs, labels).sum().item())
        break

In [None]:
np.random.seed(500_001)
explain_img, lbl, *_ = valid_set.generate_one()
heat_map = finite_differences_map(small_net, valid_set, lbl.argmax(), explain_img, device=device)

In [None]:
plt.subplot(1, 2, 1)
plt.imshow(explain_img, cmap="gray")
plt.subplot(1, 2, 2)
plt.imshow(heat_map, cmap="bwr", interpolation="bilinear")
plt.colorbar()
# generated with strides = 2 everywhere

In [None]:
plt_grid_figure([explain_img, heat_map], transpose=False, colorbar=True)
# generated with strides = 1, strides = 8 for lats layer

In [None]:
image_ids = [20_000, 25_000, 30_000, 600_000, 600_001]
heat_maps = []
explain_imgs = []
for i, image_id in enumerate(image_ids):
    np.random.seed(image_id)
    explain_img_i, target_i, *__ = valid_set.generate_one()
    heat_map_i = finite_differences_map(small_net, valid_set, target_i.argmax(), explain_img_i)
    heat_maps.append(heat_map_i)
    explain_imgs.append(explain_img_i)

In [None]:
plt_grid_figure([explain_imgs, heat_maps], transpose=True, colorbar=True)
# generated with strides=2 everywhere

In [None]:
plt_grid_figure([explain_imgs, heat_maps], transpose=True, colorbar=True)
# generated with strides = 1, strides = 8 for lats layer

What if we run the same experiment, but cheat with a prior on pixel values that we know *should* be informative to the output logit, namely values closest to the decision boundary?

In [None]:
unfair_prior = np.array([90, 110, 140, 160])  #  close to the critical values of 100, 150
unfair_heat_maps = []
plt.figure(figsize=(12, 5*len(image_ids)))
for i, image_id in enumerate(image_ids):
    np.random.seed(image_id)
    explain_img_i, target_i, __ = valid_set.generate_one()
    unfair_map_i = finite_differences_map(res_net, valid_set, target_i.argmax(), explain_img_i, unfairness="unfair", values_prior=unfair_prior)
    unfair_heat_maps.append(unfair_map_i)
    plt.subplot(len(image_ids), 3, 3*i+1)
    plt.imshow(explain_img_i, cmap="gray")
    plt.subplot(len(image_ids), 3, 3*i+2)
    heat_max = np.max(abs(unfair_map_i))
    plt.imshow(unfair_map_i, cmap="bwr", vmax=heat_max, vmin=-heat_max)
    plt.colorbar(shrink=0.5)
    plt.subplot(len(image_ids), 3, 3*i+3)
    heat_max = np.max(abs(heat_maps[i]))
    plt.imshow(heat_maps[i], cmap="bwr", vmax=heat_max, vmin=-heat_max)
    plt.colorbar(shrink=0.5)
plt.show()  # => very similar results, but with a 4x speedup

In [None]:
# regenerate unfair FD maps
unfair_prior = np.array([90, 110, 140, 160])  #  close to the critical values of 100, 150
image_ids = [20_000, 25_000, 30_000, 600_000, 600_001, 227_662, 998_102, 106_758]
unfair_heat_maps = []
#plt.figure(figsize=(12, 5*len(image_ids)))
for i, image_id in enumerate(image_ids):
    np.random.seed(image_id)
    explain_img_i, target_i, __ = valid_set.generate_one()
    unfair_map_i = finite_differences_map(res_net, valid_set, target_i.argmax(), explain_img_i, unfairness="unfair", values_prior=unfair_prior)
    unfair_heat_maps.append(unfair_map_i)

We can do even better by taking the "closest value in a different class" for our prior

In [None]:
very_unfair_heat_maps = []
plt.figure(figsize=(20, 5*len(image_ids)))
for i, image_id in enumerate(image_ids):
    np.random.seed(image_id)
    explain_img_i, target_i, color_i = valid_set.generate_one()
    very_unfair_map_i = finite_differences_map(res_net, valid_set, target_i.argmax(), explain_img_i, unfairness="very unfair", values_prior=[100, 150])
    very_unfair_heat_maps.append(very_unfair_map_i)
    
    plt.subplot(len(image_ids), 5, 5*i+1)
    if i == 0:
        plt.title("Image")
    plt.imshow(explain_img_i, cmap="gray")
    
    plt.subplot(len(image_ids), 5, 5*i+2)
    if i == 0:
        plt.title("Very unfair FD map")
    heat_max = np.max(abs(very_unfair_map_i))
    plt.imshow(very_unfair_map_i, cmap="bwr", vmax=heat_max, vmin=-heat_max)
    plt.colorbar(shrink=0.5)
    
    plt.subplot(len(image_ids), 5, 5*i+3)
    if i == 0:
        plt.title("Unfair FD map")
    heat_max = np.max(abs(unfair_heat_maps[i]))
    plt.imshow(unfair_heat_maps[i], cmap="bwr", vmax=heat_max, vmin=-heat_max)
    plt.colorbar(shrink=0.5)
    
    plt.subplot(len(image_ids), 5, 5*i+4)
    if i == 0:
        plt.title("FD map")
    heat_max = np.max(abs(heat_maps[i]))
    plt.imshow(heat_maps[i], cmap="bwr", vmax=heat_max, vmin=-heat_max)
    plt.colorbar(shrink=0.5)
    
    plt.subplot(len(image_ids), 5, 5*i+5)
    if i == 0:
        plt.title("Location in color space")
    plt.plot(color_probe, color_class) 
    plt.vlines([color_i], 0, valid_set.num_classes-1, linewidth=0.8,
           colors="r", label="color value",
           linestyles="dashed")
plt.show()  # => somewhat similar results (see image 2), but with an overall ~11x speedup

In [None]:
grad_heat_maps = []
plt.figure(figsize=(20, 5*len(image_ids)))
for i, image_id in tqdm(enumerate(image_ids)):
    np.random.seed(image_id)
    explain_img_i, target_i, color_i = valid_set.generate_one()
    batched_explain_img_i = torch.tensor(np.expand_dims(explain_img_i, 0).transpose(0, 3, 1, 2), requires_grad=True).to(device).float()
    output_logit_i = res_net(batched_explain_img_i)[0, target_i.argmax()]
    
    img_grad_i = torch.autograd.grad(output_logit_i, batched_explain_img_i)[0].squeeze().cpu().numpy()
    grad_times_input_i = img_grad_i * np.squeeze(explain_img_i)
    
    plt.subplot(len(image_ids), 5, 5*i+1)
    if i == 0:
        plt.title("Image")
    plt.imshow(explain_img_i, cmap="gray")
    
    plt.subplot(len(image_ids), 5, 5*i+2)
    if i == 0:
        plt.title("Input*Gradient explanation")
    heat_max = np.max(abs(grad_times_input_i))
    plt.imshow(grad_times_input_i, cmap="bwr", vmax=heat_max, vmin=-heat_max)
    plt.colorbar(shrink=0.5)
    
    plt.subplot(len(image_ids), 5, 5*i+3)
    if i == 0:
        plt.title("Gradient explanation")
    heat_max = np.max(abs(img_grad_i))
    plt.imshow(img_grad_i, cmap="bwr", vmax=heat_max, vmin=-heat_max)
    plt.colorbar(shrink=0.5)
    
    plt.subplot(len(image_ids), 5, 5*i+4)
    if i == 0:
        plt.title("FD explanation (unfair)")
    heat_max = np.max(abs(unfair_heat_maps[i]))
    plt.imshow(unfair_heat_maps[i], cmap="bwr", vmax=heat_max, vmin=-heat_max)
    plt.colorbar(shrink=0.5)
    
    plt.subplot(len(image_ids), 5, 5*i+5)
    if i == 0:
        plt.title("Location in color space")
    plt.plot(color_probe, color_class) 
    plt.vlines([color_i], 0, valid_set.num_classes-1, linewidth=0.8,
           colors="r", label="color value",
           linestyles="dashed")
plt.show()
# gradient should be zero, so double check computations, fix scale on cmap

# PCA Direction Tests

In [None]:
class DummyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(2,2)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.linear(x)
        print("\tBefore ReLUs", x)
        return self.relu(x)
dummy_net = DummyNet()
dummy_net.linear._parameters["weight"].data = torch.nn.Parameter(torch.tensor([[1., 0], [0, 1]]))
dummy_net.linear._parameters["bias"].data = torch.nn.Parameter(torch.tensor([200., 200]))
print("Network parameters", dummy_net.linear._parameters)
print("WITHOUT GUIDED BACKPROP")
inpt = torch.tensor([1., -1.], requires_grad=True)
result = dummy_net(inpt)
print("\tNetwork output", result)
print("\t'Loss'", -result[0]+result[1])
print("\tResulting gradients", torch.autograd.grad(-result[0]+result[1], inpt))
print("WITH GUIDED BACKPROP")
guided_dummy = GuidedBackprop(dummy_net)
result = guided_dummy(inpt, preserve_hooks=False)
print("\tNetwork output", result)
print("\t'Loss'", -result[0]+result[1])
print("\tResulting gradients", torch.autograd.grad(-result[0]+result[1], inpt))
print("GUIDED BACKPROP AGAIN (should auto-clean now)")
inpt = torch.tensor([1., -1.], requires_grad=True)
result = dummy_net(inpt)
print("\tNetwork output", result)
print("\t'Loss'", -result[0]+result[1])
print("\tResulting gradients", torch.autograd.grad(-result[0]+result[1], inpt))

In [None]:
default_scales = [3,5,7,9,13,15]
if 1: 
    %store -r color_pca_directions_1_stride color_pca_directions_s_stride
else:
    color_pca_directions_1_stride = find_pca_directions(valid_set, 16384, default_scales, 1)
    color_pca_directions_s_stride = find_pca_directions(valid_set, 16384, default_scales, default_scales)
    %store color_pca_directions_1_stride color_pca_directions_s_stride

In [None]:
visualize_pca_directions(color_pca_directions_1_stride, "Strides=1", default_scales, lines=False)

In [None]:
visualize_pca_directions(color_pca_directions_s_stride, "Strides=scales", default_scales, lines=True)

In [None]:
np.random.seed(200_010)
generated_img, label, *__ = valid_set.generate_one()
pca_map_strided = pca_direction_grids(small_net, valid_set, label.argmax(), generated_img, 
                                      default_scales, color_pca_directions_1_stride, strides=default_scales,
                                      device=device, batch_size=128, component=0)
pca_map_1_stride = pca_direction_grids(small_net, valid_set, label.argmax(), generated_img, 
                                      default_scales, color_pca_directions_1_stride, component=0, 
                                      device=device, batch_size=128, strides=1)

In [None]:
plt_grid_figure([generated_img, pca_map_strided])
# generated on
# small_net = ResNet([[16, 3, 2],  # num_channels (input and output), kernel_size, stride
#                     [32, 3, 2],
#                     [64, 3, 2]], 3, [128, 128, 1], 
#                    "small_net_noise_medium_grey.dict",
# => with strides == scales, unguided

In [None]:
plt_grid_figure([generated_img, pca_map_1_stride])
# generated on
# small_net = ResNet([[16, 3, 2],  # num_channels (input and output), kernel_size, stride
#                     [32, 3, 2],
#                     [64, 3, 2]], 3, [128, 128, 1], 
#                    "small_net_noise_medium_grey.dict",
# => with strides == 1, unguided

In [None]:
seeds = [1_2123, 1_40_124, 1_508_559, 1_5_019_258, 1_2_429_852, 9032, 5832, 12, 5014, 92, 42, 52, 
         52_934, 935_152, 1_000_000, 1_000_001, 27, 24, 512, 999_105]  # 20 
# def generate_many_pca(net, seeds, pca_directions_1_stride, scales, dataset, 
#         component=0, batch_size=128, strides=None, skip_1_stride=False, device=None)

In [None]:
# "unittest" for pca_direction_maps
fake_pca_directions_s_strides = []
for i, scale in enumerate(default_scales):
    fake_pca_directions_s_strides.append(color_pca_directions_1_stride[i][::scale, ::scale])
np.random.seed(12123)
test_img, label, *__ = valid_set.generate_one()
new_pca_map = pca_direction_grids(unstrided_net, valid_set, label.argmax(), test_img, 
                    default_scales, color_pca_directions_1_stride, strides=default_scales,
                    device=device, batch_size=128, component=0)
old_pca_map = old_old_pca_direction_grids(unstrided_net, valid_set, label.argmax(), test_img, 
                    pca_direction_grids=fake_pca_directions_s_strides, 
                    scales=default_scales, device=device)
print(abs(new_pca_map - old_pca_map).max())  # should be very low
print(abs(new_pca_map - old_pca_map).mean())

In [None]:
pca_map_s_strides, pca_map_1_strides, grad_maps, explain_imgs = generate_many_pca(unstrided_net, component=0, strided_scales=3)

In [None]:
guided_net = GuidedBackprop(unstrided_net)
guided_pca_map_s_strides, guided_pca_map_1_strides, guided_grad_maps, explain_imgs = generate_many_pca(guided_net, component=0, strided_scales=3)

In [None]:
plt_grid_figure([explain_imgs, guided_pca_map_s_strides, guided_pca_map_1_strides, guided_grad_maps, pca_map_s_strides, pca_map_1_strides, grad_maps], transpose=True, titles=["Image", "Guided Strides=scale", "Guided strides=1", "Guided Gradient", "Strides=scale", "strides=1", "Gradient"])
# on the strides=2 final layer network  (unstrided_net)
# with the new, correct PCA implementation
# and sample size of 16384 on PCA directions
# also it isnt actually strides=scales, but strides=3

# how do we use this to make a better classifier (the binary code behaviour)
# dead neuron caused by no gradient (not unusual)
# cutting them out doesn't work that well

In [None]:
# what if we disable guided backprop for last layer?
np.random.seed(seeds[3])
generated_img, label, *__ = valid_set.generate_one()
tensored_img = tensorize(generated_img, device=device, requires_grad=True)
guided_net = GuidedBackprop(unstrided_net, exceptions=["fully_connected.0.act_func"])
pca_map_strided = pca_direction_grids(guided_net, valid_set, label.argmax(), generated_img, 
                                      default_scales, color_pca_directions_1_stride, strides=3,
                                      device=device, batch_size=128, component=0)

In [None]:
plt_grid_figure([generated_img, pca_map_strided])
# initially seems promising

In [None]:
# now do the same for all
guided_net = GuidedBackprop(unstrided_net, exceptions=["fully_connected.0.act_func"])
guided_pca_map_s_strides, _, guided_grad_maps, explain_imgs = generate_many_pca(guided_net, component=0, strided_scales=3, skip_1_stride=True)

In [None]:
plt_grid_figure([explain_imgs, guided_pca_map_s_strides, guided_grad_maps], transpose=True, titles=["Image", "EGuided Strides=3", "EGuided Gradient"])
# related to when you start getting non-monotonic
# GB = only care about positive paths = only monotonic functions
# its only the increasing ones
# makes sense to put negative weights at the end? 
# sample how many weights in each layer are +ve/-ve?

In [None]:
plt_grid_figure([explain_imgs, guided_pca_map_s_strides, guided_pca_map_1_strides, guided_grad_maps, pca_map_s_strides, pca_map_1_strides, grad_maps], transpose=True, titles=["Image", "Guided Strides=scale", "Guided strides=1", "Guided Gradient", "Strides=scale", "strides=1", "Gradient"])
# on the strides=2 final layer network (unstrided_net)
# with the old, slightly incorrect PCA implementation
# and sample size of only 2048 on PCA directions

# Investigation into Guided Backprop Failure Cases

In [None]:
def draw_classes(min_y, max_y, ax=None, alpha=0.25):
    if ax is None:
        ax = plt.gca()
    color_probe = np.arange(255)
    classified = np.vectorize(valid_set.color_classifier)(color_probe)
    classified = classified/(valid_set.num_classes-1)*(max_y-min_y) + min_y
    ax.plot(classified, c="k", alpha=alpha)

In [None]:
@torch.no_grad()
def top_k_activating(net, loader, activation_getter, k=50, device=None):
    # assume already has AllActivations hooks
    net.eval()
    best_seeds = np.zeros(k)
    best_colors = np.zeros(k)
    seed_activations = np.full(k, -np.inf)
    for i, sample in tqdm(enumerate(loader)):
        imgs = sample["image"].to(device).float()
        seeds = sample["seeds"]
        colors = sample["color"]
        net(imgs)
        activs = activation_getter(net._features)
        for activ,seed,color in zip(activs, seeds, colors):
            if activ > min(seed_activations):
                lowest_entry = np.argmin(seed_activations)
                seed_activations[lowest_entry] = activ
                best_seeds[lowest_entry] = seed
                best_colors[lowest_entry] = color
    sorted_indices = seed_activations.argsort()[::-1]
    data, *_ = plt.hist(best_colors, bins=255)
    draw_classes(0, max(data))
    return best_seeds[sorted_indices], seed_activations[sorted_indices]
    
interp_net = AllActivations(unstrided_net)
def logit_1(features):
    return features["fully_connected.0.act_func"].sum(axis=1).detach().cpu().numpy()
best_seeds,_ = top_k_activating(interp_net, valid_loader, logit_1, device=device)

In [None]:
best_seeds,_

In [None]:
for seed in result[0]:
    np.random.seed(int(seed))
    explain_img, label, color, *_ = valid_set.generate_one()
    print(color)

In [None]:
uniq = set()
for x in color_distrib.values():     # random_net
    uniq = uniq.union(set(x.keys()))    # => did NOT learn the same 3 logit structure
print(len(uniq))
uniq

In [None]:
uniq = set()
for x in color_distrib.values():     # random_net2  learned more compact structure??
    uniq = uniq.union(set(x.keys()))
print(len(uniq))
uniq

In [None]:
uniq = set()
for x in color_distrib.values():     # random_net3 
    uniq = uniq.union(set(x.keys()))    # similarly fails to replicate unstrided_net
print(len(uniq))
uniq

In [None]:
uniq = set()
for x in color_distrib.values():
    uniq = uniq.union(set(x.keys()))
pattern_to_names = {}
for pattern in uniq:
    pattern_to_names[pattern] = pattern[1+2*1] + pattern[1+2*4] + pattern[1+2*17]
# these were the only non-zero logits across ALL cases!
# seems to be some sort of binary code
pattern_to_names

In [None]:
pattern_totals = defaultdict(int)
for color, values in color_distrib.items():
    for pattern, count in values.items():
        pattern_totals[pattern] += count
# counts of the different "binary codes"
# top 4 patterns are the 3 "pure codes"
# 1. pure
# 2. all 3
# 3. pure
# 4. pure
pattern_totals

In [None]:
for pattern in uniq:
    amounts = np.zeros((255,))
    for color, distrib in color_distrib.items():
        amounts[color] = distrib[pattern]
    plt.plot(amounts, label=pattern_to_names[pattern])
color_probe = np.linspace(0, 255, 255)
color_class = [color_classifier(x)*400 for x in color_probe]
plt.plot(color_probe, color_class, label="classes", linestyle="dotted")
plt.legend()
# transition regimes
# pure regimes
# underlying "noise"
# next plot to do is just the raw activations of the 3 classes
# predicts which images will be good and which will not

# [0, 100] => Varying hints of a circle, square-ish/sparse behaviour (pure 100, class 0)
# [100, 150] => Square-ish/sparse behaviour (all 3, class 1)
# [150, 200] => all zeros  (pure 001, class 2)
# [200, 250] => excellent quality  (pure 010, class 1)

# do activation thing

In [None]:
profile_plots,_ = activation_color_profile(AllActivations(unstrided_net), valid_loader, valid_set, device=device)

In [None]:
show_profile_plots(profile_plots, [f"fully_connected.0.act_func_{x}" for x in [1,4,17]],
                   fixed_height=True)
# color profiles of hidden units 1, 4, 17, for unstrided_net
# see that in the '111' regime, its "equally units 17 and 4, small unit 1"
# class 2 is very strongly "no 1 or 4, only 17"
# class 0 is almost entirely unit 1 (some small amount of unit 17, but basically no unit 4)
# class 1 has a different expression in the >200 color regime (as compared to its expression
#    in the '111', aka 100 < color < 150 regime), in being 
#    "mostly unit 4, some unit 17, little unit 1" (4 increase, 17 decrease)

In [None]:
show_conv_weights(unstrided_net, "conv_blocks.0.act_func1", color_profile=profile_plots, size_mul=(6,12), fixed_height=True)

In [None]:
show_conv_weights(unstrided_net, "conv_blocks.0.act_func2", color_profile=profile_plots, size_mul=(3,6))

In [None]:
uniform_inpt = torch.full((1,1,32,32), 200.0).to(device)
uniform_out = unstrided_net.conv_blocks[0](uniform_inpt)
plt.imshow(uniform_out[0,12].detach().cpu().numpy())

In [None]:
unstrided_net.eval()
c = 12

#uniform_inpt = torch.full((1,16,32,32), 100.0).to(device)
#plt.imshow(unstrided_net.conv_blocks[0].conv2.weight[c, in_c].detach().cpu().numpy(), cmap="bwr")
conv_maps = unstrided_net.conv_blocks[0].conv2.weight[c, :]
imshow_centered_colorbar(conv_maps[7].detach().cpu().numpy(), cmap="bwr")
conv_scale = conv_maps.sum(axis=-1).sum(axis=-1)
conv_shift = unstrided_net.conv_blocks[0].conv2.bias[c]
bn_scale = unstrided_net.conv_blocks[0].batch_norm2.weight[c]
bn_shift = unstrided_net.conv_blocks[0].batch_norm2.bias[c]
bn_var = unstrided_net.conv_blocks[0].batch_norm2.running_var[c]
bn_mean = unstrided_net.conv_blocks[0].batch_norm2.running_mean[c]
print(conv_shift, bn_scale, bn_shift, bn_var, bn_mean)
#(c*conv_scale + conv_shift - bn_mean) / torch.sqrt(bn_var) * bn_scale + bn_shift
slope = (conv_scale/torch.sqrt(bn_var)*bn_scale).detach().cpu().numpy()
bias = ((conv_shift - bn_mean)/torch.sqrt(bn_var)*bn_scale + bn_shift).detach().cpu().numpy()

lines = np.asarray([profile_plots[f"conv_blocks.0.act_func1_{x}"][0] for x in range(16)])

uniform_scaling = slope.dot(lines) + bias


In [None]:
uniform_scaling[uniform_scaling < 0] = 0
plt.plot(uniform_scaling)

In [None]:
last_layer_weight = unstrided_net._modules["fully_connected"][-1].fully_connected.weight.detach().cpu().numpy()
imshow_centered_colorbar(last_layer_weight, cmap="bwr", title="Last Layer FC weights", colorbar=False)
plt.vlines([1, 4, 17], ymin=0, ymax=2.5)

last_layer_bias = unstrided_net._modules["fully_connected"][-1].fully_connected.bias.detach().cpu().numpy()
print(last_layer_bias)
# in column 17, all weights are negative, so guided backprop means we immediately zero everything out
# relevant columns have been highlighted

# column 1 mostly means class 0, very strongly not class 2
# column 4 mostly means class 1, equally strongly not class 0 and 2
# column 17 means not class 0, not class 1, barely class 2 => problematic

# since class 2 is "default class" (largest bias), the negative weights in column 17 are fine
# rightmost is easiest 

In [None]:
seeds = [1_2123, 1_40_124, 1_508_559, 1_5_019_258, 1_2_429_852, 9032, 5832, 12, 5014, 92, 42, 52, 
         52_934, 935_152, 1_000_000, 1_000_001, 27, 24, 512, 999_105]  # 20 
for i, seed in enumerate(seeds):
    np.random.seed(seed)
    print(i, valid_set.generate_one()[2])

In [None]:
plt.figure(figsize=(6*4, 6*8))
first_fc = unstrided_net._modules["fully_connected"][0].fully_connected.weight.detach().cpu().numpy()
# only bother visualizing outputs 1, 4, and 17 (add others just to see)
relevant_outputs = range(32)
for i, output_col in enumerate(relevant_outputs):
    fc_weights = np.concatenate(np.concatenate(first_fc[output_col].reshape(8, 8, 63, 63),1),1)
    plt.subplot(8,4,i+1)
    imshow_centered_colorbar(fc_weights, cmap="bwr", title=f"FC weights of {output_col}")
# weights of final conv -> first fully connected

# 4-5 shared "empty" maps across useful logits => useless channels??

# antipodal-superposition-like structure in weights of useful logits
# eg. logit 4 and logit 17, map (0,1)
# eg. logit 1 and logit 4, map (1,2), map (-2, -1)

# "useful logits" are of much higher overall norm (~2) and more positive than other logits
# other maps are usually in range of ~0.02, exception is logit 16, which is ~0.2
# could indicate a logit that was used in early stages of training but was eventually dropped?


# epicycles 

# how to decide when to use other PCA components?
# large checkrboard seperated by diagonal on top its light, bottom its dark
# componont you want is 

# narrative:
# investigating classifying by color
# investigating attribution methods in cases when global context is important
# generalization of eigenfaces to more context windows
# guided backprop = edges (conv nets = edges too)
# force dataset to not be edges, what does conv net do?
# it can do, and heres how it does it
# logits not monotonic on color
#    if hidden units also non-monotonic, then what is the purpose of last layer?
#     initial conv layers are monotonic (almost linear transformation)

# classifying wiht conv nets when global context matters
# object detection = hierarchical but still template matching
# this task is not template matching at all
# patches of color cant be capture with edge/template at all
# write up mechanistic_interp.ipynb, combine with this stuff
# in addition, if we can find good way to use PCA, add to this paper or for a future one

# large network random weights some get lucky to be useful

# 2 pagse, unlimited appendix
# template for ICLR

# claim: learning non-monotonic functions is difficult, superposition
# reason most of the neurons are initialized so poorly that they can't become useful thru
# gradient descent (or at least the other ones [1, 4, 17]) were close enough that they 
# came to dominate over them
# how to check, keep snapshot of initial random network, see what happens if you zero-out
# the weights that ended up being good

In [None]:
plt.figure(figsize=(6*4, 6*9))
first_fc = unstrided_net._modules["fully_connected"][0].fully_connected.weight.detach().cpu().numpy()
# only bother visualizing outputs 1, 4, and 17 (add others just to see)
relevant_outputs = [1, 4, 17]
for i, output_col in enumerate(relevant_outputs):
    fc_weights = np.concatenate(np.concatenate(first_fc[output_col].reshape(8, 8, 63, 63),1),1)
    plt.subplot(1,3,i+1)
    imshow_centered_colorbar(fc_weights, cmap="bwr", title=f"FC weights of {output_col}")
# same figure but only the relevant ones
# maybe the zero-ed out channels are dead neurons (old idea)
# without batchnorm/dropout style things, dead will stay dead (so some get unlucky)
# batchnorm is supposed to resample => sometimes allows positive grads

In [None]:
# evaluate network activations, display tanh(activations) (for better scaling)
def visualize_network_activations(net, seed):
    net.eval()
    np.random.seed(seed)
    working_img, label, color, *_____ = valid_set.generate_one()
    print(color, label.argmax())
    tensored_img = tensorize(working_img, device=device)
    debug_net = AllActivations(net)
    debug_net(tensored_img)

    plt.figure(figsize=(6*4, 12))
    final_relu = debug_net._features["conv_blocks.2.act_func2"].detach().cpu().numpy()[0].reshape(8,8,63,63)
    print("Logits 1,4,17", debug_net._features["fully_connected.0.act_func"][0, (1,4,17)])
    compressed_results = np.tanh(np.concatenate(np.concatenate(final_relu, 1), 1))

    imshow_centered_colorbar(compressed_results, cmap="bwr", title=f"Post ReLU final conv layer activations (color={color})")
    plt.show()

In [None]:
visualize_network_activations(unstrided_net, seeds[9])
# 111 regime, brightest map is (2, -1)

In [None]:
visualize_network_activations(unstrided_net, seeds[19])
# good regime, 010
# bright areas are (4, 2) and (6, 4)  => matches FC weights

In [None]:
visualize_network_activations(unstrided_net, seeds[0])
# in 100 regime, bright maps are (2, 1), (3, 2), (1, 3), and (-2, -1)

In [None]:
visualize_network_activations(unstrided_net, seeds[11])
# in 001 regime, bright maps are  (-2, 4)
# both logit 1 and logit 17 have big negatives there

In [None]:
plt.figure(figsize=(6*4, 12))
final_relu = debug_net._features["conv_blocks.2.act_func2"].detach().cpu().numpy()[0].reshape(8,8,63,63)
compressed_results = np.tanh(np.concatenate(np.concatenate(final_relu, 1), 1))

imshow_centered_colorbar(compressed_results, cmap="bwr", title="Post ReLU final conv layer activations")
# with tanh applied to activations (note that ReLU is still the actual activation function)
# only certain maps seem important, this is 

In [None]:
plt.figure(figsize=(6*4, 12))
final_conv_map = debug_net._features["conv_blocks.2.batch_norm2"].detach().cpu().numpy()[0].reshape(8,8,63,63)
conv_max = abs(final_conv_map).max()

compressed_results = np.concatenate(np.concatenate(final_conv_map, 1), 1)
imshow_cenered_colorbar(compressed_results, cmap="bwr", title="Pre ReLU final conv layer activations")
# useless

In [None]:
def adversarial_generate(img, lbl, net, alpha, lr, runs):
    # alpha is maximum norm that the adversarial can be
    adversarial_direction = np.random.uniform(-alpha, alpha, size=(1, valid_set.size, valid_set.size))
    adversarial_direction = adversarial_direction/np.linalg.norm(adversarial_direction)*alpha
    
    # pick arbitrary target
    bad_lbl = (lbl.argmax() + 1) % valid_set.num_classes
    target = torch.tensor(np.zeros_like(lbl)).to(device).unsqueeze(0).float()
    target[0,2] = 1.
    print(target, lbl)
    
    tensor_img = torch.tensor(img.transpose(2,0,1)).unsqueeze(0).to(device).float()
    tensor_adv_dir = torch.tensor(adversarial_direction, requires_grad=True).unsqueeze(0).to(device).float()
    
    loss_func = nn.CrossEntropyLoss()
    for i in range(runs):
        curr_img = tensor_img + tensor_adv_dir
        curr_net_out = net(curr_img)
        curr_loss = loss_func(curr_net_out, target)# + 5e-10*torch.linalg.norm(tensor_adv_dir)
        grad_dir = torch.autograd.grad(curr_loss, tensor_adv_dir)[0]
        tensor_adv_dir -= lr*grad_dir
        tensor_adv_dir = torch.clamp(tensor_adv_dir, min=-alpha, max=alpha)

        if i % (runs//5) == (runs//5-1):
            print(curr_loss.item(), curr_net_out)
    return tensor_img + tensor_adv_dir
np.random.seed(58)
generated_img, gen_label, color, *_ = valid_set.generate_one()
print(color)
adv_example = adversarial_generate(generated_img, gen_label, unstrided_net, 6.4, 1e2, 500)

In [None]:
# highly variable results in terms of alpha (max pixel diff)
# seems to work "best" (lowest alpha needed) when going to class 2??? (shouldn't it be class 1
# since that one in the most polysemantic (at least in one of its regimes))

# 230 color -> alpha very close to 5, works on 1e2, 500 (seed 55)
# 136 color -> alpha of 3 (seed 54)
# 50 color -> alpha of 9-10 (seed 53)
# 181 color (in class 2 already) -> adv_dir is basically random noise, weird cyclic structure to it (seed 52)
# 82 color -> alpha of 3.8 (high) (seed 51)
# 201 color -> alpha of 0-1 (high) (seed 50)
# 110 color -> alpha of 2-3 (low) (seed 56)
# 232 color -> alpha close to 5 (seed 57)
# 60 color -> alpha 6-7 (mid) (seed 58)

In [None]:
plt.subplot(1,3,1)
plt.imshow(generated_img, cmap="gray")
plt.subplot(1,3,2)
np_adv = adv_example.detach().cpu().numpy().squeeze()
plt.imshow(np_adv, cmap="gray")
plt.subplot(1,3,3)
imshow_centered_colorbar(generated_img.squeeze()-np_adv, cmap="bwr", title="Adv direction")

In [None]:
plt_grid_figure([explain_imgs, pca_map_s_strides, pca_map_1_strides], transpose=True, titles=["Image", "Strides=scale", "strides=1"])
# add comparison to regular gradient
# smaller circles = bad?
# manually test if the gradient changes make sense
# model lerans weird stuff about the noise
# interpertation of its algo is interesting

# make texture dataset and test the methods on it
# texture generation: emerging conv?
# heuristic = lots of code
# dataset
# test on natual images eventually

# show it works when edeges important too (guided backprop first)
# could do saliency checks

# fourier transform could work for texture, if its the whole image (to get window size)
# do fft on quadrants of image to guess at scale, look at max fourier coeff =>
# should give rough idea of window size (top k coeffs?)

In [None]:
plt_grid_figure([explain_imgs, pca_map_s_strides, pca_map_1_strides, grad_maps], transpose=True, titles=["Image", "Strides=scale", "strides=1", "Gradient"])
# on the strides=8 final layer network

In [None]:
plt_grid_figure([explain_imgs, pca_map_s_strides, pca_map_1_strides, grad_maps], transpose=True, titles=["Image", "Strides=scale", "strides=1", "Gradient"])
# on the strides=2 final layer network

In [None]:
plt.imshow(valid_set.generate_one()[0], cmap="gray")
plt.xticks([])
plt.yticks([])
remove_borders(plt.gca())

In [None]:
# surely this is the best way to do this :)
x = np.arange(400).reshape(5,5,4,4)
from itertools import permutations
for transp_1 in permutations([0,1,2,3]):
    for axis1 in range(3):
        for transp_2 in permutations([0,1,2]):
            for axis2 in range(2):
                for transp_3 in permutations([0,1]):
                    try:
                        t1 = x.transpose(*transp_1)
                        t2 = np.concatenate(t1, axis1)
                        t3 = t2.transpose(*transp_2)
                        t4 = np.concatenate(t3, axis2)
                        t5 = t4.transpose(*transp_3)
                        if t5.shape == (20,20):
                            if all(t5[0,:4] == np.arange(4)) and t5[0,4] == 16:
                                pass
                                #print(transp_1, axis1, transp_2, axis2, transp_3)
                    except:
                        continue
np.concatenate(np.concatenate(x,1),1)

So the question then becomes, how do we search for useful reference images/pixel values in general? We want the distance to be close to the image (small denominator), but also lead to large differences in output logits. This is dangerously close to finding adversarial directions, so we need to make sure we stay in the data manifold => need to establish some sort of distance metric, and potentially a way of detecting whether we are in manifold or not, so we can project into manifold space if needed. This also allows us to switch to a DeepLIFT style type thing.

# Model Optimization Stuff

In [None]:
res_net.save_model_state_dict(optim=optim)

In [None]:
generated_img, _, __ = valid_set.generate_one()
generated_img = torch.tensor(generated_img.transpose(2,0,1)).to(device).unsqueeze(0).float()
for _ in range(1000):
    small_net.forward(generated_img, profile=True)

In [None]:
total = sum(stats.values())  # --> gave 3x speed! (Fast and Accurate Model scaling?)
for k,v in stats.items():    # --> the 3x speedup caused underfitting though, so switched to 2x
    print(k,(100.*v/total))
