In [2]:
import sys 
sys.path.append("..") 
from train import *
from utils.score import cal_all_score
from dataloader import Dataload
from torch.utils.data import DataLoader
import cv2
import matplotlib.pyplot as plt
matplotlib.use('TkAgg')
%matplotlib inline
import datetime

  from .autonotebook import tqdm as notebook_tqdm


use gpu: True


In [3]:
def crop_tensor(image_pack, scale_x, scale_y = None, axis = 1):
    if scale_y is None:
        scale_y = scale_x
    _, _, w, h = image_pack.size()
    a = int(w/scale_x)
    b = int(h/scale_y)
    # print(a, b)
    t = torch.split(image_pack, a, dim = 2)
    ans = []
    for i in t:
        for j in torch.split(i, b, dim=3):
            ans.append(j)
            # print(j.shape)
    d = torch.stack(ans, axis)
    return d

def cat_tensor(image_pack, scale_x, scale_y = None):
    if scale_y is None:
        scale_y = scale_x
    data = []
    for i in range(scale_x):
        m = []
        for j in range(scale_y):
            m.append(image_pack[:, i * scale_y + j ,:,:,:])
            # print(  i * scale_y + j, i,j )
        data.append(torch.cat(m, dim = -1))
    
    data = torch.cat(data, dim = -2)
    return data

In [4]:
from model.GTU.models.GT_UNet import *

In [5]:
class BotBlock(nn.Module):

    def __init__(self, in_dimension, curr_h, curr_w, proj_factor = 4, activation = 'relu', pos_enc_type = 'relative',
                 stride = 1, target_dimension = None):
        super(BotBlock, self).__init__()
        if stride !=  1 or in_dimension !=  target_dimension:
            self.shortcut  =  nn.Sequential(
                nn.Conv2d(in_dimension, target_dimension, kernel_size = 3, padding = 1, stride = stride),
                BNReLU(target_dimension, activation = activation, nonlinearity = True),
            )
        else:
            self.shortcut  =  None

        bottleneck_dimension  =  target_dimension // proj_factor
        self.conv1  =  nn.Sequential(
            nn.Conv2d(in_dimension, bottleneck_dimension, kernel_size = 3, padding = 1, stride = 1 ),
            BNReLU(bottleneck_dimension, activation = activation, nonlinearity = True)
        )

        self.mhsa  =  MHSA(
            in_channels = bottleneck_dimension, 
            heads = 4, curr_h = curr_h, curr_w = curr_w,
            pos_enc_type = pos_enc_type
        )
        conv2_list  =  []
        if stride !=  1:
            assert stride  ==  2, stride
            conv2_list.append(nn.AvgPool2d(kernel_size = (2, 2), stride = (2, 2)))  # TODO: 'same' in tf.pooling
        conv2_list.append(BNReLU(bottleneck_dimension, activation = activation, nonlinearity = True))
        
        self.conv2  =  nn.Sequential(*conv2_list)

        self.conv3  =  nn.Sequential(
            nn.Conv2d(bottleneck_dimension, target_dimension, kernel_size = 3,padding = 1, stride = 1),
            BNReLU(target_dimension, nonlinearity = False, init_zero = True),
        )
        self.last_act  =  get_act(activation)


    def forward(self, x):
        # print("x origin:", x.shape)
        if self.shortcut is not None:
            shortcut  =  self.shortcut(x)
        else:
            shortcut  =  x
        Q_h  =  Q_w  =  4
        N, C, H, W  =  x.shape
        P_h, P_w  =  H // Q_h, W // Q_w
        # print("x:", x.shape, "P_h:",P_h, "P_w:", P_w)
        x  =  x.reshape(N * P_h * P_w, C, Q_h, Q_w)
        print("x:origin", x.shape)
        out  =  self.conv1(x)
        print("out in :", out.shape)
        out  =  self.mhsa(out)
        print("out out :", out.shape)
        out  =  out.permute(0, 3, 1, 2)  # back to pytorch dim order

        out  =  self.conv2(out)
        out  =  self.conv3(out)
        # print("out:", out.shape)
        N1, C1, H1, W1  =  out.shape
        out  =  out.reshape(N, C1, int(H), int(W))

        out +=  shortcut
        out  =  self.last_act(out)

        return out

In [6]:
def _make_bot_layer(ch_in, ch_out, w = 4 ):

    W  =  H  =  w
    dim_in  =  ch_in
    dim_out  =  ch_out

    stage5  =  []

    stage5.append(
        BotBlock(in_dimension = dim_in, curr_h = H, curr_w = W, stride = 1 , target_dimension = dim_out)
    )

    return nn.Sequential(*stage5)

In [7]:
class ResGroupFormer(nn.Module):
    def __init__(self, ch_in, ch_out):
        super().__init__()
        self.g1 = _make_bot_layer(
                    ch_in = ch_in, #  2 * middle_channel[i],             
                    ch_out = ch_out, #middle_channel[ i+1 ], 
                    w = 4
                ) 
        self.g2 = _make_bot_layer(
                    ch_in = ch_in, # 2 * middle_channel[i],             
                    ch_out = ch_out, #middle_channel[ i+1 ], 
                    w = 4
                ) 
        self.g3 = _make_bot_layer(
                    ch_in = ch_in, #, 2 * middle_channel[i],             
                    ch_out = ch_out,# , middle_channel[ i+1 ], 
                    w = 2
                )
        self.sigmod = nn.Sigmoid()
    def forward(self, x):
        x_clone = x.clone().detach()
        x = self.g1(x)
        x = self.g2(x)
        y = self.g3(x_clone)
        x = self.sigmod(x + y)
        return x

In [8]:
class Decode(nn.Module):
    def __init__(self, ch_in, ch_out):
        dim_in = ch_in
        dim_out = ch_out
        super(Decode, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d( dim_in, dim_out, kernel_size = 1, stride = 1, padding = 0),
            nn.BatchNorm2d(dim_out),
            nn.ReLU(),
            nn.Conv2d(dim_out, dim_out, kernel_size = 1, stride = 1, padding = 0),
            nn.BatchNorm2d(dim_out),
            nn.ReLU(),
            nn.Conv2d(dim_out, dim_out, kernel_size = 1, stride = 1, padding = 0),
            nn.BatchNorm2d(dim_out),
            nn.ReLU(),
        )
        self.up = nn.Upsample( scale_factor = 2)
        
    def forward(self, x, y):
        y = self.up(y)
        x = torch.cat([x,y], 1)
        x = self.conv(x)
        return x
        

In [9]:
from model_server.GTU.models.pvtv2 import pvt_v2_b2

In [10]:
class SKConv(nn.Module):
    #                  64       32   2  8  2
    def __init__(self, features_list, out_features, WH, M, G, r, stride=1 ,L=32):
        """ Constructor
        Args:
            features: input channel dimensionality.
            WH: input spatial dimensionality, used for GAP kernel size.
            M: the number of branchs.
            G: num of convolution groups.
            r: the radio for compute d, the length of z.
            stride: stride, default 1.
            L: the minimum dim of the vector z in paper, default 32
        """
        super(SKConv, self).__init__()
        features = features_list[-1]
        d = max(int(features/r), L)
        self.M = M
        self.features = features
        self.convs = nn.ModuleList()
        for i in range(M):
            self.convs.append(nn.Sequential(
                nn.Conv2d(features_list[i], out_features, kernel_size=3, stride=1, padding=1, groups=G),
                nn.BatchNorm2d(out_features),
                # nn.ReLU()
            ))
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(features, d)
        self.fcs = nn.ModuleList()
        for i in range(M):
            self.fcs.append(
                nn.Linear(d, features)
            )
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, x):
        for i, conv in enumerate(self.convs):
            fea = conv(x[i]).unsqueeze_(dim=1)
            if i == 0:
                feas = fea
            else:
                feas = torch.cat([feas, fea], dim=1)
            # print(i)
        # print(feas.shape)
        fea_U = torch.sum(feas, dim=1)
        fea_s = self.gap(fea_U).squeeze_()
        # print(fea_U.shape, fea_s.shape)
        fea_z = self.fc(fea_s)
        for i, fc in enumerate(self.fcs):
            vector = fc(fea_z).unsqueeze_(dim=1)
            if i == 0:
                attention_vectors = vector
            else:
                attention_vectors = torch.cat([attention_vectors, vector], dim=1)
            # print(i)
        attention_vectors = self.softmax(attention_vectors)
        attention_vectors = attention_vectors.unsqueeze(-1).unsqueeze(-1)
        # if attention_vectors.shape[0] != 1 :
        #     attention_vectors = attention_vectors.unsqueeze(0)
        # print( attention_vectors.shape , feas.shape )
        fea_v = (feas * attention_vectors).sum(dim=1)
        return fea_v

In [11]:

class GT_U_DC_PVTNet(nn.Module):
    def __init__(self, img_ch  =  1, output_ch  =  1, 
                middle_channel = [64, 128, 256, 512, 1024], 
                encode_len = 5, 
                need_return_dict = True,
                need_supervision = False,
                decode_type = "conv",
                path = './pretrained_pth/pvt_v2_b2.pth'

        ):
        super(GT_U_DC_PVTNet, self).__init__()
        pvt_channel = [64, 128, 320, 512]
        self.index_len = encode_len - 1
        self.need_return_dict = need_return_dict
        self.downsample = nn.AvgPool2d(2)
        self.Maxpool  =  nn.MaxPool2d(kernel_size = 2,stride = 2)
        self.erode = MinPool(2,2,1)
        self.dilate = nn.MaxPool2d(2, stride = 1)
        self.encode_list = nn.ModuleList()
        self.up_list = nn.ModuleList()
        self.decode_list = nn.ModuleList()
        self.sk_list = nn.ModuleList()
        if self.need_return_dict:
            self.supervision_list = nn.ModuleList()
        self.select  =  nn.Conv2d( 2, output_ch, kernel_size = 1, stride = 1, padding = 0)
        middle_channel = middle_channel[ len(middle_channel) - encode_len : ]
        middle_channel =  [ img_ch, *middle_channel]
        self.need_supervision = need_supervision
        print( middle_channel )
        self.pre_encode = _make_bot_layer(
            ch_in = middle_channel[0],         
            ch_out = middle_channel[ 1 ], 
            w = 4
        )
        for i in range(1, encode_len):
            self.encode_list.append( 
                ResGroupFormer(2 * middle_channel[i],  middle_channel[ i+1 ])
#                 _make_bot_layer(
#                     ch_in = 2 * middle_channel[i],             
#                     ch_out = middle_channel[ i+1 ], 
#                     w = 4
#                 ) 
            )
        
        for i in range(1, encode_len):
            now_dim = encode_len - i + 1
            next_dim = encode_len - i
            # print( middle_channel[ now_dim ],   middle_channel[ next_dim ])
            self.up_list.append( up_conv(ch_in = middle_channel[now_dim ] , ch_out = 2 * middle_channel[next_dim]) )
            
            if i < self.index_len + 1:
                print(pvt_channel[  self.index_len - i ], middle_channel[now_dim ])
                self.sk_list.append( 
                    SKConv([pvt_channel[  self.index_len - i], middle_channel[now_dim ]], middle_channel[now_dim ], 32,2,8,2)
                )
                
            if decode_type == "conv":
                self.decode_list.append( 
                    Decode(ch_in = 2 * middle_channel[ now_dim ], ch_out = middle_channel[next_dim]) 
                )
            else:
                self.decode_list.append( 
                     _make_bot_layer(
                        ch_in = 2 * middle_channel[ now_dim ], 
                        ch_out = middle_channel[next_dim],
                        w = 2 ** i
                    ) 
                )
            
            if self.need_supervision:
                self.supervision_list.append( nn.Conv2d(middle_channel[next_dim], output_ch, 1, 1) )
            #  _make_bot_layer(
            #     ch_in = 2 * middle_channel[ now_dim ], 
            #     ch_out = middle_channel[next_dim],
            #     w = 2 ** i) 
            #  )
            
        self.CBR = nn.ModuleList()
        for i in range(encode_len):
            self.CBR.append(
                nn.Sequential(
                    nn.Conv2d(middle_channel[0], middle_channel[i+1], 3, 1, 1),
                    nn.BatchNorm2d(middle_channel[i+1]),
                    nn.ReLU(),
                )
            )
        self.last_up = nn.Upsample(scale_factor = 2)
        self.last_decode = nn.Conv2d(middle_channel[next_dim], output_ch, kernel_size = 1, stride = 1, padding = 0)
        self.supervision_list.append( nn.Conv2d(middle_channel[next_dim], output_ch, 1, 1) )

        self.backbone = pvt_v2_b2()  # [64, 128, 320, 512]
        
        save_model = torch.load(path)
        model_dict = self.backbone.state_dict()
        state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()}
        model_dict.update(state_dict)
        self.backbone.load_state_dict(model_dict)
        n_p = sum(x.numel() for x in self.backbone.parameters()) # number parameters
        n_g = sum(x.numel() for x in self.backbone.parameters() if x.requires_grad)  # number gradients
        print(f"pvt Summary: {len(list(self.backbone.modules()))} layers, {n_p} parameters, {n_p/1e6} M, {n_g} gradients")



    def build_results(self, x,y,z, super_vision  = None):
        if super_vision is None:
            return {
            "mask": x,
            "cmask": y,
            "edge":z,
        }
        else:
            return {
                "mask": x,
                "cmask": y,
                "edge":z,
                "super":super_vision
            }
            
    
    def build_feature_pyramid(self, x): # 80
        x_list = []
        for i in range(self.index_len + 1):
            x = self.downsample(x) 
            x_list.append( x )
        return   x_list
    
    def edge_hot_map(self, x):
        x = x.clone().detach()
        edge = nn.functional.pad(x, (1, 0, 1, 0))
        edge = self.dilate(edge) - self.erode(edge)
        return edge
    
    @torch.no_grad()
    def pvt_backbone(self, x):
        x = x.clone().detach()
        pvt_x = torch.cat([x,x,x], 1)
        pvt = self.backbone(pvt_x)
        return pvt
    
    def forward(self,x):
        pvt = self.pvt_backbone(x)
        c1, c2, c3, c4 = pvt
       
        pvt_decode = list( reversed( pvt[:len(self.sk_list)] ) )#  [c3, c2, c1]
        print(c1.shape, c2.shape, c3.shape, c4.shape)
        x_list = self.build_feature_pyramid( x )
        pre_x_list = []
        out_list = []
        supervision = []
        for index in range(len(x_list)):
            pre_x = self.CBR[index](x_list[index])
            pre_x_list.append(pre_x)
            # print(x_list[index].shape, pre_x.shape)
        # pre_x_list = [x, *pre_x_list]
       
        out = self.pre_encode( x )
        out_pool = self.Maxpool(out)
        out_list.append(out)
        print( "x{}:{} {}".format( 0, out.shape, out_pool.shape ))
        
        # encoding
        for index in range(self.index_len):
            print("encode:{},{}".format(pre_x_list[index].shape, out_pool.shape  ))
            x_temp = torch.cat( [pre_x_list[index], out_pool], 1) 
            # print( "x{} cat:{}".format( index, x.shape))
            out = self.encode_list[index]( x_temp )
            out_pool = self.Maxpool(out)
            out_list.append(out)
            
            print( "x{}:{} {}".format( index + 1, out.shape, out_pool.shape ))
            
        x_temp = out_pool 
        # decoding
        
        for index in range(len(self.sk_list)):
            up = out_list[ self.index_len - index ]
            print("index:", index, self.index_len - index, "pvt and up", pvt_decode[index].shape, up.shape)
            conmbine_feature = self.sk_list[index](( self.last_up( pvt_decode[index] ), up))
            out_list[ self.index_len - index] = conmbine_feature
            # print("index:", index, self.index_len - index, "pvt and up", pvt_decode[index].shape, up.shape)
        
            
        for index in range(self.index_len):
            # up = self.up_list[index](x_temp)
            # x_temp = torch.cat( [up, out_list[ self.index_len - index ] ], dim = 1)
            up = out_list[ self.index_len - index ]
           
            print("decode (up, x_temp):", up.shape, x_temp.shape)
            x_temp = self.decode_list[index](up, x_temp)
            
            
            # x_temp = self.decode_list[index](x_temp)
            supervision.append(x_temp)
            # self.need_supervision[index](x_temp)
            # print( "decode{}:{} {}".format( index, up.shape, x_temp.shape))
        
        # print("final decode:", x_temp.shape)
        outp = self.last_up(x_temp)
        outp = self.last_decode(outp)
        
        edge = self.edge_hot_map(outp)
        outp= self.select(torch.cat([outp, edge], 1))
        # print("outp:{}".format( outp.shape ))
        if self.need_supervision:
            for i in range( self.index_len ):
                supervision[i] = self.supervision_list[i](supervision[i])
            return self.build_results(outp, outp, edge, supervision) if self.need_return_dict else( out, edge, supervision) 
        return self.build_results(outp, outp, edge) if self.need_return_dict else( out, edge ) 


In [12]:
model = GT_U_DC_PVTNet(1,1, encode_len = 5, path = r"H:/program/outpage/AITOOTH/model_server/GTU/models/pretrained_pth/pvt_v2_b2.pth")# .to("cpu")

[1, 64, 128, 256, 512, 1024]
512 1024
320 512
128 256
64 128
pvt Summary: 319 layers, 24849856 parameters, 24.849856 M, 24849856 gradients


In [13]:
model = model.to("cuda:0")

In [14]:
batch_image = torch.zeros((2,1, 320, 320)).to("cuda:0")

In [15]:
with torch.no_grad():
    ans = model(batch_image)

torch.Size([2, 64, 80, 80]) torch.Size([2, 128, 40, 40]) torch.Size([2, 320, 20, 20]) torch.Size([2, 512, 10, 10])
x:origin torch.Size([12800, 1, 4, 4])
out in : torch.Size([12800, 16, 4, 4])
out out : torch.Size([12800, 4, 4, 16])
x0:torch.Size([2, 64, 320, 320]) torch.Size([2, 64, 160, 160])
encode:torch.Size([2, 64, 160, 160]),torch.Size([2, 64, 160, 160])
x:origin torch.Size([3200, 128, 4, 4])
out in : torch.Size([3200, 32, 4, 4])
out out : torch.Size([3200, 4, 4, 32])
x:origin torch.Size([3200, 128, 4, 4])
out in : torch.Size([3200, 32, 4, 4])
out out : torch.Size([3200, 4, 4, 32])
x:origin torch.Size([3200, 128, 4, 4])
out in : torch.Size([3200, 32, 4, 4])


RuntimeError: shape '[-1, 16, 4, 7]' is invalid for input of size 614400

In [148]:
ans['mask'].shape

torch.Size([2, 1, 320, 320])

In [161]:
a = [1,2]

In [162]:
a[1:]

[2]

In [155]:
mhsa  =  MHSA(
    in_channels = 128//4, 
            heads = 4, curr_h = 4, curr_w = 4,
            pos_enc_type = "relative"
        )

In [159]:
batch_image = torch.zeros((16, 32, 4, 4))

In [160]:
mhsa(batch_image)

tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,