In [1]:
import torch
import torch.nn.functional as F
import numpy as np
import pickle as pkl
from itertools import product

In [5]:
def conv(x_shape, k_shape, bias=False, stride=1, padding=0, dilation=1):
    x = torch.randn(x_shape, requires_grad=True)
    k = torch.randn(k_shape, requires_grad=True)
    grad_list = []
    def get_grad(grad):
        grad_list.append(grad)
    x.register_hook(get_grad)
    k.register_hook(get_grad)
    
    if bias and bias is not None:
        b = torch.randn((k_shape[0],), requires_grad=True)
        b.register_hook(get_grad)
    else:
        b = None
    
    y = F.conv2d(x, k, bias=b, stride=stride, padding=padding, dilation=dilation)
    ysum = y.sum()
    ysum.backward()
    
    bnp = b.detach().numpy() if bias else None
    return x.detach().numpy(), k.detach().numpy(), bnp, y.detach().numpy(), [g.detach().numpy() for g in grad_list]

In [3]:
x_shape = ((2,3,15,15), (2,3,10,10))
k_shape = ((6,3,3,3), (4,3,5,5), (5,3,2,4))
bias = (True, False)
stride = (1, 2, (1,2), (2,2))
padding = (0, 1, (1,1), (2,1))
dilation = (1, 2, (1,2))

In [6]:
def form_dict(setting, x, k, bias, y, grads):
    xs, ks, b, s, p, d = settings
    return dict(
        x_shape=xs, 
        k_shape=ks,
        bias=b,
        stride=s,
        padding=p,
        dilation=d,
        in_x = x,
        in_k = k,
        in_b = bias,
        out=y,
        x_grad=grads[0],
        k_grad=grads[1],
        b_grad=grads[2] if b else None
    )

runs = []
for settings in product(x_shape, k_shape, bias, stride, padding, dilation):
    xs, ks, b, s, p, d = settings
    x, k, b, y, grads = conv(xs, ks, bias=b, stride=s, padding=p, dilation=d)
    cfg = form_dict(settings, x, k, b, y, grads)
    runs.append(cfg)

In [43]:
y, grads = conv((2,3,5,5), (6,3,3,3), bias=True, stride=1, padding=1, dilation=1)

3


In [8]:
with open('test_conv.pkl', 'wb') as f:
    pkl.dump(runs, f)

In [7]:
runs[0]

{'x_shape': (2, 3, 15, 15),
 'k_shape': (6, 3, 3, 3),
 'bias': True,
 'stride': 1,
 'padding': 0,
 'dilation': 1,
 'in_x': array([[[[-0.93367916, -1.8539337 ,  1.0122513 , ..., -1.0466522 ,
            2.605991  ,  0.7909215 ],
          [-0.3860215 , -0.33534542, -0.43465018, ..., -0.02299583,
            0.25409183,  0.36206657],
          [ 0.76595634, -0.4854421 , -0.9836886 , ..., -1.3520507 ,
            1.5341712 , -0.38855985],
          ...,
          [ 0.6420412 ,  1.8220707 ,  0.37680963, ...,  0.17724241,
            2.4159837 , -0.15687484],
          [-0.9942302 ,  1.3658289 ,  0.95067024, ...,  0.34386325,
           -1.6399461 ,  0.47545066],
          [-0.18554337, -1.0168406 , -0.21929225, ..., -1.3134903 ,
           -0.81779337, -0.36577868]],
 
         [[-1.9375788 , -0.7595097 ,  1.4188529 , ..., -1.2485442 ,
            1.3041812 ,  0.6565139 ],
          [ 1.0405416 , -0.36724988, -0.25460222, ..., -0.21712847,
            0.01979127, -0.98112464],
          [ 