In [None]:
import torch
import torch.nn as nn

torch.manual_seed(42)

<torch._C.Generator at 0x7bf8145c97d0>

##**Transpose Convolution**

In [None]:
input = torch.randint(
    5, (1, 2, 2),
    dtype=torch.float32
)
input

tensor([[[2., 2.],
         [1., 4.]]])

**Example 1**

In [None]:
conv_layer = nn.ConvTranspose2d(
    in_channels=1,
    out_channels=1,
    kernel_size=2,
    bias=False
)

In [None]:
conv_layer.weight.data = torch.ones(
    conv_layer.weight.data.shape
)
conv_layer.weight

Parameter containing:
tensor([[[[1., 1.],
          [1., 1.]]]], requires_grad=True)

In [None]:
output = conv_layer(input)
output

tensor([[[2., 4., 2.],
         [3., 9., 6.],
         [1., 5., 4.]]], grad_fn=<SqueezeBackward1>)

**Example 2**

In [None]:
conv_layer = nn.ConvTranspose2d(
    in_channels=1,
    out_channels=1,
    kernel_size=2,
    bias=True
)

In [None]:
conv_layer.weight.data = torch.ones(
    conv_layer.weight.data.shape
)
conv_layer.weight

Parameter containing:
tensor([[[[1., 1.],
          [1., 1.]]]], requires_grad=True)

In [None]:
conv_layer.bias =  nn.Parameter(
    torch.tensor([1], dtype=torch.float32)
)
conv_layer.bias

Parameter containing:
tensor([1.], requires_grad=True)

In [None]:
output = conv_layer(input)
output

tensor([[[ 3.,  5.,  3.],
         [ 4., 10.,  7.],
         [ 2.,  6.,  5.]]], grad_fn=<SqueezeBackward1>)

**Example 3: Padding**

In [None]:
conv_layer = nn.ConvTranspose2d(
    in_channels=1,
    out_channels=1,
    kernel_size=2,
    padding=1,
    bias=False
)

In [None]:
conv_layer.weight.data = torch.ones(
    conv_layer.weight.data.shape
)
conv_layer.weight

Parameter containing:
tensor([[[[1., 1.],
          [1., 1.]]]], requires_grad=True)

In [None]:
output = conv_layer(input)
output

tensor([[[9.]]], grad_fn=<SqueezeBackward1>)

In [None]:
conv_layer = nn.ConvTranspose2d(
    in_channels=1,
    out_channels=1,
    kernel_size=2,
    padding=(1, 0),
    bias=False
)

In [None]:
conv_layer.weight.data = torch.ones(
    conv_layer.weight.data.shape
)
conv_layer.weight

Parameter containing:
tensor([[[[1., 1.],
          [1., 1.]]]], requires_grad=True)

In [None]:
output = conv_layer(input)
output

tensor([[[3., 9., 6.]]], grad_fn=<SqueezeBackward1>)

**Example 4: Stride**

In [None]:
conv_layer = nn.ConvTranspose2d(
    in_channels=1,
    out_channels=1,
    kernel_size=2,
    stride=2,
    bias=False
)

In [None]:
conv_layer.weight.data = torch.ones(
    conv_layer.weight.data.shape
)
conv_layer.weight

Parameter containing:
tensor([[[[1., 1.],
          [1., 1.]]]], requires_grad=True)

In [None]:
output = conv_layer(input)
output

tensor([[[2., 2., 2., 2.],
         [2., 2., 2., 2.],
         [1., 1., 4., 4.],
         [1., 1., 4., 4.]]], grad_fn=<SqueezeBackward1>)

**Example 5: Padding & Stride**

In [None]:
conv_layer = nn.ConvTranspose2d(
    in_channels=1,
    out_channels=1,
    kernel_size=2,
    stride=2,
    padding=1,
    bias=False
)

In [None]:
conv_layer.weight.data = torch.ones(
    conv_layer.weight.data.shape
)
conv_layer.weight

Parameter containing:
tensor([[[[1., 1.],
          [1., 1.]]]], requires_grad=True)

In [None]:
output = conv_layer(input)
output

tensor([[[2., 2.],
         [1., 4.]]], grad_fn=<SqueezeBackward1>)

**Example 6: Multi Channels**

In [None]:
input = torch.randint(
    5, (2, 2, 2),
    dtype=torch.float32
)
input

tensor([[[0., 1.],
         [3., 0.]],

        [[1., 1.],
         [2., 4.]]])

In [None]:
conv_layer = nn.ConvTranspose2d(
    in_channels=2,
    out_channels=1,
    kernel_size=2,
    stride=2,
    bias=False
)

In [None]:
conv_layer.weight.data = torch.ones(
    conv_layer.weight.data.shape
)
conv_layer.weight

Parameter containing:
tensor([[[[1., 1.],
          [1., 1.]]],


        [[[1., 1.],
          [1., 1.]]]], requires_grad=True)

In [None]:
output = conv_layer(input)
output

tensor([[[1., 1., 2., 2.],
         [1., 1., 2., 2.],
         [5., 5., 4., 4.],
         [5., 5., 4., 4.]]], grad_fn=<SqueezeBackward1>)

##**U-Net**

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels) -> None:
        super().__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(
                in_channels, out_channels, kernel_size=3, padding=1, bias=False
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                out_channels, out_channels, kernel_size=3, padding=1, bias=False
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x = self.conv_block(x)
        return x

In [None]:
input = torch.randint(5, (1, 1, 256, 256), dtype=torch.float32)

In [None]:
in_conv = ConvBlock(in_channels=1, out_channels=64)

In [None]:
output = in_conv(input)
output.shape

torch.Size([1, 64, 256, 256])

In [None]:
# Down-Sample
class Encoder(nn.Module):
    def __init__(self, in_channels, out_channels) -> None:
        super().__init__()
        self.encoder = nn.Sequential(
            nn.MaxPool2d(2),
            ConvBlock(in_channels, out_channels)
        )

    def forward(self, x):
        x = self.encoder(x)
        return x

In [None]:
encoder = Encoder(in_channels=64, out_channels=128)

In [None]:
input = torch.randint(5, (1, 64, 256, 256), dtype=torch.float32)

In [None]:
output = encoder(input)
output.shape

torch.Size([1, 128, 128, 128])

In [None]:
# Up-Sample
class Decoder(nn.Module):
    def __init__(self, in_channels, out_channels) -> None:
        super().__init__()
        self.conv_trans = nn.ConvTranspose2d(
            in_channels, out_channels, kernel_size=4, stride=2, padding=1
        )
        self.conv_block = ConvBlock(in_channels, out_channels)

    def forward(self, x1, x2):
        x1  = self.conv_trans(x1)
        x = torch.cat([x2, x1], dim=1)
        x = self.conv_block(x)
        return x

In [None]:
decoder = Decoder(in_channels=128, out_channels=64)

In [None]:
x1 = torch.randint(5, (1, 128, 128, 128), dtype=torch.float32)
x2 = torch.randint(5, (1, 64, 256, 256), dtype=torch.float32)

In [None]:
output = decoder(x1, x2)
output.shape

torch.Size([1, 64, 256, 256])

In [None]:
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes)
        super().__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes

        self.in_conv = ConvBlock(n_channels, 64)

        self.enc_1 = Encoder(64, 128)
        self.enc_2 = Encoder(128, 256)
        self.enc_3 = Encoder(256, 512)
        self.enc_4 = Encoder(512, 1024)

        self.dec_1 = Decoder(1024, 512)
        self.dec_2 = Decoder(512, 256)
        self.dec_3 = Decoder(256, 128)
        self.dec_4 = Decoder(128, 64)

        self.out_conv = nn.Conv2d(
            64, n_classes, kernel_size=1
        )

    def forward(self, x):
        x1 = self.in_conv(x)

        x2 = self.enc_1(x1)
        x3 = self.enc_2(x2)
        x4 = self.enc_3(x3)
        x5 = self.enc_4(x4)

        x = self.dec_1(x5, x4)
        x = self.dec_2(x, x3)
        x = self.dec_3(x, x2)
        x = self.dec_4(x, x1)

        x = self.out_conv(x)

        return x

In [None]:
model = UNet(n_channels=1, n_classes=2)

In [None]:
input = torch.randint(
    5, (4, 1, 256, 256), dtype=torch.float32
)

In [None]:
predictions = model(input)

In [None]:
predictions.shape

torch.Size([4, 2, 256, 256])

In [None]:
predictions[0]

tensor([[[-0.0563,  0.0170, -0.3142,  ...,  0.0013, -0.1259,  0.0598],
         [-0.0853, -0.1676, -0.4449,  ..., -0.6244, -0.4742, -0.0824],
         [-0.2460, -0.4544, -0.3539,  ..., -0.2415, -0.4214, -0.4095],
         ...,
         [-0.3938, -0.9863, -0.3355,  ..., -0.8172, -0.6500,  0.0783],
         [-0.2712, -0.0243, -0.9204,  ..., -0.4072, -0.5009,  0.0499],
         [-0.0061,  0.0325, -0.5424,  ..., -0.1152, -0.4482, -0.0343]],

        [[-0.1440,  0.5064, -0.1361,  ..., -0.3488,  0.2831,  0.2494],
         [-0.5166,  0.2389, -0.0374,  ...,  0.2069,  0.2725, -0.2756],
         [-0.2379,  0.4510, -0.2720,  ...,  0.1055, -0.4505,  0.2620],
         ...,
         [ 0.0731, -0.0044,  0.4153,  ..., -0.1009,  0.0774, -0.0852],
         [ 0.2208,  0.3233, -0.0505,  ...,  0.0305,  0.5263,  0.1346],
         [ 0.0940, -0.1567, -0.1524,  ...,  0.1272, -0.4107,  0.0829]]],
       grad_fn=<SelectBackward0>)