In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import time
import os
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from collections import defaultdict

import color_regions, network, visualizations, utils
from color_regions import *
from network import *
from visualizations import *
from utils import *
from hooks import *

torch.backends.cudnn.benchmark = True
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
# set up autoreloading of shared code
%load_ext autoreload
%autoreload 1
%aimport color_regions,network,visualizations,utils,hooks
%aimport

In [None]:
transform = transforms.Compose(
     [transforms.ToTensor()])#,
    #transforms.Normalize((0.5), (0.5))])

batch_size = 128 # seems to be the fastest batch size
train_indices = (0, 100_000) # size of training set
valid_indices = (1_250_000, 1_300_000)
test_indices = (2_260_000, 2_360_000)

def color_classifier(color):  
    if color <= 30:  # => 3 classes
        return 0
    if 30 < color <= 60:  # => 90/255 is 0, 90/255 is 1, 75/255 is 2
        return 1
    if 60 < color <= 90:
        return 2
    if 90 < color <= 120:
        return 1
    if 120 < color <= 150:
        return 0
    if 150 < color <= 180:
        return 1
    if 180 < color <= 210:
        return 2
    if 210 < color <= 240:
        return 0
    if 240 < color:
        return 2
critical_color_values = list(range(0,241,30))

def set_loader_helper(indices, infinite=False):
    data_set = ColorDatasetGenerator(color_classifier=color_classifier,
                                    image_indices=indices,
                                    transform=transform,
                                    color_range=(5, 255),
                                    noise_size=(1,9),
                                    num_classes=3,
                                    infinite=infinite,
                                    size=128,
                                    num_objects=0,
                                    radius=(128//8, 128//7))
    loader = torch.utils.data.DataLoader(data_set, batch_size=batch_size, 
                                          shuffle=True, num_workers=6, pin_memory=True)
    return data_set, loader

train_set, train_loader = set_loader_helper(train_indices, infinite=False)
valid_set, valid_loader = set_loader_helper(valid_indices)
test_set, test_loader = set_loader_helper(test_indices)

In [None]:
# determine good weight decay and gain parameter
curves = {}
#results = {}
for gain in [0.1]:#, 0.4]: #0.01]:
    for i, weight_decay in enumerate(list(10**np.linspace(-5, 1, 12)) + [1e-6, 3e-6]):
        # set up model
        for k in results:
            if k[0] == gain and np.isclose(k[1], weight_decay):
                print("Done gain", gain, "weight decay", weight_decay, "continuing...")
                break
        else:
            large_net = ResNet([[16, 3, 1],  # num_channels (input and output), kernel_size, stride
                                [32, 3, 1]], 3, [128, 128, 1], 
                           f"decay/large_{weight_decay:.7f}_{gain}.dict", 
                               global_avg_pooling=True, fc_layers=[]).to(device)
            loss_func = nn.CrossEntropyLoss()  # dont start from same initialization
            optim = torch.optim.Adam(large_net.parameters(), weight_decay=weight_decay)
            set_initializers(large_net, gain)

            # train and evaluate
            print("Training with gain", gain, "weight decay", weight_decay)
            curve = train(large_net, optim, loss_func, 25, train_loader, valid_loader, device=device)
            large_net.load_model_state_dict()  # load the best model found over training
            result = evaluate(large_net, loss_func, test_loader, device=device) 

            # record result
            curves[(gain, weight_decay)] = curve
            results[(gain, weight_decay)] = result

In [None]:
print(results)

In [None]:
results = {(0.1, 1e-05): (99.80202204734087, 0.94754), (0.1, 3.511191734215127e-05): (103.67627491801977, 0.94252), (0.1, 0.0001232846739442066): (102.99732363969088, 0.94883), (0.1, 0.0004328761281083057): (118.29680179059505, 0.93689), (0.1, 0.0015199110829529332): (149.95270317792892, 0.9212), (0.1, 0.005336699231206307): (194.24484246224165, 0.90363), (0.1, 0.01873817422860383): (196.70017708837986, 0.92232), (0.1, 0.06579332246575675): (300.3072998225689, 0.89403), (0.1, 0.2310129700083158): (622.37857401371, 0.62512), (0.1, 0.8111308307896856): (832.7933620214462, 0.44141)}

In [None]:
results_arr = np.zeros((2,12))
for i, gain in enumerate([0.1, 0.01]):
    for j, weight_decay in enumerate(10**np.linspace(-5, 1, 12)):
        if (gain, weight_decay) in results:
            results_arr[i,j] = results[(gain, weight_decay)][0]
imshow_centered_colorbar(results_arr)
