In [None]:
#!pip install connected-components-3d

import cc3d

In [None]:
import sys, os
sys.path.append('/kaggle/input/blood-vessel-segmentation-third-party')
sys.path.append('/kaggle/input/blood-vessel-segmentation-00')

from helper import *

import cv2
import pandas as pd
from glob import glob
import numpy as np

from timeit import default_timer as timer


import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib
import matplotlib.pyplot as plt

In [None]:
class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

In [None]:
cfg = dotdict(
    batch_size = 3,
    p_threshold = 0.10,
    cc_threshold = -1,
)

mode = 'submit' # 'local' #

data_dir = \
    '/kaggle/input/blood-vessel-segmentation'


def file_to_id(f):
    s = f.split('/')
    return s[-3]+'_' + s[-1][:-4]

if 'local' in mode:
    valid_folder = [
        ('kidney_3_sparse', (496, 996+1)),
        #('kidney_1_dense', (0, 1000+1)),
    ] #debug for local development
    
    valid_meta = []
    for image_folder, image_no in valid_folder:
        file = [f'{data_dir}/train/{image_folder}/images/{i:04d}.tif' for i in range(*image_no)]
        H,W = cv2.imread(file[0],cv2.IMREAD_GRAYSCALE).shape
        valid_meta.append(dotdict(
            name  = image_folder,
            file  = file,
            shape = (len(file), H, W),
            id = [file_to_id(f) for f in file],
        ))
        
if 'submit' in mode:
    valid_meta = []
    valid_folder = sorted(glob(f'{data_dir}/test/*'))
    for image_folder in valid_folder:
        file = sorted(glob(f'{image_folder}/images/*.tif'))
        H, W = cv2.imread(file[0], cv2.IMREAD_GRAYSCALE).shape
        valid_meta.append(dotdict(
            name=image_folder,
            file=file,
            shape=(len(file), H, W),
            id=[file_to_id(f) for f in file],
        ))


print('len(valid_meta) :', len(valid_meta))
print(valid_meta[0].file[:3])

In [None]:
class MyLoader(object):
    def __init__(self, meta):
        self.meta  = meta
        self.split = np.array_split(meta.file, max(1,int(len(meta.file)//cfg.batch_size)))

    def __len__(self,):
        return len(self.split)

    def __getitem__(self, index):
        file = self.split[index]

        image = []
        for f in file:
            m = cv2.imread(f,cv2.IMREAD_GRAYSCALE)

            # process image
            m = (m - m.min())/(m.max() - m.min() +0.001)

            image.append(m)

        image = np.stack(image)
        image = torch.from_numpy(image).float().unsqueeze(1)
        return image

In [None]:
def make_dummy_submission(): 
    submission_df = []
    for d in valid_meta: 
        submission_df.append(
            pd.DataFrame(data={
                'id'  : d['id'],
                'rle' : ['1 0']*len(d['id']),
            })
        )
    submission_df =pd.concat(submission_df)
    submission_df.to_csv('submission.csv', index=False)
    print(submission_df)
    

#https://www.kaggle.com/competitions/blood-vessel-segmentation/discussion/456033
def choose_biggest_object(mask, threshold):
    mask = ((mask > threshold) * 255).astype(np.uint8)
    num_label, label, stats, centroid = cv2.connectedComponentsWithStats(mask, connectivity=8)
    max_label = -1
    max_area = -1
    for l in range(1, num_label):
        if stats[l, cv2.CC_STAT_AREA] >= max_area:
            max_area = stats[l, cv2.CC_STAT_AREA]
            max_label = l
    processed = (label==max_label).astype(np.uint8)
    return processed


def remove_small_objects(mask, min_size, threshold):
    mask = ((mask > threshold) * 255).astype(np.uint8)
    # find all connected components (labels)
    num_label, label, stats, centroid = cv2.connectedComponentsWithStats(mask, connectivity=8)
    # create a mask where small objects are removed
    processed = np.zeros_like(mask)
    for l in range(1, num_label):
        if stats[l, cv2.CC_STAT_AREA] >= min_size:
            processed[label == l] = 1
    return processed


def rle_encode(mask):
    pixel = mask.flatten()
    pixel = np.concatenate([[0], pixel, [0]])
    run = np.where(pixel[1:] != pixel[:-1])[0] + 1
    run[1::2] -= run[::2]
    rle = ' '.join(str(r) for r in run)
    if rle == '':
        rle = '1 0'
    return rle

In [None]:
checkpoint_file = \
    '/kaggle/input/blood-vessel-segmentation/trained_unet.pth'

net = Net()
state_dict = torch.load(checkpoint_file, map_location=lambda storage, loc: storage)#['state_dict']
print(net.load_state_dict(state_dict, strict=False))  # True

net = net.eval()
net = net.cuda()
#net = torch.compile(net)

In [None]:
def do_submit():
    
    submission_df = []
    for d in valid_meta:
        volume = [cv2.imread(f, cv2.IMREAD_GRAYSCALE) for f in d.file]
        volume = np.stack(volume)
        D, H, W = volume.shape
        
        predict = np.zeros(d.shape, dtype=np.float16)
        axes = [0,1,2] #[2]  # 
        for axis in axes:  # 0
            loader = np.array_split(np.arange((D, H, W)[axis]), max(1, int((D, H, W)[axis] // cfg.batch_size)))
            num_valid = len(loader)
            
            B = 0 
            start_timer = timer()
            for t in range(num_valid):
                print(f'\r validation: {t}/{num_valid}', timer() - start_timer, end='', flush=True)

                if axis == 0:
                    image = volume[loader[t].tolist()]
                if axis == 1:
                    image = volume[:, loader[t].tolist()]
                    image = image.transpose(1, 0, 2)
                if axis == 2:
                    image = volume[:, :, loader[t].tolist()]
                    image = image.transpose(2, 0, 1)

                batch_size, bh, bw = image.shape
                m = image.reshape(batch_size, -1)
                m = (m - m.min(keepdims=True)) / (m.max(keepdims=True) - m.min(keepdims=True) + 0.001)
                m = m.reshape(batch_size, bh, bw)
                m = np.ascontiguousarray(m)
                image = torch.from_numpy(m).float().cuda().unsqueeze(1)

                #----
                counter = 0
                vessel, kidney = 0, 0
                image = image.cuda() 
                with torch.cuda.amp.autocast(enabled=True):
                    with torch.no_grad():
                        v, k = net(image)
                        vessel += v
                        kidney += k
                        counter += 1

                        v, k = net(torch.flip(image, dims=[2,]))
                        vessel += torch.flip(v, dims=[2,])
                        kidney += torch.flip(k, dims=[2,])
                        counter += 1

                        v, k = net(torch.flip(image, dims=[3,]))
                        vessel += torch.flip(v, dims=[3,])
                        kidney += torch.flip(k, dims=[3,])
                        counter += 1

                        v, k = net(torch.rot90(image, k=1, dims=[2,3]))
                        vessel += torch.rot90(v, k=-1, dims=[2,3])
                        kidney += torch.rot90(k, k=-1, dims=[2,3])
                        counter += 1

                        v, k = net(torch.rot90(image, k=2, dims=[2,3]))
                        vessel += torch.rot90(v, k=-2, dims=[2,3])
                        kidney += torch.rot90(k, k=-2, dims=[2,3])
                        counter += 1

                        v, k = net(torch.rot90(image, k=3, dims=[2,3]))
                        vessel += torch.rot90(v, k=-3, dims=[2,3])
                        kidney += torch.rot90(k, k=-3, dims=[2,3])
                        counter += 1

                vessel = vessel/counter   
                kidney = kidney/counter      
                #print(i, image.shape, mask.shape) 

                vessel = vessel.float().data.cpu().numpy()
                kidney = kidney.float().data.cpu().numpy()

                # ----------------------------------------
                batch_size = len(vessel)
                for b in range(batch_size):
                    mk = kidney[b, 0]
                    mk = choose_biggest_object(mk, threshold=0.5) 
                    mv = vessel[b, 0]
                    p = (mv * mk)
                    if axis == 0:
                        predict[B + b] += p
                    if axis == 1:
                        predict[:, B + b] += p
                    if axis == 2:
                        predict[:, :, B + b] += p

                    #debug only
                    if (t==0) and (mode=='local'): 
                  
                        m = image[b, 0].float().data.cpu().numpy()
                        #p = predict[B+b]

                        plt.imshow(np.hstack([m,p]),cmap='gray')
                        plt.show()
                        #plt.waitforbuttonpress()

                #----------------------------------------
                B += batch_size

        print('')
        predict = predict / len(axes)
        predict = (predict>cfg.p_threshold).astype(np.uint8)

        #post processing ---
        if cfg.cc_threshold>0:
            predict = cc3d.dust(
                predict,
                connectivity=26,
                threshold=cfg.cc_threshold,
                in_place=False
            )

        rle = [rle_encode(p) for p in predict]
        submission_df.append(
            pd.DataFrame(data={
                'id'  : d['id'],
                'rle' : rle,
            })
        )

    submission_df =pd.concat(submission_df)
    submission_df.to_csv('submission.csv', index=False)
    print(submission_df)
    

glob_file = glob(f'{data_dir}/test/kidney_5/images/*.tif')
if (mode=='submit') and (len(glob_file)==3): #cannot do 3d cnn because too few test files
    make_dummy_submission()
else:
    do_submit()