In [1]:
from typing import (
    Optional,
    Tuple,
    Union,
)

import ipytest
import numpy as np
import pytest
import torch
from torch import nn
from torch.nn.common_types import Tensor

In [2]:
ipytest.autoconfig()

In [92]:
def convolution3D(
    input: Tensor,
    weights: Tensor,
    bias: Optional[Tensor] = None,
    stride: Optional[Union[int, Tuple]] = 1,
    padding: Optional[Union[int, Union[Tuple, str]]] = 0,
    dilation: Optional[Union[int, Tuple]] = 1,
    groups: Optional[int] = 1,
) -> Tensor:
    batch_size, in_channels, input_height, input_width, input_depth = input.shape
    (
        out_channels,
        in_channels_groups,
        weights_height,
        weights_width,
        weights_depth,
    ) = weights.shape

    input = nn.functional.pad(
        input, (padding, padding, padding, padding, padding, padding), mode="constant"
    )

    result_height = (
        input_height + 2 * padding - dilation * (weights_height - 1) - 1
    ) // stride + 1
    result_width = (
        input_width + 2 * padding - dilation * (weights_width - 1) - 1
    ) // stride + 1
    result_depth = (
        input_depth + 2 * padding - dilation * (weights_depth - 1) - 1
    ) // stride + 1

    grouped_channels = out_channels // groups if groups else out_channels

    result = torch.zeros(
        (batch_size, grouped_channels, result_height, result_width, result_depth)
    )

    for batch in range(0, batch_size):
        for channel in range(out_channels):
            for i in range(0, input.shape[2] - (weights_height - 1), stride):
                for j in range(0, input.shape[3] - (weights_width - 1), stride):
                    for k in range(0, input.shape[4] - (weights_depth - 1), stride):
                        for group in range(grouped_channels):
                            result[:, group, i // stride, j // stride, k // stride] = (
                                weights[batch]
                                * input[
                                    batch,
                                    :,
                                    i : i + weights_height,
                                    j : j + weights_width,
                                    k : k + weights_depth,
                                ]
                            ).sum()
            result[batch] += bias[channel] if bias else 0

    return result

In [115]:
@pytest.fixture(scope='class')
def inputs():
    return torch.randn(1, 2, 4, 4, 4)

@pytest.fixture(scope='class')
def weights():
    return torch.randn(1, 2, 3, 3, 3)

@pytest.fixture(scope='class')
def bias():
    return torch.randn(1)

In [119]:
%%ipytest -s

@pytest.mark.usefixtures('inputs')
@pytest.mark.usefixtures('weights')
@pytest.mark.usefixtures('bias')
class TestConv2D:
    def test_conv2d_success(self, inputs, weights):
        result = convolution3D(inputs, weights)
        expected_result = nn.functional.conv3d(inputs, weights)
        assert torch.allclose(expected_result, result)
        
    def test_conv2d_bias_success(self, inputs, weights, bias):
        result = convolution3D(inputs, weights, bias)
        expected_result = nn.functional.conv3d(inputs, weights, bias)
        assert torch.allclose(expected_result, result)
        
    def test_conv2d_bias_padding_success(self, inputs, weights, bias):
        result = convolution3D(inputs, weights, bias, padding=2)
        expected_result = nn.functional.conv3d(inputs, weights, bias, padding=2)
        assert torch.allclose(expected_result, result)
        
    def test_conv2d_bias_padding_stride_success(self, inputs, weights, bias):
        result = convolution3D(inputs, weights, bias, padding=1, stride=2)
        expected_result = nn.functional.conv3d(inputs, weights, bias, padding=1, stride=2)
        assert torch.allclose(expected_result, result)

[32m.[0m[32m.[0m[32m.[0m[32m.[0m
[32m[32m[1m4 passed[0m[32m in 0.07s[0m[0m
