In [1]:
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from torch.cuda.amp import autocast
import cv2
import os, sys
from glob import glob
import matplotlib.pyplot as plt
import pandas as pd
import segmentation_models_pytorch as smp
from torch.utils.data import Dataset, DataLoader
from torch.nn.parallel import DataParallel
from dotenv import load_dotenv

1. Load the whole volume into a tensor.
2. Add neccesary top and bottom padding to the volume so that the model can predict on stacks of 5 slices. Remember that the model takes a stack of 5 slices and predicts the mask of the slice in the middle of the stack. So with stack [0, 5] it predicts on [2], with stack [1, 6] it predicts on 3 and so on.
3. Now take a stack of 5 slices, this stack needs to be splitted into a bunch of small patches so that the data can fit in the accelerators memory. Make sure to pad the height and width of the stack so that there is no loss of information in the patches extraction.

In [2]:
class CFG:
    # ============== model CFG =============
    model_name = "Unet"
    backbone = "se_resnext50_32x4d"

    in_chans = 5
    # ============== _ CFG =============
    image_size = 512
    input_size = 512
    tile_size = image_size
    stride = tile_size // 4
    drop_egde_pixel = 32

    target_size = 1
    chopping_percentile = 1e-3
    # ============== fold =============
    batch = 128
    th_percentile = 0.0021
    model_path = ["/kaggle/input/sennet-yoyobar-v6-model/se_resnext50_32x4d_19_loss0.12_score0.79_val_loss0.25_val_score0.79.pt"]

In [3]:
class CustomModel(nn.Module):
    def __init__(self, CFG, weight=None):
        super().__init__()
        self.CFG = CFG
        self.model = smp.Unet(
            encoder_name=CFG.backbone,
            encoder_weights=weight,
            in_channels=CFG.in_chans,
            classes=CFG.target_size,
            activation=None,
        )
        self.batch = CFG.batch

    def forward_(self, image):
        output = self.model(image)
        return output[:, 0]

    def forward(self, x: torch.Tensor):
        x = x.to(torch.float32)  # (bs, c, h, w)
        x = norm_with_clip(x.reshape(-1, *x.shape[2:])).reshape(x.shape)

        with torch.no_grad():
            x = self.forward(x)
        x = x.sigmoid()

        return x


def build_model(weight=None):
    load_dotenv()

    print("model_name", CFG.model_name)
    print("backbone", CFG.backbone)

    model = CustomModel(CFG, weight)

    return model.cuda()

In [4]:
def rle_encode(mask):
    pixel = mask.flatten()
    pixel = np.concatenate([[0], pixel, [0]])
    run = np.where(pixel[1:] != pixel[:-1])[0] + 1
    run[1::2] -= run[::2]
    rle = " ".join(str(r) for r in run)
    if rle == "":
        rle = "1 0"
    return rle


def min_max_normalization(x: torch.Tensor) -> torch.Tensor:
    """input.shape=(batch,f1,...)"""
    shape = x.shape
    if x.ndim > 2:
        x = x.reshape(x.shape[0], -1)

    min_ = x.min(dim=-1, keepdim=True)[0]
    max_ = x.max(dim=-1, keepdim=True)[0]
    if min_.mean() == 0 and max_.mean() == 1:
        return x.reshape(shape)

    x = (x - min_) / (max_ - min_ + 1e-9)
    return x.reshape(shape)


def norm_with_clip(x: torch.Tensor, smooth=1e-5):
    dim = list(range(1, x.ndim))
    mean = x.mean(dim=dim, keepdim=True)
    std = x.std(dim=dim, keepdim=True)
    x = (x - mean) / (std + smooth)
    x[x > 5] = (x[x > 5] - 5) * 1e-3 + 5
    x[x < -3] = (x[x < -3] + 3) * 1e-3 - 3
    return x


class Data_loader(Dataset):
    def __init__(self, path, s="/images/"):
        self.paths = glob(path + f"{s}*.tif")
        self.paths.sort()
        self.bool = s == "/labels/"

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        img = cv2.imread(self.paths[index], cv2.IMREAD_GRAYSCALE)
        img = torch.from_numpy(img)
        if self.bool:
            img = img.to(torch.bool)
        else:
            img = img.to(torch.uint8)
        return img


def load_data(path, s):
    data_loader = Data_loader(path, s)
    data_loader = DataLoader(data_loader, batch_size=16, num_workers=2)
    data = []
    for x in tqdm(data_loader, desc="Loading data"):
        data.append(x)
    x = torch.cat(data, dim=0)
    ########################################################################
    TH = x.reshape(-1).numpy()
    index = -int(len(TH) * CFG.chopping_percentile)
    TH: int = np.partition(TH, index)[index]
    x[x > TH] = int(TH)
    ########################################################################
    TH = x.reshape(-1).numpy()
    index = -int(len(TH) * CFG.chopping_percentile)
    TH: int = np.partition(TH, -index)[-index]
    x[x < TH] = int(TH)
    ########################################################################
    # x=(min_max_normalization(x.to(torch.float16))*255).to(torch.uint8)
    return x


class Pipeline_Dataset(Dataset):
    def __init__(self, x, path):
        self.img_paths = glob(path + "/images/*")
        self.img_paths.sort()
        self.in_chan = CFG.in_chans
        z = torch.zeros(self.in_chan // 2, *x.shape[1:], dtype=x.dtype)
        self.x = torch.cat((z, x, z), dim=0)
        print("Len x:", len(self.x))

    def __len__(self):
        return self.x.shape[0] - self.in_chan + 1

    def __getitem__(self, index):
        x = self.x[index : index + self.in_chan]
        print(index + self.in_chan)
        return x, index

    def get_mark(self, index):
        id = self.img_paths[index].split("/")[-3:]
        id.pop(1)
        id = "_".join(id)
        return id[:-4]

    def get_marks(self):
        ids = []
        for index in range(len(self)):
            ids.append(self.get_mark(index))
        return ids


def add_edge(x: torch.Tensor, edge: int):
    # x=(C,H,W)
    # output=(C,H+2*edge,W+2*edge)
    mean_ = int(x.to(torch.float32).mean())
    x = torch.cat([x, torch.ones([x.shape[0], edge, x.shape[2]], dtype=x.dtype, device=x.device) * mean_], dim=1)
    x = torch.cat([x, torch.ones([x.shape[0], x.shape[1], edge], dtype=x.dtype, device=x.device) * mean_], dim=2)
    x = torch.cat([torch.ones([x.shape[0], edge, x.shape[2]], dtype=x.dtype, device=x.device) * mean_, x], dim=1)
    x = torch.cat([torch.ones([x.shape[0], x.shape[1], edge], dtype=x.dtype, device=x.device) * mean_, x], dim=2)
    return x

In [5]:
model = build_model()
model.load_state_dict(torch.load(CFG.model_path[0], "cpu"))
model.eval();

model_name Unet
backbone se_resnext50_32x4d


In [6]:
def get_output():
    path = "/kaggle/input/blood-vessel-segmentation/train/kidney_1_dense"

    x = load_data(path, "/images/")
    labels = torch.zeros_like(x, dtype=torch.uint8)

    # xy, zx, yz inference
    for axis in [0]:
        if axis == 0:
            x_ = x
            labels_ = labels
        elif axis == 1:
            x_ = x.permute(1, 2, 0)
            labels_ = labels.permute(1, 2, 0)
        elif axis == 2:
            x_ = x.permute(2, 0, 1)
            labels_ = labels.permute(2, 0, 1)

        print("Creating dataset and dataloader")
        dataset = Pipeline_Dataset(x_, path)
        dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=2)
        shape = dataset.x.shape[-2:]

        x1_list = np.arange(0, shape[0] + 1, CFG.stride)
        y1_list = np.arange(0, shape[1] + 1, CFG.stride)

        # predict one patch at a time
        # for img, index in tqdm(dataloader, desc="Inferencing dataset"):
        for img, index in dataloader:
            # print("loading img")
            # img = img.to("cuda")  # (1, C, H, W)
            # print("padding img")
            # img = add_edge(img[0], CFG.tile_size // 2)[None]  # (1, C, H + tile_size, W + tile_size)

            # mask_pred = torch.zeros_like(img[:, 0], dtype=torch.float32, device=img.device)
            # mask_count = torch.zeros_like(img[:, 0], dtype=torch.float32, device=img.device)

            # print("Img shape:", img.shape)
            # indexs = []
            # chip = []
            # for y1 in y1_list:
            #     for x1 in x1_list:
            #         x2 = x1 + CFG.tile_size
            #         y2 = y1 + CFG.tile_size
            #         indexs.append(
            #             [
            #                 x1 + CFG.drop_egde_pixel,
            #                 x2 - CFG.drop_egde_pixel,
            #                 y1 + CFG.drop_egde_pixel,
            #                 y2 - CFG.drop_egde_pixel,
            #             ]
            #         )
            #         chip.append(img[..., x1:x2, y1:y2])
            #         print(f"[{x1}:{x2}, {y1}:{y2}]")
            # break
            # # print("chip shape:", torch.cat(chip).shape)  # (n_chips, 5, 512, 512)

            # y_preds = model.forward(torch.cat(chip)).to(device=0)
            # # print("preds shape:", y_preds.shape)  # (n_chips, 512, 512)

            # if CFG.drop_egde_pixel:
            #     y_preds = y_preds[
            #         ..., CFG.drop_egde_pixel : -CFG.drop_egde_pixel, CFG.drop_egde_pixel : -CFG.drop_egde_pixel
            #     ]
            #     # print("preds shape after drop_edge_pixel:", y_preds.shape)  # (n_chips, 448, 448)

            # for i, (x1, x2, y1, y2) in enumerate(indexs):
            #     mask_pred[..., x1:x2, y1:y2] += y_preds[i]
            #     mask_count[..., x1:x2, y1:y2] += 1

            # mask_pred /= mask_count

            # # print("mask_pred shape:", mask_pred.shape)  # (1, 1815, 1424)

            # # Recover
            # mask_pred = mask_pred[
            #     ..., CFG.tile_size // 2 : -CFG.tile_size // 2, CFG.tile_size // 2 : -CFG.tile_size // 2
            # ]

            # # print("mask_pred shape:", mask_pred.shape) # (1, 1303, 912)

            # labels_[index] += (mask_pred[0] * 255).to(torch.uint8).cpu()
            print(index)



    return labels

In [7]:
# path = "/kaggle/input/blood-vessel-segmentation/train/kidney_1_dense"
# x = load_data(path, "/images/")
# dataset = Pipeline_Dataset(x, path)

In [10]:
# 2 + 2279 + 2

# 2283


# 0, 5
# 1, 6
# 2, 7
# ...
# 2278, 2284

2279

In [7]:
output = get_output()

Loading data:  80%|███████▉  | 114/143 [00:04<00:01, 24.89it/s]

Loading data: 100%|██████████| 143/143 [00:06<00:00, 23.36it/s]


Creating dataset and dataloader
Len x: 2283
5
6
7
89

1110

13
12
15
14
17
16
19
18
21
20
23
22
25
24
27
26
29
28
31
30
33
32
35
34
3736

3938

4140

tensor([0])
tensor([1])
tensor([2])
tensor([3])
tensor([4])
tensor([5])
tensor([6])
tensor([7])
tensor([8])
tensor([9])
tensor([10])
tensor([11])
tensor([12])
tensor([13])
tensor([14])
tensor([15])
tensor([16])
tensor([17])
tensor([18])
tensor([19])
tensor([20])
tensor([21])
tensor([22])
tensor([23])
tensor([24])
tensor([25])
tensor([26])
tensor([27])
tensor([28])
tensor([29])
tensor([30])
tensor([31])
tensor([32])
tensor([33])
tensor([34])
4243

4544

4746
48

5049

5251

53
54
55
56
57
5859

60
61
62
6364

6665

6867

6970

72
71
74
73tensor([35])
tensor([36])
tensor([37])
tensor([38])
tensor([39])
tensor([40])
tensor([41])
tensor([42])
tensor([43])
tensor([44])
tensor([45])
tensor([46])
tensor([47])
tensor([48])
tensor([49])
tensor([50])
tensor([51])
tensor([52])
tensor([53])
tensor([54])
tensor([55])
tensor([56])
tensor([57])
tensor([

In [None]:
####################################
TH = [x.flatten().numpy() for x in output]
TH = np.concatenate(TH)
index = -int(len(TH) * CFG.th_percentile)
TH: int = np.partition(TH, index)[index]
print(TH)

####################################
submission_df = []
debug_count = 0
for index in range(len(ids)):
    id = ids[index]
    i = 0
    for x in output:
        if index >= len(x):
            index -= len(x)
            i += 1
        else:
            break
    mask_pred = (output[i][index] > TH).numpy()
    ####################################
    if not is_submit:
        plt.subplot(121)
        plt.imshow(mask_pred)
        plt.show()
        debug_count += 1
        if debug_count > 6:
            break

    rle = rle_encode(mask_pred)

    submission_df.append(
        pd.DataFrame(
            data={
                "id": id,
                "rle": rle,
            },
            index=[0],
        )
    )

submission_df = pd.concat(submission_df)
submission_df.to_csv("submission.csv", index=False)
submission_df.head(6)