In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
torch.set_default_tensor_type(torch.cuda.FloatTensor)
torch.backends.cudnn.benchmark = True
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
# boiler plate code from https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
transform = transforms.Compose( # dont really want to normalize here
     [transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 32
num_valid = 5000

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, 
                                          shuffle=False)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
valid, test = torch.utils.data.random_split(testset, [num_valid, 10000-num_valid], 
                                            generator=torch.Generator(device="cuda").manual_seed(4))
testloader = torch.utils.data.DataLoader(torch.utils.data.Subset(testset, valid.indices),
                                         batch_size=batch_size, shuffle=False)
validloader = torch.utils.data.DataLoader(torch.utils.data.Subset(testset, test.indices), 
                                         batch_size=batch_size, shuffle=False)

In [None]:
class BasicResNet(nn.Module):
    def __init__(self, conv_layers, num_classes, img_shape, groups=1):
        super().__init__()
        self.conv_layers1 = []  # entry into residual block 
        self.conv_layers2 = []  # https://arxiv.org/pdf/1512.03385.pdf Figure 3
        self.batch_norms1 = []
        self.batch_norms2 = []
        self.is_resid = []
        channels = img_shape[-1]
        img_size = img_shape[0]
        for l in conv_layers:  # (out_channels, kernel_size, stride) is each l
            if l[2] == 2: # stride
                pad_type = "valid"
                img_size = (img_size-l[1])//l[2] + 1 # https://arxiv.org/pdf/1603.07285.pdf
            else:
                pad_type = "same"
            if isinstance(l[0], float):
                l[0] = int(l[0])
                l[0] -= l[0] % groups # ensure divisble by groups
            self.is_resid.append(l[2] == 1 and channels == l[0])
            self.conv_layers1.append(nn.Conv2d(channels, l[0], l[1], stride=l[2], padding=pad_type, groups=groups))
            channels = l[0]
            self.final_num_logits = channels * img_size * img_size 
            self.batch_norms1.append(nn.BatchNorm2d(channels))
            self.batch_norms2.append(nn.BatchNorm2d(channels))
            self.conv_layers2.append(nn.Conv2d(channels, channels, l[1], padding=pad_type, groups=groups))
        self.conv_layers1 = nn.ModuleList(self.conv_layers1)
        self.conv_layers2 = nn.ModuleList(self.conv_layers1)
        self.batch_norms1 = nn.ModuleList(self.batch_norms1)
        self.batch_norms2 = nn.ModuleList(self.batch_norms2)

        self.fully_connected1 = nn.Linear(self.final_num_logits, 1000)
        self.fully_connected2 = nn.Linear(1000, num_classes)

    def forward(self, x):
        network_iter = zip(self.conv_layers1, self.conv_layers2, self.batch_norms1, self.batch_norms2, self.is_resid)
        for conv1, conv2, batch_norm1, batch_norm2, is_resid in network_iter:
            x_conv1 = F.relu(batch_norm1(conv1(x)))
            x_conv2 = F.relu(batch_norm2(conv2(x)))
            if is_resid:
                x = x + x_conv2  # residual block
            else:
                x = x_conv2  # dimension increasing block
        x = torch.flatten(x, 1)
        x = F.relu(self.fully_connected1(x))
        x = self.fully_connected2(x)
        return x    

    def num_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
    
    def save_model_state_dict(self, path):
        torch.save(self.state_dict(), path)
    
    def load_model_state_dict(self, path):
        self.load_state_dict(torch.load(path))

In [None]:
def correct_num(pred_logits, labels):
    pred_probabilities = F.softmax(pred_logits)
    classifications = torch.argmax(pred_probabilities, 1)
    correct = (labels == classifications).sum()
    return correct

def train(net, optimizer, loss, epochs):
    va_losses = []
    tr_losses = []
    va_accuracies = []
    for epoch in range(epochs):
        epoch_tr_loss = 0.0
        for i, (imgs, labels) in tqdm(enumerate(trainloader)):
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = net(imgs)
            batch_loss = loss(outputs, labels)
            epoch_tr_loss += batch_loss.item()
            batch_loss.backward()
            optimizer.step()
        epoch_va_loss = 0.0
        epoch_va_correct = 0
        for i, (imgs, labels) in enumerate(validloader):
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = net(imgs)
            epoch_va_loss += loss(outputs, labels).item()
            epoch_va_correct += correct_num(outputs, labels).item()
        epoch_va_accuracy = epoch_va_correct/num_valid
        print(f'Epoch {epoch + 1}: va_loss: {epoch_va_loss}, va_accuracy: {epoch_va_accuracy}, tr_loss: {epoch_tr_loss}')
        va_losses.append(epoch_va_loss)
        tr_losses.append(epoch_tr_loss)
        va_accuracies.append(epoch_va_accuracy)
    return va_losses, va_accuracies, tr_losses

In [None]:
basic_res_net = BasicResNet([[64, 7, 1],  # num_channels (input and output), kernel_size, stride
                             [64, 3, 1],
                             [128, 3, 1],
                             [128, 3, 1],
                             [128, 3, 2],
                             [128, 3, 2],
                             [256, 3, 1],
                             [256, 3, 1],
                             [256, 3, 1],
                             [256, 3, 1],
                             [512, 3, 1],
                             [512, 3, 1],
                             [512, 3, 2]], 10, [32, 32, 3])
loss_func = nn.CrossEntropyLoss()
optim = torch.optim.Adam(basic_res_net.parameters())
print(basic_res_net.num_params())

In [None]:
results = train(basic_res_net, optim, loss_func, 200)

In [None]:
w_scale = np.sqrt(3)
color_cnn = BasicResNet([[64*w_scale, 7, 1],  
                         [64*w_scale, 3, 1], # num_channels (input and output), kernel_size, stride
                         [128*w_scale, 3, 1],
                         [128*w_scale, 3, 1],
                         [128*w_scale, 3, 2],
                         [128*w_scale, 3, 2],
                         [256*w_scale, 3, 1],
                         [256*w_scale, 3, 1],
                         [256*w_scale, 3, 1],
                         [256*w_scale, 3, 1],
                         [512*w_scale, 3, 1],
                         [512*w_scale, 3, 1],
                         [512*w_scale, 3, 2]], 10, [32, 32, 3], groups=3)
loss_func = nn.CrossEntropyLoss()
optim = torch.optim.Adam(color_cnn.parameters())
print(color_cnn.num_params())

In [None]:
results = train(color_cnn, optim, loss_func, 200)

In [None]:
color_cnn.save_model_state_dict("cifar10_colorcnn.dict")

In [None]:
explain_img = valid.dataset.data[55]
print(valid.dataset.targets[55])

In [None]:
@torch.no_grad()
def finite_differences(model, target_class, stacked_img, locations, color):
    num_iters = 40 # sample 8 values randomly
    baseline_activations = model(torch.tensor(stacked_img))[:, target_class]
    largest_slope = np.zeros(stacked_img.shape[0])  # directional finite difference?
    slices = np.index_exp[np.arange(64), color, locations[:, 0], locations[:, 1]]
    for i in range(num_iters):
        diff = np.random.randint(-20, 20, (stacked_img.shape[0]))*2 + 1
        img = stacked_img.copy()
        img[slices] = np.clip(img[slices] + diff, 0, 255)
        actual_diffs = img[slices] - stacked_img[slices]  # due to clipping, need to recalculate
        img_norm = ((img - img.mean()) / img.std() * 0.5) + 0.5  # normalize
        activations = model(torch.tensor(img_norm))[:, target_class]
        activation_diff = (activations - baseline_activations).cpu().numpy()
        finite_difference = np.clip(activation_diff/actual_diffs, -30, 30) # take absolute slope
        largest_slope = np.where(abs(finite_difference) > abs(largest_slope), finite_difference, largest_slope)
    return largest_slope
        

def finite_differences_map(model, target_class, img):
    # generate a saliency map using finite differences method (iterate over colors)
    model.eval()
    batch_size = 64  # check 64 pixel positions in parallel
    im_size = img.shape[0]
    img = img.astype(np.float32)/255. # pixels in range [0, 1]
    values_x = np.repeat(np.arange(im_size), im_size)
    values_y = np.tile(np.arange(im_size), im_size)
    indices = np.stack((values_x, values_y), axis=1)
    stacked_img = np.repeat(np.expand_dims(img, 0), batch_size, axis=0)
    stacked_img = np.transpose(stacked_img, (0, 3, 1, 2)) # NCHW format
    img_heat_map = np.zeros((im_size, im_size, 3)) 
    for color in range(3):
        for k in tqdm(range(0, im_size*im_size, batch_size)):
            actual_batch_size = min(batch_size, im_size*im_size-k+batch_size)
            locations = indices[k:k+batch_size]
            largest_slopes = finite_differences(model, target_class, stacked_img, locations, color)
            img_heat_map[locations[:,0], locations[:,1], color] = largest_slopes
    return img_heat_map.sum(axis=2)  # linear approximation aggregation?
heat_map = finite_differences_map(color_cnn, 8, explain_img)

In [None]:
print(heat_map.max(), heat_map.min(), heat_map.mean(), heat_map.std())

In [None]:
from matplotlib import pyplot as plt
plt.imshow(heat_map, cmap="bwr", interpolation="bilinear")
plt.colorbar()
plt.show()

In [None]:
plt.imshow(explain_img)

In [None]:
heat_map_basic = finite_differences_map(basic_res_net, 8, explain_img)

In [None]:
plt.imshow(heat_map_basic, cmap="bwr", interpolation="bilinear")
plt.colorbar()
plt.show()

In [None]:
num_images = 10
random_images = np.random.choice(5000, (num_images), replace=False)
plt.figure(figsize=(8, 5*num_images))
for i, r in enumerate(random_images):
    target_class = valid.dataset.targets[r]
    explain_random_img = valid.dataset.data[r]
    random_heat_map = finite_differences_map(color_cnn, target_class, explain_random_img)
    plt.subplot(num_images, 2, 2*i+1)
    plt.imshow(explain_random_img)
    plt.subplot(num_images, 2, 2*i+2)
    plt.imshow(random_heat_map, cmap="bwr", interpolation="bilinear")
    plt.colorbar()
plt.show()
