In [1]:
import torch
import os
from PIL import Image
import torchvision.transforms as T
import pandas as pd
from rle import rle_decode
import numpy as np
from tqdm import tqdm

torch_ver_major = int(torch.__version__.split('.')[0])
dtype_index = torch.int32 if torch_ver_major >= 2 else torch.long

train_dir = "/kaggle/input/blood-vessel-segmentation/train/"
msks_dir = f"{train_dir}kidney_1_dense/labels/"
imgs_dir = f"{train_dir}kidney_1_dense/images/"
slices_ids = sorted(os.listdir(imgs_dir))
device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
w = Image.open(msks_dir + slices_ids[0]).width
h = Image.open(msks_dir + slices_ids[0]).height
print("Width:", w)
print("Height:", h)

Width: 912
Height: 1303


In [3]:
class TestMetricDataset(torch.utils.data.Dataset):
    def __init__(self, sub_df, msks_dir, slices_ids, transform=None, target_transform=None):
        self.sub_df = sub_df
        self.msks_dir = msks_dir
        self.slices_ids = slices_ids
        self.transform = transform
        self.target_transform = target_transform
        
    def __len__(self):
        return len(self.slices_ids)
    
    def __getitem__(self, idx):
        slice_id = self.slices_ids[idx]

        pred_rle = sub_df.iloc[idx]["rle"]
        pred = rle_decode(pred_rle, (h, w))
        pred = torch.from_numpy(pred)

        target_path = self.msks_dir + slice_id 
        target = Image.open(target_path)

        if self.target_transform is not None:
            target = self.target_transform(target).type(torch.int8).squeeze()

        return pred, target

In [4]:
sub_df = pd.read_csv("ref_sub.csv")

target_transform = T.Compose([
    T.ToTensor(), 
])

ds = TestMetricDataset(sub_df=sub_df, msks_dir=msks_dir, slices_ids=slices_ids, target_transform=target_transform)
dl = torch.utils.data.DataLoader(ds, batch_size=2)
n_batches = len(dl)
print("ds len:", len(ds))
print("dl len:", n_batches)


ds len: 2279
dl len: 1140


In [5]:
len(ds) % 2

1

In [6]:
# pred, target = next(iter(dl))
# print(pred.shape, target.shape)

In [7]:
from surface_dice import SurfaceDiceMetric

metric = SurfaceDiceMetric(n_batches)
for pred, target in tqdm(dl):
    metric.process_batch(pred, target)

100%|██████████| 1140/1140 [01:43<00:00, 11.05it/s]


In [8]:
metric.compute_metric()

0.8802322745323181

In [None]:
pred = torch.vstack([torch.zeros((1, h, w), dtype=torch.uint8), pred])
target = torch.vstack([torch.zeros((1, h, w), dtype=torch.uint8), target])
print(pred.shape, target.shape)

In [None]:
unfold = torch.nn.Unfold(kernel_size=(2, 2), padding=1)
# compute the surface area for pred
surface_pred = pred.to(torch.float16).view(3, 2, h, w)
print("suraface_pred shape:", surface_pred.shape)
cubes_float = unfold(surface_pred).permute(1, 0, 2).reshape(8, -1)
print("cubes_float shape:", cubes_float.shape)
cubes_byte = torch.zeros(cubes_float.size(1), dtype=dtype_index)
print("cubes_byte shape:", cubes_byte.shape)

In [None]:
# x = torch.arange(3*8*5).view(3, 8, 5)
# print(x)
# print(x.shape)
# print(x.permute(1, 0, 2).reshape(8, -1))


In [None]:
for k in range(8):
    cubes_byte += cubes_float[k, :].to(dtype_index) << k

In [None]:
cubes_area = metric.area[cubes_byte]
cubes_area