In [1]:
import torch
import os
from PIL import Image
import torchvision.transforms as T
import pandas as pd
from rle import rle_decode
import numpy as np
from tqdm import tqdm

torch_ver_major = int(torch.__version__.split('.')[0])
dtype_index = torch.int32 if torch_ver_major >= 2 else torch.long

train_dir = "/kaggle/input/blood-vessel-segmentation/train/"
msks_dir = f"{train_dir}kidney_1_dense/labels/"
imgs_dir = f"{train_dir}kidney_1_dense/images/"
slices_ids = sorted(os.listdir(imgs_dir))
device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
w = Image.open(msks_dir + slices_ids[0]).width
h = Image.open(msks_dir + slices_ids[0]).height
print("Width:", w)
print("Height:", h)

Width: 912
Height: 1303


In [3]:
class TestMetricDataset(torch.utils.data.Dataset):
    def __init__(self, sub_df, msks_dir, slices_ids, transform=None, target_transform=None):
        self.sub_df = sub_df
        self.msks_dir = msks_dir
        self.slices_ids = slices_ids
        self.transform = transform
        self.target_transform = target_transform
        
    def __len__(self):
        return len(self.slices_ids)
    
    def __getitem__(self, idx):
        slice_id = self.slices_ids[idx]

        pred_rle = sub_df.iloc[idx]["rle"]
        pred = rle_decode(pred_rle, (h, w))
        pred = torch.from_numpy(pred)

        target_path = self.msks_dir + slice_id 
        target = Image.open(target_path)

        if self.target_transform is not None:
            target = self.target_transform(target).type(torch.int8).squeeze()

        return pred, target

In [4]:
sub_df = pd.read_csv("ref_sub.csv")

target_transform = T.Compose([
    T.ToTensor(), 
])

ds = TestMetricDataset(sub_df=sub_df, msks_dir=msks_dir, slices_ids=slices_ids, target_transform=target_transform)
# batch_size needs to be an odd number
dl = torch.utils.data.DataLoader(ds, batch_size=15, num_workers=os.cpu_count(), drop_last=False)
n_batches = len(dl)
print("ds len:", len(ds))
print("dl len:", n_batches)


ds len: 2279
dl len: 152


In [5]:
from surface_dice import SurfaceDiceMetric
device = "cpu"
metric = SurfaceDiceMetric(n_batches, device)
for pred, target in tqdm(dl):
    pred, target = pred.to(device), target.to(device)
    metric.process_batch(pred, target)
metric.compute_metric()

# 0.87990802526474
# 0.8799066326898859
# 0.8799066326898859

  0%|          | 0/152 [00:00<?, ?it/s]

  6%|▌         | 9/152 [00:17<04:39,  1.96s/it]


KeyboardInterrupt: 