In [1]:
from typing import (
    Optional,
    Tuple,
    Union,
)

import ipytest
import numpy as np
import pytest
import torch
from torch import nn
from torch.nn.common_types import Tensor

In [7]:
ipytest.autoconfig()

In [356]:
def conv_transpose2d(
    input: Tensor,
    weights: Tensor,
    bias: Optional[Tensor] = None,
    stride: Optional[Union[int, Tuple]] = 1,
    padding: Optional[Union[int, Union[Tuple, str]]] = 0,
    output_padding: Optional[Union[int, Union[Tuple, str]]] = 0,
    dilation: Optional[Union[int, Tuple]] = 1,
    groups: Optional[int] = 1,
):

    batch_size, in_channels, input_height, input_width = input.shape
    out_channels, in_channels, kernel_height, kernel_width = weights.shape

    out_height = (input_height - 1) * stride - 2 * padding + kernel_height + output_padding
    out_width = (input_width - 1) * stride - 2 * padding + kernel_width + output_padding
    
    groupped_channels = out_channels // groups
    output = np.zeros((batch_size, groupped_channels, out_height, out_width))

    for batch in range(batch_size):
        for channel in range(out_channels):
            for i in range(out_height):
                for j in range(out_width):
                    for k in range(in_channels):
                        for s in range(kernel_height):
                            for t in range(kernel_width):
                                for single_channel in range(groupped_channels):
                                    ii = i + padding - s * dilation
                                    jj = j + padding - t * dilation
                                    if ( \
                                        ii >= 0 \
                                        and jj >= 0 \
                                        and ii < input_height * stride \
                                        and jj < input_width * stride \
                                        and ii % stride == 0 \
                                        and jj % stride == 0 \
                                    ):
                                        ii //= stride
                                        jj //= stride
                                        output[batch, single_channel, i, j] += input[batch, k, ii, jj] * weights[single_channel, k, s, t]
                                    
            output[batch] += bias[channel] if bias else 0
            
    return output

In [357]:
@pytest.fixture(scope='class')
def inputs():
    return torch.randn(1, 1, 10, 10)

@pytest.fixture(scope='class')
def weights():
    return torch.randn(1, 1, 5, 5)

@pytest.fixture(scope='class')
def bias():
    return torch.randn(1)

In [358]:
%%ipytest -s

@pytest.mark.usefixtures('inputs')
@pytest.mark.usefixtures('weights')
@pytest.mark.usefixtures('bias')
class TestConv2D:
    def test_conv2d_success(self, inputs, weights):
        result = conv_transpose2d(inputs, weights)
        expected_result = nn.functional.conv_transpose2d(inputs, weights)   
        assert np.allclose(expected_result, result, atol=1e-4, rtol=1e-3)
        
    def test_conv2d_bias_success(self, inputs, weights, bias):
        result = conv_transpose2d(inputs.numpy(), weights.numpy(), bias.numpy())
        expected_result = nn.functional.conv_transpose2d(inputs, weights, bias=bias).numpy()   
        assert np.allclose(expected_result, result, atol=1e-4, rtol=1e-3)
        
    def test_conv2d_bias_padding_success(self, inputs, weights, bias):
        result = conv_transpose2d(inputs.numpy(), weights.numpy(), bias.numpy(), padding=5)
        expected_result = nn.functional.conv_transpose2d(inputs, weights, bias=bias, padding=5).numpy()   
        assert np.allclose(expected_result, result, atol=1e-4, rtol=1e-3)
        
    def test_conv2d_bias_padding_stride_success(self, inputs, weights, bias):
        result = conv_transpose2d(inputs.numpy(), weights.numpy(), bias.numpy(), padding=5, stride=2)
        expected_result = nn.functional.conv_transpose2d(inputs, weights, bias=bias, padding=5, stride=2).numpy()   
        assert np.allclose(expected_result, result, atol=1e-4, rtol=1e-3)


[32m.[0m[32m.[0m[32m.[0m[32m.[0m
[32m[32m[1m4 passed[0m[32m in 0.21s[0m[0m
