In [39]:
import torch
import torch.nn as nn

In [24]:
class simple_net(nn.Module):
    def __init__(self):
        super(simple_net, self).__init__()
        
        self.conv0 = nn.Conv2d(10, 20, 3, 1)
        self.bn = nn.BatchNorm2d(20)
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(20, 20, 3, 1)
    
    def __getattr__(self, name):
        return getattr(self.module, name)
    
    def forward(self, inputs):
        output = self.relu(self.bn(self.conv0(inputs)))
        output = self.conv1(output)
        return output

In [25]:
model = simple_net()
model.to("cuda")
model = nn.DataParallel(model)

In [30]:

a = [1,2,3]
b = ['a', 'b', 'c']
list(map(lambda x,y: (x, y), a, b))

[(1, 'a'), (2, 'b'), (3, 'c')]

In [None]:
filter(lambda x: )

In [37]:
class Separable_conv(nn.Module):
    """
    3x3 depth-wise conv + 1x1 point-wise conv
    3x3 depth-wise conv --> [bn, relu](optional) --> 1x1 point-wise conv --> bn
    """
    def __init__(self, in_channel, out_channel=None, kernel=3, stride=1, padding=1, dilation=1, is_bn=False, bias=False):
        super(Separable_conv, self).__init__()
        self.is_bn = is_bn

        if out_channel is None:
            out_channel = in_channel

        self.depthwise_conv = nn.Conv2d(in_channel, in_channel, kernel_size=kernel, stride=stride, padding=padding,
                                        dilation=dilation, groups=in_channel, bias=bias)
        if self.is_bn:
            self.bn = nn.BatchNorm2d(in_channel)
            self.act = nn.PReLU()

        self.pointwise_conv = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0, bias=bias)

    def forward(self, inputs):

        output = self.depthwise_conv(inputs)

        if self.is_bn:
            output = self.act(self.bn(output))

        output = self.pointwise_conv(output)

        return output
    
class SE_block(nn.Module):
    def __init__(self, in_channel, ratio=2, is_dropout=False, block_type='scSE'):
        """
        :param in_channel:
        :param ratio:
        :param block_type: 'scSE', 'sSE', 'cSE'
        https://arxiv.org/pdf/1808.08127.pdf
        https://github.com/ai-med/squeeze_and_excitation/blob/master/squeeze_and_excitation/squeeze_and_excitation.py
        """
        super(SE_block, self).__init__()
        assert block_type in ('scSE', 'sSE', 'cSE')

        self.type = block_type
        self.is_dropout = is_dropout
        self.channel_se = False if self.type == 'sSE' else True
        self.spatial_se = False if self.type == 'cSE' else True

        # spatial squeeze & channel excitation
        if self.channel_se:
            out_channel = in_channel//ratio
            self.cse_fc0 = nn.Linear(in_channel, out_channel)
            self.relu = nn.ReLU(inplace=True)
            self.cse_fc1 = nn.Linear(out_channel, in_channel)
            self.sigmoid = nn.Sigmoid()

        # channel squeeze
        if self.spatial_se:
            self.sse_conv = nn.Conv2d(in_channel, 1, kernel_size=1, stride=1)
            self.sigmoid = nn.Sigmoid()

        if self.is_dropout:
            self.dropout = nn.Dropout2d(0.2)

    def forward(self, inputs):
        batch_size, num_channels, H, W = inputs.size()
        output = inputs

        if self.channel_se:
            output_cse = inputs.view(batch_size, num_channels, -1).mean(dim=2)   # global average
            output_cse = self.relu(self.cse_fc0(output_cse))     # fc --> relu
            output_cse = self.sigmoid(self.cse_fc1(output_cse))  # fc --> sigmoid

            # channel-wise multiple
            output_cse = torch.mul(inputs, output_cse.view(batch_size, num_channels, 1, 1))
            output = output_cse

        if self.spatial_se:
            output_sse = self.sigmoid(self.sse_conv(inputs))     # conv --> sigmoid

            # spatially multiple
            output_sse = torch.mul(inputs, output_sse.view(batch_size, 1, H, W))
            output = output_sse

        if self.type == 'scSE':
            # otuput = output_cse + output_sse
            output = torch.max(output_cse, output_sse)

        if self.is_dropout:
            output = self.dropout(output)

        return output
    
class Transition_Up(nn.Module):
    def __init__(self, in_channel, out_channel=None):
        super(Transition_Up, self).__init__()
        if out_channel is None:
            out_channel = in_channel

        self.out_channel = out_channel
        self.trans_conv = nn.ConvTranspose2d(in_channel, out_channel,
                                             kernel_size=3, stride=2, padding=1, output_padding=1)

    def forward(self, inputs):
        output = self.trans_conv(inputs)

        return output



In [41]:

class Decoder_Dense(nn.Module):
    def __init__(self, in_channel, encode_in, out_channel, is_dropout=True):
        super(Decoder_Dense, self).__init__()
        self.is_dropout = is_dropout
        self.trans_up = Transition_Up(in_channel, encode_in)
        if self.is_dropout:
            self.dropout = nn.Dropout2d(0.2)

        self.depthwise_conv = nn.Sequential(nn.BatchNorm2d(2*encode_in),
                                            nn.PReLU(),
                                            Separable_conv(2*encode_in, out_channel, is_bn=True))

        self.SE_block = SE_block(out_channel)

    def forward(self, encode_in, up_in):
        output = self.trans_up(up_in)
        output = torch.cat([encode_in, output], dim=1)

        if self.is_dropout:
            output = self.dropout(output)

        output = self.depthwise_conv(output)
        output = self.SE_block(output)

        return output

In [42]:
a = Decoder_Dense(20, 10,20)

In [43]:
for name, value in a.named_parameters():
    print(name)

trans_up.trans_conv.weight
trans_up.trans_conv.bias
depthwise_conv.0.weight
depthwise_conv.0.bias
depthwise_conv.1.weight
depthwise_conv.2.depthwise_conv.weight
depthwise_conv.2.bn.weight
depthwise_conv.2.bn.bias
depthwise_conv.2.act.weight
depthwise_conv.2.pointwise_conv.weight
SE_block.cse_fc0.weight
SE_block.cse_fc0.bias
SE_block.cse_fc1.weight
SE_block.cse_fc1.bias
SE_block.sse_conv.weight
SE_block.sse_conv.bias
