In [13]:
import torch
import torch.nn as nn
from torchvision import transforms,models
from torchvision.models.resnet import ResNet, BasicBlock

In [14]:
WIDTH = 640
HEIGHT = 480

In [15]:
def Conv1x1(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

In [16]:
def ConvTrans2x2(in_planes, out_planes):
    return nn.ConvTranspose2d(in_planes, out_planes, kernel_size=2, stride=2, bias=False)

In [17]:
def ReLU_BN(in_planes):
    return nn.Sequential(nn.ReLU(),nn.BatchNorm2d(in_planes))

In [18]:
class NimbroNet18(ResNet):
    def __init__(self):
        super(NimbroNet18,self).__init__(BasicBlock,[2,2,2,2])
        state_dict = models.utils.load_state_dict_from_url('https://download.pytorch.org/models/resnet18-5c106cde.pth',
                                              progress=True)
        self.load_state_dict(state_dict)
        del self.avgpool
        del self.fc
        for name, child in self.named_children():
            for name2, params in child.named_parameters():
                params.requires_grad = False
                
        self.conv_1_1x1 = Conv1x1(64,128)
        self.conv_2_1x1 = Conv1x1(128,256)
        self.conv_3_1x1 = Conv1x1(256,256)
        
        self.relu1 = nn.ReLU()
        self.conv_trans1 = ConvTrans2x2(512,256)
        
        self.relu_bn1 = ReLU_BN(512)
        self.conv_trans2 = ConvTrans2x2(512,256)
        
        self.relu_bn2 = ReLU_BN(512)
        self.conv_trans3 = ConvTrans2x2(512,128)
        
        self.relu_bn3 = ReLU_BN(256)
        
        self.conv_4_1x1 = Conv1x1(256,3)
        
        self.conv_5_1x1 = Conv1x1(256,3)
        
        self.loc_dep_bias = nn.Parameter(torch.randn((1,3,HEIGHT//4,WIDTH//4),requires_grad=True))
        
        
    def forward(self,x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        x1 = self.layer1(x)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)
        
        x1 = self.conv_1_1x1(x1)
        x2 = self.conv_2_1x1(x2)
        x3 = self.conv_3_1x1(x3)
        
        x4 = self.relu1(x4)
        x4 = self.conv_trans1(x4)
        
        x4_x3 = torch.cat((x3,x4),1)
        
        x4_x3 = self.relu_bn1(x4_x3)
        x4_x3 = self.conv_trans2(x4_x3)
        
        x4_x3_x2 = torch.cat((x4_x3,x2),1)
        
        x4_x3_x2 = self.relu_bn2(x4_x3_x2)
        x4_x3_x2 = self.conv_trans3(x4_x3_x2)
        
        x4_x3_x2_x1 = torch.cat((x4_x3_x2,x1),1)
        
        x4_x3_x2_x1 = self.relu_bn3(x4_x3_x2_x1)
        
        if torch.cuda.is_available():
            seg = self.conv_4_1x1(x4_x3_x2_x1) + self.loc_dep_bias.cuda()
            blobs = self.conv_5_1x1(x4_x3_x2_x1) + self.loc_dep_bias.cuda()
        else:
            seg = self.conv_4_1x1(x4_x3_x2_x1) + self.loc_dep_bias
            blobs = self.conv_5_1x1(x4_x3_x2_x1) + self.loc_dep_bias
        
        return seg,blobs
        