# Libraries

In [None]:
import torch
import torch.nn as nn
import torchvision
import numpy as np
import torch.nn.functional as F
from torchsummary import summary
from torch.nn import init
from skimage import morphology as morph
import torch.utils.model_zoo as model_zoo
import math
#from skimage.morphology import watershed
from skimage.segmentation import find_boundaries
from scipy import ndimage


# Preparation

In [None]:
# https://github.com/miguelvr/dropblock/blob/master/dropblock/dropblock.py
class DropBlock2D(nn.Module):
    def __init__(self, drop_prob, block_size):
        super(DropBlock2D, self).__init__()
        self.drop_prob = drop_prob
        self.block_size = block_size
    def forward(self, x):
        # shape: (bsize, channels, height, width)
        assert x.dim() == 4, \
            "Expected input with 4 dimensions (bsize, channels, height, width)"
        if not self.training or self.drop_prob == 0.:
            return x
        else:
            # get gamma value
            gamma = self.drop_prob / (self.block_size ** 2)
            # sample mask
            mask = (torch.rand(x.shape[0], *x.shape[2:], device= x.device) < gamma).float()
            # compute block mask
            block_mask = self._compute_block_mask(mask)
            # apply block mask
            out = x * block_mask[:, None, :, :]
            # scale output
            out = out * block_mask.numel() / block_mask.sum()
            return out
    def _compute_block_mask(self, mask):
        block_mask = F.max_pool2d(input=mask[:, None, :, :],
                                  kernel_size=(self.block_size, self.block_size),
                                  stride=(1, 1),
                                  padding=self.block_size // 2)

        if self.block_size % 2 == 0:
            block_mask = block_mask[:, :, :-1, :-1]
        block_mask = 1 - block_mask.squeeze(1)
        return block_mask

class DropBlock3D(DropBlock2D):
    def __init__(self, drop_prob, block_size):
        super(DropBlock3D, self).__init__(drop_prob, block_size)
    def forward(self, x):
        # shape: (bsize, channels, depth, height, width)
        assert x.dim() == 5, \
            "Expected input with 5 dimensions (bsize, channels, depth, height, width)"
        if not self.training or self.drop_prob == 0.:
            return x
        else:
            # get gamma value
            gamma = self.drop_prob / (self.block_size ** 3)
            # sample mask
            mask = (torch.rand(x.shape[0], *x.shape[2:]) < gamma).float()
            # place mask on input device
            mask = mask.to(x.device)
            # compute block mask
            block_mask = self._compute_block_mask(mask)
            # apply block mask
            out = x * block_mask[:, None, :, :, :]
            # scale output
            out = out * block_mask.numel() / block_mask.sum()
            return out
    def _compute_block_mask(self, mask):
        block_mask = F.max_pool3d(input=mask[:, None, :, :, :],
                                  kernel_size=(self.block_size, self.block_size, self.block_size),
                                  stride=(1, 1, 1),
                                  padding=self.block_size // 2)
        if self.block_size % 2 == 0:
            block_mask = block_mask[:, :, :-1, :-1, :-1]
        block_mask = 1 - block_mask.squeeze(1)
        return block_mask

In [None]:
class CBAM(nn.Module):
    def __init__(self, in_channel, reduction_ratio = 8):
        super().__init__()
        self.hid_channel = max(1, in_channel // reduction_ratio)
        self.globalAvgPool = nn.AdaptiveAvgPool2d(1)
        self.globalMaxPool = nn.AdaptiveMaxPool2d(1)
        # Shared MLP.
        self.fc = nn.Sequential(nn.Conv2d(in_channel, self.hid_channel, 1, bias=False),
                               nn.Mish(),
                               nn.Conv2d(self.hid_channel, in_channel, 1, bias=False))
        self.sigmoid = nn.Sigmoid()
        self.conv1 = nn.Conv2d(2, 1, kernel_size=7,
                               stride=1, padding=3, bias=False)
    def forward(self, x):
        ''' Channel attention '''
        avgOut = self.fc(self.globalAvgPool(x))
        maxOut = self.fc(self.globalMaxPool(x))
        Mc = self.sigmoid(avgOut + maxOut)
        Mf1 = Mc * x

        ''' Spatial attention. '''
        avg_out = torch.mean(Mf1, dim=1, keepdim=True)
        max_out, _ = torch.max(Mf1, dim=1, keepdim=True)

        Ms = torch.cat([max_out, avg_out], dim=1)
        Ms = self.sigmoid(self.conv1(Ms))
        Mf2 = Ms * Mf1
        return Mf2

In [None]:
class ConvBn(nn.Sequential):
    def __init__(self, in_channel, out_channel, kernel_size = 3,
                 padding = 1, drop_block=False, block_size = 1, drop_prob = 0):
        super().__init__()
        self.add_module("conv",nn.Conv2d(in_channel, out_channel, kernel_size, padding = padding,bias=False))
        if drop_block:
            self.add_module("drop_block", DropBlock2D(block_size = block_size, drop_prob = drop_prob))
        self.add_module("bn", nn.BatchNorm2d(out_channel))
        self.add_module("mish", nn.Mish())
        self.add_module("cbam", CBAM(out_channel))

class DownSampleBlock(nn.Sequential):
    def __init__(self, in_channel, block_size = 1, drop_prob = 0):
        super().__init__()
        out_channel = in_channel // 2
        self.add_module("conv1", nn.Conv2d(in_channel, out_channel, 1, bias=False))
        self.add_module("drop_block1", DropBlock2D(block_size = block_size, drop_prob = drop_prob))
        self.add_module("bn", nn.BatchNorm2d(out_channel))
        self.add_module("mish", nn.Mish())
        self.add_module("cbam", CBAM(out_channel))
        self.add_module("conv2", nn.Conv2d(out_channel, out_channel, 2, 2, bias=False))
        self.add_module("drop_block2", DropBlock2D(block_size = block_size, drop_prob = drop_prob))


class AttentionBlock(nn.Module):
    def __init__(self, in_channel, in_channel_skip, out_channel):
        super().__init__()
        self.conv_input = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 1, padding = 0, bias=False),
            nn.BatchNorm2d(out_channel),
            nn.ConvTranspose2d(out_channel, out_channel, 2, 2),
            CBAM(out_channel)
        )
        self.conv_skip = nn.Sequential(
            nn.Conv2d(in_channel_skip, out_channel, 1, bias = False),
            nn.BatchNorm2d(out_channel),
        )
        self.mixed_weight = nn.Sequential(
            nn.Mish(),
            nn.Conv2d(out_channel, 1, 1, bias = False),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
    def forward(self, x, skip):
        input_weight = self.conv_input(x)
        skip_weight = self.conv_skip(skip)
        output_weight = self.mixed_weight(input_weight + skip_weight)
        return output_weight * skip

class DenseLayer(nn.Module):
    def __init__(self, in_channel, grow_rate):
        super().__init__()
        self.layer = nn.Sequential(
            ConvBn(in_channel, grow_rate*4,kernel_size=1, padding=0),
            ConvBn(grow_rate*4, grow_rate)
        )
    def forward(self, x):
        output = self.layer(x)
        return torch.cat([output, x], dim = 1)

class DenseBlock(nn.Sequential):
    def __init__(self, in_channel, grow_rate, repetition):
        super().__init__()
        for i in range(repetition):
            layer = DenseLayer(in_channel+i*grow_rate, grow_rate)
            self.add_module(f"dense_layer_{i+1}", layer)

class DecoderBlock(nn.Module):
    def __init__(self, in_channel, in_channel_skip, out_channel,
                 block_size = 1, drop_prob = 0):
        super().__init__()
        self.conv_trans = nn.ConvTranspose2d(in_channel, out_channel, 2, 2)
        self.attention = AttentionBlock(in_channel, in_channel_skip, out_channel)
        self.convbn = ConvBn(in_channel_skip + out_channel, out_channel, drop_block=True,
                            block_size = block_size, drop_prob = drop_prob)

    def forward(self, x, skip):
        output = self.conv_trans(x)
        attention = self.attention(x, skip)
        output = torch.cat([output, attention], dim=1)
        return self.convbn(output)

class UpsampleBlock(nn.Sequential):
    def __init__(self,  in_channel, out_channel, times):
        super().__init__()
        for i in range(times):
            channel = in_channel if i == 0 else out_channel
            self.add_module(f"convtrans{i+1}", nn.ConvTranspose2d(channel, out_channel, 2, 2))
            self.add_module(f"cbam{i+1}", CBAM(out_channel))

In [None]:
class conv_block(nn.Sequential):
    def __init__(self, ch_in, ch_out, kernel_size = 3,
                 padding = 1, drop_block=False, block_size = 1, drop_prob = 0):
        super().__init__()
        self.add_module("conv1",nn.Conv2d(ch_in, ch_out, kernel_size, padding = padding,bias=False))
        self.add_module("bn1", nn.BatchNorm2d(ch_out))
        self.add_module("relu1", nn.ReLU(inplace=True))
        self.add_module("conv2",nn.Conv2d(ch_out, ch_out, kernel_size, padding = padding,bias=False))
        if drop_block:
            self.add_module("drop_block", DropBlock2D(block_size = block_size, drop_prob = drop_prob))
        self.add_module("bn2", nn.BatchNorm2d(ch_out))
        self.add_module("relu2", nn.ReLU(inplace=True))

class up_conv(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(up_conv,self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(ch_in,ch_out,kernel_size=2,stride=1,padding="same",bias=False,),
		    nn.BatchNorm2d(ch_out),
			nn.ReLU(inplace=True)
        )

    def forward(self,x):
        x = self.up(x)
        return x

In [None]:
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        return self.fn(x) + x

def ConvMixer(dim, depth, kernel_size=9, patch_size=7):
    return nn.Sequential(
        nn.Conv2d(dim, dim, kernel_size=patch_size, stride=patch_size),
        nn.GELU(),
        nn.BatchNorm2d(dim),
        *[nn.Sequential(
                Residual(nn.Sequential(
                    nn.Conv2d(dim, dim, (1, kernel_size), groups=dim, padding="same"),
                    nn.GELU(),
                    nn.BatchNorm2d(dim),
                    nn.Conv2d(dim, dim, (kernel_size, 1), groups=dim, padding="same"),
                    nn.GELU(),
                    nn.BatchNorm2d(dim)
                )),
                nn.Conv2d(dim, dim, kernel_size=1),
                nn.GELU(),
                nn.BatchNorm2d(dim)
        ) for i in range(depth)])
        #nn.AdaptiveAvgPool2d((1,1)),
        #nn.Flatten(),
        #nn.Linear(dim, n_classes))

In [None]:
class Attention_block(nn.Module):
    def __init__(self,F_g,F_l,F_int):
        super(Attention_block,self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
            )

        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self,g,x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1+x1)
        psi = self.psi(psi)

        return x*psi

In [None]:
'''class Conv2d(nn.Module):
    def __init__(self, pdc, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False):
        super(Conv2d, self).__init__()
        if in_channels % groups != 0:
            raise ValueError('in_channels must be divisible by groups')
        if out_channels % groups != 0:
            raise ValueError('out_channels must be divisible by groups')
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, kernel_size, kernel_size))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()
        self.pdc = pdc

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input):

        return self.pdc(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)'''

"class Conv2d(nn.Module):\n    def __init__(self, pdc, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False):\n        super(Conv2d, self).__init__()\n        if in_channels % groups != 0:\n            raise ValueError('in_channels must be divisible by groups')\n        if out_channels % groups != 0:\n            raise ValueError('out_channels must be divisible by groups')\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.kernel_size = kernel_size\n        self.stride = stride\n        self.padding = padding\n        self.dilation = dilation\n        self.groups = groups\n        self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, kernel_size, kernel_size))\n        if bias:\n            self.bias = nn.Parameter(torch.Tensor(out_channels))\n        else:\n            self.register_parameter('bias', None)\n        self.reset_parameters()\n        self.pdc = pdc\n\n    def rese

In [None]:
class _ASPPModule(nn.Module):
    def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm):
        super().__init__()
        self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
                                      stride=1, padding=padding, dilation=dilation, bias=False)
        self.bn = BatchNorm(planes)
        self.silu = nn.SiLU(inplace=True)
        self._init_weight()

    def forward(self, x):
        x = self.atrous_conv(x)
        x = self.bn(x)

        return self.silu(x)

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

class ASPP(nn.Module):
    def __init__(self, inplanes, outplanes, output_stride, BatchNorm):
        super().__init__()
        if output_stride == 4:
            dilations = [1, 6, 12, 18]
        elif output_stride == 8:
            dilations = [1, 4, 6, 10]
        elif output_stride == 2:
            dilations = [1, 12, 24, 36]
        else:
            raise NotImplementedError

        #self.aspp1 = _ASPPModule(inplanes, outplanes, 1, padding=0,dilation=dilations[0], BatchNorm=BatchNorm)
        self.aspp2 = _ASPPModule(inplanes, outplanes, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm)
        self.aspp3 = _ASPPModule(inplanes, outplanes, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm)
        self.aspp4 = _ASPPModule(inplanes, outplanes, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm)

        self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1,1)),
                                             nn.Conv2d(inplanes, outplanes, 1, stride=1, bias=False),
                                             #BatchNorm(outplanes),
                                             nn.SiLU(inplace=True))
        self.conv1 = nn.Conv2d(outplanes*4, outplanes, 1, bias=False)
        self.bn1 = BatchNorm(outplanes)
        self.silu = nn.SiLU(inplace=True)
        self.dropout = nn.Dropout(0.0)
        self._init_weight()

    def forward(self, x):
        #x1 = self.aspp1(x)
        x2 = self.aspp2(x)
        x3 = self.aspp3(x)
        x4 = self.aspp4(x)
        x5 = self.global_avg_pool(x)
        x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat((x2, x3, x4, x5), dim=1)

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.silu(x)

        return self.dropout(x)

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

class PASPP(nn.Module):
    def __init__(self, inplanes, outplanes, output_stride=4, BatchNorm=nn.BatchNorm2d):
        super().__init__()
        if output_stride == 4:
            dilations = [1, 6, 12, 18]
        elif output_stride == 8:
            dilations = [1, 4, 6, 10]
        elif output_stride == 2:
            dilations = [1, 12, 24, 36]
        elif output_stride == 16:
            dilations = [1, 2, 3, 4]
        elif output_stride == 1:
            dilations = [1, 16, 32, 48]
        else:
            raise NotImplementedError
        self._norm_layer = BatchNorm
        self.silu = nn.SiLU(inplace=True)
        self.conv1 = self._make_layer(inplanes, inplanes // 4)
        self.conv2 = self._make_layer(inplanes, inplanes // 4)
        self.conv3 = self._make_layer(inplanes, inplanes // 4)
        self.conv4 = self._make_layer(inplanes, inplanes // 4)
        self.atrous_conv1 = nn.Conv2d(inplanes // 4, inplanes // 4, kernel_size=3, dilation=dilations[0], padding=dilations[0])
        self.atrous_conv2 = nn.Conv2d(inplanes // 4, inplanes // 4, kernel_size=3, dilation=dilations[1], padding=dilations[1])
        self.atrous_conv3 = nn.Conv2d(inplanes // 4, inplanes // 4, kernel_size=3, dilation=dilations[2], padding=dilations[2])
        self.atrous_conv4 = nn.Conv2d(inplanes // 4, inplanes // 4, kernel_size=3, dilation=dilations[3], padding=dilations[3])
        self.conv5 = self._make_layer(inplanes // 2, inplanes // 2)
        self.conv6 = self._make_layer(inplanes // 2, inplanes // 2)
        self.convout = self._make_layer(inplanes, inplanes)

    def _make_layer(self, inplanes, outplanes):
        layer = []
        layer.append(nn.Conv2d(inplanes, outplanes, kernel_size = 1))
        layer.append(self._norm_layer(outplanes))
        layer.append(self.silu)
        return nn.Sequential(*layer)

    def forward(self, X):
        x1 = self.conv1(X)
        x2 = self.conv2(X)
        x3 = self.conv3(X)
        x4 = self.conv4(X)

        x12 = torch.add(x1, x2)
        x34 = torch.add(x3, x4)

        x1 = torch.add(self.atrous_conv1(x1),x12)
        x2 = torch.add(self.atrous_conv2(x2),x12)
        x3 = torch.add(self.atrous_conv3(x3),x34)
        x4 = torch.add(self.atrous_conv4(x4),x34)

        x12 = torch.cat([x1, x2], dim = 1)
        x34 = torch.cat([x3, x4], dim = 1)

        x12 = self.conv5(x12)
        x34 = self.conv5(x34)
        x = torch.cat([x12, x34], dim=1)
        x = self.convout(x)
        return x

In [None]:
class activation_block(nn.Module):
  def __init__(self, outplane):
    super(activation_block, self).__init__()
    self.gelu = nn.GELU()
    self.outplane = outplane
    self.batchnorm = nn.BatchNorm2d(outplane)

  def forward(self, x):
    x = self.gelu(x)
    x = self.batchnorm(x)
    return x

class conv_stem(nn.Module):
  def __init__(self, inplane, outplane, patch_size):
    super(conv_stem, self).__init__()
    self.inplane = inplane
    self.outplane = outplane
    self.patch_size = patch_size
    self.conv = nn.Conv2d(self.inplane, self.outplane, kernel_size=self.patch_size, stride=self.patch_size)
    self.activation = activation_block(self.outplane)

  def forward(self, x):
    x = self.conv(x)
    x = self.activation(x)
    return x

class DepthwiseConv2d(nn.Module):
  def __init__(self, inplane, kernels_per_layer, outplane):
    super(DepthwiseConv2d, self).__init__()
    self.inplane = inplane
    self.outplane = outplane
    self.kernels_per_layer = kernels_per_layer
    self.depthwise = nn.Conv2d(self.inplane, self.inplane * self.kernels_per_layer, kernel_size=3, padding=1, groups=self.inplane)
    #self.pointwise = nn.Conv2d(nin * kernels_per_layer, nout, kernel_size=1)

  def forward(self, x):
    out = self.depthwise(x)
    #out = self.pointwise(out)
    return out

class ConvMixer(nn.Module):
  def __init__(self, inplane, kernels_per_layer, outplane, kernels_size):
    super(ConvMixer, self).__init__()
    self.inplane = inplane
    self.outplane = outplane
    self.kernels_per_layer = kernels_per_layer
    self.kernel_size = kernels_size
    self.depthwise = DepthwiseConv2d(self.inplane, self.kernels_per_layer, self.outplane)
    self.pointwise = nn.Conv2d(self.inplane * self.kernels_per_layer, self.outplane, kernel_size=1)
    self.activation = activation_block(self.outplane)

  def forward(self, x):
    #Depthwise convolution
    x0 = x
    x = self.depthwise(x)
    x = x + x0 #Residual

    #Pointwise convolution
    x = self.pointwise(x)
    x = self.activation(x)
    return x

'''def get_conv_mixer_256_8(image_size=32, filters=256, depth=8, kernel_size=5, patch_size=2, num_classes=10):
    inputs = keras.Input((image_size, image_size, 3))
    x = layers.Rescaling(scale=1.0 / 255)(inputs)

    # Extract patch embeddings.
    x = conv_stem(x, filters, patch_size)

    # ConvMixer blocks.
    for _ in range(depth):
        x = conv_mixer_block(x, filters, kernel_size)

    # Classification block.
    x = layers.GlobalAvgPool2D()(x)
    outputs = layers.Dense(num_classes, activation="softmax")(x)

    return keras.Model(inputs, outputs)'''

'def get_conv_mixer_256_8(image_size=32, filters=256, depth=8, kernel_size=5, patch_size=2, num_classes=10):\n    inputs = keras.Input((image_size, image_size, 3))\n    x = layers.Rescaling(scale=1.0 / 255)(inputs)\n\n    # Extract patch embeddings.\n    x = conv_stem(x, filters, patch_size)\n\n    # ConvMixer blocks.\n    for _ in range(depth):\n        x = conv_mixer_block(x, filters, kernel_size)\n\n    # Classification block.\n    x = layers.GlobalAvgPool2D()(x)\n    outputs = layers.Dense(num_classes, activation="softmax")(x)\n\n    return keras.Model(inputs, outputs)'

In [None]:
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        return self.fn(x) + x


class ConvMixerBlock(nn.Module):
    def __init__(self, dim=512, depth=7, k=7):
        super(ConvMixerBlock, self).__init__()
        self.block = nn.Sequential(
            *[nn.Sequential(
                Residual(nn.Sequential(
                    # deep wise
                    nn.Conv2d(dim, dim, kernel_size=(k, k), groups=dim, padding=(k // 2, k // 2)),
                    nn.GELU(),
                    nn.BatchNorm2d(dim)
                )),
                nn.Conv2d(dim, dim, kernel_size=(1, 1)),
                nn.GELU(),
                nn.BatchNorm2d(dim)
            ) for i in range(depth)]
        )

    def forward(self, x):
        x = self.block(x)
        return x

# Models

## Attention U-Net

In [None]:
class Attention_UNet(nn.Module):
    def __init__(self,img_ch=3,output_ch=2, drop_prob=0):
        super(Attention_UNet,self).__init__()

        self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)

        self.Conv1 = conv_block(ch_in=img_ch,ch_out=64)
        self.Conv2 = conv_block(ch_in=64,ch_out=128)
        self.Conv3 = conv_block(ch_in=128,ch_out=256)
        self.Conv4 = conv_block(ch_in=256,ch_out=512, drop_block=True, block_size = 5, drop_prob = drop_prob)
        self.Conv5 = conv_block(ch_in=512,ch_out=1024, drop_block=True, block_size = 5, drop_prob = drop_prob)

        self.Up5 = up_conv(ch_in=1024,ch_out=512)
        self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256)
        self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)

        self.Up4 = up_conv(ch_in=512,ch_out=256)
        self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128)
        self.Up_conv4 = conv_block(ch_in=512, ch_out=256)

        self.Up3 = up_conv(ch_in=256,ch_out=128)
        self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64)
        self.Up_conv3 = conv_block(ch_in=256, ch_out=128)

        self.Up2 = up_conv(ch_in=128,ch_out=64)
        self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32)
        self.Up_conv2 = conv_block(ch_in=128, ch_out=64)

        self.Conv_1x1 = nn.Sequential(
            nn.Conv2d(64, output_ch, kernel_size=1,stride=1,padding=0),
            nn.Softmax(dim=1)
            )

    def forward(self,x):
        # encoding path
        x1 = self.Conv1(x)

        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)

        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        x4 = self.Att5(d5,x4)
        d5 = torch.cat((x4,d5),dim=1)
        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        x3 = self.Att4(d4,x3)
        d4 = torch.cat((x3,d4),dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        x2 = self.Att3(d3,x2)
        d3 = torch.cat((x2,d3),dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        x1 = self.Att2(d2,x1)
        d2 = torch.cat((x1,d2),dim=1)
        d2 = self.Up_conv2(d2)

        d1 = self.Conv_1x1(d2)

        return d1

## Proposed

In [None]:
class SegNet(nn.Module):
    def __init__(self, input_channel = 3, in_channel = 32,
                 num_classes = 2, drop_prob = 0):
        super().__init__()
        self.conv1 = nn.Sequential(
            ConvBn(input_channel, in_channel),
            ConvBn(in_channel, in_channel)
        )
        grow_list = [16, 32, 64, 64, 64]
        repetition_list = [6, 6, 6, 6, 6]
        block_list = [5, 4, 3, 2]
        ch_decoder = [256, 128, 64, 32]
        in_ch_skip = []
        self.dense_list = nn.ModuleList()
        self.downsample_list = nn.ModuleList()
        self.decoder_list = nn.ModuleList()
        self.up_sample_list = nn.ModuleList()

        for i in range(4):
            self.dense_list.append(DenseBlock(in_channel, grow_list[i], repetition_list[i]))
            in_channel += repetition_list[i] * grow_list[i]
            in_ch_skip.append(in_channel)
            self.downsample_list.append(DownSampleBlock(in_channel, block_list[i], drop_prob))
            in_channel = in_channel // 2

        i+=1
        self.bottle_neck = DenseBlock(in_channel, grow_list[i], repetition_list[i])
        in_channel += repetition_list[i] * grow_list[i]
        for i in range(4):
            self.decoder_list.append(DecoderBlock(in_channel, in_ch_skip[-i-1], ch_decoder[i],
                                                  block_list[-i-1], drop_prob))
            self.up_sample_list.append(UpsampleBlock(in_channel, num_classes, 4-i))
            in_channel = ch_decoder[i]
        in_channel += 4 * num_classes

        self.conv2 = nn.Sequential(
            nn.BatchNorm2d(in_channel),
            nn.Mish(),
            nn.Conv2d(in_channel, num_classes, kernel_size=1, padding=0),
            nn.Softmax(dim=1)
            )

    def forward(self, x):
        x = self.conv1(x)
        encoder_for_cat = []
        output_cat = []
        for i in range(4):
            x = self.dense_list[i](x)
            encoder_for_cat.append(x)
            x = self.downsample_list[i](x)
        x = self.bottle_neck(x)
        #x = self.middle(x)
        output_cat.append(self.up_sample_list[0](x))
        for i in range(4):
            x = self.decoder_list[i](x, encoder_for_cat[-i-1])
            if i < 3:
                output_cat.append(self.up_sample_list[i+1](x))
        output_cat.append(x)
        output = torch.cat(output_cat, dim=1)
        output = self.conv2(output)

        return output


## draft

In [None]:
class U_Net_mixer(nn.Module):
    def __init__(self,img_ch=3,output_ch=2, drop_prob = 0):
        super(U_Net_mixer,self).__init__()

        self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
        channel = 64
        self.Conv1 = conv_block(ch_in=img_ch,ch_out=channel)
        self.Conv2 = conv_block(ch_in=channel,ch_out=channel*2)
        self.Conv3 = conv_block(ch_in=channel*2,ch_out=channel*4)
        self.Conv4 = conv_block(ch_in=channel*4,ch_out=channel*8, drop_block=True, block_size = 5, drop_prob = drop_prob)
        self.Conv5 = conv_block(ch_in=channel*8,ch_out=channel*16, drop_block=True, block_size = 3, drop_prob = drop_prob)

        self.middle = ConvMixer(channel*16, 1, 9, 7)

        self.Up5 = up_conv(ch_in=channel*16,ch_out=channel*8)
        self.Up_conv5 = conv_block(ch_in=channel*16, ch_out=channel*8)

        self.Up4 = up_conv(ch_in=channel*8,ch_out=channel*4)
        self.Up_conv4 = conv_block(ch_in=channel*8, ch_out=channel*4)

        self.Up3 = up_conv(ch_in=channel*4,ch_out=channel*2)
        self.Up_conv3 = conv_block(ch_in=channel*4, ch_out=channel*2)

        self.Up2 = up_conv(ch_in=channel*2,ch_out=channel)
        self.Up_conv2 = conv_block(ch_in=channel*2, ch_out=channel)

        self.Conv_1x1 = nn.Sequential(
            nn.Conv2d(channel, output_ch,kernel_size=1,stride=1,padding=0),
            nn.Softmax(dim=1)
            )


    def forward(self,x):
        # encoding path
        x1 = self.Conv1(x)

        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)

        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)

        #x5 = self.middle(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        d5 = torch.cat((x4,d5),dim=1)
        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        d4 = torch.cat((x3,d4),dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((x2,d3),dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((x1,d2),dim=1)
        d2 = self.Up_conv2(d2)

        d1 = self.Conv_1x1(d2)

        return d1

## Modified PiDiNet

In [None]:
class CSAM(nn.Module):
    """
    Compact Spatial Attention Module
    """
    def __init__(self, channels):
        super(CSAM, self).__init__()

        mid_channels = 4
        self.relu1 = nn.ReLU()
        self.conv1 = nn.Conv2d(channels, mid_channels, kernel_size=1, padding=0)
        self.conv2 = nn.Conv2d(mid_channels, 1, kernel_size=3, padding=1, bias=False)
        self.sigmoid = nn.Sigmoid()
        nn.init.constant_(self.conv1.bias, 0)

    def forward(self, x):
        y = self.relu1(x)
        y = self.conv1(y)
        y = self.conv2(y)
        y = self.sigmoid(y)

        return x * y

class CDCM(nn.Module):
    """
    Compact Dilation Convolution based Module
    """
    def __init__(self, in_channels, out_channels):
        super(CDCM, self).__init__()

        self.relu1 = nn.ReLU()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
        self.conv2_1 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=5, padding=5, bias=False)
        self.conv2_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=7, padding=7, bias=False)
        self.conv2_3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=9, padding=9, bias=False)
        self.conv2_4 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=11, padding=11, bias=False)
        nn.init.constant_(self.conv1.bias, 0)

    def forward(self, x):
        x = self.relu1(x)
        x = self.conv1(x)
        x1 = self.conv2_1(x)
        x2 = self.conv2_2(x)
        x3 = self.conv2_3(x)
        x4 = self.conv2_4(x)
        return x1 + x2 + x3 + x4

class MapReduce(nn.Module):
    """
    Reduce feature maps into a single edge map
    """
    def __init__(self, channels):
        super(MapReduce, self).__init__()
        self.conv = nn.Conv2d(channels, 2, kernel_size=1, padding=0)
        nn.init.constant_(self.conv.bias, 0)

    def forward(self, x):
        return self.conv(x)

In [None]:
class PiDiNet(nn.Module):
    def __init__(self, img_channel=3, inplane=32, num_classes=2, dil=8, sa=True, ta=True, drop_prob = 0, msag = True, csag = True): #dil=None, sa=False; inplane luc dau muon depth tu 3 len 32/64, pdcs
        super(PiDiNet, self).__init__()
        self.sa = sa
        self.ta = ta
        self.msag = msag
        self.csag = csag
        if dil is not None:
            assert isinstance(dil, int), 'dil should be an int'
        self.dil = dil

        self.fuseplanes = []

        self.inplane = inplane
        self.img_channel = img_channel
        self.num_classes = num_classes
        self.drop_prob = drop_prob

        self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
        self.Conv5 = conv_block(ch_in=self.inplane*8,ch_out=self.inplane*16, drop_block=True, block_size = 3, drop_prob = self.drop_prob)
        self.middle_1 = conv_stem(inplane=self.inplane*16, outplane=self.inplane*16, patch_size=1)
        self.middle_2 = ConvMixer(inplane=self.inplane*16, kernels_per_layer=1, outplane=self.inplane*16, kernels_size=5)
        self.softmax = nn.Softmax(dim=1)

        self.init_block = nn.Conv2d(self.img_channel, self.inplane, kernel_size=3, padding=1) #x1->32
        self.aspp = PASPP(self.inplane*16,self.inplane*16,16)

        self.Conv1 = conv_block(ch_in=self.img_channel,ch_out=self.inplane)
        self.Up2 = up_conv(ch_in=self.inplane*2,ch_out=self.inplane)
        self.Att2 = Attention_block(F_g=self.inplane,F_l=self.inplane,F_int=self.inplane//2)
        self.Up_conv2 = conv_block(ch_in=self.inplane*2, ch_out=self.inplane)
        self.fuseplanes.append(self.inplane) # C
        #ours
        self.msag1 = MSAG(self.inplane)
        self.csag1 = CSAG(channel1 = self.inplane, channel2 = self.img_channel)
        self.mixer1 = ConvMixerBlock(dim = self.inplane, depth = 1)

        inplane = self.inplane
        self.inplane = self.inplane * 2 #x2->64
        self.Conv2 = conv_block(ch_in=self.inplane//2,ch_out=self.inplane)
        self.Up3 = up_conv(ch_in=self.inplane*2,ch_out=self.inplane)
        self.Att3 = Attention_block(F_g=self.inplane,F_l=self.inplane,F_int=self.inplane//2)
        self.Up_conv3 = conv_block(ch_in=self.inplane*2, ch_out=self.inplane)
        self.fuseplanes.append(self.inplane) # 2C
        #ours
        self.msag2 = MSAG(self.inplane)
        self.csag2 = CSAG(channel1 = self.inplane, channel2 = self.inplane//2)
        self.mixer2 = ConvMixerBlock(dim = self.inplane, depth = 1)


        inplane = self.inplane
        self.inplane = self.inplane * 2 #x4->128
        self.Conv3 = conv_block(ch_in=self.inplane//2,ch_out=self.inplane)
        self.Up4 = up_conv(ch_in=self.inplane*2,ch_out=self.inplane)
        self.Att4 = Attention_block(F_g=self.inplane,F_l=self.inplane,F_int=self.inplane//2)
        self.Up_conv4 = conv_block(ch_in=self.inplane*2,ch_out=self.inplane)
        self.fuseplanes.append(self.inplane) # 4C
        #ours
        self.msag3 = MSAG(self.inplane)
        self.csag3 = CSAG(channel1 = self.inplane, channel2 = self.inplane//2)
        self.mixer3 = ConvMixerBlock(dim = self.inplane, depth = 1)

        inplane = self.inplane
        self.inplane = self.inplane * 2 #x8->256
        self.Conv4 = conv_block(ch_in=self.inplane//2,ch_out=self.inplane, drop_block=True, block_size = 5, drop_prob = self.drop_prob)
        self.Up5 = up_conv(ch_in=self.inplane*2,ch_out=self.inplane)
        self.Att5 = Attention_block(F_g=self.inplane,F_l=self.inplane,F_int=self.inplane//2)
        self.Up_conv5 = conv_block(ch_in=self.inplane*2,ch_out=self.inplane)
        self.fuseplanes.append(self.inplane) # 8C
        #ours
        self.csag4 = CSAG(channel1 = self.inplane, channel2 = self.inplane//2)
        self.msag4 = MSAG(self.inplane)
        self.mixer4 = ConvMixerBlock(dim = self.inplane, depth = 1)

        #Conv-Mixer
        self.convmixer = ConvMixerBlock(dim = self.inplane*2)

        self.conv_reduces = nn.ModuleList()
        if self.sa and self.dil is not None:
            self.attentions = nn.ModuleList()
            self.dilations = nn.ModuleList()
            for i in range(4):
                self.dilations.append(CDCM(self.fuseplanes[i], self.dil))
                self.attentions.append(CSAM(self.dil))
                self.conv_reduces.append(MapReduce(self.dil))
        elif self.sa:
            self.attentions = nn.ModuleList()
            for i in range(4):
                self.attentions.append(CSAM(self.fuseplanes[i]))
                self.conv_reduces.append(MapReduce(self.fuseplanes[i]))
        elif self.dil is not None:
            self.dilations = nn.ModuleList()
            for i in range(4):
                self.dilations.append(CDCM(self.fuseplanes[i], self.dil))
                self.conv_reduces.append(MapReduce(self.dil))
        else:
            for i in range(4):
                self.conv_reduces.append(MapReduce(self.fuseplanes[i]))

        #self.classifier = nn.Conv2d(4, 1, kernel_size=1)
        self.classifier = nn.Sequential(
            nn.Conv2d(8, self.num_classes, kernel_size=1,stride=1,padding=0), nn.Softmax(dim=1))
        #nn.init.constant_(self.classifier.weight, 0.25)
        #nn.init.constant_(self.classifier.bias, 0)

        print('initialization done')

    def get_weights(self):
        conv_weights = []
        bn_weights = []
        relu_weights = []
        for pname, p in self.named_parameters():
            if 'bn' in pname:
                bn_weights.append(p)
            elif 'relu' in pname:
                relu_weights.append(p)
            else:
                conv_weights.append(p)

        return conv_weights, bn_weights, relu_weights

    def forward(self, x):
        H, W = x.size()[2:]
        #x = self.init_block(x)
        # encoding path
        x1 = self.Conv1(x)
        x1_1 = x1

        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)
        x2_1 = x2

        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)
        x3_1 = x3

        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)
        x4_1 = x4

        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)
        x5_1 = x5


        #Conv-Mixer
        x5 = self.convmixer(x5)
        print(x5.shape)
        x5 = self.middle_1(x5)
        x5 = self.middle_2(x5)


        # decoding + concat path
        d5 = self.Up5(x5)
        #Proposed Attention
        if self.ta:
            x4 = self.Att5(d5,x4)
        if self.csag:
            x4 = self.csag4(x3, x4, x5)
        if self.msag:
            x4 = self.msag4(x4)

        d5 = torch.cat((x4,d5),dim=1)
        d5 = self.Up_conv5(d5)
        d5 = self.mixer4(d5)
        d4 = self.Up4(d5)

        #Proposed Attention
        if self.ta:
            x3 = self.Att4(d4,x3)
        if self.csag:
            x3 = self.csag3(x2, x3, x4)
        if self.msag:
            x3 = self.msag3(x3)

        d4 = torch.cat((x3,d4),dim=1)
        d4 = self.Up_conv4(d4)
        d4 = self.mixer3(d4)
        d3 = self.Up3(d4)
        #Proposed Attention
        if self.ta:
            x2 = self.Att3(d3,x2)
        if self.csag:
            x2 = self.csag2(x1, x2, x3)
        if self.msag:
            x2 = self.msag2(x2)
        d3 = torch.cat((x2,d3),dim=1)
        d3 = self.Up_conv3(d3)
        d3 = self.mixer2(d3)
        d2 = self.Up2(d3)
        #Proposed Attention
        if self.ta:
            x1 = self.Att2(d2,x1)
        if self.csag:
            x1 = self.csag1(x, x1, x2)
        if self.msag:
            x1 = self.msag1(x1)

        #
        d2 = torch.cat((x1,d2),dim=1)
        d2 = self.Up_conv2(d2)
        d2 = self.mixer1(d2)
        x_fuses = []
        if self.sa and self.dil is not None:
            for i, xi in enumerate([d2,d3,d4,d5]):
                x_fuses.append(self.attentions[i](self.dilations[i](xi)))
        elif self.sa:
            for i, xi in enumerate([d2,d3,d4,d5]):
                x_fuses.append(self.attentions[i](xi))
        elif self.dil is not None:
            for i, xi in enumerate([d2,d3,d4,d5]):
                x_fuses.append(self.dilations[i](xi))
        else:
            x_fuses = [d2,d3,d4,d5]

        e1 = self.conv_reduces[0](x_fuses[0])
        e1 = F.interpolate(e1, (H, W), mode="bilinear", align_corners=False)

        e2 = self.conv_reduces[1](x_fuses[1])
        e2 = F.interpolate(e2, (H, W), mode="bilinear", align_corners=False)

        e3 = self.conv_reduces[2](x_fuses[2])
        e3 = F.interpolate(e3, (H, W), mode="bilinear", align_corners=False)

        e4 = self.conv_reduces[3](x_fuses[3])
        e4 = F.interpolate(e4, (H, W), mode="bilinear", align_corners=False)

        outputs = [e1, e2, e3, e4]

        output = self.classifier(torch.cat(outputs, dim=1))

        output = self.softmax(output)
        return output



In [None]:
x = torch.rand(1,3,192,256)
conv_stem1 = conv_stem(512,51,1)
model = PiDiNet()
print(model(x).shape)


initialization done
torch.Size([1, 512, 12, 16])
torch.Size([1, 2, 192, 256])


#Hope Net for last version


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MHSA(nn.Module):
    def __init__(self, n_dims, width=14, height=14, heads=4):
        super(MHSA, self).__init__()
        self.heads = heads

        self.query = nn.Conv2d(n_dims, n_dims, kernel_size=1)
        self.key = nn.Conv2d(n_dims, n_dims, kernel_size=1)
        self.value = nn.Conv2d(n_dims, n_dims, kernel_size=1)

        self.rel_h = nn.Parameter(torch.randn([1, heads, n_dims // heads, 1, height]), requires_grad=True)
        self.rel_w = nn.Parameter(torch.randn([1, heads, n_dims // heads, width, 1]), requires_grad=True)

        self.softmax = nn.Softmax(dim=-1)
        self.norm = nn.LayerNorm([n_dims, width, height])

    def forward(self, x):
        n_batch, C, width, height = x.size()
        q = self.query(x).view(n_batch, self.heads, C // self.heads, -1)
        k = self.key(x).view(n_batch, self.heads, C // self.heads, -1)
        v = self.value(x).view(n_batch, self.heads, C // self.heads, -1)

        content_content = torch.matmul(q.permute(0, 1, 3, 2), k)

        content_position = (self.rel_h + self.rel_w).view(1, self.heads, C // self.heads, -1).permute(0, 1, 3, 2)
        content_position = torch.matmul(content_position, q)

        energy = content_content + content_position
        attention = self.softmax(energy)

        out = torch.matmul(v, attention.permute(0, 1, 3, 2))
        out = out.view(n_batch, C, width, height)
        out = self.norm(out) + x
        return out

In [None]:
class CSAG(nn.Module):

    def __init__(self, channel1, channel2):
        super(CSAG, self).__init__()
        self.channel2 = channel2
        self.upsample_layer = nn.ConvTranspose2d(in_channels=channel1*2, out_channels=channel1, kernel_size=2, stride=2, padding=0)
        nn.init.constant_(self.upsample_layer.bias, 0)

        self.conv_up = nn.Conv2d(channel2, channel1, kernel_size=2, padding=0, stride = 2)
        self.conv_up_input = nn.Conv2d(channel2, channel1, kernel_size=1, padding=0)
        nn.init.constant_(self.conv_up.bias, 0)
        self.conv_end = nn.Conv2d(channel1*3, channel1, kernel_size=1, padding=0)
    def forward(self, x1, x2, x3):
        x3 = self.upsample_layer(x3)

        if self.channel2 == 3:
            x1 = self.conv_up_input(x1)
        else:
            x1 = self.conv_up(x1)
        out = [x1, x2, x3]
        out = torch.cat(out, dim=1)
        out = self.conv_end(out)
        return out

In [None]:
import torch.nn as nn
import torch


class MSAG(nn.Module):
    """
    Multi-scale attention gate
    """
    def __init__(self, channel):
        super(MSAG, self).__init__()
        self.channel = channel
        self.pointwiseConv = nn.Sequential(
            nn.Conv2d(self.channel, self.channel, kernel_size=1, padding=0, bias=True),
            nn.BatchNorm2d(self.channel),
        )
        self.ordinaryConv = nn.Sequential(
            nn.Conv2d(self.channel, self.channel, kernel_size=3, padding=1, stride=1, bias=True),
            nn.BatchNorm2d(self.channel),
        )
        self.dilationConv = nn.Sequential(
            nn.Conv2d(self.channel, self.channel, kernel_size=3, padding=2, stride=1, dilation=2, bias=True),
            nn.BatchNorm2d(self.channel),
        )
        self.voteConv = nn.Sequential(
            nn.Conv2d(self.channel * 3, self.channel, kernel_size=(1, 1)),
            nn.BatchNorm2d(self.channel),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x1 = self.pointwiseConv(x)
        x2 = self.ordinaryConv(x)
        x3 = self.dilationConv(x)
        _x = self.relu(torch.cat((x1, x2, x3), dim=1))
        _x = self.voteConv(_x)
        x = x + x * _x
        return x

In [None]:
class Hope_Net(nn.Module):
    def __init__(self, img_channel=3, inplane=32, num_classes=2, dil=8, sa=True, ta=False, drop_prob = 0, mhsa = True, csag = True): #dil=None, sa=False; inplane luc dau muon depth tu 3 len 32/64, pdcs
        super(Hope_Net, self).__init__()
        self.sa = sa
        self.ta = ta
        self.csag = csag
        self.mhsa = mhsa
        if dil is not None:
            assert isinstance(dil, int), 'dil should be an int'
        self.dil = dil

        self.fuseplanes = []

        self.inplane = inplane
        self.img_channel = img_channel
        self.num_classes = num_classes
        self.drop_prob = drop_prob

        self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
        self.Conv5 = conv_block(ch_in=self.inplane*8,ch_out=self.inplane*16, drop_block=True, block_size = 3, drop_prob = self.drop_prob)
        self.middle_1 = conv_stem(inplane=self.inplane*16, outplane=self.inplane*16, patch_size=2)
        self.middle_2 = ConvMixer(inplane=self.inplane*16, kernels_per_layer=1, outplane=self.inplane*16, kernels_size=5)
        self.softmax = nn.Softmax(dim=1)

        self.init_block = nn.Conv2d(self.img_channel, self.inplane, kernel_size=3, padding=1) #x1->32
        self.aspp = PASPP(self.inplane*16,self.inplane*16,16)

        self.Conv1 = conv_block(ch_in=self.img_channel,ch_out=self.inplane)
        self.Up2 = up_conv(ch_in=self.inplane*2,ch_out=self.inplane)
        self.Att2 = Attention_block(F_g=self.inplane,F_l=self.inplane,F_int=self.inplane//2)
        self.Up_conv2 = conv_block(ch_in=self.inplane*2, ch_out=self.inplane)
        self.fuseplanes.append(self.inplane) # C
        #ours
        self.csag1 = CSAG(channel1 = self.inplane, channel2 = self.img_channel)
        self.mixer1 = ConvMixerBlock(dim = self.inplane, depth = 1)

        self.msag1 = MSAG(self.inplane)

        inplane = self.inplane
        self.inplane = self.inplane * 2 #x2->64
        self.Conv2 = conv_block(ch_in=self.inplane//2,ch_out=self.inplane)
        self.Up3 = up_conv(ch_in=self.inplane*2,ch_out=self.inplane)
        self.Att3 = Attention_block(F_g=self.inplane,F_l=self.inplane,F_int=self.inplane//2)
        self.Up_conv3 = conv_block(ch_in=self.inplane*2, ch_out=self.inplane)
        self.fuseplanes.append(self.inplane) # 2C
        #ours
        self.csag2 = CSAG(channel1 = self.inplane, channel2 = self.inplane//2)
        self.mixer2 = ConvMixerBlock(dim = self.inplane, depth = 1)

        self.msag2 = MSAG(self.inplane)

        inplane = self.inplane
        self.inplane = self.inplane * 2 #x4->128
        self.Conv3 = conv_block(ch_in=self.inplane//2,ch_out=self.inplane)
        self.Up4 = up_conv(ch_in=self.inplane*2,ch_out=self.inplane)
        self.Att4 = Attention_block(F_g=self.inplane,F_l=self.inplane,F_int=self.inplane//2)
        self.Up_conv4 = conv_block(ch_in=self.inplane*2,ch_out=self.inplane)
        self.fuseplanes.append(self.inplane) # 4C
        #ours
        self.csag3 = CSAG(channel1 = self.inplane, channel2 = self.inplane//2)
        self.mixer3 = ConvMixerBlock(dim = self.inplane, depth = 1)

        self.msag3 = MSAG(self.inplane)

        inplane = self.inplane
        self.inplane = self.inplane * 2 #x8->256
        self.Conv4 = conv_block(ch_in=self.inplane//2,ch_out=self.inplane, drop_block=True, block_size = 5, drop_prob = self.drop_prob)
        self.Up5 = up_conv(ch_in=self.inplane*2,ch_out=self.inplane)
        self.Att5 = Attention_block(F_g=self.inplane,F_l=self.inplane,F_int=self.inplane//2)
        self.Up_conv5 = conv_block(ch_in=self.inplane*2,ch_out=self.inplane)
        self.fuseplanes.append(self.inplane) # 8C
        #ours
        self.csag4 = CSAG(channel1 = self.inplane, channel2 = self.inplane//2)
        self.mixer4 = ConvMixerBlock(dim = self.inplane, depth = 1)

        self.msag4 = MSAG(self.inplane)
        #Conv-Mixer
        self.convmixer = ConvMixerBlock(dim = self.inplane*2)
        self.mhsa = MHSA(n_dims = 512, width=12, height=16, heads=4)

        self.conv_reduces = nn.ModuleList()
        if self.sa and self.dil is not None:
            self.attentions = nn.ModuleList()
            self.dilations = nn.ModuleList()
            for i in range(4):
                self.dilations.append(CDCM(self.fuseplanes[i], self.dil))
                self.attentions.append(CSAM(self.dil))
                self.conv_reduces.append(MapReduce(self.dil))
        elif self.sa:
            self.attentions = nn.ModuleList()
            for i in range(4):
                self.attentions.append(CSAM(self.fuseplanes[i]))
                self.conv_reduces.append(MapReduce(self.fuseplanes[i]))
        elif self.dil is not None:
            self.dilations = nn.ModuleList()
            for i in range(4):
                self.dilations.append(CDCM(self.fuseplanes[i], self.dil))
                self.conv_reduces.append(MapReduce(self.dil))
        else:
            for i in range(4):
                self.conv_reduces.append(MapReduce(self.fuseplanes[i]))

        #self.classifier = nn.Conv2d(4, 1, kernel_size=1)
        self.classifier = nn.Sequential(
            nn.Conv2d(8, self.num_classes, kernel_size=1,stride=1,padding=0), nn.Softmax(dim=1))
        #nn.init.constant_(self.classifier.weight, 0.25)
        #nn.init.constant_(self.classifier.bias, 0)

        print('initialization done')

    def get_weights(self):
        conv_weights = []
        bn_weights = []
        relu_weights = []
        for pname, p in self.named_parameters():
            if 'bn' in pname:
                bn_weights.append(p)
            elif 'relu' in pname:
                relu_weights.append(p)
            else:
                conv_weights.append(p)

        return conv_weights, bn_weights, relu_weights

    def forward(self, x):
        H, W = x.size()[2:]

        # encoding path
        #Layer 1
        x1 = self.Conv1(x)

        #Layer 2
        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)

        #Layer 3
        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)

        #Layer 4
        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)

        #Layer 5
        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)



        #Transformer
        #x5 = self.convmixer(x5)
        X5 = self.mhsa(x5)


        # decoding + concat path
        d5 = self.Up5(x5)

        #Proposed Attention
        if self.ta:
            x4 = self.Att5(d5,x4)
        if self.csag:
            x4 = self.csag4(x3, x4, x5)
        if self.mhsa:
            x4 = self.msag4(x4)

        d5 = torch.cat((x4,d5),dim=1)
        d5 = self.Up_conv5(d5)
        d5 = self.mixer4(d5)

        d4 = self.Up4(d5)

        #Proposed Attention
        if self.ta:
            x3 = self.Att4(d4,x3)
        if self.csag:
            x3 = self.csag3(x2, x3, d5)
        if self.mhsa:
            x3 = self.msag3(x3)


        d4 = torch.cat((x3,d4),dim=1)
        d4 = self.Up_conv4(d4)
        d4 = self.mixer3(d4)
        d3 = self.Up3(d4)
        #Proposed Attention
        if self.ta:
            x2 = self.Att3(d3,x2)
        if self.csag:
            x2 = self.csag2(x1, x2, d4)
        if self.mhsa:
            x2 = self.msag2(x2)

        d3 = torch.cat((x2,d3),dim=1)
        d3 = self.Up_conv3(d3)
        d3 = self.mixer2(d3)
        d2 = self.Up2(d3)
        #Proposed Attention
        if self.ta:
            x1 = self.Att2(d2,x1)
        if self.csag:
            x1 = self.csag1(x, x1, d3)
        if self.mhsa:
            x1 = self.msag1(x1)

        #
        d2 = torch.cat((x1,d2),dim=1)
        d2 = self.Up_conv2(d2)
        d2 = self.mixer1(d2)
        x_fuses = []
        if self.sa and self.dil is not None:
            for i, xi in enumerate([d2,d3,d4,d5]):
                x_fuses.append(self.attentions[i](self.dilations[i](xi)))
        elif self.sa:
            for i, xi in enumerate([d2,d3,d4,d5]):
                x_fuses.append(self.attentions[i](xi))
        elif self.dil is not None:
            for i, xi in enumerate([d2,d3,d4,d5]):
                x_fuses.append(self.dilations[i](xi))
        else:
            x_fuses = [d2,d3,d4,d5]

        e1 = self.conv_reduces[0](x_fuses[0])
        e1 = F.interpolate(e1, (H, W), mode="bilinear", align_corners=False)

        e2 = self.conv_reduces[1](x_fuses[1])
        e2 = F.interpolate(e2, (H, W), mode="bilinear", align_corners=False)

        e3 = self.conv_reduces[2](x_fuses[2])
        e3 = F.interpolate(e3, (H, W), mode="bilinear", align_corners=False)

        e4 = self.conv_reduces[3](x_fuses[3])
        e4 = F.interpolate(e4, (H, W), mode="bilinear", align_corners=False)

        outputs = [e1, e2, e3, e4]

        output = self.classifier(torch.cat(outputs, dim=1))

        output = self.softmax(output)
        return output



In [None]:
#x = torch.rand(1,3,192,256)
#model = Hope_Net()
#print(model(x).shape)
