In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import matplotlib.pyplot as plt 
import torch.optim as optim

In [None]:
def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.ReLU()
    ) 

class GDMapUNet(nn.Module):
    def __init__(self, n_classes):
        super(GDMapUNet, self).__init__()

        self.downconv1 = double_conv(3, 64)
        self.downconv2 = double_conv(64, 128)
        self.downconv3 = double_conv(128, 256)
        self.downconv4 = double_conv(256, 512) 

        self.maxpool = nn.MaxPool2d(kernel_size=2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)     

        self.upconv1 = double_conv(256 + 512, 256)
        self.upconv2 = double_conv(128 + 256, 128)
        self.upconv3 = double_conv(128 + 64, 64)
        self.upconvSMAP = nn.Conv2d(64, n_classes, 1)

    def forward(self, x):
        conv1 = self.downconv1(x)
        x = self.maxpool(conv1)

        conv2 = self.downconv2(x)
        x = self.maxpool(conv2)
        
        conv3 = self.downconv3(x)
        x = self.maxpool(conv3)   
        
        x = self.downconv4(x)
        
        # feed in information from the encoder using concat skip connections

        x = self.upsample(x)        
        x = torch.cat([x, conv3], dim=1)
        
        x = self.upconv1(x)
        x = self.upsample(x)        
        x = torch.cat([x, conv2], dim=1)       

        x = self.upconv2(x)
        x = self.upsample(x)        
        x = torch.cat([x, conv1], dim=1)   
        
        return self.upconvSMAP(self.upconv3(x))