In [1]:
import sys
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'



In [2]:
import torch
import torch.nn as nn
import torchvision

In [3]:
vggnet = torchvision.models.vgg19(pretrained=False).eval()
vggnet.eval()

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padd

In [4]:
def replace_pool(m, name):
    for attr_str in dir(m):
        target_attr = getattr(m, attr_str)
        if type(target_attr) == torch.nn.MaxPool2d:
            print('replaced: ', name, attr_str)
            setattr(m, attr_str, torch.nn.AvgPool2d(target_attr.kernel_size, target_attr.stride, target_attr.padding, target_attr.ceil_mode))

    if isinstance(m, nn.Sequential):
      for i in range(len(m)):
        if type(m[i]) == torch.nn.MaxPool2d:
            print('replaced: ', m[i])
            m[i] = torch.nn.AvgPool2d(m[i].kernel_size, m[i].stride, m[i].padding, m[i].ceil_mode)

    for n, ch in m.named_children():
        replace_pool(ch, n)
        
replace_pool(vggnet, "model")

replaced:  MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
replaced:  MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
replaced:  MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
replaced:  MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
replaced:  MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)


In [5]:
hooks = []
def apply_back_hooks(mod):
  # Make ReLUs not inplace, since we may lose output gradient
  if isinstance(mod, nn.ReLU):
    mod.inplace = False

  # Set model weights to constant for appropriate backprop
  if isinstance(mod, nn.Conv2d):
    mod.weight.data = (1 / mod.weight.data.shape[0])*torch.ones_like(mod.weight.data) # torch.abs(mod.weight.data)
  
  global hooks
  # Override gradients for most modules with output gradients, 
  # directly propagating gradients. Some modules have additional
  # gradient (batch-norm, etc.), so only replacing first
  def back_hook(module, grad_input, grad_output):
    if isinstance(grad_input, tuple):
      return (grad_output[0],) + grad_input[1:]
    return grad_output

  # Override all non-conv or avg pooling modules
  if not isinstance(mod, (nn.Conv2d, nn.AvgPool2d)):
    hooks.append(mod.register_backward_hook(back_hook))
  
vggnet.apply(apply_back_hooks)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU()
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU()
    (9): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU()
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU()
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU()
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (17): ReLU()
    (18): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1,

In [6]:
def modify_conv_module(mod):
  if isinstance(mod, torch.nn.Conv2d):
    mod.padding = (0, 0)
vggnet.apply(modify_conv_module)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (3): ReLU()
    (4): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
    (6): ReLU()
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
    (8): ReLU()
    (9): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
    (11): ReLU()
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
    (13): ReLU()
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
    (15): ReLU()
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
    (17): ReLU()
    (18): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1))
    (20): ReLU()
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1))
    (22): ReLU()
    (23): Conv2d(512, 512, 

In [7]:
saved_layer = None
def forward_hook(mod, inp, out):
  global saved_layer
  saved_layer = out
forward_hook = vggnet.features[34].register_forward_hook(forward_hook)

In [8]:
# # Valid Padding
# inp = torch.zeros(1,3,512,512)
# inp.requires_grad = True
# out = vggnet(inp)
# grad_inp_no_padding = torch.autograd.grad(torch.sum(saved_layer), inp)

In [9]:
# # Half Padding
# def modify_conv_module(mod):
#   if isinstance(mod, torch.nn.Conv2d):
#     mod.padding = (1, 1)
# vggnet.apply(modify_conv_module)

# inp = torch.zeros(1,3,512,512)
# inp.requires_grad = True
# out = vggnet(inp)
# grad_inp_same_padding = torch.autograd.grad(torch.sum(saved_layer), inp)

In [10]:
# # Full padding   
# def modify_conv_module(mod):
#   if isinstance(mod, torch.nn.Conv2d):
#     mod.padding = (2,2)
# vggnet.apply(modify_conv_module)

# inp = torch.zeros(1,3,512,512)
# inp.requires_grad = True
# out = vggnet(inp)
# grad_inp_full_padding = torch.autograd.grad(torch.sum(saved_layer), inp)


In [11]:
# # Reflect Padding
# def modify_conv_module(mod):
#   if isinstance(mod, torch.nn.Conv2d):
#     mod.padding_mode = 'reflect'
#     mod.padding = (1, 1)
    
# vggnet.apply(modify_conv_module)

# inp = torch.zeros(1,3,512,512)
# inp.requires_grad = True
# out = vggnet(inp)
# grad_inp_reflect_padding = torch.autograd.grad(torch.sum(saved_layer), inp)


In [12]:
# Replicate Padding
def modify_conv_module(mod):
  if isinstance(mod, torch.nn.Conv2d):
    mod.padding = (2,2)
    mod.padding_mode = 'replicate'
    mod.kernel_size = (5,5)
    
vggnet.apply(modify_conv_module)

inp = torch.zeros(1,3,512,512)
inp.requires_grad = True
out = vggnet(inp)
grad_inp_replicate_padding = torch.autograd.grad(torch.sum(saved_layer), inp)




In [13]:
# Dilated Padding
def modify_conv_module(mod):
  if isinstance(mod, torch.nn.Conv2d):
    mod.padding = (1,1)
    mod.kernel_size = (3,3)
    mod.dilation=2
vggnet.apply(modify_conv_module)

inp = torch.zeros(1,3,512,512)
inp.requires_grad = True
out = vggnet(inp)
grad_inp_dilated_padding = torch.autograd.grad(torch.sum(saved_layer), inp)


In [14]:
# Circular Padding
def modify_conv_module(mod):
  if isinstance(mod, torch.nn.Conv2d):
    mod.padding = (1, 1)
    mod.padding_mode = 'circular'
    mod.kernel_size = (3,3)
    mod.dilation=1
vggnet.apply(modify_conv_module)

inp = torch.zeros(1,3,512,512)
inp.requires_grad = True
out = vggnet(inp)
grad_inp_circular_padding = torch.autograd.grad(torch.sum(saved_layer), inp)


In [15]:
import numpy as np

In [16]:
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

fig = plt.figure(figsize = (31,10), dpi = 63)
total=3
# ax = fig.add_subplot(1, total, 1)
# plt.imshow((grad_inp_no_padding[0][0].sum(dim=0)), cmap=plt.get_cmap('plasma')) # / torch.min(grad_inp[0]))[0][0]
# ax.set_title("Valid Padding (0, 0)",  size=21,  pad = 15)
# divider = make_axes_locatable(ax)
# cax = divider.append_axes("right", size="5%", pad=0.25)
# plt.colorbar(cax=cax)

# ax = fig.add_subplot(1, total, 2)
# plt.imshow((grad_inp_same_padding[0][0].sum(dim=0)), cmap=plt.get_cmap('plasma')) # / torch.min(grad_inp[0]))[0][0]
# plt.title("Same Padding (1, 1)", {'fontsize':21}, pad = 15)
# divider = make_axes_locatable(ax)
# cax = divider.append_axes("right", size="5%", pad=0.25)
# plt.colorbar(cax=cax)

# ax = fig.add_subplot(1, total, 3)
# plt.imshow(grad_inp_full_padding[0][0].sum(dim=0), cmap=plt.get_cmap('plasma')) # / torch.min(grad_inp[0]))[0][0]
# ax.set_title("Full Padding (2, 2)",  size=21,  pad = 15)
# divider = make_axes_locatable(ax)
# cax = divider.append_axes("right", size="5%", pad=0.25)
# cbar=plt.colorbar(cax=cax)
# cbar.ax.tick_params(labelsize=10) 

# ax = fig.add_subplot(1, total, 4)
# plt.imshow(grad_inp_reflect_padding[0][0].sum(dim=0), cmap=plt.get_cmap('plasma')) # / torch.min(grad_inp[0]))[0][0]
# ax.set_title("Reflect Padding (2, 2)",  size=21,  pad = 15)
# divider = make_axes_locatable(ax)
# cax = divider.append_axes("right", size="5%", pad=0.25)
# cbar=plt.colorbar(cax=cax)
# cbar.ax.tick_params(labelsize=10) 

ax = fig.add_subplot(1, total, 1)
plt.imshow(grad_inp_replicate_padding[0][0].sum(dim=0), cmap=plt.get_cmap('plasma')) # / torch.min(grad_inp[0]))[0][0]
ax.set_title("Replicate Padding (2, 2)",  size=21,  pad = 15)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.25)
cbar=plt.colorbar(cax=cax)
cbar.ax.tick_params(labelsize=10) 

ax = fig.add_subplot(1, total, 2)
plt.imshow(grad_inp_circular_padding[0][0].sum(dim=0), cmap=plt.get_cmap('plasma')) # / torch.min(grad_inp[0]))[0][0]
ax.set_title("Circular Padding (2, 2)",  size=21,  pad = 15)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.25)
cbar=plt.colorbar(cax=cax)
cbar.ax.tick_params(labelsize=10) 

ax = fig.add_subplot(1, total, 3)
plt.imshow(grad_inp_dilated_padding[0][0].sum(dim=0), cmap=plt.get_cmap('plasma')) # / torch.min(grad_inp[0]))[0][0]
ax.set_title("Dilated Padding (2, 2)",  size=21,  pad = 15)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.25)
cbar=plt.colorbar(cax=cax)
cbar.ax.tick_params(labelsize=10) 

