In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Import libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.init as init
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import transforms

import os
import numpy as np
import seaborn as sns
from PIL import Image
import matplotlib.pyplot as plt

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# Encoder / Contracting Path
class Encoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Encoder, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3),
            nn.ReLU(inplace=True),
        )
    
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
    def forward(self, x):
        feat = self.conv(x) # The skip connection
        out = self.pool(feat) # The input for the next layer
        return feat, out

In [None]:
# Bottleneck
class Bottleneck(nn.Module):
    def __init__(self):
        super(Bottleneck, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(512, 1024, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.Conv2d(1024, 1024, kernel_size=3),
            nn.ReLU(inplace=True),
        )
    
    def forward(self, x):
        return self.conv(x)

In [None]:
# Decoder / Expanding Path
class Decoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Decoder, self).__init__()

        self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)

        self.conv = nn.Sequential(
            nn.Conv2d(out_channels * 2, out_channels, kernel_size=3),
            nn.ReLU(inplace=2),
            nn.Conv2d(out_channels, out_channels, kernel_size=3),
            nn.ReLU(inplace=2),
        )
    
    def forward(self, x, skip):
        x = self.upconv(x)
        # # Concatenate along the channel dimension (dim=1)
        x = torch.cat([skip, x], dim=1)
        return self.conv(x)

In [None]:
def center_crop(enc_feat, target_tensor):
    _, _, h, w = target_tensor.shape
    enc_h, enc_w = enc_feat.shape[2], enc_feat.shape[3]

    delta_h = enc_h - h
    delta_w = enc_w - w

    return enc_feat[:, :, delta_h // 2 : enc_h - delta_h // 2, delta_w // 2 : enc_w - delta_w // 2]

In [None]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        self.encoder = self._make_encoder_block(in_channels=1, out_channels=64, num_blocks=5)
        self.decoder = self._make_decoder_block(in_channels, out_channels, num_blocks=4)
    
    def _make_encoder_block(self, in_channels, out_channels, num_blocks):
        encoder_blocks = []

        for _ in range(num_blocks):
            encoder_blocks.append(Encoder(in_channels, out_channels))

            in_channels = out_channels
            out_channels = out_channels * 2
        
        return nn.Sequential(*encoder_blocks)
    
    def _make_decoder_block(self, in_channels, out_channels, num_blocks):
        encoder_blocks = []

        for _ in range(num_blocks):
            encoder_blocks.append(Decoder(in_channels, out_channels))

            in_channels = out_channels
            out_channels = out_channels * 2
        
        return nn.Sequential(*encoder_blocks)

    def forward(self, x):
        # Encoder
        x1 = self.encoder(x)
