# Baseline


In [1]:
import torch
from PIL import Image
import numpy as np
from tqdm import tqdm
import os
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.utils.data import DataLoader
from surface_dice import SurfaceDiceMetric
import albumentations as A
import random
import segmentation_models_pytorch as smp
from patcher import Patcher

hostname = os.uname().nodename
print("Hostname:", hostname)
input_dir = "data/blood-vessel-segmentation/" if hostname == "gamma" else "/kaggle/input/blood-vessel-segmentation/"


device = "cuda" if torch.cuda.is_available() else "cpu"
train_dir = input_dir + "train/"

# reproducibility
seed = 42
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ["PYTHONHASHSEED"] = str(seed)



Hostname: 6a6058df1898


## Load data

### Dataset

In [2]:
class KidneyDataset(torch.utils.data.Dataset):
    def __init__(self, imgs_dir, msks_dir, slices_ids, transforms=None):
        self.imgs_dir = imgs_dir
        self.msks_dir = msks_dir
        self.slices_ids = slices_ids
        self.transforms = transforms
        self.h = Image.open(imgs_dir + slices_ids[0]).height
        self.w = Image.open(imgs_dir + slices_ids[0]).width

    def __len__(self):
        return len(self.slices_ids)

    def __getitem__(self, idx):
        slice_id = self.slices_ids[idx]
        img_path = self.imgs_dir + slice_id
        msk_path = self.msks_dir + slice_id

        img = Image.open(img_path)
        msk = Image.open(msk_path)
        img = np.array(img, dtype=np.float32)
        msk = np.array(msk)

        if self.transforms is not None:
            t = self.transforms(image=img, mask=msk)
            img = t["image"]
            msk = t["mask"]
            
        img = torch.from_numpy(img)[None, :]
        msk = torch.as_tensor(msk)
        img /= 31000
        msk = msk // 255

        return img, msk

In [3]:
imgs_dir = f"{train_dir}kidney_1_dense/images/"
msks_dir = f"{train_dir}kidney_1_dense/labels/"
slices_ids = sorted(os.listdir(imgs_dir))

patch_size = 224
transforms = A.Compose(
    [
        A.RandomCrop(patch_size, patch_size)
    ]
)

train_ds = KidneyDataset(
    imgs_dir=imgs_dir,
    msks_dir=msks_dir,
    slices_ids=slices_ids,
    transforms=transforms,
)

eval_ds = KidneyDataset(
    imgs_dir=imgs_dir,
    msks_dir=msks_dir,
    slices_ids=slices_ids,
)

print("Train Dataset length:", len(train_ds))
print("Eval Dataset length:", len(eval_ds))


Train Dataset length: 2279
Eval Dataset length: 2279


### Dataloaders

In [4]:
train_dl = DataLoader(
    train_ds,
    batch_size=32,
    num_workers=8 if hostname == "gamma" else 2,
    shuffle=True,
    persistent_workers=True
)

eval_dl = DataLoader(
    eval_ds,
    batch_size=16,
    num_workers=8 if hostname == "gamma" else 2,
    shuffle=False,
    persistent_workers=False
)

print("Train DataLoader length:", len(train_dl))
print("Eval DataLoader length:", len(eval_dl))

Train DataLoader length: 72
Eval DataLoader length: 143


## Define model

In [5]:
net = smp.Unet(
    encoder_name="timm-mobilenetv3_small_075",
    encoder_weights=None,
    in_channels=1,
    classes=1,
)
net.to(device)
print(f"Number of params: {sum([p.nelement() for p in net.parameters()]):,}")

Number of params: 2,881,625


## Train and evalution pipeline

### Loss function

In [6]:
loss_fn = torch.nn.BCEWithLogitsLoss()

In [None]:
print("Random loss:", -torch.tensor(1/2).log())

### Optimizer and scheduler

In [None]:
lr = 3e-4
optimizer = torch.optim.Adam(lr=lr, params=net.parameters())

### Train method

In [None]:
def train():
    train_loss = 0.0
    net.train()
    for x, y in tqdm(train_dl):
        x, y = x.to(device), y.to(device).float()
        logits = net(x).squeeze()
        loss = loss_fn(logits, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    train_loss /= len(train_dl)

    return train_loss

### Extract and Merge Patches

In [7]:
h, w = eval_ds.h, eval_ds.w
overlap = 50
patcher = Patcher(h, w, patch_size=patch_size, overlap=overlap)

### Eval method

In [8]:
@torch.no_grad()
def eval(save_preds=False):
    eval_loss = 0.0
    idx = 0
    net.eval()
    metric = SurfaceDiceMetric(n_batches=len(eval_dl), device=device)
    for x, y in tqdm(eval_dl):
        B, C, H, W = x.shape
        x, y = x.to(device), y.to(device).float()
        x = patcher.extract_patches(x)  # (B, n_patches, C, H, W)
        x = x.flatten(end_dim=1)  # (B * n_patches, C, H, W)

        logits = net(x)  # (B * n_patches, C, H, W)
        logits = logits.unflatten(0, (B, -1))  # (B, n_patches, C, H, W)
        logits = patcher.merge_patches(logits).squeeze()  # (B, H, W)

        loss = loss_fn(logits, y)

        # save probabilities maps
        if save_preds:
            for i in range(bs):
                Image.fromarray((logits.cpu()[i].sigmoid() * (2**16 - 1)).numpy().astype(np.uint16)).save(f"preds/{idx:04}.tif")
                idx += 1

        pred = torch.where(logits.sigmoid() >= 0.5, 1, 0)

        metric.process_batch(pred, y)
        eval_loss += loss.item()

    eval_loss /= len(eval_dl)
    surface_dice = metric.compute()

    return eval_loss, surface_dice

In [None]:
%%time

epochs = 40
losses = []
dices = []

for epoch in range(1, epochs + 1):
    train_loss = train()
    print(f"EPOCH {epoch}, TLOSS {train_loss:.4f}")
    losses.append(train_loss)

In [9]:
eval_loss, surface_dice = eval(save_preds=False)
print("Evaluation")
print(f"ELOSS {eval_loss:.4f}, SURFACE_DICE {surface_dice:.4f}")

  0%|          | 0/143 [00:00<?, ?it/s]

100%|██████████| 143/143 [01:11<00:00,  2.00it/s]

Evaluation
ELOSS 0.0042, SURFACE_DICE 0.7274





In [None]:
torch.save(net.state_dict(), f"checkpoints/baseline_train_sdc_{surface_dice:.3f}.pth")