In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import os
import math
import shutil
import random
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils


In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

#dont know the activation function in last layer, currently no activation function used after final 1x1 conv

class UNet(nn.Module):
    def contracting_block(self, in_channels, out_channels, kernel_size=(3,3), padding=(1,1)):
        block = nn.Sequential(
                nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding, stride=1),
                #nn.Dropout2d(inplace=True),
                nn.ReLU(inplace=True)
                #nn.BatchNorm2d(out_channels)
                )
        return block

    def bottle_neck(self, in_channels, out_channels, kernel_size=(3,3), padding=(1,1)):
        block = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            nn.ReLU(inplace=True)
        )
        return block
    
    def expansive_block(self, in_channels, out_channels, kernel_size=(3,3), padding=(1,1)):
        block = nn.Sequential(
                #nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel, padding=padding),
                nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=1, padding=padding),
                nn.ReLU(inplace=True)
                #nn.BatchNorm2d(out_channels),
                #nn.ConvTranspose2d(in_channels=mid_channel, out_channels=out_channels, kernel_size=kernel_size, stride=2, padding=padding, output_padding=1)
                )
        return  block
    
    def __init__(self, in_channel, out_channel):
        super(UNet, self).__init__()
        #Encode
        self.conv_encode1 = self.contracting_block(in_channels=in_channel, out_channels=64)
        self.conv_encode2 = self.contracting_block(64, 128)
        self.conv_encode3 = self.contracting_block(128, 256)
        #self.conv_encode4 = self.contracting_block(64, 128)
        self.max_pool = nn.MaxPool2d(2, stride=2)
        
        self.neck = self.bottle_neck(256, 512)

        self.upsample = nn.Upsample(scale_factor=2)
        self.upconv1 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=(3,3), stride=1, padding=(1,1))
        self.upconv2 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=(3,3), stride=1, padding=(1,1))
        self.upconv3 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=(3,3), stride=1, padding=(1,1))

        # Decode
        #self.conv_decode5 = self.expansive_block(256, 128)
        #self.conv_decode4 = self.expansive_block(256, 64)
        self.conv_decode3 = self.expansive_block(512, 256)
        self.conv_decode2 = self.expansive_block(256, 128)
        self.conv_decode1 = self.expansive_block(128, 64)
        self.final_conv = nn.Conv2d(in_channels=64, out_channels=out_channel, kernel_size=(1,1), stride=1, padding=(0,0))
        #self.sigmoid = nn.Sigmoid()
        
    def crop_and_concat(self, upsampled, bypass, crop=False):
        if crop:
            c = (bypass.size()[2] - upsampled.size()[2]) // 2
            bypass = F.pad(bypass, (-c, -c, -c, -c))
        return torch.cat((upsampled, bypass), 1)
    
    def forward(self, x):
        # Encode
        encode_block1 = self.max_pool(self.conv_encode1(x)) #(128, 108, 64)
        encode_block2 = self.max_pool(self.conv_encode2(encode_block1)) #(64, 54, 128)
        encode_block3 = self.max_pool(self.conv_encode3(encode_block2)) #(32, 27, 256)
        neck = self.neck(encode_block3) #(32, 27, 512)
        #encode_block4 = self.conv_encode4(encode_block3)
        #encode_block5 = self.conv_encode5(encode_block4)
        
        # Decode
        decode_block3 = self.upconv1(self.upsample(neck)) #(64, 54, 256)
        decode_block3 = self.conv_decode3(self.crop_and_concat(decode_block3, encode_block3, crop=False)) #(64, 54, 256)
        decode_block2 = self.upconv2(self.upsample(decode_block3)) #(128, 108, 128)
        decode_block2 = self.conv_decode2(self.crop_and_concat(decode_block2, encode_block2, crop=False)) #(128, 108, 128)
        decode_block1 = self.upconv3(self.upsample(decode_block2)) #(256, 216, 64)
        decode_block1 = self.conv_decode1(self.crop_and_concat(decode_block1, encode_block1, crop=False)) #(256, 216, 64)
        final_output = self.final_conv(decode_block1))
        return  final_output
