In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Import libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.init as init
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import transforms

import os
import numpy as np
import seaborn as sns
from PIL import Image
import matplotlib.pyplot as plt

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# Encoder / Contracting Path
class Encoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Encoder, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3),
            nn.ReLU(inplace=True),
        )
    
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
    def forward(self, x):
        feat = self.conv(x) # The skip connection
        out = self.pool(feat) # The input for the next layer
        return feat, out

In [None]:
# Bottleneck
class Bottleneck(nn.Module):
    def __init__(self):
        super(Bottleneck, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(512, 1024, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.Conv2d(1024, 1024, kernel_size=3),
            nn.ReLU(inplace=True),
        )
    
    def forward(self, x):
        return self.conv(x)

In [None]:
def center_crop(enc_feat, target_tensor):
    _, _, h, w = target_tensor.shape
    enc_h, enc_w = enc_feat.shape[2], enc_feat.shape[3]

    delta_h = enc_h - h
    delta_w = enc_w - w

    return enc_feat[:, :, delta_h // 2 : enc_h - delta_h // 2, delta_w // 2 : enc_w - delta_w // 2]

In [None]:
# Decoder / Expanding Path
class Decoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Decoder, self).__init__()

        self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)

        self.conv = nn.Sequential(
            nn.Conv2d(out_channels * 2, out_channels, kernel_size=3),
            nn.ReLU(inplace=2),
            nn.Conv2d(out_channels, out_channels, kernel_size=3),
            nn.ReLU(inplace=2),
        )
    
    def forward(self, x, enc_feat):
        x = self.upconv(x)
        enc_feat = center_crop(enc_feat, x)
        x = torch.cat([enc_feat, x], dim=1)
        return self.conv(x)

In [None]:
class UNet(nn.Module):
    def __init__(self, in_channels=1, num_classes=2):
        super(UNet, self).__init__()

        # Encoder
        self.enc1 = Encoder(in_channels, 64)
        self.enc2 = Encoder(64, 128)
        self.enc3 = Encoder(128, 256)
        self.enc4 = Encoder(256, 512)

        # Bottleneck
        self.bottleneck = Bottleneck()

        # Decoder
        self.dec4 = Decoder(1024, 512)
        self.dec3 = Decoder(512, 256)
        self.dec2 = Decoder(256, 128)
        self.dec1 = Decoder(128, 64)

        # Final Conv 1x1
        self.final = nn.Conv2d(64, num_classes, kernel_size=1)


    def forward(self, x):
        s1, p1 = self.enc1(x)
        s2, p2 = self.enc2(p1)
        s3, p3 = self.enc3(p2)
        s4, p4 = self.enc4(p3)

        b = self.bottleneck(p4)

        d4 = self.dec4(b, s4)
        d3 = self.dec3(d4, s3)
        d2 = self.dec2(d3, s2)
        d1 = self.dec1(d2, s1)

        return self.final(d1)
