In [3]:
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import models

#Any variable with _ means change inplace
def conv3x3(in_, out):
    return nn.Conv2d(in_, out, 3, padding=1)

def concat(xs): # concat xs=(x1,x2) along dimension 1
    return torch.cat(xs,1)
      
class ConvRelu(nn.Module):
    def __init__(self, in_:int, out:int):
        super().__init__()
        self.conv = conv3x3(in_, out)
        self.activation = nn.ReLU(inplace=True)
        
        def forward(self,x):
            x = self.conv(x)
            x = self.ativation(x)
            return x
        

class UNet11(nn.Module):
    def __init__(self, num_classes=1, num_filters=32):
        super().__init__()
        
        self.pool = nn.MaxPool2d(2,2)
        self.encoder = models.vgg11(pretrained=True).features
        self.relu = self.encoder[1]
        self.conv1 = self.encoder[0]
        self.conv2 = self.encoder[3]
        self.conv3s = self.encoder[6]
        self.conv3 = self.encoder[8]
        self.conv4s = self.encoder[11]
        self.conv4 = self.encoder[13]
        self.conv5s = self.encoder[16]
        self.conv5 = self.encoder[18]
        
        self.center = DecoderBlock(num_filters*8*2, num_filters*8*2, num_filters*8)
        self.dec5 = DecoderBlock(num_filters*(16+8), num_filters*8*2, num_filters*8)
        self.dec4 = DecoderBlock(num_filters*(16+8), num_filters*8*2, num_filters*4)
        self.dec3 = DecoderBlock(num_filters*(8+4), num_filters*4*2, num_filters*2)
        self.dec2 = DecoderBlock(num_filters*(4+2), num_filters*2*2, num_filters)
        self.dec1 = ConvRelu(num_filters*(2+1),num_filters)
        
        self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
        
    def forward(self, x):
        conv1 = self.relu(self.conv1(x))
        conv2 = self.relu(self.conv2(self.pool(conv1)))
        conv3s = self.relu(self.conv3s(self.pool(conv2)))
        conv3 = self.relu(self.conv3(conv3s))
        conv4s = self.relu(self.conv4s(self.pool(conv3)))
        conv4 = self.relu(self.conv4(conv4s))
        conv5s = self.relu(self.conv5s(self.pool(conv4)))
        conv5 = self.relu(self.conv5(conv5s))
        
        center = self.center(self.pool(conv5))
        
        dec5 = self.dec5(torch.cat([center, conv5],1))
        dec4 = self.dec4(torch.cat([dec5, conv4],1))
        dec3 = self.dec3(torch.cat([dec4, conv3],1))
        dec2 = self.dec2(torch.cat([dec3, conv2],1))
        dec1 = self.dec1(torch.cat([dec2, conv1],1))
        return F.sigmoid(self.final(dec1))
    
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        super().__init__()
        
        self.block = nn.Sequential(
            ConvRelu(in_channels, middle_channels), 
            nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=3, stride=2, padding =1, output_padding=1),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        return self.block(x)
    
class Loss:
    def __init__(self, dice_weight=1): # This is used in x = Loss();
        self.nll_loss = nn.BCELoss()
        self.dice_weight = dice_weight
        
    def __call__(self, outputs, targets): # This is used in x=Loss(); x(outputs, targets)
        loss = self.nll_loss(outputs, targets)
        if self.dice_weight:
            eps = 1e-15
            dice_target = (targets==1).float()
            dice_output = outputs
            intersection = (dice_output * dice_target).sum()
            union = dice_output.sum() + dice_target.sum() + eps
            
            loss -= torch.log(2*intersection/union)
            
        return loss