**This code is base on [2.5d segmentaion baseline [inference]](https://www.kaggle.com/code/tanakar/2-5d-segmentaion-baseline-inference)**
If you think my code is useful,please upvote it ^w^.

# Import

In [20]:
# !python -m pip install /kaggle/input/install-cc3d/connected_components_3d-3.12.3-cp310-cp310-win_amd64.whl

In [21]:
import torch 
import torch.nn as nn  
import numpy as np
from tqdm import tqdm
from torch.cuda.amp import autocast
import cv2
import os,sys
from glob import glob
import matplotlib.pyplot as plt
import pandas as pd
!python -m pip install --no-index --find-links=/kaggle/input/pip-download-for-segmentation-models-pytorch segmentation-models-pytorch
import segmentation_models_pytorch as smp
# import cc3d
from torch.utils.data import Dataset, DataLoader

Looking in links: /kaggle/input/pip-download-for-segmentation-models-pytorch
INFO: pip is looking at multiple versions of torchvision to determine which version is compatible with other requirements. This could take a while.

The conflict is caused by:
    efficientnet-pytorch 0.7.1 depends on torch
    pretrainedmodels 0.7.4 depends on torch
    timm 0.9.2 depends on torch>=1.7
    torchvision 0.15.2 depends on torch==2.0.1

To fix this you could try to:
1. loosen the range of package versions you've specified
2. remove package versions to allow pip attempt to solve the dependency conflict



ERROR: Cannot install efficientnet-pytorch==0.7.1, pretrainedmodels==0.7.4, timm==0.9.2 and torchvision==0.15.2 because these package versions have conflicting dependencies.
ERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts


# config

In [22]:
class CFG:
# ============== model CFG =============
    model_name = 'Unet'
    #backbone = 'efficientnet-b0'
    backbone = 'se_resnet50'

    in_chans = 5 # 65
    # ============== training CFG =============
    image_size = 256
    input_size=256
    tile_size = image_size
    stride = tile_size // 2
    drop_egde_pixel=0
    
    target_size = 1
    # ============== fold =============
    valid_id = 1
    batch=64
    th_percentile = 0.002#0.005E:\CSworks\kaggle_blood_vessel\models\2023-12-10-baseline-v0
    model_path=["./models/2023-12-10-baseline-v0/se_resnet50_9_loss0.01_score0.76_val_loss0.01_val_scorenan.pt"]

# Model

In [23]:
class CustomModel(nn.Module):
    def __init__(self, CFG, weight=None):
        super().__init__()
        self.CFG = CFG
        self.encoder = smp.Unet(
            encoder_name=CFG.backbone, 
            encoder_weights=weight,
            in_channels=CFG.in_chans,
            classes=CFG.target_size,
            activation=None,
        )

    def forward(self, image):
        output = self.encoder(image)
        # output = output.squeeze(-1)
        return output[:,0]#.sigmoid()


def build_model(weight=None):
    from dotenv import load_dotenv
    load_dotenv()

    print('model_name', CFG.model_name)
    print('backbone', CFG.backbone)

    model = CustomModel(CFG, weight)

    return model

# Functions

In [24]:
def rle_encode(mask):
    pixel = mask.flatten()
    pixel = np.concatenate([[0], pixel, [0]])
    run = np.where(pixel[1:] != pixel[:-1])[0] + 1
    run[1::2] -= run[::2]
    rle = ' '.join(str(r) for r in run)
    if rle == '':
        rle = '1 0'
    return rle

def load_img(paths):
    output = []
    for path in paths:
        if path is None:
            output.append(None)
            continue
        img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
        img = img.astype('float32') # original is uint16
        output.append(img)
    shape=[x.shape for x in output if x is not None][0]
    for i in range(len(output)):
        if output[i] is None:
            output[i] = torch.randn(shape)
    output=np.stack(output, axis=0)
    return torch.from_numpy(output)

def min_max_normalization(x:torch.Tensor)->torch.Tensor:
    """input.shape=(batch,f1,...)"""
    shape=x.shape
    if x.ndim>2:
        x=x.reshape(x.shape[0],-1)
    
    min_=x.min(dim=-1,keepdim=True)[0]
    max_=x.max(dim=-1,keepdim=True)[0]
    if min_.mean()==0 and max_.mean()==1:
        return x.reshape(shape)
    
    x=(x-min_)/(max_-min_+1e-9)
    return x.reshape(shape)

class Data_loader(Dataset):
    def __init__(self,path,s="/images/"):
        self.paths=glob(path+f"{s}*.tif")
        self.paths.sort()
        self.bool=s=="/labels/"
    
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self,index):
        img=cv2.imread(self.paths[index],cv2.IMREAD_GRAYSCALE)
        img=torch.from_numpy(img)
        if self.bool:
            img=img.to(torch.bool)
        else:
            img=img.to(torch.uint8)
        return img

def load_data(path,s):
    data_loader=Data_loader(path,s)
    data_loader=DataLoader(data_loader, batch_size=16, num_workers=2)
    data=[]
    for x in tqdm(data_loader):
        data.append(x)
    return torch.cat(data,dim=0)    

class Pipeline_Dataset(Dataset):
    def __init__(self,x,path,labels=None):
        self.img_paths  = glob(path+"/images/*")
        self.img_paths.sort()
        #assert int(self.img_paths[-1].split("/")[-1][:-4])+1==len(x)
        self.debug=labels
        self.in_chan = 5
        z=torch.zeros(self.in_chan//2,*x.shape[1:],dtype=x.dtype)
        self.x=torch.cat((z,x,z),dim=0)
        self.labels=labels
        
    def __len__(self):
        return self.x.shape[0]-4
    
    def __getitem__(self, index):
        x  = self.x[index:index+self.in_chan]
        if self.labels is not None :
            label=self.labels[index]
        else:
            label=torch.zeros_like(x[0])
        #Normalization
        id=self.img_paths[index].split("/")[-3:]
        id.pop(1)
        id="_".join(id)
        #return img,torch.from_numpy(mask),id
        return x,label,id[:-4]

def add_edge(x:torch.Tensor,edge:int):
    #x=(C,H,W)
    #output=(C,H+2*edge,W+2*edge)
    x=torch.cat([x,torch.ones([x.shape[0],edge,x.shape[2]],dtype=x.dtype,device=x.device)*128],dim=1)
    x=torch.cat([x,torch.ones([x.shape[0],x.shape[1],edge],dtype=x.dtype,device=x.device)*128],dim=2)
    x=torch.cat([torch.ones([x.shape[0],edge,x.shape[2]],dtype=x.dtype,device=x.device)*128,x],dim=1)
    x=torch.cat([torch.ones([x.shape[0],x.shape[1],edge],dtype=x.dtype,device=x.device)*128,x],dim=2)
    return x

# def TTA(x:tc.Tensor,model:nn.Module,batch=CFG.batch):
#     x=x.to(tc.float32)
#     x=min_max_normalization(x)
#     #x.shape=(batch,c,h,w)
#     if CFG.input_size!=CFG.image_size:
#         x=nn.functional.interpolate(x,size=(CFG.input_size,CFG.input_size),mode='bilinear',align_corners=True)
    
#     shape=x.shape
#     x=[x,*[tc.rot90(x,k=i,dims=(-2,-1)) for i in range(1,4)]]
#     x=tc.cat(x,dim=0)
#     with autocast():
#         with tc.no_grad():
#             #x=[model((x[i*batch:(i+1)*batch],print(x[i*batch:(i+1)*batch].shape))[0]) for i in range(x.shape[0]//batch+1)]
#             x=[model(x[i*batch:(i+1)*batch]) for i in range(x.shape[0]//batch+1)]
#             # batch=64,64...48
#             x=tc.cat(x,dim=0)
#     x=x.sigmoid()
#     x=x.reshape(4,shape[0],*shape[2:])
#     x=[tc.rot90(x[i],k=-i,dims=(-2,-1)) for i in range(4)]
#     x=tc.stack(x,dim=0).mean(0)
    
#     if CFG.input_size!=CFG.image_size:
#         x=nn.functional.interpolate(x[None],size=(CFG.image_size,CFG.image_size),mode='bilinear',align_corners=True)[0]
#     return x


def TTA(x:torch.Tensor, model:nn.Module, batch=CFG.batch):
    x = x.to(torch.float32)
    x = min_max_normalization(x)  # 假设提前已定义
    # x.shape=(batch,c,h,w)
    
    # 如果input_size和image_size不同，则执行新的padding策略
    if CFG.input_size != CFG.image_size:
        pad_height = CFG.input_size - x.size(2)
        pad_width = CFG.input_size - x.size(3)
        # 在底部和右侧进行零填充
        padding = [pad_width, 0, pad_height, 0] # 右, 左, 底, 上
        x = nn.functional.pad(x, padding, 'constant', 0)
    
    shape = x.shape
    x = [x, *[torch.rot90(x, k=i, dims=(-2, -1)) for i in range(1, 4)]]
    x = torch.cat(x, dim=0)
    with autocast():
        with torch.no_grad():
            x=[model(x[i*batch:(i+1)*batch]) for i in range(x.shape[0]//batch+1)]
            x = torch.cat(x, dim=0)
    x = x.sigmoid()
    x = x.reshape(4, shape[0], *shape[2:])
    x = [torch.rot90(x[i], k=-i, dims=(-2, -1)) for i in range(4)]
    x = torch.stack(x, dim=0).mean(0)
    
    if CFG.input_size != CFG.image_size:
        # 如果需要，剪裁返回到原始图像大小
        x = x[:, pad_height:, :CFG.image_size]  # 裁剪掉底部和右侧填充的部分
    
    return x

# Build model(s)

In [25]:
from torch.nn import DataParallel

In [26]:
def get_output(debug=False):
    outputs=[]
    if debug:
        paths=["./kaggle/input/blood-vessel-segmentation/train/kidney_2"]
    else:
        paths=glob("./kaggle/input/blood-vessel-segmentation/test/*")
    debug_count=0
    for path in paths:
        x=load_data(path,"/images/")
        dataset=Pipeline_Dataset(x,path,None)
        dataloader=DataLoader(dataset,batch_size=1,shuffle=debug,num_workers=2)
        for img,label,id in tqdm(dataloader):
            #print(label.shape)
            #img=(C,H,W)
            img=img.to("cuda:0")
            label=label.to("cuda:0")
            img=add_edge(img[0],CFG.tile_size//2)[None]
            label=add_edge(label,CFG.tile_size//2)
            x1_list = np.arange(0, label.shape[-2]-CFG.tile_size+1, CFG.stride)
            y1_list = np.arange(0, label.shape[-1]-CFG.tile_size+1, CFG.stride)

            mask_pred = torch.zeros_like(label,dtype=torch.float32,device=label.device)
            mask_count = torch.zeros_like(label,dtype=torch.float32,device=label.device)

            indexs=[]
            chip=[]
            for y1 in y1_list:
                for x1 in x1_list:
                    x2 = x1 + CFG.tile_size
                    y2 = y1 + CFG.tile_size
                    indexs.append([x1+CFG.drop_egde_pixel,x2-CFG.drop_egde_pixel,
                                   y1+CFG.drop_egde_pixel,y2-CFG.drop_egde_pixel])
                    chip.append(img[...,x1:x2,y1:y2])

            y_preds = TTA(torch.cat(chip),model)
            if CFG.drop_egde_pixel:
                y_preds=y_preds[...,CFG.drop_egde_pixel:-CFG.drop_egde_pixel,
                                    CFG.drop_egde_pixel:-CFG.drop_egde_pixel]
            for i,(x1,x2,y1,y2) in enumerate(indexs):
                mask_pred[...,x1:x2, y1:y2] += y_preds[i]
                mask_count[...,x1:x2, y1:y2] += 1

            mask_pred /= mask_count

            #Rrecover
            mask_pred=mask_pred[...,CFG.tile_size//2:-CFG.tile_size//2,CFG.tile_size//2:-CFG.tile_size//2]
            label=label[...,CFG.tile_size//2:-CFG.tile_size//2,CFG.tile_size//2:-CFG.tile_size//2]

            outputs.append(((mask_pred*255).to(torch.uint8).cpu().numpy()[0],id))
            if debug:
                debug_count+=1
                plt.subplot(121)
                plt.imshow(img[0,2].cpu().detach().numpy())
                plt.subplot(122)
                plt.imshow(mask_pred[0].cpu().detach().numpy())
                plt.show()
                if debug_count>6:
                    break
    return outputs
    

In [27]:
# is_submit=len(glob("/kaggle/input/blood-vessel-segmentation/test/kidney_5/images/*.tif"))!=3
# outputs=get_output(not is_submit)

# TH = [output.flatten() for output,id in outputs] 
# TH = np.concatenate(TH)
# index = -int(len(TH) * CFG.th_percentile)
# TH:int = np.partition(TH, index)[index]
# print(TH)
# submission_df=[]
# for mask_pred,id in outputs:
#     if not is_submit:
#         plt.subplot(121)
#         plt.imshow(mask_pred)
#         plt.subplot(122)
#         plt.imshow(mask_pred>TH)
#         plt.show()
#     mask_pred=mask_pred>TH
#     rle = rle_encode(mask_pred)
    
#     submission_df.append(
#         pd.DataFrame(data={
#             'id'  : id,
#             'rle' : rle,
#         })
#     )

# submission_df =pd.concat(submission_df)
# submission_df.to_csv('submission.csv', index=False)
# submission_df.head(6)

In [28]:
import sys, os
# sys.path.append('/kaggle/input/blood-vessel-segmentation-third-party')
# sys.path.append('/kaggle/input/blood-vessel-segmentation-00')

# from helper import *

import cv2
import pandas as pd
from glob import glob
import numpy as np

from timeit import default_timer as timer


import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib
import matplotlib.pyplot as plt

print('IMPORT OK  !!!!')


IMPORT OK  !!!!


In [29]:
class dotdict(dict):
    """Simple dot-accessible dictionary."""
    
    def __getattr__(self, attr):
        return self.get(attr)
    
    def __setattr__(self, key, value):
        self[key] = value
        
    def __delattr__(self, key):
        del self[key]

In [30]:
cfg = dotdict(
    batch_size = 3,
    p_threshold = 0.002,
    cc_threshold = -1,
)

mode = 'local' # 'local' #

data_dir = \
    './kaggle/input/blood-vessel-segmentation'

#-----
def file_to_id(f):
    s = f.split('/')
    return s[-3]+'_' + s[-1][:-4]

if 'local' in mode:
    valid_folder = [
        ('kidney_3_sparse', (496, 996+1)),
        #('kidney_1_dense', (0, 1000+1)),
    ] #debug for local development
    
    valid_meta = []
    for image_folder, image_no in valid_folder:
        file = [f'{data_dir}/train/{image_folder}/images/{i:04d}.tif' for i in range(*image_no)]
        H,W = cv2.imread(file[0],cv2.IMREAD_GRAYSCALE).shape
        valid_meta.append(dotdict(
            name  = image_folder,
            file  = file,
            shape = (len(file), H, W),
            id = [file_to_id(f) for f in file],
        ))
        
if 'submit' in mode:
    valid_meta = []
    valid_folder = sorted(glob(f'{data_dir}/test/*'))
    for image_folder in valid_folder:
        file = sorted(glob(f'{image_folder}/images/*.tif'))
        H, W = cv2.imread(file[0], cv2.IMREAD_GRAYSCALE).shape
        valid_meta.append(dotdict(
            name=image_folder,
            file=file,
            shape=(len(file), H, W),
            id=[file_to_id(f) for f in file],
        ))

#     glob_file = glob(f'{data_dir}/kidney_5/images/*.tif')
#     if len(glob_file)==3:
#         mode = 'submit-fake' #fake submission to save gpu time when submitting
#         #todo .....




print('len(valid_meta) :', len(valid_meta))
print(valid_meta[0].file[:3])

print('MODE OK  !!!!')

len(valid_meta) : 1
['./kaggle/input/blood-vessel-segmentation/train/kidney_3_sparse/images/0496.tif', './kaggle/input/blood-vessel-segmentation/train/kidney_3_sparse/images/0497.tif', './kaggle/input/blood-vessel-segmentation/train/kidney_3_sparse/images/0498.tif']
MODE OK  !!!!


In [31]:
if 'submit' in mode:
    valid_meta = []
    valid_folder = sorted(glob(f'{data_dir}/test/*'))
    for image_folder in valid_folder:
        file = sorted(glob(f'{image_folder}/images/*.tif'))
        H, W = cv2.imread(file[0], cv2.IMREAD_GRAYSCALE).shape
        valid_meta.append(dotdict(
            name=image_folder,
            file=file,
            shape=(len(file), H, W),
            id=[file_to_id(f) for f in file],
        ))

In [32]:
class MyLoader(object):
    def __init__(self, meta):
        self.meta  = meta
        self.split = np.array_split(meta.file, max(1,int(len(meta.file)//cfg.batch_size)))

    def __len__(self,):
        return len(self.split)

    def __getitem__(self, index):
        file = self.split[index]

        image = []
        for f in file:
            m = cv2.imread(f,cv2.IMREAD_GRAYSCALE)

            #---
            #process image
            m = (m - m.min())/(m.max() - m.min() +0.001)

            #---
            image.append(m)

        image = np.stack(image)
        image = torch.from_numpy(image).float().unsqueeze(1)
        return image

print('DATASET OK  !!!!')

DATASET OK  !!!!


In [33]:
def make_dummy_submission(): 
    submission_df = []
    for d in valid_meta: 
        submission_df.append(
            pd.DataFrame(data={
                'id'  : d['id'],
                'rle' : ['1 0']*len(d['id']),
            })
        )
    submission_df =pd.concat(submission_df)
    submission_df.to_csv('submission.csv', index=False)
    print(submission_df)
    

#https://www.kaggle.com/competitions/blood-vessel-segmentation/discussion/456033
def choose_biggest_object(mask, threshold):
    mask = ((mask > threshold) * 255).astype(np.uint8)
    num_label, label, stats, centroid = cv2.connectedComponentsWithStats(mask, connectivity=8)
    max_label = -1
    max_area = -1
    for l in range(1, num_label):
        if stats[l, cv2.CC_STAT_AREA] >= max_area:
            max_area = stats[l, cv2.CC_STAT_AREA]
            max_label = l
    processed = (label==max_label).astype(np.uint8)
    return processed


def remove_small_objects(mask, min_size, threshold):
    mask = ((mask > threshold) * 255).astype(np.uint8)
    # find all connected components (labels)
    num_label, label, stats, centroid = cv2.connectedComponentsWithStats(mask, connectivity=8)
    # create a mask where small objects are removed
    processed = np.zeros_like(mask)
    for l in range(1, num_label):
        if stats[l, cv2.CC_STAT_AREA] >= min_size:
            processed[label == l] = 1
    return processed


def rle_encode(mask):
    pixel = mask.flatten()
    pixel = np.concatenate([[0], pixel, [0]])
    run = np.where(pixel[1:] != pixel[:-1])[0] + 1
    run[1::2] -= run[::2]
    rle = ' '.join(str(r) for r in run)
    if rle == '':
        rle = '1 0'
    return rle

#-------------------------------

In [34]:
net=build_model()
net = DataParallel(net)
net.load_state_dict(torch.load(CFG.model_path[0],"cpu"))
net=net.cuda()

model_name Unet
backbone se_resnet50


AssertionError: Torch not compiled with CUDA enabled

In [None]:



# from model_2cls_50 import *
# checkpoint_file = \
#     '/public/sist/home/hongmt2022/MyWorks/kaggle-bv/models/baseline-v0-2023-12-20/se_resnet50_9_loss0.01_score0.76_val_loss0.01_val_scorenan.pt'
    #'/kaggle/input/blood-vessel-segmentation-00/00001085.pth'
    #'/kaggle/input/blood-vessel-segmentation-00/00000966.pth'

# net = Net()
# #run_check_net()
# state_dict = torch.load(checkpoint_file, map_location=lambda storage, loc: storage)['state_dict']
# print(net.load_state_dict(state_dict, strict=False))  # True


#net = torch.compile(net)

"""
1. 函数`do_submit()`定义了一个提交函数。

2. 创建一个空列表`submission_df`，用于最后将所有的预测结果合并成一个DataFrame对象。

3. 通过一个for循环遍历在`valid_meta`中的每个元素（可能代表一个患者的医学影像数据）。

4. 读取该元素所关联的体积影像，体积影像由多个切片组成。

5. 将这些切片堆叠成一个三维numpy数组`volume`。

6. 获取体积影像的深度（D）、高度（H）、宽度（W）。

7. 创建一个和影像体积相同形状全零的numpy数组`predict`，用于保存预测结果。

8. 定义需要遍历的轴`axes`。

9. 通过三个不同的方向（轴0、轴1、轴2）遍历影像体积。

10. 计算每个轴的分块索引，并赋值给变量`loader`。

11. 初始化计时器`start_timer`。

12. 使用两层嵌套的for循环批量处理图片数据。

13. 打印进度信息。

14. 根据当前的轴，获取batch的子集`image`，并可能执行必要的转置操作（针对轴1或轴2）。

15. 对影像数据标准化，使其值分布在0到1之间。

16. 将numpy数组转换为PyTorch张量，并移动到GPU上。

17. 进入混合精度执行环境，通常用于辅助GPU以提高效率。

18. 在没有梯度的情况下执行预测。

19. 网络`net`对图片数据`image`做出预测，获取血管和肾脏的预测结果`vessel`和`kidney`。

20. 通过四次数据增强（水平、垂直翻转和两次旋转）来增加模型预测的可靠性，并将结果相加。

21. 将预测值归一化（即计数器`counter`）。

22. 将预测结果从GPU转移到CPU，并将其转换回numpy格式。

23. 用预测结果更新`predict`数组。

24. 如果配置允许，使用`choose_biggest_object`函数选择超过某个阈值的最大物体。

25. 对局部模式（可能是调试模式）下的首个批量结果进行可视化显示。

26. 更新处理过的批量大小`B`。

27. 综合所有轴方向的预测结果，并应用一个阈值来二值化预测结果。

28. 如果配置指定了连通性分量阈值，则移除体积小于此阈值的连通组件。

29. 使用RLE编码压缩每个预测，并将结果保存到DataFrame中。

30. 将每个患者的结果添加到`submission_df`列表中。

31. 将结果列表合并为一个DataFrame，并将其保存为csv文件。

32. 打印最终的提交DataFrame。
"""
def do_submit():
    
    submission_df = []
    for d in valid_meta:
        volume = [cv2.imread(f, cv2.IMREAD_GRAYSCALE) for f in d.file]
        volume = np.stack(volume)
        D, H, W = volume.shape
        
        predict = np.zeros(d.shape, dtype=np.float16)
        axes = [0,1,2] #[2]  # 
        for axis in axes:  # 0
            loader = np.array_split(np.arange((D, H, W)[axis]), max(1, int((D, H, W)[axis] // cfg.batch_size)))
            num_valid = len(loader)
            
            B = 0 
            start_timer = timer()
            for t in range(num_valid):
                # print(f'\r validation: {t}/{num_valid}', time_to_str(timer() - start_timer, 'min'), end='', flush=True)

                if axis == 0:
                    image = volume[loader[t].tolist()]
                if axis == 1:
                    image = volume[:, loader[t].tolist()]
                    image = image.transpose(1, 0, 2)
                if axis == 2:
                    image = volume[:, :, loader[t].tolist()]
                    image = image.transpose(2, 0, 1)

                batch_size, bh, bw = image.shape
                m = image.reshape(batch_size, -1)
                m = (m - m.min(keepdims=True)) / (m.max(keepdims=True) - m.min(keepdims=True) + 0.001)
                m = m.reshape(batch_size, bh, bw)
                m = np.ascontiguousarray(m)
                image = torch.from_numpy(m).float().cuda().unsqueeze(1)

                #----
                counter = 0
                vessel, kidney = 0, 0
                image = image.cuda() 
                with torch.cuda.amp.autocast(enabled=True):
                    with torch.no_grad():
                        v, k = net(image)
                        vessel += v
                        kidney += k
                        counter += 1

                        v, k = net(torch.flip(image, dims=[2,]))
                        vessel += torch.flip(v, dims=[2,])
                        kidney += torch.flip(k, dims=[2,])
                        counter += 1

                        v, k = net(torch.flip(image, dims=[3,]))
                        vessel += torch.flip(v, dims=[3,])
                        kidney += torch.flip(k, dims=[3,])
                        counter += 1

                        v, k = net(torch.rot90(image, k=1, dims=[2,3]))
                        vessel += torch.rot90(v, k=-1, dims=[2,3])
                        kidney += torch.rot90(k, k=-1, dims=[2,3])
                        counter += 1

                        v, k = net(torch.rot90(image, k=2, dims=[2,3]))
                        vessel += torch.rot90(v, k=-2, dims=[2,3])
                        kidney += torch.rot90(k, k=-2, dims=[2,3])
                        counter += 1

                        v, k = net(torch.rot90(image, k=3, dims=[2,3]))
                        vessel += torch.rot90(v, k=-3, dims=[2,3])
                        kidney += torch.rot90(k, k=-3, dims=[2,3])
                        counter += 1

                vessel = vessel/counter   
                kidney = kidney/counter      
                #print(i, image.shape, mask.shape) 

                vessel = vessel.float().data.cpu().numpy()
                kidney = kidney.float().data.cpu().numpy()

                # ----------------------------------------
                batch_size = len(vessel)
                for b in range(batch_size):
                    mk = kidney[b, 0]
                    mk = choose_biggest_object(mk, threshold=0.5) 
                    mv = vessel[b, 0]
                    p = (mv * mk)
                    if axis == 0:
                        predict[B + b] += p
                    if axis == 1:
                        predict[:, B + b] += p
                    if axis == 2:
                        predict[:, :, B + b] += p

                    #debug only
                    if (t==0) and (mode=='local'): 
                  
                        m = image[b, 0].float().data.cpu().numpy()
                        #p = predict[B+b]

                        plt.imshow(np.hstack([m,p]),cmap='gray')
                        plt.show()
                        #plt.waitforbuttonpress()

                #----------------------------------------
                B += batch_size

        print('')
        predict = predict / len(axes)
        predict = (predict>cfg.p_threshold).astype(np.uint8)

        #post processing ---
#         if cfg.cc_threshold>0:
#             predict = cc3d.dust(
#                 predict,
#                 connectivity=26,
#                 threshold=cfg.cc_threshold,
#                 in_place=False
#             )

        rle = [rle_encode(p) for p in predict]
        submission_df.append(
            pd.DataFrame(data={
                'id'  : d['id'],
                'rle' : rle,
            })
        )

    submission_df =pd.concat(submission_df)
    submission_df.to_csv('submission.csv', index=False)
    print(submission_df)
    

glob_file = glob(f'{data_dir}/test/kidney_5/images/*.tif')
if (mode=='submit') and (len(glob_file)==3): #cannot do 3d cnn because too few test files
    make_dummy_submission()
else:
    do_submit()