In [3]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from collections import OrderedDict

In [93]:
class Unet(nn.Module):
    def contracting_block(self, in_channels, out_channels, kernel_size=3):
        block = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size)),
            ('relu1', nn.ReLU()),
            ('batch_norm1', nn.BatchNorm2d(out_channels)),
            ('conv2', nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size)),
            ('relu2', nn.ReLU()),
            ('batch_norm2', nn.BatchNorm2d(out_channels))
        ]))
        return block
     
    def expansion_block(self, in_channels, mid_channels, out_channels, kernel_size=3):
        block = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(in_channels=in_channels, out_channels=mid_channels, kernel_size=kernel_size)),
            ('relu1', nn.ReLU()),
            ('batch_norm1', nn.BatchNorm2d(mid_channels)),
            ('conv2', nn.Conv2d(in_channels=mid_channels, out_channels=mid_channels, kernel_size=kernel_size)),
            ('relu2', nn.ReLU()),
            ('batch_norm2', nn.BatchNorm2d(mid_channels)),
            ('upconv', nn.ConvTranspose2d(in_channels=mid_channels, out_channels=out_channels, 
                                          kernel_size=3, stride=2, padding=1, output_padding=1))
        ]))
        return block
    
    def final_block(self, in_channels, out_channels, kernel_size=3):
        block = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(in_channels=in_channels, out_channels=mid_channels, kernel_size=kernel_size)),
            ('relu1', nn.ReLU()),
            ('batch_norm1', nn.BatchNorm2d(mid_channels)),
            ('conv2', nn.Conv2d(in_channels=mid_channels, out_channels=mid_channels, kernel_size=kernel_size)),
            ('relu2', nn.ReLU()),
            ('batch_norm2', nn.BatchNorm2d(mid_channels)),
            ('final_conv', nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=kernel_size)),
            ('relu3', nn.ReLU()),
            ('batch_norm3', nn.BatchNorm2d(out_channels))
        ]))