In [1]:
import torch
import numpy as np
import torch.nn as nn
import numpy as np
from tqdm import tqdm
import cv2
from glob import glob
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader

In [2]:
class CFG:
    # ============== model CFG =============
    model_name = "Unet"
    backbone = "se_resnext50_32x4d"

    in_chans = 5  # 65
    # ============== _ CFG =============
    image_size = 512
    input_size = 512
    tile_size = image_size
    stride = tile_size // 4
    drop_egde_pixel = 32

    target_size = 1
    chopping_percentile = 1e-3
    # ============== fold =============
    valid_id = 1
    batch = 128
    th_percentile = 0.0021
    model_path = ["data/se_resnext50_32x4d_19_loss0.12_score0.79_val_loss0.25_val_score0.79.pt"]

In [3]:
def rle_encode(mask):
    pixel = mask.flatten()
    pixel = np.concatenate([[0], pixel, [0]])
    run = np.where(pixel[1:] != pixel[:-1])[0] + 1
    run[1::2] -= run[::2]
    rle = " ".join(str(r) for r in run)
    if rle == "":
        rle = "1 0"
    return rle


def min_max_normalization(x: torch.Tensor) -> torch.Tensor:
    """input.shape=(batch,f1,...)"""
    shape = x.shape
    if x.ndim > 2:
        x = x.reshape(x.shape[0], -1)

    min_ = x.min(dim=-1, keepdim=True)[0]
    max_ = x.max(dim=-1, keepdim=True)[0]
    if min_.mean() == 0 and max_.mean() == 1:
        return x.reshape(shape)

    x = (x - min_) / (max_ - min_ + 1e-9)
    return x.reshape(shape)


def norm_with_clip(x: torch.Tensor, smooth=1e-5):
    dim = list(range(1, x.ndim))
    mean = x.mean(dim=dim, keepdim=True)
    std = x.std(dim=dim, keepdim=True)
    x = (x - mean) / (std + smooth)
    x[x > 5] = (x[x > 5] - 5) * 1e-3 + 5
    x[x < -3] = (x[x < -3] + 3) * 1e-3 - 3
    return x


class Data_loader(Dataset):
    def __init__(self, path, s="/images/"):
        self.paths = glob(path + f"{s}*.tif")
        self.paths.sort()
        self.bool = s == "/labels/"

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        img = cv2.imread(self.paths[index], cv2.IMREAD_GRAYSCALE)
        img = torch.from_numpy(img)
        if self.bool:
            img = img.to(torch.bool)
        else:
            img = img.to(torch.uint8)
        return img


def load_data(path, s):
    data_loader = Data_loader(path, s)
    data_loader = DataLoader(data_loader, batch_size=16, num_workers=2)
    data = []
    for x in tqdm(data_loader):
        data.append(x)
    x = torch.cat(data, dim=0)
    ########################################################################
    TH = x.reshape(-1).numpy()
    index = -int(len(TH) * CFG.chopping_percentile)
    TH: int = np.partition(TH, index)[index]
    x[x > TH] = int(TH)
    ########################################################################
    TH = x.reshape(-1).numpy()
    index = -int(len(TH) * CFG.chopping_percentile)
    TH: int = np.partition(TH, -index)[-index]
    x[x < TH] = int(TH)
    ########################################################################
    # x=(min_max_normalization(x.to(torch.float16))*255).to(torch.uint8)
    return x


class Pipeline_Dataset(Dataset):
    def __init__(self, x, path):
        self.img_paths = glob(path + "/images/*")
        self.img_paths.sort()
        self.in_chan = CFG.in_chans
        z = torch.zeros(self.in_chan // 2, *x.shape[1:], dtype=x.dtype)
        self.x = torch.cat((z, x, z), dim=0)

    def __len__(self):
        return self.x.shape[0] - self.in_chan + 1

    def __getitem__(self, index):
        x = self.x[index : index + self.in_chan]
        return x, index

    def get_mark(self, index):
        id = self.img_paths[index].split("/")[-3:]
        id.pop(1)
        id = "_".join(id)
        return id[:-4]

    def get_marks(self):
        ids = []
        for index in range(len(self)):
            ids.append(self.get_mark(index))
        return ids


def add_edge(x: torch.Tensor, edge: int):
    # x=(C,H,W)
    # output=(C,H+2*edge,W+2*edge)
    mean_ = int(x.to(torch.float32).mean())
    x = torch.cat([x, torch.ones([x.shape[0], edge, x.shape[2]], dtype=x.dtype, device=x.device) * mean_], dim=1)
    x = torch.cat([x, torch.ones([x.shape[0], x.shape[1], edge], dtype=x.dtype, device=x.device) * mean_], dim=2)
    x = torch.cat([torch.ones([x.shape[0], edge, x.shape[2]], dtype=x.dtype, device=x.device) * mean_, x], dim=1)
    x = torch.cat([torch.ones([x.shape[0], x.shape[1], edge], dtype=x.dtype, device=x.device) * mean_, x], dim=2)
    return x

In [4]:
output = [torch.load("output.pt")]
path = "data/blood-vessel-segmentation/train/kidney_1_dense"
x = load_data(path, "/images/")
ids = Pipeline_Dataset(x, path).get_marks()

100%|██████████| 143/143 [00:06<00:00, 21.09it/s]


In [5]:
TH = [x.flatten().numpy() for x in output]
TH = np.concatenate(TH)
index = -int(len(TH) * CFG.th_percentile)
TH: int = np.partition(TH, index)[index]
print(TH)

255


In [6]:
pred = (output[0] >= 128).byte()

In [7]:
from PIL import Image
class KidneyDataset(torch.utils.data.Dataset):
    def __init__(self, pred, msks_dir):
        self.pred = pred.float()
        self.msks_dir = msks_dir

    def __len__(self):
        return len(self.pred)

    def __getitem__(self, idx):
        msk_path = self.msks_dir + f"{idx:04}.tif"

        msk = Image.open(msk_path)
        msk = np.array(msk)
        slice_pred = self.pred[idx]

        msk = torch.as_tensor(msk, dtype=torch.float32)
        msk /= 255 # {0, 1} values

        return slice_pred, msk

In [8]:
ds = KidneyDataset(pred, "data/blood-vessel-segmentation/train/kidney_1_dense/labels/")
dl = torch.utils.data.DataLoader(ds, batch_size=5, shuffle=False, num_workers=8)


In [9]:
from surface_dice import SurfaceDiceMetric
metric = SurfaceDiceMetric(len(dl), device="cuda:1")

In [10]:
for pred, target in tqdm(dl):
    pred, target = pred.to("cuda:1"), target.to("cuda:1")
    metric.process_batch(pred, target)

100%|██████████| 456/456 [00:20<00:00, 22.24it/s]


In [11]:
metric.compute()

0.630912184715271