In [1]:
import sys 
sys.path.append("..") 
from train import *
from utils.score import cal_all_score
from dataloader import Dataload
from torch.utils.data import DataLoader
import cv2
import matplotlib.pyplot as plt
matplotlib.use('TkAgg')
%matplotlib inline
import datetime

  from .autonotebook import tqdm as notebook_tqdm


[0]
detect set : [0]
use gpu: True


In [2]:
from model.GTU.models.GT_UNet import *

In [25]:
class Decode(nn.Module):
    def __init__(self, ch_in, ch_out):
        dim_in = ch_in
        dim_out = ch_out
        super(Decode, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d( dim_in, dim_out, kernel_size = 3, stride = 1, padding = 1),
            nn.BatchNorm2d(dim_out),
            nn.ReLU(),
            nn.Conv2d(dim_out, dim_out, kernel_size = 1, stride = 1, padding = 0),
            nn.BatchNorm2d(dim_out),
            nn.ReLU(),
            nn.Conv2d(dim_out, dim_out, kernel_size = 1, stride = 1, padding = 0),
            nn.BatchNorm2d(dim_out),
            nn.ReLU(),
        )
        self.up = nn.Upsample( scale_factor = 2)
        
    def forward(self, x, y):
        y = self.up(y)
        x = torch.cat([x,y], 1)
        x = self.conv(x)
        return x


In [26]:
class BotBlock(nn.Module):

    def __init__(self, in_dimension, curr_h, curr_w, proj_factor = 4, activation = 'relu', pos_enc_type = 'relative',
                 stride = 1, target_dimension = None):
        self.w = curr_w
        super(BotBlock, self).__init__()
        if stride !=  1 or in_dimension !=  target_dimension:
            self.shortcut  =  nn.Sequential(
                nn.Conv2d(in_dimension, target_dimension, kernel_size = 3, padding = 1, stride = stride),
                BNReLU(target_dimension, activation = activation, nonlinearity = True),
            )
        else:
            self.shortcut  =  None

        bottleneck_dimension  =  target_dimension // proj_factor
        self.conv1  =  nn.Sequential(
            nn.Conv2d(in_dimension, bottleneck_dimension, kernel_size = 3, padding = 1, stride = 1 ),
            BNReLU(bottleneck_dimension, activation = activation, nonlinearity = True)
        )

        self.mhsa  =  MHSA(
            in_channels = bottleneck_dimension, 
            heads = 4, curr_h = curr_h, curr_w = curr_w,
            pos_enc_type = pos_enc_type
        )
        conv2_list  =  []
        if stride !=  1:
            assert stride  ==  2, stride
            conv2_list.append(nn.AvgPool2d(kernel_size = (2, 2), stride = (2, 2)))  # TODO: 'same' in tf.pooling
        conv2_list.append(BNReLU(bottleneck_dimension, activation = activation, nonlinearity = True))
        
        self.conv2  =  nn.Sequential(*conv2_list)

        self.conv3  =  nn.Sequential(
            nn.Conv2d(bottleneck_dimension, target_dimension, kernel_size = 3,padding = 1, stride = 1),
            BNReLU(target_dimension, nonlinearity = False, init_zero = True),
        )
        self.last_act  =  get_act(activation)


    def forward(self, x):
        # print("x origin:", x.shape)
        if self.shortcut is not None:
            shortcut  =  self.shortcut(x)
        else:
            shortcut  =  x
        Q_h  =  Q_w  =  self.w
        N, C, H, W  =  x.shape
        P_h, P_w  =  H // Q_h, W // Q_w
        print("x:", x.shape)
        x  =  x.reshape(N * P_h * P_w, C, Q_h, Q_w)
        print("x:", x.shape, "P_h:",P_h, "P_w:", P_w)
        out  =  self.conv1(x)
        print("out mhsa in:",out.shape)
        out  =  self.mhsa(out)
        print("out mhsa out:",out.shape)
        out  =  out.permute(0, 3, 1, 2)  # back to pytorch dim order

        out  =  self.conv2(out)
        out  =  self.conv3(out)
        
        N1, C1, H1, W1  =  out.shape
        out  =  out.reshape(N, C1, int(H), int(W))

        out +=  shortcut
        out  =  self.last_act(out)
        print("out:", out.shape)
        return out

In [27]:
class MHSA(nn.Module):


    def __init__(self, in_channels, heads, curr_h, curr_w, pos_enc_type = 'relative', use_pos = True):
        super(MHSA, self).__init__()
        self.q_proj  =  GroupPointWise(in_channels, heads, proj_factor = 1)
        self.k_proj  =  GroupPointWise(in_channels, heads, proj_factor = 1)
        self.v_proj  =  GroupPointWise(in_channels, heads, proj_factor = 1)

        assert pos_enc_type in ['relative', 'absolute']
        if pos_enc_type  ==  'relative':
            self.self_attention  =  RelPosSelfAttention(curr_h, curr_w, in_channels // heads, fold_heads = True)
        else:
            raise NotImplementedError

    def forward(self, input):
        q  =  self.q_proj(input)
        k  =  self.k_proj(input)
        v  =  self.v_proj(input)
        # print("qkv:", q.shape, k.shape, v.shape, )
        o  =  self.self_attention(q = q, k = k, v = v)
        return o

In [28]:
class GT_U_DCNet(nn.Module):
    def __init__(self, img_ch  =  1, output_ch  =  1, 
                middle_channel = [64, 128, 256, 512, 1024], 
                encode_len = 5, 
                need_return_dict = True,
                need_supervision = False,
                decode_type = "conv",

        ):
        super(GT_U_DCNet,self).__init__()
        
        self.index_len = encode_len - 1
        self.need_return_dict = need_return_dict
        self.downsample = nn.AvgPool2d(2)
        self.Maxpool  =  nn.MaxPool2d(kernel_size = 2,stride = 2)
        self.erode = MinPool(2,2,1)
        self.dilate = nn.MaxPool2d(2, stride = 1)
        self.encode_list = nn.ModuleList()
        self.up_list = nn.ModuleList()
        self.decode_list = nn.ModuleList()
        if self.need_return_dict:
            self.supervision_list = nn.ModuleList()
        self.select  = nn.Sequential(
                nn.Conv2d(2, 32, 3, 1, 1),
                nn.BatchNorm2d(32),
                nn.ReLU(),
                nn.Conv2d( 32, 16, 3, 1, 1),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.Conv2d( 16, output_ch, 1, 1),
                nn.BatchNorm2d(output_ch),
                nn.ReLU(),
            )
        middle_channel = middle_channel[ len(middle_channel) - encode_len : ]
        middle_channel =  [ img_ch, *middle_channel]
        self.need_supervision = need_supervision
        print( middle_channel )
        self.pre_encode = _make_bot_layer(
            ch_in = middle_channel[0],         
            ch_out = middle_channel[ 1 ], 
            w = 4
        )
        for i in range(1, encode_len):
            self.encode_list.append( 
                _make_bot_layer(
                    ch_in = 2 * middle_channel[i],             
                    ch_out = middle_channel[ i+1 ], 
                    w = 4
                ) 
            )
        
        for i in range(1, encode_len):
            now_dim = encode_len - i + 1
            next_dim = encode_len - i
            print( middle_channel[ now_dim ],   middle_channel[ next_dim ])
            self.up_list.append( up_conv(ch_in = middle_channel[now_dim ] , ch_out = 2 * middle_channel[next_dim]) )
            if decode_type == "conv":
                self.decode_list.append( 
                    Decode(ch_in = 2 * middle_channel[ now_dim ], ch_out = middle_channel[next_dim]) 
                    # DSK( middle_channel[ now_dim ], middle_channel[next_dim])
                )
            else:
                self.decode_list.append( 
                     _make_bot_layer(
                        ch_in = 2 * middle_channel[ now_dim ], 
                        ch_out = middle_channel[next_dim],
                        w = 2 ** i
                    ) 
                )
            
            if self.need_supervision:
                self.supervision_list.append( nn.Conv2d(middle_channel[next_dim], output_ch, 1, 1) )
            #  _make_bot_layer(
            #     ch_in = 2 * middle_channel[ now_dim ], 
            #     ch_out = middle_channel[next_dim],
            #     w = 2 ** i) 
            #  )
            
        self.CBR = nn.ModuleList()
        for i in range(encode_len):
            self.CBR.append(
                nn.Sequential(
                    nn.Conv2d(middle_channel[0], middle_channel[i+1], 3, 1, 1),
                    nn.BatchNorm2d(middle_channel[i+1]),
                    nn.ReLU(),
                )
            )
        self.last_up = nn.Upsample(scale_factor = 2)
        self.last_decode = nn.Conv2d(middle_channel[next_dim], output_ch, kernel_size = 1, stride = 1, padding = 0)
        self.supervision_list.append( nn.Conv2d(middle_channel[next_dim], output_ch, 1, 1) )


    def build_results(self, x,y,z, super_vision  = None):
        if super_vision is None:
            return {
            "mask": x,
            "cmask": y,
            "edge":z,
        }
        else:
            return {
                "mask": x,
                "cmask": y,
                "edge":z,
                "super":super_vision
            }
            
    
    def build_feature_pyramid(self, x): # 80
        x_list = []
        for i in range(self.index_len + 1):
            x = self.downsample(x) 
            x_list.append( x )
        return   x_list
    
    def edge_hot_map(self, x):
        x = x.clone().detach()
        edge = nn.functional.pad(x, (1, 0, 1, 0))
        edge = self.dilate(edge) - self.erode(edge)
        return edge
    

    def forward(self,x):
        x_list = self.build_feature_pyramid( x )
        pre_x_list = []
        out_list = []
        supervision = []
        for index in range(len(x_list)):
            pre_x = self.CBR[index](x_list[index])
            pre_x_list.append(pre_x)
            # print(x_list[index].shape, pre_x.shape)
        # pre_x_list = [x, *pre_x_list]
        
        out = self.pre_encode( x )
        out_pool  =  self.Maxpool(out)
        out_list.append(out)
        # print( "x{}:{} {}".format( 0, out.shape, out_pool.shape ))
        
        # encoding
        for index in range(self.index_len):
            x_temp = torch.cat( [pre_x_list[index], out_pool], 1) 
            # print( "x{} cat:{}".format( index, x.shape))
            # print(x_temp.shape)
            out = self.encode_list[index]( x_temp )
            out_pool = self.Maxpool(out)
            out_list.append(out)
            print( "encode {}:{} {}".format( index + 1, out.shape, out_pool.shape ))
            
        x_temp = out_pool 
        # decoding
        for index in range(self.index_len):
            # up = self.up_list[index](x_temp)
            # x_temp = torch.cat( [up, out_list[ self.index_len - index ] ], dim = 1)
            up = out_list[ self.index_len - index ]
            x_temp = self.decode_list[index](up, x_temp)
            # x_temp = self.decode_list[index](x_temp)
            supervision.append(x_temp)
            # self.need_supervision[index](x_temp)
            print( "decode{}:{} {}".format( index, up.shape, x_temp.shape))
        
        # print("final decode:", x_temp.shape)
        outp = self.last_up(x_temp)
        out = self.last_decode(outp)
        
        edge = self.edge_hot_map(out)
        outp= self.select(torch.cat([out, edge], 1))
        # print("outp:{}".format( outp.shape ))
        if self.need_supervision:
            for i in range( self.index_len ):
                supervision[i] = self.supervision_list[i](supervision[i])
            return self.build_results(outp, outp, edge, supervision) if self.need_return_dict else( out, edge, supervision) 
        return self.build_results(out, outp, edge) if self.need_return_dict else( out, edge ) 

In [29]:
def _make_bot_layer(ch_in, ch_out, w = 4):

    W  =  H  =  w
    dim_in  =  ch_in
    dim_out  =  ch_out

    stage5  =  []

    stage5.append(
        BotBlock(
            in_dimension = dim_in, curr_h = H, 
            curr_w = W, stride = 1 , 
            target_dimension = dim_out,
        )
    )
    return nn.Sequential(*stage5)

In [30]:
model = GT_U_DCNet(1,1, middle_channel = [64,128, 256, 512, 768, 1024], encode_len = 6)# .to("cpu")

[1, 64, 128, 256, 512, 768, 1024]
1024 768
768 512
512 256
256 128
128 64


In [31]:
model = model.to("cuda:0")

In [32]:
batch_image = torch.zeros((1,1, 640, 640)).to("cuda:0")

In [33]:
with torch.no_grad():
    ans = model(batch_image)

x: torch.Size([1, 1, 640, 640])
x: torch.Size([25600, 1, 4, 4]) P_h: 160 P_w: 160
out mhsa in: torch.Size([25600, 16, 4, 4])
out mhsa out: torch.Size([25600, 4, 4, 16])
out: torch.Size([1, 64, 640, 640])
x: torch.Size([1, 128, 320, 320])
x: torch.Size([6400, 128, 4, 4]) P_h: 80 P_w: 80
out mhsa in: torch.Size([6400, 32, 4, 4])
out mhsa out: torch.Size([6400, 4, 4, 32])
out: torch.Size([1, 128, 320, 320])
encode 1:torch.Size([1, 128, 320, 320]) torch.Size([1, 128, 160, 160])
x: torch.Size([1, 256, 160, 160])
x: torch.Size([1600, 256, 4, 4]) P_h: 40 P_w: 40
out mhsa in: torch.Size([1600, 64, 4, 4])
out mhsa out: torch.Size([1600, 4, 4, 64])
out: torch.Size([1, 256, 160, 160])
encode 2:torch.Size([1, 256, 160, 160]) torch.Size([1, 256, 80, 80])
x: torch.Size([1, 512, 80, 80])
x: torch.Size([400, 512, 4, 4]) P_h: 20 P_w: 20
out mhsa in: torch.Size([400, 128, 4, 4])
out mhsa out: torch.Size([400, 4, 4, 128])
out: torch.Size([1, 512, 80, 80])
encode 3:torch.Size([1, 512, 80, 80]) torch.Size(

RuntimeError: Given groups=1, weight of size [32, 2, 3, 3], expected input[1, 65, 640, 640] to have 2 channels, but got 65 channels instead

In [10]:
ans = model(batch_image)

x: torch.Size([1, 1, 320, 320])
x: torch.Size([400, 1, 16, 16]) P_h: 20 P_w: 20
out mhsa in: torch.Size([400, 8, 16, 16])
out mhsa out: torch.Size([400, 16, 16, 8])
out: torch.Size([400, 32, 16, 16])
x0:torch.Size([1, 32, 320, 320]) torch.Size([1, 32, 160, 160])
x: torch.Size([1, 64, 160, 160])
x: torch.Size([100, 64, 16, 16]) P_h: 10 P_w: 10
out mhsa in: torch.Size([100, 16, 16, 16])
out mhsa out: torch.Size([100, 16, 16, 16])
out: torch.Size([100, 64, 16, 16])
x1:torch.Size([1, 64, 160, 160]) torch.Size([1, 64, 80, 80])
x: torch.Size([1, 128, 80, 80])
x: torch.Size([100, 128, 8, 8]) P_h: 10 P_w: 10
out mhsa in: torch.Size([100, 32, 8, 8])
out mhsa out: torch.Size([100, 8, 8, 32])
out: torch.Size([100, 128, 8, 8])
x2:torch.Size([1, 128, 80, 80]) torch.Size([1, 128, 40, 40])
x: torch.Size([1, 256, 40, 40])
x: torch.Size([100, 256, 4, 4]) P_h: 10 P_w: 10
out mhsa in: torch.Size([100, 64, 4, 4])
out mhsa out: torch.Size([100, 4, 4, 64])
out: torch.Size([100, 256, 4, 4])
x3:torch.Size([1,

In [19]:
ans['mask'].shape

torch.Size([1, 512, 80, 80])

In [305]:
torch.save(obj=model.state_dict(), f="test.pkl")

In [7]:
from model.FL_GTU import FL_GTU

In [8]:
model = FL_GTU()

[1, 128, 256, 512]
[1, 128, 256, 512]
[1, 128, 256, 512]


In [9]:
batch_image = torch.zeros((1,1, 320, 640))

In [36]:
class GT_UFPN_Net(nn.Module):
    def __init__(self, img_ch  =  1, output_ch  =  1, 
                middle_channel = [32, 64, 128, 256, 512], 
                encode_len = 3, 
                need_return_dict = False
        ):
        super(GT_UFPN_Net,self).__init__()
        
        self.index_len = encode_len - 1
        self.need_return_dict = need_return_dict
        self.downsample = nn.AvgPool2d(2)
        self.Maxpool  =  nn.MaxPool2d(kernel_size = 2,stride = 2)
        self.erode = MinPool(2,2,1)
        self.dilate = nn.MaxPool2d(2, stride = 1)
        self.encode_list = nn.ModuleList()
        self.up_list = nn.ModuleList()
        self.decode_list = nn.ModuleList()
        
        middle_channel = middle_channel[ len(middle_channel) - encode_len : ]
        middle_channel =  [ img_ch, *middle_channel]
        print( middle_channel )
        self.pre_encode = _make_bot_layer(ch_in = middle_channel[0], ch_out = middle_channel[ 1 ])
        for i in range(1, encode_len):
            self.encode_list.append( _make_bot_layer(ch_in = 2 * middle_channel[i], ch_out = middle_channel[ i+1 ]) )
        
        for i in range(1, encode_len):
            now_dim = encode_len - i + 1
            next_dim = encode_len - i
            # print( middle_channel[ now_dim ],   middle_channel[ next_dim ])
            self.up_list.append( up_conv(ch_in = middle_channel[now_dim ] , ch_out = 2 * middle_channel[next_dim]) )
            self.decode_list.append( _make_bot_layer(ch_in = 2 * middle_channel[ now_dim ], ch_out = middle_channel[next_dim]) )
            
        self.CBR = nn.ModuleList()
        for i in range(encode_len):
            self.CBR.append(
                nn.Sequential(
                    nn.Conv2d(middle_channel[0], middle_channel[i+1], 3, 1, 1),
                    nn.BatchNorm2d(middle_channel[i+1]),
                    nn.ReLU(),
                )
            )
        self.last_up = nn.Upsample(scale_factor = 2)
        self.last_decode = nn.Conv2d(middle_channel[next_dim], output_ch, kernel_size = 1, stride = 1, padding = 0)



    def build_results(self, x,y,z):
        return {
            "mask": x,
            "cmask": y,
            "edge":z,
        }
    
    def build_feature_pyramid(self, x): # 80
        x_list = []
        for i in range(self.index_len + 1):
            x = self.downsample(x) 
            x_list.append( x )
        return   x_list
    
    def edge_hot_map(self, x):
        x = x.clone().detach()
        edge = nn.functional.pad(x, (1, 0, 1, 0))
        edge = self.dilate(edge) - self.erode(edge)
        return edge
    

    def forward(self,x):
        x_list = self.build_feature_pyramid( x )
        pre_x_list = []
        out_list = []
        
        for index in range(len(x_list)):
            pre_x = self.CBR[index](x_list[index])
            pre_x_list.append(pre_x)
            # print(x_list[index].shape, pre_x.shape)
        # pre_x_list = [x, *pre_x_list]
        
        out = self.pre_encode( x )
        out_pool  =  self.Maxpool(out)
        out_list.append(out)
        # print( "x{}:{} {}".format( 0, out.shape, out_pool.shape ))
        
        # encoding
        for index in range(self.index_len):
            x_temp = torch.cat( [pre_x_list[index], out_pool], 1) 
            # print( "x{} cat:{}".format( index, x.shape))
            out = self.encode_list[index]( x_temp )
            out_pool = self.Maxpool(out)
            out_list.append(out)
            # print( "x{}:{} {}".format( index, out.shape, out_pool.shape ))
            
        x_temp = out_pool
        # decoding
        for index in range(self.index_len):
            up = self.up_list[index](x_temp)
            x_temp = torch.cat( [up, out_list[ self.index_len - index ] ], dim = 1)
            x_temp = self.decode_list[index](x_temp)
            # print( "decode{}:{} {}".format( index, up.shape, x_temp.shape))
        
        # print("final decode:", x_temp.shape)
        outp = self.last_up(x_temp)
        outp = self.last_decode(outp)
        # print("outp:{}".format( outp.shape ))
        return self.build_results(outp, outp, 0) if self.need_return_dict else( outp, 0 ) 

In [46]:
model = GT_UFPN_Net( encode_len = 3 , middle_channel=[128,256,512,1024])

[1, 256, 512, 1024]


In [42]:
batch_image = torch.zeros((2,1,80,80))

In [44]:
model(batch_image)

x: torch.Size([2, 1, 80, 80])
x: torch.Size([800, 1, 4, 4]) P_h: 20 P_w: 20
out mhsa in: torch.Size([800, 32, 4, 4])
out mhsa out: torch.Size([800, 4, 4, 32])
out: torch.Size([800, 128, 4, 4])
x: torch.Size([2, 256, 40, 40])
x: torch.Size([200, 256, 4, 4]) P_h: 10 P_w: 10
out mhsa in: torch.Size([200, 64, 4, 4])
out mhsa out: torch.Size([200, 4, 4, 64])
out: torch.Size([200, 256, 4, 4])
x: torch.Size([2, 512, 20, 20])
x: torch.Size([50, 512, 4, 4]) P_h: 5 P_w: 5
out mhsa in: torch.Size([50, 128, 4, 4])
out mhsa out: torch.Size([50, 4, 4, 128])
out: torch.Size([50, 512, 4, 4])
x: torch.Size([2, 1024, 20, 20])
x: torch.Size([50, 1024, 4, 4]) P_h: 5 P_w: 5
out mhsa in: torch.Size([50, 64, 4, 4])
out mhsa out: torch.Size([50, 4, 4, 64])
out: torch.Size([50, 256, 4, 4])
x: torch.Size([2, 512, 40, 40])
x: torch.Size([200, 512, 4, 4]) P_h: 10 P_w: 10
out mhsa in: torch.Size([200, 32, 4, 4])
out mhsa out: torch.Size([200, 4, 4, 32])
out: torch.Size([200, 128, 4, 4])


tensor([[[[-0.0528, -0.0528, -0.0528,  ..., -0.0528, -0.0528, -0.0528],
          [-0.0528, -0.0528, -0.0528,  ..., -0.0528, -0.0528, -0.0528],
          [-0.0528, -0.0528, -0.0528,  ..., -0.0528, -0.0528, -0.0528],
          ...,
          [-0.0528, -0.0528, -0.0528,  ..., -0.0528, -0.0528, -0.0528],
          [-0.0528, -0.0528, -0.0528,  ..., -0.0528, -0.0528, -0.0528],
          [-0.0528, -0.0528, -0.0528,  ..., -0.0528, -0.0528, -0.0528]]],


        [[[-0.0528, -0.0528, -0.0528,  ..., -0.0528, -0.0528, -0.0528],
          [-0.0528, -0.0528, -0.0528,  ..., -0.0528, -0.0528, -0.0528],
          [-0.0528, -0.0528, -0.0528,  ..., -0.0528, -0.0528, -0.0528],
          ...,
          [-0.0528, -0.0528, -0.0528,  ..., -0.0528, -0.0528, -0.0528],
          [-0.0528, -0.0528, -0.0528,  ..., -0.0528, -0.0528, -0.0528],
          [-0.0528, -0.0528, -0.0528,  ..., -0.0528, -0.0528, -0.0528]]]],
       grad_fn=<ConvolutionBackward0>)