In [3]:
import torch
from torch import nn, optim
from torch.autograd import Variable
import torch.nn.functional as F

In [6]:
# simple 4-layer FCN
class FullyConvNet_1(nn.Module):
    def __init__(self):
        super(FullyConvNet, self).__init__()
        self.conv1 = nn.Conv2d(200, 100,kernel_size=5, stride=1,padding=2)  # input and output is 1 dimension
        self.conv2 = nn.Conv2d(100, 50,kernel_size=7, stride=1,padding=3)
        self.conv3 = nn.Conv2d(50, 10,kernel_size=7, stride=1,padding=3)
        self.conv4 = nn.Conv2d(10, 1,kernel_size=3, stride=1,padding=1)
        self.relu = nn.ReLU()

    def forward(self, x):
        out_layer1 = self.relu(self.conv1(x))
        out_layer2 = self.relu(self.conv2(out_layer1))
        out_layer3 = self.relu(self.conv3(out_layer2))
        out = self.conv4(out_layer3)
        return out

In [7]:
# encoding-decoding FCN
class FullyConvNet_2(nn.Module):
    def __init__(self):
        super(FullyConvNet, self).__init__()
        self.conv1 = nn.Conv2d(200, 16,kernel_size=1, stride=1,padding=0)  # input and output is 1 dimension
        self.conv2 = nn.Conv2d(16, 32,kernel_size=5, stride=1,padding=2)
        self.conv3 = nn.Conv2d(32, 64,kernel_size=7, stride=1,padding=3)
        self.conv4 = nn.Conv2d(64, 128,kernel_size=3, stride=1,padding=1)
        
        self.conv5 = nn.Conv2d(128, 64,kernel_size=3, stride=1,padding=1)
        self.conv6 = nn.Conv2d(64, 32,kernel_size=7, stride=1,padding=3)
        self.conv7 = nn.Conv2d(32, 16,kernel_size=5, stride=1,padding=2)
        self.conv8 = nn.Conv2d(16, 1,kernel_size=1, stride=1,padding=1)
        self.relu = nn.ReLU()

    def forward(self, x):
        out_layer1 = self.relu(self.conv1(x))
        
        out_layer2 = self.relu(F.max_pool2d(self.conv2(out_layer1),2))
        out_layer3 = self.relu(F.max_pool2d(self.conv3(out_layer2),2))
        out_layer4 = self.relu(F.max_pool2d(self.conv4(out_layer3),2))
        
        out_layer5 = self.relu(self.conv5(F.upsample(out_layer4,scale_factor=2,mode='bilinear')))
        out_layer6 = self.relu(self.conv6(F.upsample(out_layer5,scale_factor=2,mode='bilinear')))
        out_layer7 = self.relu(self.conv7(F.upsample(out_layer6,scale_factor=2,mode='bilinear')))
        
        out = self.conv8(out_layer7)
        return out