Credits: [Symmetric Skip
Connections Paper](https://arxiv.org/pdf/1606.08921)

In [1]:
import torch.nn as nn
import torch.nn.functional as F

Lets make a random tensor and play around to see how the dimensions work out
Making random tensor because running this locally right now, will work with dataset in colab (mac with 20gb left problems lol)

In [2]:
import torch

x = torch.randn(3, 224, 224)  
print("Original x shape, this is how the image is going to be, 3 channels with 224x224 : ", x.shape);
x = x.unsqueeze(0)
print("We have added an extra dimesion to suit required input dimensions: ", x.shape)
 
# dilation is one by default, keeping that
#( n + 2p - f)/s + 1
c1 = nn.Conv2d(3,32,kernel_size=3,stride=2,padding=1) 
x = c1(x) # pass through first consolutional layer
print("enc1: ", x.shape)

c2 = nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1)
x = c2(x)
print("enc2: ", x.shape)

c3 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
x = c3(x)
print("enc3: ", x.shape)

c4 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1)
x = c4(x)
print("enc4: ", x.shape)

c5 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
x = c5(x)
print("enc5: ", x.shape)

c6 = nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1)
x = c6(x)
print("enc6: ", x.shape)

Original x shape, this is how the image is going to be, 3 channels with 224x224 :  torch.Size([3, 224, 224])
We have added an extra dimesion to suit required input dimensions:  torch.Size([1, 3, 224, 224])
enc1:  torch.Size([1, 32, 112, 112])
enc2:  torch.Size([1, 32, 56, 56])
enc3:  torch.Size([1, 64, 28, 28])
enc4:  torch.Size([1, 64, 14, 14])
enc5:  torch.Size([1, 128, 7, 7])
enc6:  torch.Size([1, 128, 4, 4])


In [3]:
d1 = nn.ConvTranspose2d(128, 128, kernel_size=3, stride=2, padding=1, output_padding=1)
x = d1(x)
print("dec1: ", x.shape)  

d2 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)
x = d2(x)
print("dec2: ", x.shape)  

d3 = nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2, padding=1, output_padding=1)
x = d3(x)
print("dec3: ", x.shape)  

d4 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
x = d4(x)
print("dec4: ", x.shape)  

d5 = nn.ConvTranspose2d(32, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
x = d5(x)
print("dec5: ", x.shape)  

d6 = nn.ConvTranspose2d(32, 3, kernel_size=3, stride=2, padding=1, output_padding=1)
x = d6(x)
print("output (reconstructed image): ", x.shape) 


dec1:  torch.Size([1, 128, 8, 8])
dec2:  torch.Size([1, 64, 16, 16])
dec3:  torch.Size([1, 64, 32, 32])
dec4:  torch.Size([1, 32, 64, 64])
dec5:  torch.Size([1, 32, 128, 128])
output (reconstructed image):  torch.Size([1, 3, 256, 256])


In [None]:
class SkipAutoencoder(nn.Module):
    def __init__(self):
        super(SkipAutoencoder, self).__init__()

        # encoder
        self.enc1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(),
        )
        self.enc2 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(),
        )
        self.enc3 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU()
        )
        self.enc4 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU()
        )
        self.enc5 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU()
        )
        self.enc6 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU()
        )

        # Decoder (transpose conv)
        self.dec1 = nn.Sequential(
            nn.ConvTranspose2d(128, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.LeakyReLU()
        )
        self.dec2 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.LeakyReLU()
        )
        self.dec3 = nn.Sequential(
            nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.LeakyReLU()
        )
        self.dec4 = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.LeakyReLU()
        )
        self.dec5 = nn.Sequential(
            nn.ConvTranspose2d(32, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.LeakyReLU()
        )
        self.out = nn.Sequential(
            nn.ConvTranspose2d(32, 3, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )

        self.bn = nn.BatchNorm2d(32)

    def forward(self, x):
        # Encoder with skip connections
        x1 = self.enc1(x)
        x2 = self.enc2(x1)
        x3 = self.enc3(x2)  # skip-1
        x4 = self.enc4(x3)
        x5 = self.enc5(x4)  # skip-2
        x6 = self.enc6(x5)

        # Decoder with skips
        y = self.dec1(x6)
        y = self.dec2(y)
        y = self.dec3(y + x5)  # skip-2
        y = self.dec4(y)
        y = self.dec5(y + x3)  # skip-1
        y = self.out(y)
        return y