In [109]:
from UNet import ConvBlock, Encoder, Decoder, UNet
import torch
import torch.nn as nn

torch.manual_seed(42)

<torch._C.Generator at 0x1ca7f31e9b0>

In [110]:
input = torch.randint(5, (1, 2, 2), dtype=torch.float32)
input

tensor([[[2., 2.],
         [1., 4.]]])

Weight & Bias

In [111]:
tconv_layer = nn.ConvTranspose2d(
    in_channels=1, out_channels=1, kernel_size=2, bias=True)

In [112]:
tconv_layer.weight.data = torch.ones(
    tconv_layer.weight.data.shape
)
tconv_layer.weight

Parameter containing:
tensor([[[[1., 1.],
          [1., 1.]]]], requires_grad=True)

In [113]:
tconv_layer.bias = nn.Parameter(
    torch.tensor([1], dtype=torch.float32)
)

In [114]:
output = tconv_layer(input)
output

tensor([[[ 3.,  5.,  3.],
         [ 4., 10.,  7.],
         [ 2.,  6.,  5.]]], grad_fn=<SqueezeBackward1>)

Padding

In [115]:
tconv_layer = nn.ConvTranspose2d(
    in_channels=1, out_channels=1, kernel_size=2, padding=1
)

In [116]:
tconv_layer.weight.data = torch.ones(
    tconv_layer.weight.data.shape
)
tconv_layer.weight

Parameter containing:
tensor([[[[1., 1.],
          [1., 1.]]]], requires_grad=True)

In [117]:
output = tconv_layer(input)
output

tensor([[[9.0677]]], grad_fn=<SqueezeBackward1>)

In [118]:
tconv_layer = nn.ConvTranspose2d(
    in_channels=1,
    out_channels=1,
    kernel_size=2,
    padding=(1, 0),
    bias=False
)

tconv_layer.weight.data = torch.ones(
    tconv_layer.weight.data.shape
)

output = tconv_layer(input)
output

tensor([[[3., 9., 6.]]], grad_fn=<SqueezeBackward1>)

Stride

In [120]:
tconv_layer = nn.ConvTranspose2d(
    in_channels=1,
    out_channels=1,
    kernel_size=2,
    stride=2,
    padding=0,
    bias=False
)

tconv_layer.weight.data = torch.ones(
    tconv_layer.weight.data.shape
)

output = tconv_layer(input)
output

tensor([[[2., 2., 2., 2.],
         [2., 2., 2., 2.],
         [1., 1., 4., 4.],
         [1., 1., 4., 4.]]], grad_fn=<SqueezeBackward1>)

In [121]:
tconv_layer = nn.ConvTranspose2d(
    in_channels=1,
    out_channels=1,
    kernel_size=2,
    stride=2,
    padding=1,
    bias=False
)

tconv_layer.weight.data = torch.ones(
    tconv_layer.weight.data.shape
)

output = tconv_layer(input)
output

tensor([[[2., 2.],
         [1., 4.]]], grad_fn=<SqueezeBackward1>)

ConvBlock

In [122]:
input = torch.randint(5, (1, 1, 256, 256), dtype=torch.float32)

In [124]:
in_conv = ConvBlock(in_channels=1, out_channels=64)
in_conv

ConvBlock(
  (conv_block): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
  )
)

In [125]:
output = in_conv(input)
output.shape

torch.Size([1, 64, 256, 256])

Encoder

In [127]:
encoder = Encoder(in_channels=64, out_channels=128)
input = torch.randint(5, (1, 64, 256, 256), dtype=torch.float32)
output = encoder(input)
output.shape

torch.Size([1, 128, 128, 128])

UNet

In [128]:
model = UNet(n_channels=1, n_classes=2)

In [129]:
input = torch.randint(
    5, (4, 1, 256, 256), dtype=torch.float32
)

predictions = model(input)

Shape Input Conv: torch.Size([4, 64, 256, 256])
Shape Encoder 1: torch.Size([4, 128, 128, 128])
Shape Encoder 2: torch.Size([4, 256, 64, 64])
Shape Encoder 3: torch.Size([4, 512, 32, 32])
Shape Encoder 4: torch.Size([4, 1024, 16, 16])
Shape Decoder 1: torch.Size([4, 512, 32, 32])
Shape Decoder 2: torch.Size([4, 256, 64, 64])
Shape Decoder 3: torch.Size([4, 128, 128, 128])
Shape Decoder 4: torch.Size([4, 64, 256, 256])
Shape Output decoder: torch.Size([4, 2, 256, 256])
