In [2]:
import torch
import torch.nn as nn
import sys
sys.path.append("..")
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.util import MinPool
from model.RESUNet import ResBlock
from model.model import *
from model.util import cat_tensor, crop_tensor
from model.model import Unet
from model.FL_seris import Encode, Decode
from torchsummary import summary
from model.util import MinPool, cat_tensor, crop_tensor

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from model.FL_base import FL_base

In [5]:
class Unet(nn.Module):
    def __init__(
                self,
                in_channel = 1,
                out_channel = 1,
                # block_layers=[6, 12, 24, 16], 
                # transition_layer = [256, 512, 1024, 1024],
                middle_channel = [16, 32, 64, 128],
              
        ):
        super(Unet,self).__init__()
   
        self.pre_encode = nn.Sequential(
            Encode(in_channel, middle_channel[0], 4)
        )
        self.out = nn.Sequential(
            nn.Conv2d( middle_channel[0], out_channel,1,1)
        )
        self.brige = nn.Sequential(
            nn.Conv2d(middle_channel[-1], 2 * middle_channel[-1], 1,1),
            nn.BatchNorm2d(2 * middle_channel[-1]),
            nn.ReLU()
        )
        self.brige1 = nn.Sequential(
            nn.Conv2d(middle_channel[-1], middle_channel[-1], 1,1),
            nn.BatchNorm2d(middle_channel[-1]),
            nn.ReLU()
        )
        self.last_decode = Decode(middle_channel[1], middle_channel[0], conv_type = "conv") 
        self.encode = nn.ModuleList(
            [ Encode(
                middle_channel[i], 
                middle_channel[ i+1 ], 2, 
                conv_type = "conv"
            )  for i in range(3) ]
        )
        self.decode = nn.ModuleList(
            [
                Decode(
                    2 * middle_channel[3 - i], 
                    2 * middle_channel[3 - i - 1], 
                    conv_type = "conv") 
                for i in range(3)
            ]
        )
        
    def feature(self, x):
        xc_0, xp_0 = self.pre_encode(x)
        ec_0, ep_0 = self.encode[0](xp_0)
        ec_1, ep_1 = self.encode[1](ep_0)
        ec_2, ep_2 = self.encode[2](ep_1)
        
        x_m = self.brige(ep_2)
        x_n = self.brige1(ec_2)

        d_0 = self.decode[0](x_m, x_n)
        d_1 = self.decode[1](d_0, ec_1)
        d_2 = self.decode[2](d_1, ec_0)
        d_3 = self.last_decode(d_2, xc_0)
        out = self.out(d_3)
        return  out, 0
    
    
    def forward(self, x):
        x= self.feature(x)
        return x

In [60]:
class FL_DETR(FL_base):
    def __init__(
                self,
                in_channel = 1,
                out_channel = 1,
                encode_len = 4, 
                need_return_dict = False
        ):
        super(FL_FPN, self).__init__()
        self.index_len = encode_len - 1
        self.need_return_dict = need_return_dict
        self.downsample = nn.AvgPool2d(2)
        self.upsample = nn.Upsample(scale_factor = 2)
        self.model = Unet()
        self.edge_body = nn.Sequential(
            ResBlock(1,4),
            nn.Conv2d( 4, out_channel, 1, 1),
            nn.ReLU(),
        )
        
    def build_feature_pyramid(self, x): # 80
        x_list = []
        x_list.append(x)
        for i in range(self.index_len):
            x = self.downsample(x) 
            x_list.append( x )
        #         for i in range(self.index_len):
        #             print(x_list[i].shape)
        return   x_list
    
    def get_embeding(self, x_list):
        x_embed1 = crop_tensorensor(x_list[0], 4, 8)
        x_embed2 = crop_tensor(x_list[1], 2, 4)
        x_embed3 = crop_tensor(x_list[2], 1, 2)
        return x_embed1, x_embed2, x_embed3
    
    def get_embeding_detail(self, x, w, h):
        x_re1 = crop_tensor(x, w, h)   
        return x_re1
    
    def re_build_detail(self, x, w, h):
        x_re1 = cat_tensor(x, w, h)   
        return x_re1
    
    def re_build(self, x_list):
        # x = x.permute(0, 2, 1, 3, 4)
        x_re1 = self.re_build_detail(x_list[0], 4, 8)
        x_re2 = self.re_build_detail(x_list[1], 2, 4)
        x_re3 = self.re_build_detail(x_list[2], 1, 2)
        return x_re1, x_re2, x_re3
    
    def ext_feature(self, x):
        # x_feature_list = []
        
        feature = self.ext_feature_batch(x[2], 1, 2)
        attn_map = self.re_build_detail(feature, 1, 2)
        
        feature = self.ext_feature_batch(x[1], 2, 4, attention_map = attn_map)
        attn_map = self.re_build_detail(feature, 2, 4)
        
        feature = self.ext_feature_batch(x[0], 4, 8, attention_map = attn_map)
        hot_map = self.re_build_detail(feature, 4, 8)
        
        return hot_map

    def ext_feature_batch(self, x, w, h, attention_map = None):
        
        if attention_map is not None:
            x = x * self.upsample(attention_map)
        
        print(x.shape, w, h)
        x_embed = self.get_embeding_detail(x, w, h)
        BB, B, C, W, H =  x_embed.shape
        
        batch_item_combined_hm_preds = []
        for batch_index in range(BB): 

            batch_item_x_embed = x_embed[batch_index,:,:,:,:]
            #### your forward model here
            output, _ = self.model( batch_item_x_embed )
            #### 
            batch_item_combined_hm_preds.append(output)

        x_combine = torch.stack(batch_item_combined_hm_preds, 0)
        return x_combine
    

    def build_results(self, x,y,z):
        return {
            "mask": x,
            "cmask": y,
            "edge":z,
        }
    
    def edge_hot_map(self, x):
        edge = nn.functional.pad(x, (1, 0, 1, 0))
        edge = self.dilate(edge) - self.erode(edge)
        edge = self.edge_body(edge)
        return edge
    
    def forward(self, x):  
        x_list = self.build_feature_pyramid(x)
        out = self.ext_feature(x_list)
        edge = self.edge_hot_map(out)
        return self.build_results(out, out, edge) if self.need_return_dict else (out, out, edge)


In [61]:
model = FL_FPN()

In [62]:
image = torch.zeros((2,1,320,640))

In [63]:
model(image)

torch.Size([2, 1, 80, 160]) 1 2
torch.Size([2, 1, 160, 320]) 2 4
torch.Size([2, 1, 320, 640]) 4 8


(tensor([[[[ 0.4556,  0.2871,  0.6322,  ...,  0.5112,  0.6055,  0.5030],
           [ 0.1032,  0.1723,  0.9200,  ...,  0.2042,  0.5902,  0.5217],
           [ 0.3199,  0.3158,  0.2130,  ...,  0.9888, -0.3649,  0.4125],
           ...,
           [ 0.3667,  0.1408,  0.4353,  ...,  0.4583,  0.6676,  0.4669],
           [ 0.4667,  0.3416,  0.5208,  ...,  0.4969,  0.2872,  0.3834],
           [ 0.6321,  0.7829,  0.4685,  ...,  0.5026,  0.8143,  0.0414]]],
 
 
         [[[ 0.4556,  0.2871,  0.6322,  ...,  0.5112,  0.6055,  0.5030],
           [ 0.1032,  0.1723,  0.9200,  ...,  0.2042,  0.5902,  0.5217],
           [ 0.3199,  0.3158,  0.2130,  ...,  0.9888, -0.3649,  0.4125],
           ...,
           [ 0.3667,  0.1408,  0.4353,  ...,  0.4583,  0.6676,  0.4669],
           [ 0.4667,  0.3416,  0.5208,  ...,  0.4969,  0.2872,  0.3834],
           [ 0.6321,  0.7829,  0.4685,  ...,  0.5026,  0.8143,  0.0414]]]],
        grad_fn=<CatBackward0>),
 tensor([[[[ 0.4556,  0.2871,  0.6322,  ...,  0.51

In [25]:
l = model.build_feature_pyramid(image)

torch.Size([2, 1, 320, 640])
torch.Size([2, 1, 160, 320])
torch.Size([2, 1, 80, 160])
torch.Size([2, 1, 40, 80])


In [35]:
cr1 = crop_tensor(l[0], 4, 8)
cr2 = crop_tensor(l[1], 2, 4)
cr3 = crop_tensor(l[2], 1, 2)
cr1.shape, cr2.shape, cr3.shape

(torch.Size([2, 32, 1, 80, 80]),
 torch.Size([2, 8, 1, 80, 80]),
 torch.Size([2, 2, 1, 80, 80]))

In [None]:
cr = crop_tensor(l[0], 2, 4)
cr.shape

In [21]:
model(image)

torch.Size([2, 1, 320, 640])
torch.Size([2, 1, 160, 320])
torch.Size([2, 1, 80, 160])
torch.Size([2, 1, 40, 80])


[tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]],
 
 
         [[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]]]),
 tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]],
 
 
         [[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0