In [3]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from collections import OrderedDict

In [95]:
class Unet(nn.Module):
    def contracting_block(self, in_channels, out_channels, kernel_size=3):
        block = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size)),
            ('relu1', nn.ReLU()),
            ('batch_norm1', nn.BatchNorm2d(out_channels)),
            ('conv2', nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size)),
            ('relu2', nn.ReLU()),
            ('batch_norm2', nn.BatchNorm2d(out_channels))
        ]))
        return block
     
    def expansion_block(self, in_channels, mid_channels, out_channels, kernel_size=3):
        block = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(in_channels=in_channels, out_channels=mid_channels, kernel_size=kernel_size)),
            ('relu1', nn.ReLU()),
            ('batch_norm1', nn.BatchNorm2d(mid_channels)),
            ('conv2', nn.Conv2d(in_channels=mid_channels, out_channels=mid_channels, kernel_size=kernel_size)),
            ('relu2', nn.ReLU()),
            ('batch_norm2', nn.BatchNorm2d(mid_channels)),
            ('upconv', nn.ConvTranspose2d(in_channels=mid_channels, out_channels=out_channels, 
                                          kernel_size=3, stride=2, padding=1, output_padding=1))
        ]))
        return block
    
    def final_block(self, in_channels, mid_channels, out_channels, kernel_size=3):
        block = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(in_channels=in_channels, out_channels=mid_channels, kernel_size=kernel_size)),
            ('relu1', nn.ReLU()),
            ('batch_norm1', nn.BatchNorm2d(mid_channels)),
            ('conv2', nn.Conv2d(in_channels=mid_channels, out_channels=mid_channels, kernel_size=kernel_size)),
            ('relu2', nn.ReLU()),
            ('batch_norm2', nn.BatchNorm2d(mid_channels)),
            ('final_conv', nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=kernel_size)),
            ('relu3', nn.ReLU()),
            ('batch_norm3', nn.BatchNorm2d(out_channels))
        ]))
    
    def __init__(self, input_channels, output_channels):
        self.max_pool = nn.MaxPool2d(kernel_size=2)
        self.conv_encode1 = self.contracting_block(input_channels, 64)
        self.conv_encode2 = self.contracting_block(64, 128)
        self.conv_encode3 = self.contracting_block(128, 256)
        self.conv_encode4 = self.contracting_block(256,512)
        self.bottle_neck = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3),
            nn.ReLU(),
            nn.BatchNorm2d(1024),
            nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3),
            nn.ReLU(),
            nn.BatchNorm2d(1024),
            nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=3, stride=2, padding=1, output_padding=1)
        )
        
        self.decode1 = self.expansion_block(512,256,128)
        self.decode2 = self.expansion_block(256,128,64)
        self.final_layer = self.final_block(128,64, output_channels)
        