In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import time
import os
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from collections import defaultdict
import pickle

from color_regions import *
from network import *
from visualizations import *
from utils import *
from hooks import *
from config_objects import *
from training import *

# set up autoreloading of shared code
%load_ext autoreload
%autoreload 1
%aimport color_regions,network,visualizations,utils,hooks,config_objects,training
%aimport

torch.backends.cudnn.benchmark = True
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
train_indices = (0, 250_000) # size of training set
valid_indices = (1_250_000, 1_275_000)
test_indices = (3_260_000, 3_560_000)

critical_color_values = list(range(0,241,30))

dset_config = ColorDatasetConfig(task_difficulty="hard",
                                 noise_size=(1,9),
                                 num_classes=3,
                                 num_objects=0,  # => permuted
                                 radius=(1/8., 1/7.),
                                 device=device,
                                 batch_size=128)

# copies the config each time
train_set = ColorDatasetGenerator(train_indices, dset_config)
valid_set = ColorDatasetGenerator(valid_indices, dset_config)
test_set = ColorDatasetGenerator(test_indices, dset_config)
# train_set.cfg.infinite = True

In [None]:
tiny_config = ExperimentConfig()
tiny_config.layer_sizes = [[2, 3, 4], [6, 3, 4]]

tiny_net = ResNet("full_random_noisy/0.001_3.7926901907322535e-06_False_tiny_size_0.2_30.dict", 
                  tiny_config, dset_config).to(device)
loss_func = nn.CrossEntropyLoss()
tiny_optim = torch.optim.Adam(tiny_net.parameters())
print(tiny_net.num_params())
tiny_net.load_model_state_dict(optim=tiny_optim)

In [None]:
tiny_net.eval()
interp_net = AllActivations(tiny_net)
profile_plots,_ = activation_color_profile(AllActivations(tiny_net), valid_set)

In [None]:
show_conv_weights(interp_net, "conv_blocks.0.act_func1", color_profile=profile_plots)

In [None]:
show_conv_weights(interp_net, "conv_blocks.0.act_func2", color_profile=profile_plots)

In [None]:
show_conv_weights(interp_net, "conv_blocks.1.act_func1", color_profile=profile_plots)

In [None]:
show_conv_weights(interp_net, "conv_blocks.1.act_func2", color_profile=profile_plots, 
                  full_gridspec=False, show_scale=True)

In [None]:
show_fc_conv(interp_net, color_profile=profile_plots)