In [2]:
import sys
sys.path.append("..")
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.util import MinPool
from model.RESUNet import ResBlock
from model.model import *
from model.util import cat_tensor, crop_tensor

  from .autonotebook import tqdm as notebook_tqdm


In [10]:


class DecodeBlock(nn.Module):
    def __init__(self, middle_channel = [8, 16, 32, 64, 128]):
        super().__init__( )
        self.pre = nn.Conv2d(in_channel, middle_channel[0], 1, 1)
        
        self.encode_1 = Encode(middle_channel[0], middle_channel[1], block_number)
        self.encode_2 = Encode(middle_channel[1], middle_channel[2], block_number)
        self.encode_3 = Encode(middle_channel[2], middle_channel[3], block_number)
        self.encode_4 = Encode(middle_channel[3], middle_channel[4], block_number)
        
    def forward(self, x):
        x = self.pre(x)
        x1 = self.encode_1(x)
        x2 = self.encode_2(x1)
        x3 = self.encode_3(x2)
        x4 = self.encode_4(x3)
        return x1, x2, x3, x4

class Encode(nn.Module):
    def __init__(self, in_channel, out_channel, block_number = 1):
        super().__init__( )
        self.conv = RCS(in_channel, out_channel, block_number = block_number)
        self.downsample = nn.MaxPool2d(2)
        
    def forward(self, x):
        x1 = self.conv(x)
        x2 = self.downsample(x1)
        return x1, x2


class EncodeBlock(nn.Module):
    def __init__(self, in_channel, block_number = [ 2, 2, 2, 2], middle_channel = [8, 16, 32, 64, 128]):
        super().__init__( )
        self.pre = nn.Conv2d(in_channel, middle_channel[0], 1, 1)
        
        self.encode_1 = Encode(middle_channel[0], middle_channel[1], block_number[0])
        self.encode_2 = Encode(middle_channel[1], middle_channel[2], block_number[1])
        self.encode_3 = Encode(middle_channel[2], middle_channel[3], block_number[2])
        self.encode_4 = Encode(middle_channel[3], middle_channel[4], block_number[3])
        
    def forward(self, x):
        x = self.pre(x)
        x_1_1, x_1_2 = self.encode_1(x)
        x2 = self.encode_2(x1)
        x3 = self.encode_3(x2)
        x4 = self.encode_4(x3)
        return x1, x2, x3, x4

    
class Decode(nn.Module):
    def __init__(self, in_channel, out_channel):
        super().__init__( )
        self.deconv = DCBL( in_channel, out_channel)
        self.conv = RCS(in_channel, out_channel)
        
    def forward(self, x, y):
        x = self.deconv(x)
        # print(x.shape, y.shape)
        concat = torch.cat([x, y], dim=1)
        x = self.conv(concat)
        return x


class UBlock(nn.Module):
    def __init__(self, in_channel = 1, out_channel = 16, middle_channel = [ 8, 16, 32, 64, 128 ]):
        super().__init__()
        self.encode = EncodeBlock(in_channel, block_number = [ 2, 2, 2, 2], middle_channel = middle_channel )
        self.brige = nn.Conv2d(middle_channel[-1], middle_channel[4], 1, 1)
        self.decode_1 = Decode(middle_channel[2], middle_channel[1])
        self.decode_2 = Decode(middle_channel[3], middle_channel[2])
        self.decode_3 = Decode(middle_channel[4], middle_channel[3])
        self.final = nn.Conv2d( middle_channel[1], out_channel, 1, 1)
        self.up = nn.ConvTranspose2d( middle_channel[1], middle_channel[1], 2, 2)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        x1, x2, x3, x4 = self.encode(x)
        
        x = self.decode_3(x4, x3)
        x = self.decode_2(x, x2)
        x = self.decode_1(x, x1)
        x = self.up(x)
        outp = self.sigmoid(self.final(x))
        return  outp

In [11]:
model = UBlock(1,1)
model(torch.zeros((1,1,32,32))).shape

torch.Size([1, 1, 32, 32])

In [9]:
from model.FPN import FL_base

class FL_tiny(FL_base):
    def __init__(
                self,
                in_channel = 1,
                out_channel = 1,
                middle_channel = 1,
                embed_shape = ( 2, 4),
                batch_size = 16,
                need_return_dict = False
        ):
        super(FL_tiny, self).__init__()
     
        self.batch_size = batch_size
        self.embed_shape = embed_shape
        self.need_return_dict = need_return_dict
        self.middle_channel = middle_channel

        # replace your model
        ####################################
        self.model = nn.Sequential(
            UBlock(middle_channel, middle_channel)
        )
        ####################################
        self.edge_body = nn.Sequential(
            ResBlock(8,4),
            nn.Conv2d( 4, out_channel, 1, 1),
            nn.ReLU(),
        )
        ####################################
        self.consit_body = nn.Sequential(
            nn.Conv2d( middle_channel, 32, 4, 2, 1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Upsample(scale_factor = 2),

            nn.Conv2d( 32, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.Conv2d( 64, 8, 2, 2),
            nn.BatchNorm2d(8),
            nn.ReLU(),
            nn.Upsample(scale_factor = 2),
        )
        ####################################
        
        self.final = nn.Conv2d(8, out_channel, 1,1 )
        self.edge_final = nn.Conv2d(8, out_channel, 1,1 )
        self.relu = nn.ReLU()
        self.sigmod = nn.Sigmoid()
        
        
    def ext_feature(self, x):
        B,C,W,H =  x.shape
        x_embed = self.get_embeding(x) 
        batch_item_combined_hm_preds = []
        for batch_index in range(B): 

            batch_item_x_embed = x_embed[batch_index,:,:,:,:]
            # print(batch_item_x_embed.shape)
            
            #### your forward model here
            output = self.model( batch_item_x_embed ) # only for mask not edge, edge will have another way
            #### 
                
            batch_item_combined_hm_preds.append(output)
            
        x_combine = torch.stack(batch_item_combined_hm_preds, 0)
        outp = self.re_build(x_combine)

        return outp
    
    def consist(self, x):
        x = x.clone().detach()
        x = self.consit_body(x)
        return x
    
    def edge_hot_map(self, x):
        edge = nn.functional.pad(x, (1, 0, 1, 0))
        edge = self.dilate(edge) - self.erode(edge)
        edge = self.edge_body(edge)
        return edge
    
    def forward(self, x):   
        x = self.downsample(x)
        # print(x.shape)
        x = self.ext_feature(x)
        # print(x.shape)
        outp = self.consist(x)
        edge = self.edge_hot_map(outp)
        
        outp = self.final(outp)
        outp = self.sigmod(outp * x)
        # print(outp.shape, edge.shape)
        outp = self.upsample(outp)
        return self.build_results(outp, edge) if (self.need_return_dict) else (outp, edge)

In [10]:
model = FL_tiny(batch_size = 2)

NameError: name 'EncodeBlock' is not defined

In [6]:
model(torch.zeros((2, 1, 320, 640)))

(tensor([[[[0.3197, 0.3175, 0.4390,  ..., 0.4244, 0.3159, 0.2808],
           [0.3141, 0.2890, 0.4597,  ..., 0.3846, 0.2963, 0.2747],
           [0.4872, 0.4927, 0.2044,  ..., 0.4536, 0.3769, 0.3363],
           ...,
           [0.4940, 0.4975, 0.4432,  ..., 0.5497, 0.3067, 0.2640],
           [0.3905, 0.2954, 0.5661,  ..., 0.3658, 0.1840, 0.3157],
           [0.3470, 0.3447, 0.5770,  ..., 0.3713, 0.3525, 0.2240]]],
 
 
         [[[0.3197, 0.3175, 0.4390,  ..., 0.4244, 0.3159, 0.2808],
           [0.3141, 0.2890, 0.4597,  ..., 0.3846, 0.2963, 0.2747],
           [0.4872, 0.4927, 0.2044,  ..., 0.4536, 0.3769, 0.3363],
           ...,
           [0.4940, 0.4975, 0.4432,  ..., 0.5497, 0.3067, 0.2640],
           [0.3905, 0.2954, 0.5661,  ..., 0.3658, 0.1840, 0.3157],
           [0.3470, 0.3447, 0.5770,  ..., 0.3713, 0.3525, 0.2240]]]],
        grad_fn=<SigmoidBackward0>),
 tensor([[[[1.1205, 1.3335, 1.6467,  ..., 0.3186, 0.3237, 0.2227],
           [1.7295, 0.3899, 0.2119,  ..., 0.1929, 0