<a href="https://colab.research.google.com/github/Nanashi-bot/autoencoder/blob/main/unet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Initial data preprocessing


In [2]:
# Importing libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

In [16]:
dim = 572

class UNet(nn.Module):
    def __init__(self, input_channels = 3, num_classes = 1):
        super(UNet, self).__init__()

        # Encoders
        self.encoder1 = self.encoder_block(input_channels, 64)
        self.encoder2 = self.encoder_block(64, 128)
        self.encoder3 = self.encoder_block(128, 256)
        self.encoder4 = self.encoder_block(256, 512)

        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(512, 1024, kernel_size=3, padding=0),
            nn.ReLU(inplace=True),
            nn.Conv2d(1024, 1024, kernel_size=3, padding=0),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2),
        )

        # Decoders
        self.decoder1 = self.decoder_block(1024, 512, 256)
        self.decoder2 = self.decoder_block(512, 256, 128)
        self.decoder3 = self.decoder_block(256, 128, 64)

        # Output Layer
        self.output = self.output_block(128,64,32)


    def encoder_block(self, in_channels, num_filters):
        return nn.Sequential(
            nn.Conv2d(in_channels, num_filters, kernel_size=3, padding=0),  # Convolution with 3x3 filter
            nn.ReLU(inplace=True),                                          # ReLU activation
            nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=0),  # Convolution with 3x3 filter
            nn.ReLU(inplace=True),                                          # ReLU activation
            # nn.MaxPool2d(kernel_size=2, stride=2)                           # Max Pooling with 2x2 filter
            )

    def decoder_block(self, in_channels, num_filters, output_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, num_filters, kernel_size=3, padding=0),
            nn.ReLU(inplace=True),
            nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=0),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(num_filters, output_channels, kernel_size=2, stride=2)
        )

    def output_block(self, in_channels, num_filters, output_channels):
        return nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=0),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, out_channels=1, kernel_size=1),
            nn.Sigmoid()   # Use softmax for multiclass, sigmoid for binary
        )

    def forward(self, x):
        # print("Input shape:", x.shape)
        # Encoders
        s1 = self.encoder1(x)           # Gives first encoder output to be passed to d4
        # print("First encoding done, s1 shape:",s1.shape)
        s2 = self.encoder2(F.max_pool2d(s1, kernel_size=2, stride=2)) # Gives second encoder output to be passed to d3
        # print("Second encoding done, s2 shape:",s2.shape)
        s3 = self.encoder3(F.max_pool2d(s2, kernel_size=2, stride=2))  # Gives third encoder output to be passed to d2
        # print("Third encoding done, s3 shape:",s3.shape)
        s4 = self.encoder4(F.max_pool2d(s3, kernel_size=2, stride=2))  # Gives fourth encoder output to be passed to d1
        # print("Fourth encoding done, s4 shape:",s4.shape)
        # Bottleneck
        b1 = self.bottleneck(F.max_pool2d(s4, kernel_size=2, stride=2))
        # print("Bottleneck passed, shape:",b1.shape)


        # Decoders
        d1 = b1
        # print("First downsampling, output shape:",d1.shape)
        resize_transform1 = transforms.Resize((56, 56))
        s4 = resize_transform1(s4)
        # print("Shape of s4:",s4.shape)
        d1 = torch.cat([d1, s4], dim=1)
        # print("After concatenating",d1.shape)
        d2 = self.decoder1(d1)
        # print("First decoding done, d2 shape:",d2.shape)

        resize_transform2 = transforms.Resize((104, 104))
        s3 = resize_transform2(s3)
        # print("Shape of s3:",s3.shape)
        d2 = torch.cat([d2, s3], dim=1)
        # print("After concatenating, d2 shape:",d2.shape)
        d3 = self.decoder2(d2)
        # print("Second decoding done, output shape:",d3.shape)

        resize_transform3 = transforms.Resize((200, 200))
        s2 = resize_transform3(s2)
        # print("Shape of s2:",s2.shape)
        d3 = torch.cat([d3, s2], dim=1)
        # print("After concatenating, d3 shape:",d3.shape)
        d4 = self.decoder3(d3)
        # print("Third decoding done, d4 shape:",d4.shape)

        resize_transform4 = transforms.Resize((392, 392))
        s1 = resize_transform4(s1)
        # print("Shape of s1:",s1.shape)
        d4 = torch.cat([d4, s1], dim=1)
        print("After concatenating, d4 shape:",d4.shape)

        # Output
        outputs = self.output(d4)
        print("Output shape:",output.shape)
        return outputs


In [17]:
x = torch.randn(1, 3, 572, 572)
model = UNet()
output = model(x)

After concatenating, d4 shape: torch.Size([1, 128, 392, 392])
Output shape: torch.Size([1, 2, 388, 388])


In [None]:
# CODE FOLLOWING THIS IS FOR TEST PURPOSES

In [None]:
up_conv = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2)

input_tensor = torch.randn(1, 1024, 28, 28)
normal_conv = nn.Conv2d(512, 512, kernel_size=3, padding=0)
output_tensor = up_conv(input_tensor)
print(output_tensor.shape)
v = normal_conv(output_tensor)
print(v.shape)

In [88]:
import torchvision.transforms as transforms
image = torch.randn(1, 3, 64, 64)
resize_transform = transforms.Resize((56, 56))
resized_image = resize_transform(image)

print(f"Resized shape: {resized_image.shape}")

Resized shape: torch.Size([1, 3, 56, 56])


In [9]:
input_tensor = torch.randn(1, 128, 392, 392)
conv1 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=0)
conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=0)

output1 = conv1(input_tensor)
print(f"Shape after first conv: {output1.shape}")

output2 = conv2(output1)
print(f"Shape after second conv: {output2.shape}")

Shape after first conv: torch.Size([1, 64, 390, 390])
Shape after second conv: torch.Size([1, 64, 388, 388])
