In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import time
import os
from color_regions import ColorDatasetGenerator
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
torch.backends.cudnn.benchmark = True
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
import sys
prev_time = 0
gamma = 0.99
stats = {}  # tracks ewma running average
def benchmark(point=None, profile=True, verbose=True, cuda=True): # not thread safe at all
    global prev_time
    if not profile:
        return
    if cuda:
        torch.cuda.synchronize()
    time_now = time.perf_counter()
    if point is not None:
        point = f"{sys._getframe().f_back.f_code.co_name}-{point}"
        time_taken = time_now - prev_time
        if point not in stats:
            stats[point] = time_taken
        stats[point] = stats[point]*gamma + time_taken*(1-gamma)
        if verbose:
            print(f"took {time_taken} to reach {point}, ewma={stats[point]}")
    prev_time = time_now

In [None]:
transform = transforms.Compose(
     [transforms.ToTensor()])#,
    #transforms.Normalize((0.5), (0.5))])

batch_size = 512 # seems to be the fastest batch size
train_indices = (0, 250_000) # size of training set
valid_indices = (1_250_000, 1_270_000)
test_indices = (2_260_000, 2_270_000)

def color_classifier(color):  
    if color <= 30:  # => 3 classes
        return 0
    if 30 < color <= 60:  # => 90/255 is 0, 90/255 is 1, 75/255 is 2
        return 1
    if 60 < color <= 90:
        return 2
    if 90 < color <= 120:
        return 1
    if 120 < color <= 150:
        return 0
    if 150 < color <= 180:
        return 1
    if 180 < color <= 210:
        return 2
    if 210 < color <= 240:
        return 0
    if 240 < color:
        return 2
critical_color_values = list(range(0,241,30))

def set_loader_helper(indices):
    data_set = ColorDatasetGenerator(color_classifier=color_classifier,
                                    image_indices=indices,
                                    transform=transform,
                                    color_range=(5, 255),
                                    noise_size=(1,9),
                                    num_classes=3,
                                    size=128,
                                    radius=(128//6, 128//3))
    loader = torch.utils.data.DataLoader(data_set, batch_size=batch_size, 
                                          shuffle=True, num_workers=6, pin_memory=True)
    return data_set, loader
train_set, train_loader = set_loader_helper(train_indices)
valid_set, valid_loader = set_loader_helper(valid_indices)
test_set, test_loader = set_loader_helper(test_indices)

In [None]:
# the "hard" task
color_probe = np.linspace(0, 255, 255)
color_class = [color_classifier(x) for x in color_probe]
plt.plot(color_probe, color_class)

In [None]:
class ResNet(nn.Module):
    def __init__(self, conv_layers, num_classes, img_shape, path, fc_layers=[1000], groups=1):
        super().__init__()
        self.conv_layers1 = []  # entry into residual block 
        self.conv_layers2 = []  # https://arxiv.org/pdf/1512.03385.pdf Figure 3
        self.batch_norms1 = []
        self.batch_norms2 = []
        self.is_resid = []
        self.path = path
        self.num_classes = num_classes
        channels = img_shape[-1]
        img_size = img_shape[0]
        for l in conv_layers:  # (out_channels, kernel_size, stride) is each l
            if l[2] > 1: # stride
                pad_type = "valid"
                img_size = (img_size-l[1])//l[2] + 1 # https://arxiv.org/pdf/1603.07285.pdf
            else:
                pad_type = "same"
            if isinstance(l[0], float):
                l[0] = int(l[0])
                l[0] -= l[0] % groups # ensure divisble by groups
            self.is_resid.append(l[2] == 1 and channels == l[0])
            self.conv_layers1.append(nn.Conv2d(channels, l[0], l[1], stride=l[2], padding=pad_type, groups=groups))
            channels = l[0]
            self.final_num_logits = channels * img_size * img_size 
            self.batch_norms1.append(nn.BatchNorm2d(channels)) # cant use track_running_stats=False since
            self.batch_norms2.append(nn.BatchNorm2d(channels)) # it causes poor performance for inference with batch size=1 (or probably with the same image repeated a bunch of times)
            self.conv_layers2.append(nn.Conv2d(channels, channels, l[1], stride=1, padding="same", groups=groups))
        self.conv_layers1 = nn.ModuleList(self.conv_layers1)
        self.conv_layers2 = nn.ModuleList(self.conv_layers2)
        self.batch_norms1 = nn.ModuleList(self.batch_norms1)
        self.batch_norms2 = nn.ModuleList(self.batch_norms2)

        fully_connected = []
        fc_layers.insert(0, self.final_num_logits)
        fc_layers.append(num_classes)
        for fc_prev, fc_next in zip(fc_layers, fc_layers[1:]):
            fully_connected.append(nn.Linear(fc_prev, fc_next))
        self.fully_connected = nn.ModuleList(fully_connected)

    def forward(self, x, logits=False):
        network_iter = zip(self.conv_layers1, self.conv_layers2, self.batch_norms1, self.batch_norms2, self.is_resid)
        for i, (conv1, conv2, batch_norm1, batch_norm2, is_resid) in enumerate(network_iter):                
            x_conv1 = F.relu(batch_norm1(conv1(x)))
            x_conv2 = F.relu(batch_norm2(conv2(x_conv1)))
            if is_resid:
                x = x + x_conv2  # residual block
            else:
                x = x_conv2  # dimension increasing block            
        x = torch.flatten(x, 1)
        for i, fc_layer in enumerate(self.fully_connected):
            x = fc_layer(x)
            if i != len(self.fully_connected) - 1: # dont ReLU the last one
                x = F.relu(x)            
        if self.num_classes == 1 and not logits:  # always allow returning logits
            x = torch.sigmoid(x)
        return x    

    def num_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
    
    def save_model_state_dict(self, path=None, optim=None):
        if path is None:
            path = self.path
        if optim is not None:
            save_dict = {}
            save_dict["model"] = self.state_dict()
            save_dict["optim"] = optim.state_dict()
        else:
            save_dict = self.state_dict()
        torch.save(save_dict, path)
    
    def load_model_state_dict(self, path=None, optim=None):
        if path is None:
            path = self.path
        if not os.path.exists(path):
            return
        load_dict = torch.load(path)
        if "model" in load_dict:
            if optim is not None:
                optim.load_state_dict(load_dict["optim"]) 
            self.load_state_dict(load_dict["model"])
        else:
            self.load_state_dict(load_dict)

In [None]:
class ProfileExecution(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.handles = []
        
        for list_name, mod_list in self.model.named_children():
            for mod_name, mod in mod_list.named_children():
                name = f"{list_name}_{mod_name}"
                self.handles.append(mod.register_forward_hook(self.benchmark_hook(name)))
    
    def benchmark_hook(self, name):
        def fn(layer, inpt, outpt):
            benchmark(name, verbose=False)
        return fn
    
    def clean_up(self):
        for handle in self.handles:
            handle.remove()

    def forward(self, *args):
        benchmark()
        return self.model(*args)

class AllActivations(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self._features = {}
        self.handles = []

        for list_name, mod_list in self.model.named_children():
            for mod_name, mod in mod_list.named_children():
                name = f"{list_name}_{mod_name}"
                self.handles.append(mod.register_forward_hook(self.save_activations_hook(name)))

    def save_activations_hook(self, name):
        def fn(layer, inpt, output):
            self._features[name] = output
        return fn

    def clean_up(self):
        for hand in self.handles:
            handle.remove()

    def forward(self, *args):
        self.model(*args)
        return self._features

In [None]:
def correct(pred_logits, labels):
    if labels.shape[1] != 1:
        pred_probabilities = F.softmax(pred_logits, dim=1)
        classifications = torch.argmax(pred_probabilities, dim=1)
        labels_argmax = torch.argmax(labels, dim=1)
    else:
        classifications = pred_logits.int()
        labels_argmax = labels
    correct = (labels_argmax == classifications)
    return correct

def train(net, optimizer, loss, epochs):
    va_losses = []
    tr_losses = []
    va_accuracies = []
    for epoch in range(epochs):
        epoch_tr_loss = 0.0
        net.train()
        for i, sample in tqdm(enumerate(train_loader)):
            imgs = sample["image"].to(device, non_blocking=False).float()
            labels = sample["label"].to(device).float()
            optimizer.zero_grad()
            outputs = net(imgs)
            batch_loss = loss(outputs, labels)
            epoch_tr_loss += batch_loss.item()
            batch_loss.backward()
            optimizer.step()
        epoch_va_loss = 0.0
        epoch_va_correct = 0
        net.eval()
        with torch.no_grad():
            for i, sample in enumerate(valid_loader):
                imgs = sample["image"].to(device).float()
                labels = sample["label"].to(device).float()
                outputs = net(imgs)
                epoch_va_loss += loss(outputs, labels).item()
                epoch_va_correct += correct(outputs, labels).sum().item()
        epoch_va_accuracy = epoch_va_correct/(valid_indices[1] - valid_indices[0])
        print(f'Epoch {epoch + 1}: va_loss: {epoch_va_loss}, va_accuracy: {epoch_va_accuracy}, tr_loss: {epoch_tr_loss}')
        if not va_losses or epoch_va_loss < min(va_losses):
            net.save_model_state_dict(optim=optimizer)
        va_losses.append(epoch_va_loss)
        tr_losses.append(epoch_tr_loss)
        va_accuracies.append(epoch_va_accuracy)
    return va_losses, va_accuracies, tr_losses

In [None]:
small_net = ResNet([[16, 3, 2],  # num_channels (input and output), kernel_size, stride
                    [32, 3, 2],
                    [64, 3, 2]], 3, [128, 128, 1], 
                   "small_net_noise_hard_grey.dict", fc_layers=[32]).to(device)
loss_func = nn.CrossEntropyLoss()
small_optim = torch.optim.Adam(small_net.parameters())
print(small_net.num_params())
small_net.load_model_state_dict(optim=small_optim)

In [None]:
results = train(small_net, small_optim, loss_func, 1000)

In [None]:
tiny_net = ResNet([[2, 3, 4],  # num_channels (input and output), kernel_size, stride
                   #[4, 3, 2],
                   [6, 3, 4]], 3, [128, 128, 1], 
                   "tiny_net_noise_hard_grey.dict", fc_layers=[]).to(device)
loss_func = nn.CrossEntropyLoss()
tiny_optim = torch.optim.Adam(tiny_net.parameters())
print(tiny_net.num_params())
tiny_net.load_model_state_dict(optim=tiny_optim)

In [None]:
results = train(tiny_net, tiny_optim, loss_func, 1000)

In [None]:
@torch.no_grad()
def rate_distribution(net, loader, dataset, buckets=100):
    net.eval()
    total = np.zeros((buckets))
    num_correct = np.zeros((buckets))
    num_possible_colors = dataset.color_range[1] - dataset.color_range[0]
    for sample in tqdm(loader):
        imgs = sample["image"].to(device).float()
        labels = sample["label"].to(device).float()
        actual_colors = sample["color"]
        color_indices = (buckets * (actual_colors - dataset.color_range[0]) / num_possible_colors).int().numpy()
        outputs = net(imgs)
        correct_preds = correct(outputs, labels).cpu().numpy()
        for i, color_idx in enumerate(color_indices):
            total[color_idx] += 1  
            num_correct[color_idx] += correct_preds[i]
    return num_correct, total
_num_correct, _total = rate_distribution(small_net, valid_loader, valid_set)

In [None]:
def make_graph(num_correct, total, dataset, critical_values=[150-0.5], buckets=100):
    num_wrong = total - num_correct
    width = 0.4
    labels = [int(x) for i, x in enumerate(np.linspace(*dataset.color_range, buckets))]
    plt.bar(labels, num_correct, width, label="correct amount")
    plt.bar(labels, num_wrong, width, bottom=num_correct, label="wrong amount")
    plt.vlines(critical_values, 5, -5, 
               linewidth=0.8, colors="r", label="decision boundary",
               linestyles="dashed")
    plt.legend()
    plt.xlabel("Color value")
    plt.show()
make_graph(_num_correct, _total, valid_set,   # with .eval(), looks good
           critical_values=[x-0.5 for x in critical_color_values]) 

In [None]:
small_net.eval() # very important!
with torch.no_grad():
    test_index = 987_652  # results seem pretty dependent on image, especially in small-color regimes
    counterfactual_color_values = np.linspace(0, 255, 255) # probably because bad batchnorm estimates
    responses = []
    for color in counterfactual_color_values:
        np.random.seed(test_index)
        generated_img, lbl, __ = valid_set.generate_one(set_color=color)
        generated_img = np.expand_dims(generated_img, 0).transpose(0, 3, 1, 2)
        generated_img = torch.tensor(generated_img).to(device).float()
        response = small_net(torch.tensor(generated_img).to(device).float(), logits=True).cpu().numpy()
        responses.append(np.squeeze(response))

In [None]:
def plot_responses(resp, colors, title):
    resp = np.arcsinh(np.array(resp))
    for output_logit in range(resp.shape[1]):
        plt.plot(colors, resp[:, output_logit], label=f"class {output_logit}")
    plt.legend()
    plt.xlabel("Color value")
    plt.ylabel("Network output logit")
    plt.title(title)
    plt.vlines(critical_color_values, np.min(resp), np.max(resp), linewidth=0.8,
               colors="r", label="decision boundary",
               linestyles="dashed") # with .eval() works well

In [None]:
plot_responses(responses, counterfactual_color_values, "")
#i ncreas esize of training set and roughness (mean squared grad across img or pixelvalues) should go down, meght be overfitting (double deep descent)
#  -> could be causing decrease in quality of FD maps
# to be expected

In [None]:
@torch.no_grad()
def finite_differences(model, dataset, target_class, stacked_img, locations, channel, unfairness, values_prior):
    num_iters = 20 # sample 20 values evenly spaced
    cuda_stacked_img = torch.tensor(stacked_img).to(device)
    if dataset.num_classes == 2:
        class_multiplier = 1 if target_class == 1 else -1 
        baseline_activations = class_multiplier*model(cuda_stacked_img, logits=True)
    else:
        baseline_activations = model(cuda_stacked_img)[:, target_class]
    largest_slope = np.zeros(stacked_img.shape[0])  # directional finite difference?
    slices = np.index_exp[np.arange(stacked_img.shape[0]), channel, locations[:, 0], locations[:, 1]]
    if values_prior is None:
        values_prior = np.linspace(5, 250, stacked_img.shape[0]) # uniform distribution assumption
    elif isinstance(values_prior, list):
        values_prior = np.expand_dims(np.asarray(values_prior), 1)
    num_loops = 1 if unfairness == "very unfair" else len(values_prior)
    for i in range(num_loops):
        shift_img = stacked_img.copy()
        # shifting method
        if unfairness in ["fair", "unfair"]:
            shift_img[slices] = values_prior[i]+0.01  # add tiny offset to "guarantee" non-zero shift
        elif unfairness in ["very unfair"]:
            critical_value_dists = shift_img[slices] - values_prior
            closest = np.argmin(abs(critical_value_dists), axis=0) # find closest class boundary
            shift_img[slices] = 0.01 + np.choose(closest, values_prior) - 10*np.sign(np.choose(closest, critical_value_dists))
        
        actual_diffs = shift_img[slices] - stacked_img[slices]  
        img_norm = torch.tensor(shift_img).to(device) # best is no normalization anyway
        if dataset.num_classes == 2:
            activations = class_multiplier*model(img_norm, logits=True)
        else:
            activations = model(img_norm)[:, target_class]
        activation_diff = (activations - baseline_activations).cpu().numpy().squeeze()
        finite_difference = np.clip(activation_diff/actual_diffs, -30, 30) # take absolute slope
        largest_slope = np.where(abs(finite_difference) > abs(largest_slope), finite_difference, largest_slope)
    return largest_slope      

def finite_differences_map(model, dataset, target_class, img, unfairness="fair", values_prior=None):
    # generate a saliency map using finite differences method (iterate over colors)
    model.eval()
    batch_size = 32  # check batch_size num pixel positions in parallel
    im_size = dataset.size
    #img = img.astype(np.float32)/255. # normalization handled later
    values_x = np.repeat(np.arange(im_size), im_size)
    values_y = np.tile(np.arange(im_size), im_size)
    indices = np.stack((values_x, values_y), axis=1)
    stacked_img = np.repeat(np.expand_dims(img, 0), batch_size, axis=0)
    stacked_img = np.transpose(stacked_img, (0, 3, 1, 2)).astype(np.float32) # NCHW format
    img_heat_map = np.zeros_like(img)
    for channel in range(dataset.channels):
        for k in tqdm(range(0, im_size*im_size, batch_size)):
            actual_batch_size = min(batch_size, im_size*im_size-k+batch_size)
            locations = indices[k:k+batch_size]
            largest_slopes = finite_differences(model, dataset, target_class, stacked_img, locations, channel, unfairness, values_prior)
            img_heat_map[locations[:,0], locations[:,1], channel] = largest_slopes
    return img_heat_map#.sum(axis=2)  # linear approximation aggregation?

In [None]:
np.random.seed(500_001)
explain_img, label, *_ = valid_set.generate_one()
heat_map = finite_differences_map(small_net, valid_set, label.argmax(), explain_img)

In [None]:
def plt_grid_figure(grid, titles=None, colorbar=True, cmap=None, transpose=False, hspace=-0.4):      
    np_grid = np.array(grid).squeeze()
    if len(np_grid.shape) != 4:
        np_grid = np.expand_dims(np_grid, 0)
    if transpose:
        np_grid = np_grid.T
        
    if cmap is None:
        cmap = "bwr"
    nrows, ncols = np_grid.shape[0], np_grid.shape[1]
    im_size = np_grid.shape[2]
    fig = plt.figure(figsize=(4/128*im_size*ncols, 5/128*im_size*nrows))
    gridspec = fig.add_gridspec(nrows, ncols, hspace=hspace)
    axes = gridspec.subplots(sharex="col", sharey="row")
    if len(axes.shape) == 1:
        axes = np.expand_dims(axes, 0)
    print(np_grid.shape, nrows, ncols)
    for i, row in enumerate(np_grid):
        for j, img in enumerate(row):
            if j == 0: # assume explain_img is the first thing
                im = axes[i,j].imshow(img, cmap="gray")
            else:
                img_max = np.max(abs(img))
                if cmap != "gray":
                    im = axes[i,j].imshow(img, cmap=cmap, interpolation="nearest", vmax=img_max, vmin=-img_max)
                else:
                    axes[i,j].imshow(img, cmap=cmap)
                if colorbar:
                    plt.colorbar(im, pad=0, fraction=0.048)
            if titles and i == 0:
                axes[i,j].set_title(titles[j])
    plt.show()

In [None]:
plt_grid_figure([explain_img, heat_map], titles=["Image", "FD Map"])

In [None]:
image_ids = [20_000, 25_000, 30_000, 600_000, 600_001]
heat_maps = []
explain_imgs = []
for i, image_id in enumerate(image_ids):
    np.random.seed(image_id)
    explain_img_i, target_i, *__ = valid_set.generate_one()
    heat_map_i = finite_differences_map(small_net, valid_set, target_i.argmax(), explain_img_i)
    heat_maps.append(heat_map_i)
    explain_imgs.append(explain_img_i)

In [None]:
plt_grid_figure([explain_imgs, heat_maps], transpose=True, colorbar=True)

In [None]:
def random_pixels_response(num_pixels, one_class=True):
    small_net.eval() # very important!
    test_index = 987_652  # results seem pretty dependent on image, especially in small-color regimes
    np.random.seed(test_index)
    generated_img, lbl, color, size, pos  = valid_set.generate_one()
    generated_img = np.expand_dims(generated_img, 0).transpose(0, 3, 1, 2)
    
    np.random.seed(int(time.time()/np.pi))
    selected_pixels = np.random.randint(0, valid_set.size, (num_pixels, 2))
    num_inside = 0
    for p in selected_pixels:
        if np.linalg.norm(p-pos) < size:
            num_inside += 1
    print(f"Percent of random inside circle: {num_inside/num_pixels*100.}")
    
    counterfactual_color_values = np.linspace(0, 255, 255) # probably because bad batchnorm estimates
    responses = []
    with torch.no_grad():
        for color in counterfactual_color_values:
            generated_img[0, 0, selected_pixels[:,0], selected_pixels[:,1]] = color
            tensor_img = torch.tensor(generated_img).to(device).float()
            response = small_net(torch.tensor(generated_img).to(device).float(), logits=True).cpu().numpy()
            if one_class:
                responses.append(np.expand_dims(np.squeeze(response[:,lbl.argmax()]), 0))
            else:
                responses.append(np.squeeze(response))
    plot_responses(responses, counterfactual_color_values, "Randomly selected pixels")
    
def circle_pixels_response(num_pixels, one_class=True):
    small_net.eval() # very important!
    test_index = 987_652 
    np.random.seed(test_index)  # generate image
    generated_img, lbl, color, size, pos = valid_set.generate_one()
    generated_img = np.expand_dims(generated_img, 0).transpose(0, 3, 1, 2)
    
    np.random.seed(int(time.time()/np.pi))
    angle = np.random.uniform(0,2*np.pi, num_pixels)
    radii = np.random.uniform(0, size, num_pixels)
    selected_pixels = np.zeros((num_pixels, 2))
    selected_pixels[:,0] = pos[0][0] + np.cos(angle)*radii
    selected_pixels[:,1] = pos[0][1] + np.sin(angle)*radii
    selected_pixels = np.round(selected_pixels).astype(np.int64)
    
    counterfactual_color_values = np.linspace(0, 255, 255) # probably because bad batchnorm estimates
    responses = []
    with torch.no_grad():
        for color in counterfactual_color_values:
            generated_img[0, 0, selected_pixels[:,0], selected_pixels[:,1]] = color
            tensor_img = torch.tensor(generated_img).to(device).float()
            response = small_net(torch.tensor(generated_img).to(device).float(), logits=True).cpu().numpy()
            if one_class:
                responses.append(np.expand_dims(np.squeeze(response[:,lbl.argmax()]), 0))
            else:
                responses.append(np.squeeze(response))
    plot_responses(responses, counterfactual_color_values, "Pixels inside circle")
    
def both_pixels_response(num_pixels, one_class):
    plt.figure(figsize=(10,4))
    plt.subplot(1,2,1)
    circle_pixels_response(num_pixels, one_class=one_class)
    plt.subplot(1,2,2)
    random_pixels_response(num_pixels, one_class=one_class)

In [None]:
both_pixels_response(1, True)
# deeper layers = more ability to do step function
# conv layer = not be able to do step function
# different pxiles have different jobs, display graph of a bunch of pixels at once
# if theory is correct, we would have a bunch of clutter, but the average would converge to the 
# decision boundary
# mask out everything but 1 pixel, and train on that too
# vector value might be ok
# want it to be beyond edge detection( add color detection)

# interesting result is that even if you force the network to do non-linear decesion boundary
# seemingly, it doesnt work on a pixel-by-pixel basis. Could try to make a dataset where
# that really does matter but at that point eh. 
# Want to get a better understanding of how this actually happens. Like what is the network
# learning, or what can we hypothesize about how nets in general function. Seems to be taking
# the "average value" somehow and using that for its decisions

What if we run the same experiment, but cheat with a prior on pixel values that we know *should* be informative to the output logit, namely values closest to the decision boundary?

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self, conv_layers, img_shape, path, fc_layers=[], embed_size=128):
        super().__init__()
        self.path = path  # for saving and loading
        
        enc_layers1 = []  
        enc_layers2 = []
        enc_maxpools = []
        enc_batchnorms1 = []
        enc_batchnorms2 = []
        
        dec_layers1 = []  
        dec_layers2 = []
        dec_maxpools = []  # should be maxunpools
        dec_batchnorms1 = []
        dec_batchnorms2 = []
        
        is_resid = []        
        channels = img_shape[-1]
        img_size = img_shape[0]
        for l in conv_layers:  # (out_channels, kernel_size, stride) is each l
            is_resid.append(l[2] == 1 and channels == l[0])
            
            if l[2] > 1:
                img_size = (img_size-l[2])//l[2]+1
                print(img_size)
                enc_maxpools.append(nn.MaxPool2d(l[2], return_indices=True))
                dec_maxpools.insert(0, nn.MaxUnpool2d(l[2]))
            else:
                enc_maxpools.append(None)
                dec_maxpools.insert(0, None)
                
            enc_layers1.append(nn.Conv2d(channels, l[0], l[1], padding="same"))
            enc_batchnorms1.append(nn.BatchNorm2d(l[0])) 

            dec_layers2.insert(0,nn.ConvTranspose2d(l[0], channels, l[1], padding=l[1]//2))
            dec_batchnorms2.insert(0,nn.BatchNorm2d(channels)) 

            channels = l[0]
            
            enc_layers2.append(nn.Conv2d(channels, channels, l[1], padding="same"))
            enc_batchnorms2.append(nn.BatchNorm2d(channels))

            dec_layers1.insert(0,nn.ConvTranspose2d(channels, channels, l[1], padding=l[1]//2))
            dec_batchnorms1.insert(0,nn.BatchNorm2d(channels)) 
            
        self.final_flat_shape = channels*img_size*img_size
        self.final_img_shape = [channels, img_size, img_size]
        self.embed_size = embed_size
        
        enc_fully_connected = []
        dec_fully_connected = []
        extended_fc_layers = fc_layers.copy()
        extended_fc_layers.insert(0, self.final_flat_shape)
        extended_fc_layers.append(embed_size)
        for fc_prev, fc_next in zip(extended_fc_layers, extended_fc_layers[1:]):
            enc_fully_connected.append(nn.Linear(fc_prev, fc_next))
            dec_fully_connected.insert(0,nn.Linear(fc_next, fc_prev))
            
        self.enc_layers1 = nn.ModuleList(enc_layers1)
        self.enc_layers2 = nn.ModuleList(enc_layers2)
        self.enc_maxpools = nn.ModuleList(enc_maxpools)
        self.enc_batchnorms1 = nn.ModuleList(enc_batchnorms1)
        self.enc_batchnorms2 = nn.ModuleList(enc_batchnorms2)
        self.enc_fully_connected = nn.ModuleList(enc_fully_connected)
        
        self.dec_layers1 = nn.ModuleList(dec_layers1)
        self.dec_layers2 = nn.ModuleList(dec_layers2)
        self.dec_maxpools = nn.ModuleList(dec_maxpools)
        self.dec_batchnorms1 = nn.ModuleList(dec_batchnorms1)
        self.dec_batchnorms2 = nn.ModuleList(dec_batchnorms2)
        self.dec_fully_connected = nn.ModuleList(dec_fully_connected)

        self.enc_is_resid = is_resid
        self.dec_is_resid = reversed(is_resid)
        
        iter_names = ["layers1", "layers2", "maxpools", "batchnorms1", "batchnorms2", "is_resid"]
        self.enc_iter = list(zip(*[getattr(self, "enc_"+name) for name in iter_names]))
        self.dec_iter = list(zip(*[getattr(self, "dec_"+name) for name in iter_names]))
        
    def net_block(self, x, indices, block_name):
        conv_iter = getattr(self, block_name + "_iter")
        fc_iter = getattr(self, block_name + "_fully_connected")

        if block_name == "dec":
            for i, fc_layer in enumerate(fc_iter):
                x = F.relu(fc_layer(x))
            x = torch.reshape(x, (-1,*self.final_img_shape))
#         print(x.shape, block_name, "pre-conv", type(conv_iter), next(conv_iter))
        #print([t.shape if t is not None else t for t in indices])
        for i, (conv1, conv2, maxpool, batch_norm1, batch_norm2, is_resid) in enumerate(conv_iter):
            x_conv1 = F.relu(batch_norm1(conv1(x)))
#             if block_name == "dec":
#                 print("DEC", i, type(indices[i]), maxpool)
            if maxpool is not None:
                if block_name == "enc":
                    x_conv1, indices_layer = maxpool(x_conv1)
                    #print(x.shape, conv1, indices_layer.shape)
                    indices.insert(0,indices_layer)
                else:
                    #print(x.shape, maxpool, i, type(indices[i-1]), type(indices[i]), type(indices[i+1]))
                    x_conv1 = maxpool(x_conv1, indices[i])
            else:
                if block_name == "enc":
                    indices.insert(0,None)
                    #print(x.shape, conv1)
                    
            if block_name == "dec" and i == len(conv_iter)-1:
                x_conv2 = F.relu(conv2(x_conv1))
            else:
                x_conv2 = F.relu(batch_norm2(conv2(x_conv1)))
                
            if is_resid:
                x = x + x_conv2  # residual block
            else:
                x = x_conv2  # dimension increasing block
          #  print(x.shape, block_name, "conv_mid", i)
#         print(x.shape, block_name, "post_conv")
        if block_name == "enc":
            x = torch.flatten(x, 1)
            for i, fc_layer in enumerate(fc_iter):
                x = fc_layer(x)
                if i != len(fc_iter) - 1:
                    x = F.relu(x)
        return x    
    
    def encode(self, x, indices): # due to the symmetry of decoding/encoding, we can do this nicely
        return self.net_block(x, indices, "enc")
    
    def decode(self, z, indices): # => this implementation means the output is ReLU'd
        return self.net_block(z, indices, "dec")

    def forward(self, x):
        save_indices = []
        #print(save_indices, x.shape, "pre anything")
        z = self.encode(x, save_indices)  # write to save_indices
        return self.decode(z, save_indices) # read from save_indices
    
    def num_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
    
    def save_model_state_dict(self, path=None, optim=None):
        if path is None:
            path = self.path
        if optim is not None:
            save_dict = {}
            save_dict["model"] = self.state_dict()
            save_dict["optim"] = optim.state_dict()
        else:
            save_dict = self.state_dict()
        torch.save(save_dict, path)
    
    def load_model_state_dict(self, path=None, optim=None):
        if path is None:
            path = self.path
        if not os.path.exists(path):
            return
        load_dict = torch.load(path)
        if "model" in load_dict:
            if optim is not None:
                optim.load_state_dict(load_dict["optim"]) 
            self.load_state_dict(load_dict["model"])
        else:
            self.load_state_dict(load_dict)

In [None]:
def autoenc_train(net, optimizer, loss, epochs):
    va_losses = []
    tr_losses = []
    for epoch in range(epochs):
        epoch_tr_loss = 0.0
        net.train()
        for i, sample in tqdm(enumerate(train_loader)):
            imgs = sample["image"].to(device, non_blocking=False).float()
            optimizer.zero_grad()
            outputs = net(imgs)  # should be close to the image
            batch_loss = loss(outputs, imgs)  # L2 loss of reconstruction
            epoch_tr_loss += batch_loss.item()
            batch_loss.backward()
            optimizer.step()
        epoch_va_loss = 0.0
        net.eval()
        with torch.no_grad():
            for i, sample in enumerate(valid_loader):
                imgs = sample["image"].to(device).float()
                outputs = net(imgs)
                epoch_va_loss += loss(outputs, imgs).item()
        print(f'Epoch {epoch + 1}: va_loss: {epoch_va_loss}, tr_loss: {epoch_tr_loss}')
        if not va_losses or epoch_va_loss < min(va_losses):
            net.save_model_state_dict(optim=optimizer)
        va_losses.append(epoch_va_loss)
        tr_losses.append(epoch_tr_loss)
    return va_losses, tr_losses

In [None]:
auto_enc = AutoEncoder([[16, 7, 1],  # num_channels (input and output), kernel_size, max_pool kernel
                        [32, 3, 2],  # make sure to change the batch size before working with this
                        [32, 3, 2],
                        [64, 3, 2],
                        [64, 3, 2],
                        [128, 3, 2],
                        [128, 3, 1]], [128, 128, 1], "auto_enc_greyscale_no_norm.dict",
                        fc_layers=[], embed_size=256).to(device)
auto_enc_loss = nn.MSELoss()
auto_enc_optim = torch.optim.Adam(auto_enc.parameters())
print(auto_enc.num_params())
auto_enc.load_model_state_dict(optim=auto_enc_optim)

In [None]:
results = autoenc_train(auto_enc, auto_enc_optim, auto_enc_loss, 200)

In [None]:
# simple sanity check
auto_enc.eval()
generated_img = valid_set.generate_one()[0]
tensor_img = torch.tensor(np.expand_dims(generated_img,0).transpose(0,3,1,2)).to(device).float()
reconstruction = np.expand_dims(auto_enc(tensor_img).detach().cpu().numpy().squeeze(),-1)

In [None]:
print(reconstruction.mean(), generated_img.mean())

In [None]:
plt_grid_figure([generated_img.squeeze(), reconstruction.squeeze()], colorbar=False, cmap="gray", transpose=False)

In [None]:
def is_indicator(outpts, inputs):  # assume values are equally spaced
    outpt_diffs = outpts[:-1] - outpts[1:]
    input_diffs = inputs[:-1] - inputs[1:]
    deriv = outpt_diffs/input_diffs
    avg_abs_grad = np.where(abs(deriv) < 1e-3, 1, 0).mean()  # for linear funcs, output_range = m*x_range
    grad_range = np.max(abs(deriv)) - np.min(abs(deriv))
    output_range = np.max(outpts) - np.min(outpts) # avg_square_grad = m**2/x_range
    return output_range, avg_abs_grad, grad_range/output_range
    
def random_polynomial(num_pts, degree, pts_range, inpts):
    input_pts = np.random.uniform(inpts.min(), inpts.max(), num_pts)
    input_pts = np.concatenate((input_pts, [0,255]))
    output_pts = np.random.uniform(*pts_range, num_pts)
    output_pts = np.concatenate((output_pts, np.random.uniform(*pts_range, 2)))
    return np.poly1d(np.polyfit(input_pts, output_pts, degree))(inpts)

def random_indicator(jumps_range, inputs):
    num_jumps = np.random.randint(*jumps_range)
    jumps = np.sort(np.random.uniform(inputs.min(), inputs.max(), num_jumps))
    noise = np.random.normal(loc=0, scale=0.01, size=(len(inputs)))
    return np.digitize(inputs, jumps) % 2 + noise

input_range = np.linspace(0,255,255)
plt.subplot(1,2,1)
poly_results = random_polynomial(4, 3, (-4, 4), input_range)
plt.plot(input_range, poly_results)
plt.subplot(1,2,2)
indic_results = random_indicator((1,5), input_range)
plt.plot(input_range, indic_results)
print(is_indicator(poly_results, input_range))
print(is_indicator(indic_results, input_range))

In [None]:
def baseline_image(encoder_net, img, dataset, target_class, sample_size=1024):
    sample = []
    model.eval()
    while len(sample) != sample_size:
        sampled_img, sampled_label, *_ =  dataset.generate_one()
        if sampled_label.argmax() != target_class:  # only consider sample from other classes
            sample.append(sampled_img)
    im_size = img.shape[0]
    sample = np.array(sample).transpose(0,3,1,2)
    sample_tensor = torch.tensor(sample).to(device).float()
    print(sample_tensor.shape)
    batch_size = 32
    encoding_vectors = []
    encoding_indices = []
    for k in range(0, sample_size, batch_size):
        selected = sample_tensor[k:min(k+batch_size, sample_size)]
        encodings, indices = encoder_net.encode(selected)
        encoding_vectors.append(encodings.detach().cpu().numpy())
        encoding_indices.append(indices.detach().cpu().numpy())
    img_tensor = torch.tensor(img.transpose(2,0,1)).unsqueeze(0).to(device)
    img_enc, img_indices = encoder_net.encode(img_tensor)
    img_enc = img_enc.detach().cpu().numpy()
    encoding_vectors = np.array(encoding_vectors)
    diffs = np.linalg.norm(encoding_vectors - img_enc, axis=1)
    min_diff = np.argmin(diffs)
    closest_img = sample[min_diff]
    # now, do a binary search on alpha values for the closest value that is the other class
    # interpolation on "indices" can be done with ???
    # convolutional VAE 
    # paperswithcode
    alpha_bin_search(img, closest_img, classifier, auto_encoder, img_enc, closest_enc)

# Model Optimization Stuff

In [None]:
small_net.eval()
generated_img = torch.tensor(valid_set.generate_one()[0].transpose(2,0,1)).unsqueeze(0).to(device).float()
profile_model = ProfileExecution(small_net)
for _ in tqdm(range(1000)):
    profile_model.forward(generated_img)
profile_model.clean_up()

In [None]:
total = sum(stats.values())  # --> gave 3x speed! (Fast and Accurate Model scaling?)
for k,v in sorted(stats.items(), key=lambda x: x[0]):    # --> the 3x speedup caused underfitting though, so switched to 2x
    print(k,(100.*v/total))