In [1]:
import torch
import numpy as np
from torch.nn.functional import conv3d as libConv3d

In [2]:
class Conv3DSelf():
    def __init__(
        self,
        input_data,
        in_channels: int,
        out_channels: int,
        kernel_size: tuple | int,
        bias: float | None = None,
        stride: int = 1,
        padding: tuple[int, int] | int | str = (0, 0),
        dilation: int = 1,
    ):
        self.input_data = input_data.numpy()
        self.input_data_for_torch = input_data
        self.bias = bias
        
        self.in_channels, self.out_channels = in_channels, out_channels
          
        if type(kernel_size) == tuple:
            self.kernel_size = kernel_size
        else:
            self.kernel_size = (kernel_size, kernel_size, kernel_size)
        
        if type(stride) == tuple:
            self.stride = stride
        else:
            self.stride = (stride, stride, stride)

        if type(dilation) == tuple:
            self.dilation = dilation
        else:
            self.dilation = (dilation, dilation, dilation)
            
        if type(padding) == tuple:
            self.padding = padding
        elif padding == "same":
            if self.stride[0] != 1 or self.stride[1] != 1 or self.stride[2] != 1:
                raise ValueError("padding 'same' works only with stride=1")
            self.padding = (self.kernel_size[0]-1,self.kernel_size[1]-1,self.kernel_size[2]-1)
        elif padding == "valid":
            self.padding = (0,0,0)
        else:
            self.padding = (padding,padding,padding)
            
        self.weight_tensor_for_torch = torch.randn(1, 1, self.kernel_size[0], self.kernel_size[1], self.kernel_size[2])
        self.weight_tensor = self.weight_tensor_for_torch.numpy()
        
    def conv3d(self):
        batches = len(self.input_data)
        out = []

        for b in range(batches):
            d_in = self.input_data[b].shape[1]
            h_in = self.input_data[b].shape[2]
            w_in = self.input_data[b].shape[3]

            if self.kernel_size[0] > h_in or self.kernel_size[1] > w_in or self.kernel_size[2] > d_in:
                raise ValueError('kernel size can\'t be greater than input size')

            d_out = int(
                (d_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / (self.stride[0]) + 1)

            h_out = int(
                (h_in + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / (self.stride[1]) + 1)

            w_out = int(
                (w_in + 2 * self.padding[2] - self.dilation[2] * (self.kernel_size[2] - 1) - 1) / (self.stride[2]) + 1)

            out.append(np.zeros((self.out_channels, d_out, h_out, w_out)))

            for c_out in range(self.out_channels):
                for z_out in range(d_out):
                    for y_out in range(h_out):
                        for x_out in range(w_out):
                            sum = 0
                            for c_in in range(self.in_channels):
                                for kernel_z in range(self.kernel_size[0]):
                                    for kernel_y in range(self.kernel_size[1]):
                                        for kernel_x in range(self.kernel_size[2]):
                                            z_in = z_out * self.stride[0] + kernel_z * self.dilation[0] - self.padding[0]
                                            y_in = y_out * self.stride[1] + kernel_y * self.dilation[1] - self.padding[1]
                                            x_in = x_out * self.stride[2] + kernel_x * self.dilation[2] - self.padding[2]
                                            if 0 <= z_in < d_in and 0 <= y_in < h_in and 0 <= x_in < w_in:
                                                sum += self.input_data[b][c_in][z_in][y_in][x_in] * self.weight_tensor[c_out][c_in][kernel_z][kernel_y][kernel_x]

                            out[b][c_out][z_out][y_out][x_out] = sum + (self.bias if self.bias else 0)

        return np.array(out)
    
    def torch_conv3d(self):
        return libConv3d(
            self.input_data_for_torch,
            self.weight_tensor_for_torch,
            bias=torch.tensor([self.bias]), 
            stride=self.stride,
            padding=self.padding,
            dilation=self.dilation,
        )
    
    def test(self, print_flg=False):
        my_conv3d = self.conv3d()
        torch_out = self.torch_conv3d().squeeze().detach().numpy()
        if print_flg:
            print(my_conv3d[0][0])
            print(torch_out)
        print(np.allclose(my_conv3d[0][0], torch_out))


In [16]:
def testing(input_data, in_channels, out_channels, kernel_size, bias, stride, padding, dilation):
    conv3d_layer = Conv3DSelf(input_data, in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias, stride=stride, padding=padding, dilation=dilation)
    conv3d_layer.test(print_flg=False)
input_data = torch.randn(1, 1, 5, 5, 5)
testing(input_data,1,1,4,0.5,1,'same',1)

True


In [19]:
input_data = torch.randn(1, 1, 5, 5, 5)
testing(input_data,1,1,4,0.5,1,2,2)

True


In [20]:
input_data = torch.randn(1, 1, 5, 5, 5)
testing(input_data,1,1,4,0.5,4,2,2)

True


In [24]:
input_data = torch.randn(1, 1, 5, 5, 5)
testing(input_data,1,1,2,0.5,1,2,2)

True
