In [35]:
"""
The function `torch_conv_layer_to_affine` takes a `torch.nn.Conv2d` layer `conv`
and produces an equivalent `torch.nn.Linear` layer `fc`.
Specifically, this means that the following holds for `x` of a valid shape:
    torch.flatten(conv(x)) == fc(torch.flatten(x))
Or equivalently:
    conv(x) == fc(torch.flatten(x)).reshape(conv(x).shape)
allowing of course for some floating-point error.
"""
from typing import Tuple
import torch
import torch.nn as nn
import numpy as np
def torch_conv_layer_to_affine(
    conv: torch.nn.Conv2d, input_size: Tuple[int, int]
) -> torch.nn.Linear:
    w, h = input_size
    # Formula from the Torch docs:
    # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
    output_size = [
            (input_size[i] + 2 * conv.padding[i] - conv.kernel_size[i]) // conv.stride[i]
    + 1
    for i in [0, 1]
        ]
    
    in_shape = (conv.in_channels, w, h)
    out_shape = (conv.out_channels, output_size[0], output_size[1])
    fc = nn.Linear(in_features=np.product(in_shape), out_features=np.product(out_shape))
    fc.weight.data.fill_(0.0)

    # Output coordinates
    for xo, yo in range2d(output_size[0], output_size[1]):
        # The upper-left corner of the filter in the input tensor
        xi0 = -conv.padding[0] + conv.stride[0] * xo
        yi0 = -conv.padding[1] + conv.stride[1] * yo
        # Position within the filter
        for xd, yd in range2d(conv.kernel_size[0], conv.kernel_size[1]):
            # Output channel
            for co in range(conv.out_channels):
                fc.bias[enc_tuple((co, xo, yo), out_shape)] = conv.bias[co]
                for ci in range(conv.in_channels):
                    # Make sure we are within the input image (and not in the padding)
                    if 0 <= xi0 + xd < w and 0 <= yi0 + yd < h:
                        cw = conv.weight[co, ci, xd, yd]
                        # Flatten the weight position to 1d in "canonical ordering",
                        # i.e. guaranteeing that:
                        # FC(img.reshape(-1)) == Conv(img).reshape(-1)
                        fc.weight[
                        enc_tuple((co, xo, yo), out_shape),
                        enc_tuple((ci, xi0 + xd, yi0 + yd), in_shape),
                                                ] = cw
    return fc

def range2d(to_a, to_b):
    for a in range(to_a):
        for b in range(to_b):
            yield a, b

def enc_tuple(tup: Tuple, shape: Tuple) -> int:
    res = 0
    coef = 1
    for i in reversed(range(len(shape))):
        assert tup[i] < shape[i]
        res += coef * tup[i]
        coef *= shape[i]
    return res

def dec_tuple(x: int, shape: Tuple) -> Tuple:
    res = []
    for i in reversed(range(len(shape))):
        res.append(x % shape[i])
        x //= shape[i]
        return tuple(reversed(res))

def test_tuple_encoding():
    x = enc_tuple((3, 2, 1), (5, 6, 7))
    assert dec_tuple(x, (5, 6, 7)) == (3, 2, 1)
    print("Tuple encoding ok")

def test_layer_conversion():
    for stride in [1, 2]:
        for padding in [0, 1, 2]:
            for filter_size in [3, 4]:
                img = torch.rand((1, 2, 6, 7))
                conv = nn.Conv2d(2, 5, filter_size, stride=stride, padding=padding)
                fc = torch_conv_layer_to_affine(conv, img.shape[2:])
                # Also checks that our encoding flattens the inputs/outputs such that
                # FC(flatten(img)) == flatten(Conv(img))
                res1 = fc(img.reshape((-1))).reshape(conv(img).shape)
                res2 = conv(img)
                worst_error = (res1 - res2).max()
                print("Output shape", res2.shape, "Worst error: ", float(worst_error))
                assert worst_error <= 1.0e-6
                print("Layer conversion ok")

In [36]:
import cv2
import torch
import torchvision
import numpy as np
import utils

import torchvision.transforms.functional

In [37]:
img = np.array(cv2.imread('images/castle.jpg'))[...,::-1]/255.0

# values from https://pytorch.org/vision/main/models/generated/torchvision.models.vgg16.html
mean = torch.Tensor([0.485, 0.456, 0.406]).reshape(1,-1,1,1)
std  = torch.Tensor([0.229, 0.224, 0.225]).reshape(1,-1,1,1)

X = (torch.FloatTensor(img[np.newaxis].transpose([0,3,1,2])*1) - mean) / std

In [38]:
model = torchvision.models.vgg19(pretrained=True)
model.eval()



VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padd

In [39]:
torch.argmax(model(X))

tensor(483)

In [40]:
layer = None
for *parent, k in [k.split('.') for k, m in model.named_modules() if type(m).__name__ == 'Conv2d']:
    layer = model.get_submodule('.'.join(parent))[int(k)]
    break

In [41]:
layer

Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

In [42]:
X.shape, layer.forward(X).shape, layer.weight.shape, layer.bias.shape

(torch.Size([1, 3, 224, 224]),
 torch.Size([1, 64, 224, 224]),
 torch.Size([64, 3, 3, 3]),
 torch.Size([64]))

In [43]:
Xnp = X.detach().cpu().numpy()

In [44]:
weights = layer.weight.data.cpu().numpy()

In [45]:
biases = layer.bias.data.cpu().numpy()

In [46]:
kernel_size = (3, 4)
H = 224
W = 224
Cout = 64
padding = (1, 1)

X_pad = np.array([[
    np.pad(x, padding, 'constant', constant_values=(0, 0)) for x in Xnp[0]]])

output = np.zeros((1, Cout, H, W))
for k in range(Cout):
    kernel = weights[k]

    for i in range(H):
        for j in range(W):
            zone = X_pad[0,:,i:i+kernel_size[0],j:j+kernel_size[1]]
            product = np.tensordot(zone, kernel, axes=3)
            result = product + biases[k]
            output[0,k,i,j] = result

In [49]:
output

array([[[[-4.35848653e-01, -2.76941657e-01, -2.86430359e-01, ...,
           7.58498907e-01,  7.56732941e-01,  4.38061357e-01],
         [-3.96351993e-01, -2.71862268e-01, -2.84331560e-01, ...,
           7.47735977e-01,  7.46625900e-01,  4.61679697e-01],
         [-3.92981827e-01, -2.72107422e-01, -2.70687878e-01, ...,
           7.38332510e-01,  7.37986922e-01,  4.65240717e-01],
         ...,
         [-1.86877537e+00, -2.27427530e+00, -2.35830474e+00, ...,
          -1.39870453e+00, -1.32285511e+00, -1.21387553e+00],
         [-1.74284554e+00, -2.03332281e+00, -2.27975750e+00, ...,
          -1.45715129e+00, -1.37679887e+00, -1.28719878e+00],
         [-1.40771985e+00, -1.46523333e+00, -1.65783656e+00, ...,
          -9.74298537e-01, -1.01699233e+00, -1.07208538e+00]],

        [[ 1.62007064e-01,  4.67999279e-01,  4.79892820e-01, ...,
           8.34800184e-01,  8.69134188e-01,  1.78299916e+00],
         [-1.59164339e-01,  2.77136236e-01,  3.30293924e-01, ...,
           1.96036294e

In [50]:
layer.forward(X).detach().cpu().numpy()

array([[[[-4.35848564e-01, -2.76941627e-01, -2.86430359e-01, ...,
           7.58498788e-01,  7.56733000e-01,  4.38061357e-01],
         [-3.96352023e-01, -2.71862209e-01, -2.84331471e-01, ...,
           7.47735739e-01,  7.46625960e-01,  4.61679697e-01],
         [-3.92981797e-01, -2.72107184e-01, -2.70687670e-01, ...,
           7.38332450e-01,  7.37986863e-01,  4.65240628e-01],
         ...,
         [-1.86877549e+00, -2.27427554e+00, -2.35830498e+00, ...,
          -1.39870441e+00, -1.32285523e+00, -1.21387565e+00],
         [-1.74284554e+00, -2.03332281e+00, -2.27975726e+00, ...,
          -1.45715141e+00, -1.37679887e+00, -1.28719878e+00],
         [-1.40771985e+00, -1.46523321e+00, -1.65783668e+00, ...,
          -9.74298537e-01, -1.01699233e+00, -1.07208538e+00]],

        [[ 1.62007079e-01,  4.67999279e-01,  4.79892820e-01, ...,
           8.34800303e-01,  8.69134247e-01,  1.78299928e+00],
         [-1.59164310e-01,  2.77136207e-01,  3.30293953e-01, ...,
           1.96036324e