In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import time
import os
import matplotlib.pyplot as plt
from collections import defaultdict

import color_regions, network, visualizations, utils
from color_regions import *
from network import *
from visualizations import *
from utils import *

torch.backends.cudnn.benchmark = True
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
# set up autoreloading of shared code
%load_ext autoreload
%autoreload 1
%aimport color_regions,network,visualizations,utils
%aimport

In [None]:
import sys
prev_time = 0
gamma = 0.99
stats = {}  # tracks ewma running average
def benchmark(point=None, profile=True, verbose=True, cuda=True): # not thread safe at all
    global prev_time
    if not profile:
        return
    if cuda:
        torch.cuda.synchronize()
    time_now = time.perf_counter()
    if point is not None:
        point = f"{sys._getframe().f_back.f_code.co_name}-{point}"
        time_taken = time_now - prev_time
        if point not in stats:
            stats[point] = time_taken
        stats[point] = stats[point]*gamma + time_taken*(1-gamma)
        if verbose:
            print(f"took {time_taken} to reach {point}, ewma={stats[point]}")
    prev_time = time_now

In [None]:
main_data_set = TextureDatasetGenerator("./data/dtd")  # -> do this so we only load once

In [None]:
transform = transforms.Compose(
     [transforms.ToTensor(),
      transforms.RandomRotation(90)])

batch_size = 128  # seems to be the fastest batch size
train_indices = (0, 500_000) # size of training set
valid_indices = (1_250_000, 1_260_000)
test_indices = (2_260_000, 2_270_000)

def set_loader_helper(indices):
    data_set = TextureDatasetGenerator(main_data_set,
                                       transform=transform,
                                       noise_size=(5,15),
                                       size=128,
                                       radius_frac=(1/3, 1/2.1),
                                       image_indices=indices)
    loader = torch.utils.data.DataLoader(data_set, batch_size=batch_size, 
                                          shuffle=True, num_workers=4, pin_memory=True)
    return data_set, loader
train_set, train_loader = set_loader_helper(train_indices)
valid_set, valid_loader = set_loader_helper(valid_indices)
test_set, test_loader = set_loader_helper(test_indices)

In [None]:
res_net = ResNet([[32, 7, 2],  # num_channels (input and output), kernel_size, stride
                  [64, 3, 1],
                  [64, 3, 1],
                  [128, 3, 2],
                  [128, 3, 1],
                  [128, 3, 1],
                  [256, 3, 2],
                  [256, 3, 1],
                  [512, 3, 2],
                  [512, 3, 1]], valid_set.num_classes, [128, 128, 3], 
                   "texture_net.dict", fc_layers=[160]).to(device)

loss_func = nn.CrossEntropyLoss()
optim = torch.optim.Adam(res_net.parameters())
print(res_net.num_params())
res_net.load_model_state_dict(optim=optim)

In [None]:
results = train(res_net, optim, loss_func, 200)

In [None]:
np.random.seed(500_001)
explain_img, explain_target_logit, *__ = valid_set.generate_one()
heat_map = finite_differences_map(res_net, valid_set, explain_target_logit.argmax(), explain_img, device=device)

In [None]:
abs(heat_map).max()

In [None]:
plt.subplot(1, 2, 1)
plt.imshow(explain_img/255.)
plt.subplot(1, 2, 2)
imshow_centered_colorbar(heat_map, cmap="bwr", title="FD Map")
print(explain_target_logit)
# do 3 maps for each color channel, and also maybe do absolute values (might cancel each other)
# MAYBE DO CUBES OF COLOR CLASS, AND ALSO VISULAZIE WITH POINT CLOUD THE NETWORK LOGIT (CAN ONLY DO 1 CLASS, OR MAYBE DO MULTIPLE MAPS)
# probably better for getting step-function style stuff

In [None]:
default_scales = [3,5,7,9,13,15]
pca_directions_1_stride = find_pca_directions(valid_set, 1024, default_scales, 1, sample_size=2048)
pca_directions_s_stride = find_pca_directions(valid_set, 1024, default_scales, default_scales, sample_size=2048)

In [None]:
plt.figure(figsize=(6*4, 12))
for i, res in enumerate(pca_directions_s_stride):
    compressed_results = np.concatenate(np.concatenate(res, 1), 1)
    plt.subplot(1,len(pca_directions_s_stride),i+1)
    if i == 0:
        plt.title("Strided windows")
    plt.imshow(compressed_results, cmap="gray")

In [None]:
plt.figure(figsize=(6*4, 12))
for i, res in enumerate(pca_directions_1_stride):
    compressed_results = np.concatenate(np.concatenate(res, 1), 1)
    plt.subplot(1,len(pca_directions_1_stride),i+1)
    if i == 0:
        plt.title("Stride=1")
    plt.imshow(compressed_results, cmap="gray")

What if we run the same experiment, but cheat with a prior on pixel values that we know *should* be informative to the output logit, namely values closest to the decision boundary?

# Model Optimization Stuff

In [None]:
res_net.save_model_state_dict(optim=optim)

In [None]:
for _ in range(1000):
    res_net.forward(generated_img, profile=True)

In [None]:
total = sum(stats.values())  # --> gave 3x speed! (Fast and Accurate Model scaling?)
for k,v in stats.items():    # --> the 3x speedup caused underfitting though, so switched to 2x
    print(k,(100.*v/total))
