In [1]:
import torch

import numpy as np

In [7]:
def get_conv2d_output_size(input_height:int, input_width:int, 
                           kernel_size:int, stride:int, 
                           padding:int) -> tuple:
    """Computs the output size of a 2d convolution.

    Args:
        input_height (int): Height of the input tensor
        input_width (int): Width of the input tensor
        kernel_size (int): Square size of the kernel
        stride (int): Stride of te convolution
        padding (int): zero-padding added to both sides of the input

    Returns:
        tuple: output height and output width
    """
    input_height += 2 * padding
    input_width += 2 * padding

    output_height = (input_height - kernel_size[0]) // stride + 1
    output_width = (input_width - kernel_size[1]) // stride + 1
    return int(output_height), int(output_width)

def conv2d_forward(a, weight, stride, tup1, tup2):
    batch_size, in_channel, im_height, im_width = a.shape
    num_filters, _, kernel_height, kernel_width = weight.shape
    output_height, output_width = get_conv2d_output_size(im_height, im_width, (kernel_height, kernel_width), stride, 0)

    a_stride = np.lib.stride_tricks.as_strided(
        x=a, 
        shape=(batch_size, in_channel, output_height, output_width, kernel_height, kernel_width),
        strides=(*a.strides[0:2], a.strides[2]*stride, a.strides[3]*stride, *a.strides[2:4]),
        writeable=False
    )
    
    try:
        ret = np.tensordot(a_stride, weight, axes=[tup1, tup2])
        print(tup1, tup2)
    except:
        pass

In [8]:
x = np.random.normal(0, 1, (32, 3, 60, 60))
w = np.random.normal(0, 1, (3, 10, 3, 3))

In [9]:
# 2-sized tuple

for i in range(6):
    for j in range(6):
        for k in range(4):
            for l in range(4):
                if i < j and k < l:
                    tup1 = (i, j)
                    tup2 = (k, l)
                    conv2d_forward(x, w, 3, tup1, tup2)

(1, 4) (0, 2)
(1, 4) (0, 3)
(1, 4) (2, 3)
(1, 5) (0, 2)
(1, 5) (0, 3)
(1, 5) (2, 3)
(4, 5) (0, 2)
(4, 5) (0, 3)
(4, 5) (2, 3)


In [10]:
# 3-sized tuple

for i in range(6):
    for j in range(6):
        for k in range(6):
            for l in range(4):
                for m in range(4):
                    for n in range(4):
                        if i < j < k and l < m < n:
                            tup1 = (i,j,k)
                            tup2 = (l,m,n)
                            conv2d_forward(x, w, 3, tup1, tup2)

(1, 4, 5) (0, 2, 3)
