In [3]:
import torch
import torch.nn as nn
 
import os
import numpy as np

# Dummy UNET

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=(1, 3)):
        super().__init__()

        padding = (1, 1)

        if not kernel_size[0]:
            padding[0] = 0

        self.conv_op = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, padding=padding),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv_op(x)


class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=None, pool_kernel_size=None):
        super().__init__()
        self.conv = DoubleConv(in_channel, out_channel, kernel_size=kernel_size)
        self.pool = nn.MaxPool2d(kernel_size=pool_kernel_size)

    def forward(self, x):
        down = self.conv(x)
        p = self.pool(down)

        return down, p


class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=None, pool_kernel_size=None):
        super().__init__()

        if kernel_size is None:
            kernel_size = (1, 3)

        self.up = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=kernel_size)
        self.conv = DoubleConv(in_channels, out_channels, kernel_size)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x1, x2], 1)
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self, in_channels, num_classes, channels, kernel_size=None):
        super().__init__()
        self.down_convolution_1 = DownSample(in_channels, channels, kernel_size)
        self.down_convolution_2 = DownSample(channels, channels*2, kernel_size)
        self.down_convolution_3 = DownSample(channels*2, channels*4, kernel_size)
        self.down_convolution_4 = DownSample(channels*4, channels*8, kernel_size)

        self.bottle_neck = DoubleConv(channels*8, channels*16, kernel_size)

        self.up_convolution_1 = UpSample(channels*16, channels*8, kernel_size)
        self.up_convolution_2 = UpSample(channels*8, channels*4, kernel_size)
        self.up_convolution_3 = UpSample(channels*4, channels*2, kernel_size)
        self.up_convolution_4 = UpSample(channels*2, channels, kernel_size)

        self.out = nn.Conv2d(in_channels=channels, out_channels=num_classes, kernel_size=1)

    def forward(self, x):
        print(x.shape)

        down_1, p1 = self.down_convolution_1(x)
        print(f'down_1 = {down_1.shape}, p1 = {p1.shape}')

        down_2, p2 = self.down_convolution_2(p1)
        print(f'down_2 = {down_2.shape}, p2 = {p2.shape}')

        down_3, p3 = self.down_convolution_3(p2)
        print(f'down_3 = {down_3.shape}, p3 = {p3.shape}')
        
        down_4, p4 = self.down_convolution_4(p3)
        print(f'down_4 = {down_4.shape}, p4 = {p4.shape}')



        b = self.bottle_neck(p4)
        print(f'b = {b.shape}')



        up_1 = self.up_convolution_1(b, down_4)
        print(f'up_1 = {up_1.shape}')

        up_2 = self.up_convolution_2(up_1, down_3)
        print(f'up_2 = {up_2.shape}')

        up_3 = self.up_convolution_3(up_2, down_2)
        print(f'up_3 = {up_3.shape}')

        up_4 = self.up_convolution_4(up_3, down_1)
        print(f'up_4 = {up_4.shape}')

        out = self.out(up_4)
        return out

In [9]:
n = 128

inp_shape = (1, 1, 1, n)

input_image = torch.rand(inp_shape)

model = UNet(1, 10, channels=16, kernel_size=(1, 3))
output = model(input_image)
print(output.size())
# You should get torch.Size([1, 10, 512, 512]) as a result

torch.Size([1, 1, 1, 128])
down_1 = torch.Size([1, 16, 5, 128]), p1 = torch.Size([1, 16, 2, 64])
down_2 = torch.Size([1, 32, 6, 64]), p2 = torch.Size([1, 32, 3, 32])
down_3 = torch.Size([1, 64, 7, 32]), p3 = torch.Size([1, 64, 3, 16])
down_4 = torch.Size([1, 128, 7, 16]), p4 = torch.Size([1, 128, 3, 8])
b = torch.Size([1, 256, 7, 8])


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 14 but got size 7 for tensor number 1 in the list.

In [11]:
x = torch.ones((1, 1, 1, n))

In [17]:
in_channels, out_channels = 1, 4
kernel_size = (1, 3)
padding_size = (1, 1)
padding_size = 1

lconv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding_size)
lconv(x).shape



torch.Size([1, 4, 3, 128])