## Transposed convolution

In [1]:
import torch
import torch.nn as nn

<img src="images/image16.png" height="150"/>

In [10]:
input = torch.tensor([[[[0, 1], [2, 3]]]], dtype=torch.float32)

conv = nn.ConvTranspose2d(
    in_channels=1,
    out_channels=1,
    kernel_size=2,
    stride=1,
    bias=False
)

conv.weight.data = torch.tensor([[[[0, 1], [2, 3]]]], dtype=torch.float32)

output = conv(input)

print(output)


tensor([[[[ 0.,  0.,  1.],
          [ 0.,  4.,  6.],
          [ 4., 12.,  9.]]]], grad_fn=<ConvolutionBackward0>)


## Pixel Shuffle

In [11]:
input = torch.tensor([[[[1, 2],[3, 4]],[[5, 6],[7, 8]],[[9, 10],[11, 12]],[[13, 14],[15, 16]]]], dtype=torch.float32)

pixel_shuffle = nn.PixelShuffle(upscale_factor=2)

output = pixel_shuffle(input)

print("Entrada:")
print(input)
print("Salida:")
print(output)

Entrada:
tensor([[[[ 1.,  2.],
          [ 3.,  4.]],

         [[ 5.,  6.],
          [ 7.,  8.]],

         [[ 9., 10.],
          [11., 12.]],

         [[13., 14.],
          [15., 16.]]]])
Salida:
tensor([[[[ 1.,  5.,  2.,  6.],
          [ 9., 13., 10., 14.],
          [ 3.,  7.,  4.,  8.],
          [11., 15., 12., 16.]]]])


## U-Net

<img src="images/image17.png" height="600"/>

In [12]:
def DoubleConv(in_chan,out_chan):
    return nn.Sequential(
        nn.Conv2d(in_chan,out_chan,kernel_size=3,padding=1),
        nn.BatchNorm2d(out_chan),
        nn.SiLU(inplace=True),
        nn.Conv2d(out_chan,out_chan,kernel_size=3,padding=1),
        nn.BatchNorm2d(out_chan),
        nn.SiLU(inplace=True),
    )

In [13]:
class Unet(nn.Module):
    def __init__(self, n_chan, n_classes):
        super(Unet,self).__init__()

        self.enc1  = DoubleConv(n_chan,64)
        self.pool1 = nn.MaxPool2d(kernel_size=2,stride=2)

        self.enc2  = DoubleConv(64,128)
        self.pool2 = nn.MaxPool2d(kernel_size=2,stride=2)

        self.enc3  = DoubleConv(128,256)
        self.pool3 = nn.MaxPool2d(kernel_size=2,stride=2)

        self.enc4  = DoubleConv(256,512)
        self.pool4 = nn.MaxPool2d(kernel_size=2,stride=2)

        self.bottleneck = DoubleConv(512,1024)

        self.up4  = nn.ConvTranspose2d(1024,512,kernel_size=2,stride=2)
        self.dec4 = DoubleConv(1024,512)

        self.up3  = nn.ConvTranspose2d(512,256,kernel_size=2,stride=2)
        self.dec3 = DoubleConv(512,256)

        self.up2  = nn.ConvTranspose2d(256,128,kernel_size=2,stride=2)
        self.dec2 = DoubleConv(256,128)

        self.up1  = nn.ConvTranspose2d(128,64,kernel_size=2,stride=2)
        self.dec1 = DoubleConv(128,64)

        self.OutLayer = nn.Conv2d(64,n_classes,kernel_size=1)

    def forward(self, x):

        # down
        
        z1 = self.enc1(x)
        z2 = self.pool1(z1)

        z2 = self.enc2(z2)
        z3 = self.pool2(z2)

        z3 = self.enc3(z3)
        z4 = self.pool3(z3)

        z4 = self.enc4 (z4)
        Z  = self.pool4(z4)
        
        Z = self.bottleneck(Z)

        # up
        y = self.up4(Z)
        y = torch.cat([y, z4], dim=1) # crop
        y = self.dec4(y)

        y = self.up3(y)
        y = torch.cat([y, z3], dim=1)
        y = self.dec3(y)

        y = self.up2(y)
        y = torch.cat([y, z2], dim=1)
        y = self.dec2(y)


        y = self.up1(y)
        y = torch.cat([y, z1], dim=1)
        y = self.dec1(y)

        return self.OutLayer(y)

In [14]:
model = Unet(3,10)
print(model)

Unet(
  (enc1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): SiLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): SiLU(inplace=True)
  )
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (enc2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): SiLU(inplace=True)
    (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): SiLU(inplace=True)
  )
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False