In [1]:
import torch
from PIL import Image
import numpy as np
from tqdm import tqdm
import os
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.utils.data import DataLoader
from surface_dice import SurfaceDiceMetric
import albumentations as A
import random
import segmentation_models_pytorch as smp
from patcher import Patcher

hostname = os.uname().nodename
print("Hostname:", hostname)
input_dir = "data/blood-vessel-segmentation/" if hostname == "gamma" else "/kaggle/input/blood-vessel-segmentation/"


device = "cuda" if torch.cuda.is_available() else "cpu"
train_dir = input_dir + "train/"

# reproducibility
seed = 42
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ["PYTHONHASHSEED"] = str(seed)




Hostname: 6a6058df1898


In [2]:
class KidneyDataset(torch.utils.data.Dataset):
    def __init__(self, imgs_dir, msks_dir, slices_ids, transforms=None):
        self.imgs_dir = imgs_dir
        self.msks_dir = msks_dir
        self.slices_ids = slices_ids
        self.transforms = transforms
        self.h = Image.open(imgs_dir + slices_ids[0]).height
        self.w = Image.open(imgs_dir + slices_ids[0]).width

    def __len__(self):
        return len(self.slices_ids)

    def __getitem__(self, idx):
            
        slice_id = self.slices_ids[idx]
        img_path = self.imgs_dir + slice_id
        msk_path = self.msks_dir + slice_id

        img = np.array(Image.open(img_path), dtype=np.float32)
        msk = np.array(Image.open(msk_path))

        if self.transforms is not None:
            t = self.transforms(image=img, mask=msk)
            img = t["image"]
            msk = t["mask"]
            
        img = torch.from_numpy(img)[None, :]
        msk = torch.from_numpy(msk)
        img /= 31000 
        msk = msk // 255 

        return img, msk

class KidneyDatasetPrefetched(torch.utils.data.Dataset):
    def __init__(self, ds, imgs_file=None, msks_file=None):
        from_file = imgs_file is not None and msks_file is not None
        print("From file:", from_file)
        if from_file:
            self.imgs = torch.load(imgs_file)
            self.msks = torch.load(msks_file)
        else:
            dl = DataLoader(ds, batch_size=len(ds), shuffle=False, persistent_workers=False)
            self.imgs, self.msks = next(iter(dl))

        self.h, self.w = self.imgs.shape[-2:]
    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        return self.imgs[idx], self.msks[idx]

In [3]:
import torch
import torch.nn.functional as F
from math import ceil

class PatcherNew:
    def __init__(self, h, w, p_size, overlap):
        self.h, self.w = h, w
        self.p_size = p_size
        self.overlap = overlap

        self.stride = p_size - overlap
        self.h_pad = self.stride * ceil((h - p_size) / self.stride) + p_size - h
        self.w_pad = self.stride * ceil((w - p_size) / self.stride) + p_size - w

        self.unfold = torch.nn.Unfold(kernel_size=(p_size, p_size), stride=self.stride)
        self.fold = torch.nn.Fold(
            output_size=(h + self.h_pad, w + self.w_pad),
            kernel_size=(p_size, p_size),
            stride=self.stride
        )

    def extract_patches(self, x):
        assert x.ndim == 4
        x = F.pad(x, (0, self.w_pad, 0, self.h_pad), mode="reflect")
        B, C, H, W = x.shape

        patches = x.unfold(2, self.p_size, self.stride).unfold(3, self.p_size, self.stride)  # (B, C, h_steps, w_steps, p_size, p_size)
        patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous()                             # (B, h_steps, w_steps, C, p_size, p_size)
        patches = patches.view(B, -1, C, self.p_size, self.p_size)                           # (B, n_patches, p_size, p_size)
        return patches
    
    def merge_patches(self, patches):
        assert patches.ndim == 5
        B, N, C, _, _ = patches.shape

        # fold expects the patches tensor to have a shape (B, C * p_size * p_size, N)
        x = patches.permute(0, 2, 3, 4, 1).view(B, C * self.p_size * self.p_size, N)  
        x = self.fold(x)  # (B, C, h + pad_h, w + pad_w)

        # as patches overlap we average the values of overlapping pixels
        weight_mask = 1 / self.fold(self.unfold(torch.ones(x.shape[-3:], device=patches.device)))
        x = x * weight_mask

        x = x[:, :, :self.h, :self.w]
        return x



In [4]:
imgs_dir = f"{train_dir}kidney_1_dense/images/"
msks_dir = f"{train_dir}kidney_1_dense/labels/"
slices_ids = sorted(os.listdir(imgs_dir))

patch_size = 224

eval_ds = KidneyDataset(
    imgs_dir=imgs_dir,
    msks_dir=msks_dir,
    slices_ids=slices_ids,
)

# eval_ds = KidneyDatasetPrefetched(eval_ds, imgs_file="kidney_1_imgs_prefetched.pth", msks_file="kidney_1_msks_prefetched.pth")
# eval_ds = KidneyDatasetPrefetched(eval_ds)

eval_dl = DataLoader(
    eval_ds,
    batch_size=16,
    num_workers=8 if hostname == "gamma" else 2,
    shuffle=False,
    persistent_workers=False
)

h = eval_ds.h
w = eval_ds.w

In [5]:
net = smp.Unet(
    encoder_name="timm-mobilenetv3_small_075",
    encoder_weights=None,
    in_channels=1,
    classes=1,
)
net.load_state_dict(torch.load("baseline_train_sdc_0.727.pth"))
net.to(device)

loss_fn = torch.nn.BCEWithLogitsLoss()

In [6]:
# for x, y in tqdm(eval_dl):
#     x, y = x.to(device), y.to(device).float()

overlap = 50
patcher = PatcherNew(h, w, p_size=patch_size, overlap=overlap)
for x, y in tqdm(eval_dl):
    bs = len(x)
    x, y = x.to(device), y.to(device).float()
    x = patcher.extract_patches(x)
    x = patcher.merge_patches(x)

100%|██████████| 143/143 [00:59<00:00,  2.39it/s]


In [7]:
@torch.no_grad()
def eval_old(save_preds=False):
    overlap = 50
    patcher = Patcher(h, w, patch_size=patch_size, overlap=overlap)
    eval_loss = 0.0
    idx = 0
    net.eval()
    metric = SurfaceDiceMetric(n_batches=len(eval_dl), device=device)
    for x, y in tqdm(eval_dl):
        bs = len(x)
        x, y = x.to(device), y.to(device).float()
        x = patcher.extract_patches(x)  # (bs, n_patches, h, w)

        logits = net(x.reshape(-1, 1, patch_size, patch_size))  # (bs * n_patches, 1, patch_size, patch_size)
        logits = logits.view(bs, -1, patch_size, patch_size)  # (bs, n_patches, patch_size, patch_size)
        logits = patcher.merge_patches(logits).squeeze()  # (bs, h, w)

        loss = loss_fn(logits, y)

        # save probabilities maps
        if save_preds:
            for i in range(bs):
                Image.fromarray((logits.cpu()[i].sigmoid() * (2**16 - 1)).numpy().astype(np.uint16)).save(f"preds/{idx:04}.tif")
                idx += 1

        pred = torch.where(logits.sigmoid() >= 0.5, 1, 0)

        metric.process_batch(pred, y)
        eval_loss += loss.item()

    eval_loss /= len(eval_dl)
    surface_dice = metric.compute()

    return eval_loss, surface_dice

@torch.no_grad()
def eval(save_preds=False):
    overlap = 50
    patcher = PatcherNew(h, w, p_size=patch_size, overlap=overlap)
    eval_loss = 0.0
    idx = 0
    net.eval()
    metric = SurfaceDiceMetric(n_batches=len(eval_dl), device=device)
    for x, y in tqdm(eval_dl):
        B, C, H, W = x.shape
        x, y = x.to(device), y.to(device).float()
        x = patcher.extract_patches(x)  # (B, n_patches, C, H, W)
        x = x.flatten(end_dim=1)

        logits = net(x)
        logits = logits.unflatten(0, (B, -1))
        logits = patcher.merge_patches(logits).squeeze()  # (bs, h, w)

        loss = loss_fn(logits, y)

        # save probabilities maps
        if save_preds:
            for i in range(bs):
                Image.fromarray((logits.cpu()[i].sigmoid() * (2**16 - 1)).numpy().astype(np.uint16)).save(f"preds/{idx:04}.tif")
                idx += 1

        pred = torch.where(logits.sigmoid() >= 0.5, 1, 0)

        metric.process_batch(pred, y)
        eval_loss += loss.item()

    eval_loss /= len(eval_dl)
    surface_dice = metric.compute()

    return eval_loss, surface_dice

In [8]:
eval_loss, surface_dice = eval_old(save_preds=False)
print(f"ELOSS {eval_loss}, SURFACE_DICE {surface_dice}")

100%|██████████| 143/143 [01:39<00:00,  1.44it/s]

ELOSS 0.00424858136896298, SURFACE_DICE 0.7274355888366699





In [9]:
eval_loss, surface_dice = eval(save_preds=False)
print(f"ELOSS {eval_loss}, SURFACE_DICE {surface_dice}")

100%|██████████| 143/143 [01:11<00:00,  2.01it/s]

ELOSS 0.00424858136896298, SURFACE_DICE 0.7274355888366699



