In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [192]:
class DoubleConv(nn.Module):
    def __init__(self,in_channel, out_channel, attach=False):
        super().__init__()
        self.attach = attach
        if attach:
            self.double_conv = nn.Sequential(nn.Conv2d(in_channel*2, out_channel, 3, 1, 1),
                                            nn.ReLU(),
                                            nn.Conv2d(out_channel, out_channel, 3, 1, 1),
                                            nn.ReLU())
        else:
            self.double_conv = nn.Sequential(nn.Conv2d(in_channel, out_channel, 3, 1),
                                            nn.ReLU(),
                                            nn.Conv2d(out_channel, out_channel, 3, 1),
                                            nn.ReLU())
            
    def forward(self, x, o=None):
        if not self.attach:
            return self.double_conv(x)
        x = torch.concat((x, o), dim = 1)
        return self.double_conv(x)
    
    
class AttachConvertor(nn.Module):
    def __init__(self,in_channel, out_channel, down_count):
        super().__init__()
        self.maxpool = nn.MaxPool2d(down_count)
        self.a_convertor = nn.Sequential(DoubleConv(in_channel, out_channel),
                                         nn.ReLU())
        
    def forward(self, x):
        return self.maxpool(self.a_convertor(x))
    
    
class Network(nn.Module):
    def __init__(self, in_channel, classes):
        super().__init__()
        self.maxpool = nn.MaxPool2d(2)
        self.head = DoubleConv(in_channel, 128)
        self.attact_l1 = AttachConvertor(in_channel, 128, 2)
        self.conv_l1 = DoubleConv(128, 64, True)
        self.attact_l2 = AttachConvertor(in_channel, 64, 4)
        self.conv_l2 = DoubleConv(64, 32, True)
        self.attact_l3 = AttachConvertor(in_channel, 32, 8)
        self.conv_l3 = DoubleConv(32, 16, True)
        self.attact_t = AttachConvertor(in_channel, 16, 16)
        self.tail = nn.Conv2d(16, 4+classes, 1)
        
    def forward(self, x):
        a1 = self.attact_l1(x)
        a2 = self.attact_l2(x)
        a3 = self.attact_l3(x)
        at = self.attact_t(x)

        x = self.maxpool(self.head(x))
        x = self.maxpool(self.conv_l1(x, a1))
        x = self.maxpool(self.conv_l2(x, a2))
        x = self.maxpool(self.conv_l3(x, a3))
        x = self.tail(x)
        
        return x

In [196]:
nn_network = Network(3, 1)
a = torch.randn((1, 3, 256, 256))
nn_network(a).shape

torch.Size([1, 5, 15, 15])