In [4]:
from timm.models.resnet import seresnext26d_32x4d
import torch
import torch.nn.functional as F
from torch import nn
import numpy as np

In [5]:
# modify from segmentation_models_pytorch as smp
class MyDecoderBlock(nn.Module):
    def __init__(
        self,
        in_channel,
        skip_channel,
        out_channel,
    ):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channel + skip_channel, out_channel, kernel_size=3, padding=1,),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True),
        )
        self.attention1 = nn.Identity()
        self.conv2 = nn.Sequential(
            nn.Conv2d(out_channel,out_channel,kernel_size=3, padding=1,),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True),
        )
        self.attention2 = nn.Identity()

    def forward(self, x, skip=None):
        x = F.interpolate(x, scale_factor=2, mode='nearest')
        if skip is not None:
            x = torch.cat([x, skip], dim=1)
            x = self.attention1(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.attention2(x)
        return x


class MyUnetDecoder(nn.Module):
    def __init__(self,
                 in_channel,
                 skip_channel,
                 out_channel,
                 ):
        super().__init__()
        self.center = nn.Identity()

        i_channel = [in_channel, ] + out_channel[:-1]
        s_channel = skip_channel
        o_channel = out_channel
        block = [
            MyDecoderBlock(i, s, o)
            for i, s, o in zip(i_channel, s_channel, o_channel)
        ]
        self.block = nn.ModuleList(block)

    def forward(self, feature, skip):
        d = self.center(feature)
        decode = []
        for i, block in enumerate(self.block):
            s = skip[i]
            d = block(d, s)
            decode.append(d)

        last = d
        return last, decode

class Net(nn.Module):
    def __init__(self, ):
        super().__init__() 
        encoder_dim = [64, 256, 512, 1024, 2048]
        decoder_dim = [256, 128, 128, 64, 32 ]

        self.encoder = seresnext26d_32x4d(pretrained=False, in_chans=3) 

        self.decoder = MyUnetDecoder(
            in_channel  = encoder_dim[-1],
            skip_channel= encoder_dim[:-1][::-1]+[0],
            out_channel = decoder_dim,
        )
        self.logit = nn.Conv2d(decoder_dim[-1], 1, kernel_size=1)


    def forward(self, image):
        B, C, H, W = image.shape
        h = (H//32)*32
        w = (W//32)*32
        x = image[:,:,:h,:w]
        x = x.expand(-1, 3, -1, -1)

        encode = [] #self.encoder.forward_features(x)
        e = self.encoder
        x = e.conv1(x)
        x = e.bn1(x)
        x = e.act1(x); encode.append(x)
        x = F.avg_pool2d(x, kernel_size=2, stride=2)

        x = e.layer1(x); encode.append(x)
        x = e.layer2(x); encode.append(x)
        x = e.layer3(x); encode.append(x)
        x = e.layer4(x); encode.append(x)
        #[print(f'encode_{i}', e.shape) for i,e in enumerate(encode)]

        last, decode = self.decoder(
            feature=encode[-1], skip=encode[:-1][::-1]+[None]
        )
        #[print(f'decode_{i}', e.shape) for i,e in enumerate(decode)]
        #print('last', last.shape)

        logit = self.logit(last)
        mask = torch.sigmoid(logit.float())
        mask = F.pad(mask,[0,W-w,0,H-h,0,0,0,0], mode='constant', value=0)
        return mask

def run_check_net():
    height, width = 256, 260
    batch_size = 2

    image = torch.from_numpy(np.random.uniform(0, 1, (batch_size, 1, height, width))).float()#.cuda()

    net = Net()#.cuda()
    #print(net)

    with torch.no_grad():
        with torch.cuda.amp.autocast(enabled=True):
            mask = net(image)

    print('image', image.shape)
    print('mask', mask.shape)


In [6]:
run_check_net()

image torch.Size([2, 1, 256, 260])
mask torch.Size([2, 1, 256, 260])
