In [2]:
import os 
os.chdir('../')
import torch 

from VGG import Unet
        

In [5]:

net = Unet(128).cuda()
A = torch.ones((8, 128, 256, 256, 8)).cuda()


In [6]:
out = net(A)

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Conv2d(nn.Module):
    def __init__(
        self, n_channels, out_channels, kernel_size, dilation=1, padding=0, stride=1
    ):
        super().__init__()

        self.kernel_size = kernel_size
        self.kernel_size_number = kernel_size * kernel_size
        self.out_channels = out_channels
        self.padding = padding
        self.dilation = dilation
        self.stride = stride
        self.n_channels = n_channels
        self.weights = nn.Parameter(
            torch.Tensor(self.out_channels, self.n_channels, self.kernel_size**2)
        )

    def __repr__(self):
        return (
            f"Conv2d(n_channels={self.n_channels}, out_channels={self.out_channels}, "
            f"kernel_size={self.kernel_size})"
        )

    def forward(self, x):
        width = self.calculate_new_width(x)
        height = self.calculate_new_height(x)
        windows = self.calculate_windows(x)

        result = torch.zeros(
            [x.shape[0] * self.out_channels, width, height],
            dtype=torch.float32, device=x.device
        )

        for channel in range(x.shape[1]):
            for i_conv_n in range(self.out_channels):
                xx = torch.matmul(windows[channel], self.weights[i_conv_n][channel])
                xx = xx.view((-1, width, height))

                xx_stride = slice(i_conv_n * xx.shape[0], (i_conv_n + 1) * xx.shape[0])
                result[xx_stride] += xx

        result = result.view((x.shape[0], self.out_channels, width, height))
        return result

    def calculate_windows(self, x):
        windows = F.unfold(
            x,
            kernel_size=(self.kernel_size, self.kernel_size),
            padding=(self.padding, self.padding),
            dilation=(self.dilation, self.dilation),
            stride=(self.stride, self.stride)
        )

        windows = (windows
            .transpose(1, 2)
            .contiguous().view((-1, x.shape[1], int(self.kernel_size**2)))
            .transpose(0, 1)
        )
        return windows

    def calculate_new_width(self, x):
        return (
            (x.shape[2] + 2 * self.padding - self.dilation * (self.kernel_size - 1) - 1)
            // self.stride
        ) + 1

    def calculate_new_height(self, x):
        return (
            (x.shape[3] + 2 * self.padding - self.dilation * (self.kernel_size - 1) - 1)
            // self.stride
        ) + 1

torch.Size([8, 128, 256, 256])