In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import time
import os
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from collections import defaultdict

import color_regions, network, visualizations, utils
from color_regions import *
from network import *
from visualizations import *
from utils import *
from hooks import *

torch.backends.cudnn.benchmark = True
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
# set up autoreloading of shared code
%load_ext autoreload
%autoreload 1
%aimport color_regions,network,visualizations,utils,hooks
%aimport

In [None]:
transform = transforms.Compose(
     [transforms.ToTensor()])#,
    #transforms.Normalize((0.5), (0.5))])

batch_size = 128 # seems to be the fastest batch size
train_indices = (0, 250_000) # size of training set
valid_indices = (1_250_000, 1_270_000)
test_indices = (2_260_000, 2_560_000)

def color_classifier(color):  
    if color <= 30:  # => 3 classes
        return 0
    if 30 < color <= 60:  # => 90/255 is 0, 90/255 is 1, 75/255 is 2
        return 1
    if 60 < color <= 90:
        return 2
    if 90 < color <= 120:
        return 1
    if 120 < color <= 150:
        return 0
    if 150 < color <= 180:
        return 1
    if 180 < color <= 210:
        return 2
    if 210 < color <= 240:
        return 0
    if 240 < color:
        return 2
critical_color_values = list(range(0,241,30))

def set_loader_helper(indices, infinite=False):
    data_set = ColorDatasetGenerator(color_classifier=color_classifier,
                                    image_indices=indices,
                                    transform=transform,
                                    color_range=(5, 255),
                                    noise_size=(1,9),
                                    num_classes=3,
                                    infinite=infinite,
                                    size=128,
                                    num_objects=0,
                                    radius=(128//8, 128//7))
    loader = torch.utils.data.DataLoader(data_set, batch_size=batch_size, 
                                          shuffle=True, num_workers=6, pin_memory=True)
    return data_set, loader
train_set, train_loader = set_loader_helper(train_indices, infinite=False)
valid_set, valid_loader = set_loader_helper(valid_indices)
test_set, test_loader = set_loader_helper(test_indices)

In [None]:
# the "hard" task
plt.figure(figsize=(6,6))
color_probe = np.linspace(0, 255, 255)
color_class = [color_classifier(x) for x in color_probe]
plt.subplot(2,1,1)
plt.plot(color_probe, color_class)
plt.xticks(critical_color_values)
plt.yticks([0, 1, 2])
plt.ylabel("Class")
def medium_color_classifier(color):
    if color <= 100:  
        return 0
    if 100 < color <= 150:
        return 1
    if 150 < color <= 200: 
        return 2
    if 200 < color:
        return 1
med_color_class = [medium_color_classifier(x) for x in color_probe]
plt.subplot(2,1,2)
plt.plot(color_probe, med_color_class)
plt.xlabel("Color")
plt.xticks([100, 150, 200])
plt.yticks([0, 1, 2])
plt.ylabel("Class")

In [None]:
# change dataset to uniformly change background (noise)
# evidence that we care about color classification
# color vs shape 
# GuidedBackprop shows that conv nets rely mostly on edges
# What happens when the network cannot rely only on the edges
# practical examples of beach orientation detection, shadow xray?

In [None]:
num_x = 4
num_y = 4
plt.figure(figsize=(3*num_x, 3*num_y))
# back_probs = [0.25]
valid_set.back_p = 0.25
for i in range(num_x*num_y):

    #valid_set.back_p = back_probs[i % 3]
    while not (80 < (img_gen := valid_set.generate_one())[2] < 150): # only do ones with target color >= 40
        pass
    plt.subplot(num_y, num_x, i+1)
#     if i // num_x == 0:
#         plt.title(f"p={valid_set.back_p}")
    imshow_centered_colorbar(img_gen[0], cmap="gray", colorbar=False)
#     plt.subplot(num_x, num_y, i*2+2)
#     plot_color_classes(valid_set, (0, 128), alpha=1.0)
#     plt.vlines([clr], 0, 128)

In [None]:
idxs = np.mgrid[:128, :128].reshape(-1, 2)
np.random.shuffle(idxs)
idxs = idxs.reshape(128, 128, 2)
print(idxs.shape)
plt.imshow(img_gen[0][idxs[...,0], idxs[...,1]], cmap="gray")

In [None]:
noise_net = ResNet([[2, 3, 4],  # num_channels (input and output), kernel_size, stride
                   [6, 3, 4]], 3, [128, 128, 1], 
                   "noise_net_hard_tiny.dict", fc_layers=[]).to(device)
loss_func = nn.CrossEntropyLoss()
noise_optim = torch.optim.Adam(noise_net.parameters())
print(noise_net.num_params())
noise_net.load_model_state_dict(optim=noise_optim)

In [None]:
permuted_net = ResNet([[2, 3, 4],  # num_channels (input and output), kernel_size, stride
                   [6, 3, 4]], 3, [128, 128, 1], 
                   "permuted_hard_tiny.dict", fc_layers=[]).to(device)
loss_func = nn.CrossEntropyLoss()
permuted_optim = torch.optim.Adam(permuted_net.parameters())
print(permuted_net.num_params())
permuted_net.load_model_state_dict(optim=permuted_optim)

In [None]:
permuted_large_net = ResNet([[16, 3, 1],  # num_channels (input and output), kernel_size, stride
                      [32, 3, 1]], 3, [128, 128, 1], 
                   "permuted_hard_large.dict", fc_layers=[]).to(device)
loss_func = nn.CrossEntropyLoss()  # dont start from same initialization
permuted_large_optim = torch.optim.Adam(permuted_large_net.parameters())
print(permuted_large_net.num_params())
permuted_large_net.load_model_state_dict(optim=permuted_large_optim)
#set_initializers(permuted_large_net, 0.1)

In [None]:
permuted_large_net2 = ResNet([[16, 3, 1],  # num_channels (input and output), kernel_size, stride
                      [32, 3, 1]], 3, [128, 128, 1], 
                   "permuted_hard_large2.dict", fc_layers=[]).to(device)
loss_func = nn.CrossEntropyLoss()  # dont start from same initialization
permuted_large_optim2 = torch.optim.Adam(permuted_large_net2.parameters())
permuted_large_net2.load_model_state_dict(optim=permuted_large_optim2)
#set_initializers(permuted_large_net2, 0.05)

In [None]:
permuted_large_net3 = ResNet([[16, 3, 1],  # num_channels (input and output), kernel_size, stride
                      [32, 3, 1]], 3, [128, 128, 1], 
                   "permuted_hard_large3.dict", fc_layers=[]).to(device)
loss_func = nn.CrossEntropyLoss()  # dont start from same initialization
permuted_large_optim3 = torch.optim.Adam(permuted_large_net3.parameters())
permuted_large_net3.load_model_state_dict(optim=permuted_large_optim3)
#set_initializers(permuted_large_net2, 0.2)

In [None]:
low_noise_net = ResNet([[2, 3, 4],  # num_channels (input and output), kernel_size, stride
                   [6, 3, 4]], 3, [128, 128, 1], 
                   "noise_net_hard_tiny_low_p.dict", fc_layers=[]).to(device)
loss_func = nn.CrossEntropyLoss()
low_noise_optim = torch.optim.Adam(low_noise_net.parameters())
print(low_noise_net.num_params())
low_noise_net.load_model_state_dict(optim=low_noise_optim)

In [None]:
tiny_noise_net = ResNet([[2, 3, 4],  # num_channels (input and output), kernel_size, stride
                   [6, 3, 4]], 3, [128, 128, 1], # p = 0.25
                   "noise_net_hard_tiny_tiny_p.dict", fc_layers=[]).to(device)
loss_func = nn.CrossEntropyLoss()
tiny_noise_optim = torch.optim.Adam(tiny_noise_net.parameters())
print(tiny_noise_net.num_params())
tiny_noise_net.load_model_state_dict(optim=tiny_noise_optim)

In [None]:
evaluate(tiny_noise_net, loss_func, valid_loader, device=device)
# with squares intact

In [None]:
evaluate(permuted_net, loss_func, valid_loader, device=device) # sample size 20k
# small net, small sample size

In [None]:
evaluate(permuted_net, loss_func, test_loader, device=device) # larger sample size (300k)
# small net, large sample size

In [None]:
evaluate(permuted_large_net, loss_func, test_loader, device=device) 
# finite dataset, gain 0.1

In [None]:
evaluate(permuted_large_net2, loss_func, test_loader, device=device) 
# infinite dataset, gain 0.05

In [None]:
evaluate(permuted_large_net3, loss_func, test_loader, device=device) 
# finite dataset, gain 0.01 (weights are saved as permuted_large4)

In [None]:
evaluate(permuted_large_net3, loss_func, test_loader, device=device) 
# finite dataset, gain 0.2

In [None]:
train(permuted_large_net2, permuted_large_optim2, loss_func, 1000, train_loader, valid_loader, device=device)
# infinite data, gain 0.05

In [None]:
results = train(permuted_large_net, permuted_large_optim, loss_func, 1000, train_loader, valid_loader, device=device)
# finite data test, gain 0.1

In [None]:
train(permuted_large_net3, permuted_large_optim3, loss_func, 1000, train_loader, valid_loader, device=device)
# finite dataset, gain 0.2

In [None]:
train(permuted_large_net3, permuted_large_optim3, loss_func, 1000, train_loader, valid_loader, device=device)
# finite dataset, gain 0.01

In [None]:
results = train(permuted_large_net, permuted_large_optim, loss_func, 1000, train_loader, valid_loader, device=device)
# infinite data test (no initializiation changes)

In [None]:
results = train(permuted_net, permuted_optim, loss_func, 1000, train_loader, valid_loader, device=device)
# small net

In [None]:
results = train(noise_net, noise_optim, loss_func, 1000, train_loader, valid_loader, device=device)
# trained with squares still visible

In [None]:
results = train(low_noise_net, low_noise_optim, loss_func, 1000, train_loader, valid_loader, device=device)
# trained with squares still visible

In [None]:
results = train(tiny_noise_net, tiny_noise_optim, loss_func, 1000, train_loader, valid_loader, device=device)
# trained with squares still visible

In [None]:
permuted_net.eval()
avg_img = np.ones((valid_set.size, valid_set.size))
tensor_avg_img = tensorize(avg_img, device=device)
responses = []
for color in np.arange(255):
    tensor_avg_img[...] = color
    responses.append(permuted_net(tensor_avg_img).detach().cpu().numpy())
responses = np.asarray(responses).squeeze()

In [None]:
for i in range(3):
    plt.plot(np.arange(255), responses[:,i], label=f"logit {i}")
plt.legend()
plot_color_classes(valid_set, (responses.min(), responses.max()))

In [None]:
def averaging_test(dataset, sample, edge_width=10):
    avg_area = np.pi/3*(dataset.radius[1]**2+dataset.radius[0]**2+dataset.radius[0]*dataset.radius[1])
    pct_area = avg_area / (dataset.size**2)
    print(f"Targets are on average {pct_area:.1%} of the image")
    other_points = []
    
    total_answered = 0
    right_calibrated = 0
    right_naive = 0
    right_color_set = 0
    right_base = 0
    right_edge_set = 0
    right_background_set = 0
    
    avg_img = np.ones((dataset.size, dataset.size))
    tensor_avg_img = tensorize(avg_img, device=device)
    for _ in tqdm(range(sample)):
        img_gen, lbl, color, *_ = dataset.generate_one()
        color = color[0]
        foreground_mask = np.where(img_gen>2)
        other_space = img_gen[(img_gen > 2) & (img_gen != color)].sum() / foreground_mask[0].shape[0]
        
        prediction = (img_gen[foreground_mask].mean() - 36.9)/(1-36.9/128)
        if np.isnan(prediction) or np.isnan(other_space):
            continue
        tensor_avg_img[...] = color  # color setting
        color_set_classif = permuted_net(tensor_avg_img).argmax()
        
        tensor_avg_img[...] = img_gen.mean()  # naive averaging
        naive_classif = permuted_net(tensor_avg_img).argmax()
        
        tensor_avg_img[...] = prediction  # calibrated averaging
        calibrated_classif = permuted_net(tensor_avg_img).argmax()
        
        tensor_img_gen = tensorize(img_gen, device=device)
        base_classif = permuted_net(tensor_img_gen).argmax() # regular classification
        
        tensor_img_gen[tensor_img_gen == 0] = (color + 30) % 255 # set background to a different class
        background_set_classif = permuted_net(tensor_img_gen).argmax()
        
        # edge set test (set to color since thats the best results)
        tensor_avg_img[...] = 0
        tensor_avg_img[0,0, 0:edge_width] = color
        tensor_avg_img[0,0, -edge_width:] = color
        tensor_avg_img[0,0,:, 0:edge_width] = color
        tensor_avg_img[0,0,:, -edge_width:] = color
        edge_set_classif = permuted_net(tensor_avg_img).argmax()
        
        total_answered += 1
        right_base += lbl.argmax() == base_classif
        right_background_set += lbl.argmax() == background_set_classif
        right_edge_set += lbl.argmax() == edge_set_classif
        right_calibrated += lbl.argmax() == calibrated_classif
        right_naive += lbl.argmax() == naive_classif
        right_color_set += lbl.argmax() == color_set_classif
    print(f"Calibrated got {right_calibrated/total_answered:.2%} correct")
    print(f"Naive got {right_naive/total_answered:.2%} correct")
    print(f"Color setting got {right_color_set/total_answered:.2%} correct")
    print(f"Edge setting got {right_edge_set/total_answered:.2%} correct")
    print(f"Background setting got {right_background_set/total_answered:.2%} correct")
    print(f"Base got {right_base/total_answered:.2%} correct")
    
result = averaging_test(valid_set, 100_000)
# PCA map to see edge behaviour (average a bunch of them?)
# color set edge test
# background only set test? (do it maliciously) (see how badly it hurts performance)

In [None]:
def error_by_color(dataset, sample=100_000):
    points = []
    avg_area = np.pi/3*(dataset.radius[1]**2+dataset.radius[0]**2+dataset.radius[0]*dataset.radius[1])
    pct_area = avg_area / (dataset.size**2)
    print(f"Targets are on average {pct_area:.1%} of the image")
    other_points = []
    total_answered = 0
    right_calibrated = 0
    right_naive = 0
    right_really_naive = 0
    for _ in tqdm(range(sample)):
        img_gen, lbl, color, *_ = dataset.generate_one()
        #prediction = np.minimum(img_gen/pct_area, 255)
        foreground_mask = np.where(img_gen>2)
        other_space = img_gen[(img_gen > 2) & (img_gen != color)].sum() / foreground_mask[0].shape[0]
        #print(len(foreground_mask[0]), img_gen[(img_gen > 2) & (img_gen != color)].size)
        # model: avg = color*(1-pct) + 128*pct
        # calculate pct by figuring out the average sum of non-target non-background pixels
        # divided by the size of the non-background area => gives you 128*pct
        
        prediction = (img_gen[foreground_mask].mean() - 36.9)/(1-36.9/128)
        if np.isnan(prediction) or np.isnan(other_space):
            continue
        total_answered += 1
        right_calibrated += lbl.argmax() == color_classifier(prediction)
        right_naive += lbl.argmax() == color_classifier(img_gen[foreground_mask].mean())
        right_really_naive += lbl.argmax() == color_classifier(img_gen.mean()/pct_area)
        points.append((color, prediction))
        other_points.append((color, other_space))
    print(f"Calibrated got {right_calibrated/total_answered:.2%} correct")
    print(f"Naive got {right_naive/total_answered:.2%} correct")
    print(f"Really naive got {right_really_naive/total_answered:.2%} correct")

    return np.asarray(points), np.asarray(other_points)
result = error_by_color(valid_set, sample=100_000)

In [None]:
plt.scatter(result[1][:,0], result[1][:,1], s=0.05)
plt.plot(np.arange(255), c="r")
result[1][:,1].mean()

In [None]:
plt.scatter(result[0][:,0], result[0][:,1], s=0.05)
plt.plot(np.arange(255), c="r")

In [None]:
interp_net = AllActivations(permuted_net)

In [None]:
np.random.seed(5_123_456)
test_img, lbl, color, size, *_  = valid_set.generate_one()
print(color)
plt.imshow(test_img, cmap="gray")
tensor_test_img = tensorize(test_img, device=device)
# with p=0.8 (ignore this, switched away from this approach)

In [None]:
np.random.seed(5_123_456)
test_img, lbl, color, size, pos  = valid_set.generate_one()
print(color)
plt.imshow(test_img, cmap="gray")
# with p = 0.4

In [None]:
np.random.seed(5_123_456)
test_img, lbl, color, size, *pos  = valid_set.generate_one()
print(color)
plt.imshow(test_img, cmap="gray")
tensor_test_img = tensorize(test_img, device=device)
# with p = 0.25

In [None]:
interp_net.eval()
interp_net(tensor_test_img)

In [None]:
c = 0
block = 1

#uniform_inpt = torch.full((1,16,32,32), 100.0).to(device)
#plt.imshow(tiny_net.conv_blocks[0].conv2.weight[c, in_c].detach().cpu().numpy(), cmap="bwr")
conv_maps = tiny_net.conv_blocks[block].conv2.weight[c, :]
#imshow_centered_colorbar(conv_maps[7].detach().cpu().numpy(), cmap="bwr")
conv_scale = conv_maps.max(axis=-1).values.max(axis=-1).values
conv_shift = tiny_net.conv_blocks[block].conv2.bias[c]
bn_scale = tiny_net.conv_blocks[block].batch_norm2.weight[c]
bn_shift = tiny_net.conv_blocks[block].batch_norm2.bias[c]
bn_var = tiny_net.conv_blocks[block].batch_norm2.running_var[c]
bn_mean = tiny_net.conv_blocks[block].batch_norm2.running_mean[c]
print(conv_shift, bn_scale, bn_shift, bn_var, bn_mean)
#(c*conv_scale + conv_shift - bn_mean) / torch.sqrt(bn_var) * bn_scale + bn_shift
slope = (conv_scale/torch.sqrt(bn_var)*bn_scale).detach().cpu().numpy()
bias = ((conv_shift - bn_mean)/torch.sqrt(bn_var)*bn_scale + bn_shift).detach().cpu().numpy()

lines = np.asarray([profile_plots[f"conv_blocks.{block}.act_func1_{x}"][0] for x in range(6)])

uniform_scaling = slope.dot(lines) + bias


In [None]:
plt.plot(np.maximum(uniform_scaling, 0))

In [None]:
noise_net.eval()
profile_plots,_ = activation_color_profile(AllActivations(noise_net), valid_loader, valid_set, device=device)

In [None]:
low_noise_net.eval()
low_profile_plots,_ = activation_color_profile(AllActivations(low_noise_net), valid_loader, valid_set, device=device)

In [None]:
tiny_noise_net.eval()
tiny_profile_plots,_ = activation_color_profile(AllActivations(tiny_noise_net), valid_loader, valid_set, device=device)

In [None]:
permuted_net.eval()
permuted_plots,_ = activation_color_profile(AllActivations(permuted_net), valid_loader, valid_set, device=device)

In [None]:
%matplotlib inline  
show_profile_plot(low_profile_plots["conv_blocks.1.act_func2_4"])

In [None]:
show_profile_plot(profile_plots["conv_blocks.1.act_func2_4"])

In [None]:
show_conv_weights(interp_net, "conv_blocks.1.act_func1", color_profile=tiny_profile_plots)

In [None]:
show_conv_weights(interp_net, "conv_blocks.1.act_func2", color_profile=tiny_profile_plots)

In [None]:
show_conv_layer(interp_net, "conv_blocks.1.act_func2")  

In [None]:
show_fc_conv(interp_net, color_profile=tiny_profile_plots, fixed_height=True, full_gridspec=True)

In [None]:
noise_net.fully_connected[0].fully_connected.bias

In [None]:
show_fc_conv(interp_net, color_profile=profile_plots, fixed_height=False, full_gridspec=False)

In [None]:
%matplotlib notebook
feature_gram, projected_weights = visualizations.fc_conv_feature_angles(noise_net, 
                            "fully_connected.0.act_func", num_embed=3, normalize=True)

# Permuted Pixels (no squares) networks

In [None]:
plt.imshow(test_img[valid_set.random_permutes[10][...,0][::-1, ::-1], valid_set.random_permutes[10][...,1][::-1, ::-1]], cmap="gray")


In [None]:
plt.imshow(orig_img, cmap="gray")

In [None]:
np.random.seed(5_13_46)
test_img, lbl, color, size, pos, noise, orig_img  = valid_set.generate_one()
print(color)

plt.figure(figsize=(12,16))
plt.subplot(1,2,1)
plt.imshow(test_img, cmap="gray")
plt.subplot(1,2,2)
denoised_img = np.where(test_img == color, color, 0)
plt.imshow(denoised_img, cmap="gray")

tensor_test_img = tensorize(test_img, device=device)
denoised_tensor_img = tensorize(denoised_img, device=device)

interp_net = AllActivations(permuted_net)
interp_net.eval()
interp_net(tensor_test_img)

de_interp_net = AllActivations(permuted_net)
de_interp_net.eval()
de_interp_net(denoised_tensor_img)

In [None]:
show_conv_weights(interp_net, "conv_blocks.0.act_func1", color_profile=permuted_plots)

In [None]:
show_conv_layer(interp_net, "conv_blocks.0.act_func1")

In [None]:
show_conv_weights(interp_net, "conv_blocks.0.act_func2", color_profile=permuted_plots)

In [None]:
show_conv_layer(interp_net, "conv_blocks.0.act_func2")

In [None]:
show_conv_layer(de_interp_net, "conv_blocks.0.act_func2")


In [None]:
#print(interp_net.model.conv_blocks[1].conv1.bias)
show_conv_weights(interp_net, "conv_blocks.1.act_func1", color_profile=permuted_plots)

In [None]:
show_conv_layer(interp_net, "conv_blocks.1.act_func1")

In [None]:
show_conv_layer(de_interp_net, "conv_blocks.1.act_func1")

In [None]:
show_conv_weights(interp_net, "conv_blocks.1.act_func2", color_profile=permuted_plots)

In [None]:
show_conv_layer(interp_net, "conv_blocks.1.act_func2")

In [None]:
show_conv_layer(de_interp_net, "conv_blocks.1.act_func2")

In [None]:
show_fc_conv(interp_net, color_profile=permuted_plots, fixed_height=True, full_gridspec=True)

In [None]:
show_conv_layer(interp_net, "conv_blocks.1.act_func1")

In [None]:
show_conv_layer(interp_net, "conv_blocks.1.act_func2")

In [None]:
show_conv_layer(de_interp_net, "conv_blocks.1.act_func2")

In [None]:
fc_mapper = get_weight(interp_net, "fully_connected.0.fully_connected")

In [None]:
permuted_large_net.final_img_shape

In [None]:
show_conv_weights(interp_net, "conv_blocks.0.act_func2", color_profile=permuted_plots)

In [None]:
show_conv_layer(interp_net, "conv_blocks.0.act_func2")

In [None]:
show_conv_weights(interp_net, "conv_blocks.1.act_func1", color_profile=permuted_plots)

In [None]:
show_conv_layer(interp_net, "conv_blocks.1.act_func1")

In [None]:
show_conv_weights(interp_net, "conv_blocks.1.act_func2", color_profile=permuted_plots)

In [None]:
show_conv_layer(interp_net, "conv_blocks.1.act_func2")
# uniform image of average and pass into network
# pasting images onto each other
# send in images that have only target pixels

In [None]:
show_fc_conv(interp_net, color_profile=permuted_plots, fixed_height=True, full_gridspec=True)

# PCA Direction Analysis

In [None]:
default_scales = [3,5,7,9,13,15]
if 0: 
    %store -r noise_back_pca_directions_1_stride noise_back_pca_directions_s_stride
else:
    noise_back_pca_directions_1_stride = find_pca_directions(valid_set, 4096, default_scales, 1)
    noise_back_pca_directions_s_stride = find_pca_directions(valid_set, 4096, default_scales, default_scales)
    %store noise_back_pca_directions_1_stride noise_back_pca_directions_s_stride

In [None]:
visualize_pca_directions(noise_back_pca_directions_s_stride, "Strides=scales", default_scales, lines=True)

In [None]:
seeds = [1_2123, 1_40_124, 1_508_559, 1_5_019_258, 1_2_429_852, 9032, 5832, 12, 5014, 92, 42, 52, 
         52_934, 935_152, 1_000_000, 1_000_001, 27, 24, 512, 999_105]  # 20 

In [None]:
pca_map_s_strides, _, grad_maps, explain_imgs = generate_many_pca(permuted_net, seeds, 
                noise_back_pca_directions_1_stride, default_scales, valid_set, component=0, 
                batch_size=512, strides=3, skip_1_stride=True, device=device)

In [None]:
plt_grid_figure([explain_imgs, pca_map_s_strides, grad_maps], transpose=True, titles=["Image", "Strides=3", "Gradient"])