In [1]:
import torch
import torch.nn as nn
import sys
sys.path.append("..")
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.util import MinPool
from model.RESUNet import ResBlock
from model.model import *
from model.util import cat_tensor, crop_tensor
from model.model import Unet
from model.FL_seris import Encode, Decode
from torchsummary import summary

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Encode(nn.Module):
    def __init__(self, in_channel, out_channel, block_number = 1, conv_type = "conv"):
        super().__init__( )
        if conv_type == "conv":
            self.conv = RC(in_channel, out_channel, block_number = block_number)
        else:
            self.conv = RCS(in_channel, out_channel, block_number = block_number)
        self.downsample = nn.MaxPool2d(2)
        
    def forward(self, x):
        x_conv = self.conv(x)
        x_pool = self.downsample(x_conv)
        return x_conv, x_pool

In [3]:

class FPN(nn.Module):
    def __init__(
                self,
                in_channel = 1,
                out_channel = 1,
                # block_layers=[6, 12, 24, 16], 
                # transition_layer = [256, 512, 1024, 1024],
                middle_channel = [16, 32, 64, 128],
                encode_len = 4,
                need_return_dict = False
        ):
        super(FPN,self).__init__()
        self.need_return_dict = need_return_dict
        self.downsample = nn.AvgPool2d(2,2)
        self.erode = MinPool(2,2,1)
        self.dilate = nn.MaxPool2d(2, stride = 1)
        middle_channel = middle_channel[ len(middle_channel) - encode_len : ]
        index_len = encode_len - 1
    
        self.pre_encode = nn.Sequential(
            Encode(in_channel, middle_channel[0], 4)
        )
        self.out = nn.Sequential(
            nn.Conv2d( middle_channel[0], out_channel,1,1)
        )
        self.last_decode = Decode(middle_channel[1], middle_channel[0], conv_type = "conv") 
        
        self.encode = nn.ModuleList(
            [ Encode(
                2 * middle_channel[i], 
                middle_channel[ i+1 ], 2, 
                conv_type = "conv"
            )  for i in range(index_len) ]
        )
        self.decode = nn.ModuleList(
            [
                Decode(
                    2 * middle_channel[index_len - i], 
                    2 * middle_channel[index_len - i - 1], 
                    conv_type = "conv") 
                for i in range(index_len)
            ]
        )
        self.CBR = nn.ModuleList()
        for i in range(encode_len):
            self.CBR.append(
                nn.Sequential(
                    nn.Conv2d(in_channel, middle_channel[i],3, 1, 1),
                    nn.BatchNorm2d(middle_channel[i]),
                    nn.ReLU(),
                )
            )
        self.index_len = index_len     
        
    def build_feature_pyramid(self, x): # 80
        x_list = []
        x_list.append(x)
        for i in range(self.index_len + 1):
            x = self.downsample(x) 
            x_list.append( x )
        return   x_list

    def feature(self, x):
        x_encode_list = []
        for i in range(self.index_len + 1):
            print(i)
            x_encode_list.append( self.CBR[i]( x[ i + 1 ] ) )
            print(x_encode_list[-1].shape)
        xc_list = []
        xp_list = []
        
        xc_0, xp_0 = self.pre_encode(x[0])
        xc_list.append(xc_0)
        xp_list.append(xp_0)
        
        for i in range(self.index_len):
            x_cat  = torch.cat([xp_list[i], x_encode_list[i]], 1)
            ec, ep = self.encode[i](x_cat)
            # print(ec.shape, ep.shape, )
            xc_list.append(ec)
            xp_list.append(ep)
        
        x_c = torch.cat([xp_list[-1], x_encode_list[-1]], 1)
        # print("cat :", x_c.shape)
        xc_list = list(reversed( xc_list ))
        
        for i in range(self.index_len):
            x_c = self.decode[i](x_c, xc_list[i])
            # print("decode :", x_c.shape)
        
        x_c = self.last_decode(x_c, xc_0)
        # print("decode :", x_c.shape)
        out = self.out(x_c)
        edge = nn.functional.pad(out, (1, 0, 1, 0))
        edge = self.dilate(edge) - self.erode(edge)
        return  out, edge
    
    
    def forward(self, x):
        x= self.feature(self.build_feature_pyramid(x))
        return x


In [4]:
model = FPN(middle_channel = [8,16,32,64,128], encode_len = 4)

In [24]:
batch_image = torch.zeros(( 1, 1, 80, 80))

In [25]:
320/3

106.66666666666667

In [27]:
ans = model(batch_image)

0
torch.Size([1, 16, 40, 40])
1
torch.Size([1, 32, 20, 20])
2
torch.Size([1, 64, 10, 10])
3
torch.Size([1, 128, 5, 5])


In [185]:
summary(model.cuda(), (1,80,80))

0
torch.Size([2, 32, 40, 40])
1
torch.Size([2, 64, 20, 20])
2
torch.Size([2, 128, 10, 10])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         AvgPool2d-1            [-1, 1, 40, 40]               0
         AvgPool2d-2            [-1, 1, 20, 20]               0
         AvgPool2d-3            [-1, 1, 10, 10]               0
            Conv2d-4           [-1, 32, 40, 40]             320
       BatchNorm2d-5           [-1, 32, 40, 40]              64
              ReLU-6           [-1, 32, 40, 40]               0
            Conv2d-7           [-1, 64, 20, 20]             640
       BatchNorm2d-8           [-1, 64, 20, 20]             128
              ReLU-9           [-1, 64, 20, 20]               0
           Conv2d-10          [-1, 128, 10, 10]           1,280
      BatchNorm2d-11          [-1, 128, 10, 10]             256
             ReLU-12          [-1, 128, 10, 10]               0
           C

In [147]:
ans[1].shape

torch.Size([1, 1, 160, 160])

In [108]:
a = [0.16820919930934905, 0.11555813759565353, 0.10406000860035419, 0.09637853257358074, 0.09082919470965863, 0.08597795620560646, 0.08181536719202995, 0.07813958384096623, 0.07437959440052509, 0.07080989994108677, 0.06790822520852088, 0.0651423578336835, 0.06234778765588999, 0.05975172657519579, 0.05761727601289749, 0.0554961309209466, 0.05328702192753553, 0.05152266371995211, 0.050051631219685075, 0.048332386501133445, 0.04696276798844337, 0.04537936400622129, 0.044138674549758436, 0.04291338182985783, 0.04156916078180075, 0.04070404600352049, 0.03959252323955297, 0.038757648505270484, 0.037821538373827936, 0.03702100329101086, 0.036216543540358546, 0.035624408610165116, 0.03498207278549671, 0.03414684461429715, 0.033567180875688794, 0.03290734075009823, 0.03242268029600382, 0.03188748175278306, 0.031559570580720904, 0.03112207241356373, 0.030761862453073264, 0.030335290562361478, 0.029909557178616524, 0.029660565312951803, 0.029230893459171056, 0.028874552380293607, 0.028548107556998728, 0.028311173133552074, 0.028025481533259154, 0.027650082129985095, 0.027576511520892383, 0.02735332813113928, 0.027140375636518003, 0.026942391190677883, 0.026719382479786873, 0.02650448229163885, 0.026182646062225104, 0.026366023905575277, 0.025947123821824788, 0.0257209007255733, 0.025685580223798753, 0.025583725329488514, 0.025302596911787986, 0.025257541853934525, 0.025111505798995494, 0.024911839812994004, 0.02482968198135495, 0.024761652015149592, 0.02465659558773041, 0.024481285382062196, 0.024310764838010073, 0.024279800951480867, 0.024063430037349464, 0.023906650431454183, 0.024009056072682142, 0.02399568434804678, 0.023833607267588378, 0.023892152924090623, 0.02363531494513154]

In [110]:
a[48]

0.028025481533259154