In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import time
import os
import matplotlib.pyplot as plt
from collections import defaultdict

import color_regions, network, visualizations, utils
from color_regions import *
from network import *
from visualizations import *
from utils import *
from hooks import *

torch.backends.cudnn.benchmark = True
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
# set up autoreloading of shared code
%load_ext autoreload
%autoreload 1
%aimport color_regions,network,visualizations,utils,hooks
%aimport

In [None]:
import sys
prev_time = 0
gamma = 0.99
stats = {}  # tracks ewma running average
def benchmark(point=None, profile=True, verbose=True, cuda=True): # not thread safe at all
    global prev_time
    if not profile:
        return
    if cuda:
        torch.cuda.synchronize()
    time_now = time.perf_counter()
    if point is not None:
        point = f"{sys._getframe().f_back.f_code.co_name}-{point}"
        time_taken = time_now - prev_time
        if point not in stats:
            stats[point] = time_taken
        stats[point] = stats[point]*gamma + time_taken*(1-gamma)
        if verbose:
            print(f"took {time_taken} to reach {point}, ewma={stats[point]}")
    prev_time = time_now

In [None]:
main_data_set = TextureDatasetGenerator("./data/dtd")  # -> do this so we only load once

In [None]:
transform = transforms.Compose(
     [transforms.ToTensor(),
      transforms.RandomRotation(90)])

batch_size = 128  # seems to be the fastest batch size
train_indices = (0, 500_000) # size of training set
valid_indices = (1_250_000, 1_260_000)
test_indices = (2_260_000, 2_270_000)

def set_loader_helper(indices):
    data_set = TextureDatasetGenerator(main_data_set,
                                       transform=transform,
                                       noise_size=(5,15),
                                       size=128,
                                       radius_frac=(1/3, 1/2.1),
                                       image_indices=indices)
    loader = torch.utils.data.DataLoader(data_set, batch_size=batch_size, 
                                          shuffle=True, num_workers=4, pin_memory=True)
    return data_set, loader
train_set, train_loader = set_loader_helper(train_indices)
valid_set, valid_loader = set_loader_helper(valid_indices)
test_set, test_loader = set_loader_helper(test_indices)

In [None]:
res_net = ResNet([[32, 7, 2],  # num_channels (input and output), kernel_size, stride
                  [64, 3, 1],
                  [64, 3, 1],
                  [128, 3, 2],
                  [128, 3, 1],
                  [128, 3, 1],
                  [256, 3, 2],
                  [256, 3, 1],
                  [512, 3, 2],
                  [512, 3, 1]], valid_set.num_classes, [128, 128, 3], 
                   "texture_net.dict", fc_layers=[160]).to(device)

loss_func = nn.CrossEntropyLoss()
optim = torch.optim.Adam(res_net.parameters())
print(res_net.num_params())
res_net.load_model_state_dict(optim=optim)

In [None]:
results = train(res_net, optim, loss_func, 200)

In [None]:
np.random.seed(500_001)
explain_img, explain_target_logit, *__ = valid_set.generate_one()
heat_map = finite_differences_map(res_net, valid_set, explain_target_logit.argmax(), explain_img, device=device, batch_size=127)

In [None]:
plt.subplot(1, 2, 1)
plt.imshow(explain_img/255.)
plt.subplot(1, 2, 2)
imshow_centered_colorbar(heat_map.sum(axis=2), cmap="bwr", title="FD Map")
print(valid_set.idx_to_texname[explain_target_logit.argmax()])

In [None]:
default_scales = [3,5,7,9,13,15]
if 1: 
    %store -r pca_directions_1_stride pca_directions_s_stride
else:
    pca_directions_1_stride = find_pca_directions(valid_set, 16384, default_scales, 1)
    # s_stride not used for pca_map calculations, just for visualizing better what the 
    # PCA directions end up looking like (technically they are slightly different from
    # just accessing 1_stride in a strided manner, since they were computed with different
    # samples (though they are very close due to large sample size)
    pca_directions_s_stride = find_pca_directions(valid_set, 16384, default_scales, default_scales)
    %store pca_directions_1_stride pca_directions_s_stride

In [None]:
visualize_pca_directions(pca_directions_1_stride, "Strides=1", default_scales, lines=False)

In [None]:
visualize_pca_directions(pca_directions_s_stride, "Strides=scales", default_scales, lines=True)

In [None]:
np.random.seed(500_001)
explain_img, explain_target_logit, *__ = valid_set.generate_one()
result = pca_direction_grids(res_net, valid_set, explain_target_logit.argmax(), explain_img, default_scales, 
                    pca_directions_s_stride, device=device, batch_size=128)

In [None]:
# Pca direction of above cell visualization
plt.subplot(1,4,1)
plt.imshow(explain_img.squeeze())
for c in range(3):
    plt.subplot(1,4,c+2)
    imshow_centered_colorbar(result[...,c], cmap="bwr", title=f"Channel {c}")

In [None]:
seeds = [1_2123, 1_40_124, 1_508_559, 1_5_019_258, 1_2_429_852, 9032, 5832, 12, 5014, 92, 42, 52, 
         52_934, 935_152, 1_000_000, 1_000_001, 27, 24, 512, 999_105]  # 20 
# def generate_many_pca(net, seeds, pca_directions_1_stride, scales, dataset, 
#         component=0, batch_size=128, strides=None, skip_1_stride=False, device=None):

In [None]:
pca_map_s_strides, pca_map_1_strides, grad_maps, explain_imgs = generate_many_pca(res_net, component=0, strided_scales=3)

In [None]:
guided_net = GuidedBackprop(res_net)
guided_pca_map_s_strides, guided_pca_map_1_strides, guided_grad_maps, explain_imgs = generate_many_pca(guided_net, component=0, strided_scales=3)

In [None]:
plt_grid_figure([explain_imgs, guided_pca_map_s_strides, guided_pca_map_1_strides, guided_grad_maps, pca_map_s_strides, pca_map_1_strides, grad_maps], 
                transpose=True, 
                titles=["Image", "Guided Strides=3", "Guided strides=1", "Guided Gradient", "Strides=3", "strides=1", "Gradient"], 
                channel_mode="split")

In [None]:
plt_grid_figure([explain_imgs, guided_pca_map_s_strides, guided_pca_map_1_strides, guided_grad_maps, pca_map_s_strides, pca_map_1_strides, grad_maps], 
                transpose=True, 
                titles=["Image", "Guided Strides=3", "Guided strides=1", "Guided Gradient", "Strides=3", "strides=1", "Gradient"], 
                channel_mode="collapse")

In [None]:
[abs(guided_pca_map_s_strides[i] - guided_pca_map_1_strides[i]).max() for i in range(len(guided_pca_map_1_strides))]
# no major changes really? Can eliminate channels, can do strides=3 for 9x speedup
# little cost to quality (and speeds up pca calculations)
# for component = 0

# add to overleaf (uoft email one)

In [None]:
plt_grid_figure([explain_imgs, guided_pca_map_s_strides, guided_pca_map_1_strides, guided_grad_maps, pca_map_s_strides, pca_map_1_strides, grad_maps], 
                transpose=True, 
                titles=["Image", "Guided Strides=3", "Guided strides=1", "Guided Gradient", "Strides=3", "strides=1", "Gradient"], 
                channel_mode=0)

In [None]:
# now lets do the same but for component=1
pca_map_s_strides, pca_map_1_strides, grad_maps, explain_imgs = generate_many_pca(res_net, component=1, strided_scales=3)

In [None]:
# now lets do the same but for component=1
guided_net = GuidedBackprop(res_net)
guided_pca_map_s_strides, guided_pca_map_1_strides, guided_grad_maps, explain_imgs = generate_many_pca(guided_net, component=1, strided_scales=3)

In [None]:
plt_grid_figure([explain_imgs, guided_pca_map_s_strides, guided_pca_map_1_strides, guided_grad_maps, pca_map_s_strides, pca_map_1_strides, grad_maps], 
                transpose=True, 
                titles=["Image", "Guided Strides=3", "Guided strides=1", "Guided Gradient", "Strides=3", "strides=1", "Gradient"], 
                channel_mode="split")

In [None]:
plt_grid_figure([explain_imgs, guided_pca_map_s_strides, guided_pca_map_1_strides, guided_grad_maps, pca_map_s_strides, pca_map_1_strides, grad_maps], 
                transpose=True, 
                titles=["Image", "Guided Strides=3", "Guided strides=1", "Guided Gradient", "Strides=3", "strides=1", "Gradient"], 
                channel_mode="collapse")

In [None]:
plt_grid_figure([explain_imgs, guided_pca_map_s_strides, guided_pca_map_1_strides, guided_grad_maps, pca_map_s_strides, pca_map_1_strides, grad_maps], 
                transpose=True, 
                titles=["Image", "Guided Strides=3", "Guided strides=1", "Guided Gradient", "Strides=3", "strides=1", "Gradient"], 
                channel_mode=0)

In [None]:
[abs(guided_pca_map_s_strides[i] - guided_pca_map_1_strides[i]).max() for i in range(len(guided_pca_map_1_strides))]
# no major changes really? Can eliminate channels, can do strides=3 for 9x speedup
# little cost to quality (and speeds up pca calculations)
# for component = 1 (largest diff is in same location, idx 6)

# PCA Direction convergence experiments

In [None]:
default_scales = [3,5,7,9,13,15]
small_pca_directions_1_stride = find_pca_directions(valid_set, 512, default_scales, 1)
small_pca_directions_s_stride = find_pca_directions(valid_set, 512, default_scales, default_scales)

In [None]:
np.random.seed(510)
test_directions = find_pca_directions(valid_set, 8192*4, default_scales, default_scales, component=0)

In [None]:
visualize_pca_directions(small_pca_directions_1_stride, "Strides=1", default_scales)
# component 0
# small sample (512)
# to get not all 1s: generate images with PCA, see if recoverable
# should be fourier basis (test on natural images?)

# do sanity checks next

In [None]:
visualize_pca_directions(pca_directions_1_stride, "Strides=1", default_scales)
# component 0
# large sample (2048)

In [None]:
visualize_pca_directions(small_pca_directions_s_stride, "Strides=scales", default_scales)
# component 0
# small sample (512)

In [None]:
visualize_pca_directions(pca_directions_s_stride, "Strides=scales", default_scales)
# component 0
# large sample (2048)

In [None]:
visualize_pca_directions(test_directions, "Strides=scales", default_scales) 
# component 0
# seed 510 gargantuan (32768) sample

In [None]:
visualize_pca_directions(test_directions, "Strides=scales", default_scales)
# component 1
# seed 508, small sample (512)

# unlikely to be diff from guided backprop since its already basically edge detector (comp 1)
# advantage of PCA method is that it can take into accont more than just the pixel
# unit of attribution isnt just a pixel

In [None]:
visualize_pca_directions(test_directions, "Strides=scales", default_scales) 
# component 1
# seed 507, small sample (512)

In [None]:
visualize_pca_directions(test_directions, "Strides=scales", default_scales) 
# component 1
# seed 507, large (2048) sample

In [None]:
visualize_pca_directions(test_directions, "Strides=scales", default_scales) 
# component 1
# seed 508, large (2048) sample

In [None]:
visualize_pca_directions(test_directions, "Strides=scales", default_scales)
# component 1
# seed 507, huge (8192) sample

In [None]:
visualize_pca_directions(test_directions, "Strides=scales", default_scales)
# component 1
# seed 507 gargantuan (32768) sample

In [None]:
visualize_pca_directions(test_directions, "Strides=scales", default_scales)
# component 1
# seed 510 gargantuan (32768) sample

What if we run the same experiment, but cheat with a prior on pixel values that we know *should* be informative to the output logit, namely values closest to the decision boundary?

# Model Optimization Stuff

In [None]:
res_net.save_model_state_dict(optim=optim)

In [None]:
for _ in range(1000):
    res_net.forward(generated_img, profile=True)

In [None]:
total = sum(stats.values())  # --> gave 3x speed! (Fast and Accurate Model scaling?)
for k,v in stats.items():    # --> the 3x speedup caused underfitting though, so switched to 2x
    print(k,(100.*v/total))
