# Baseline


In [1]:
import torch
from PIL import Image
import numpy as np
from tqdm import tqdm
import os
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.utils.data import DataLoader
from surface_dice import SurfaceDiceMetric
import albumentations as A
import random
import segmentation_models_pytorch as smp
from patcher import Patcher

hostname = os.uname().nodename
input_dir = "data/blood-vessel-segmentation/" if hostname == "gamma" else "/kaggle/input/blood-vessel-segmentation"


device = "cuda" if torch.cuda.is_available() else "cpu"
train_dir = input_dir + "train/"

# reproducibility
seed = 42
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ["PYTHONHASHSEED"] = str(seed)

  from .autonotebook import tqdm as notebook_tqdm


## Load data

### Dataset

In [2]:
class KidneyDataset(torch.utils.data.Dataset):
    def __init__(self, imgs_dir, msks_dir, slices_ids, transforms=None):
        self.imgs_dir = imgs_dir
        self.msks_dir = msks_dir
        self.slices_ids = slices_ids
        self.transforms = transforms
        self.h = Image.open(imgs_dir + slices_ids[0]).height
        self.w = Image.open(imgs_dir + slices_ids[0]).width

    def __len__(self):
        return len(self.slices_ids)

    def __getitem__(self, idx):
        slice_id = self.slices_ids[idx]
        img_path = self.imgs_dir + slice_id
        msk_path = self.msks_dir + slice_id

        img = Image.open(img_path)
        msk = Image.open(msk_path)
        img = np.array(img, dtype=np.float32)
        msk = np.array(msk)

        if self.transforms is not None:
            t = self.transforms(image=img, mask=msk)
            img = t["image"]
            msk = t["mask"]
            
        img = torch.from_numpy(img)[None, :]
        msk = torch.as_tensor(msk, dtype=torch.float32)
        img /= 31000 
        msk /= 255 # {0, 1} values

        return img, msk

In [3]:
imgs_dir = f"{train_dir}kidney_1_dense/images/"
msks_dir = f"{train_dir}kidney_1_dense/labels/"
slices_ids = sorted(os.listdir(imgs_dir))

patch_size = 224
transforms = A.Compose(
    [
        A.RandomCrop(patch_size, patch_size)
    ]
)

train_ds = KidneyDataset(
    imgs_dir=imgs_dir,
    msks_dir=msks_dir,
    slices_ids=slices_ids,
    transforms=transforms,
)

eval_ds = KidneyDataset(
    imgs_dir=imgs_dir,
    msks_dir=msks_dir,
    slices_ids=slices_ids,
)

print("Train Dataset length:", len(train_ds))
print("Eval Dataset length:", len(eval_ds))


Train Dataset length: 2279
Eval Dataset length: 2279


### Dataloaders

In [4]:
train_dl = DataLoader(
    train_ds,
    batch_size=32,
    num_workers=8 if hostname == "gamma" else 2,
    shuffle=True,
    persistent_workers=True
)

eval_dl = DataLoader(
    eval_ds,
    batch_size=16,
    num_workers=8 if hostname == "gamma" else 2,
    shuffle=False,
    persistent_workers=False
)

print("Train DataLoader length:", len(train_dl))
print("Eval DataLoader length:", len(eval_dl))

Train DataLoader length: 72
Eval DataLoader length: 143


## Define model

In [5]:
net = smp.Unet(
    encoder_name="timm-mobilenetv3_small_075",
    encoder_weights=None,
    in_channels=1,
    classes=1,
)
net.to(device)
print(f"Number of params: {sum([p.nelement() for p in net.parameters()]):,}")

Number of params: 2,881,625


## Train and evalution pipeline

### Loss function

In [6]:
loss_fn = torch.nn.BCEWithLogitsLoss()

In [7]:
print("Random loss:", -torch.tensor(1/2).log())

Random loss: tensor(0.6931)


### Optimizer and scheduler

In [8]:
lr = 3e-4
optimizer = torch.optim.Adam(lr=lr, params=net.parameters())

### Train method

In [9]:
def train():
    train_loss = 0.0
    net.train()
    for x, y in tqdm(train_dl):
        x, y = x.to(device), y.to(device)
        logits = net(x).squeeze()
        loss = loss_fn(logits, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    train_loss /= len(train_dl)

    return train_loss

### Extract and Merge Patches

In [10]:
h, w = eval_ds.h, eval_ds.w
overlap = 50
patcher = Patcher(h, w, patch_size=patch_size, overlap=overlap)

### Eval method

In [11]:
@torch.no_grad()
def eval(save_preds=False):
    eval_loss = 0.0
    idx = 0
    net.eval()
    metric = SurfaceDiceMetric(n_batches=len(eval_dl), device=device)
    for x, y in tqdm(eval_dl):
        bs = len(x)
        x, y = x.to(device), y.to(device)
        x = patcher.extract_patches(x)  # (bs, n_patches, h, w)

        logits = net(x.reshape(-1, 1, patch_size, patch_size))  # (bs * n_patches, 1, patch_size, patch_size)
        logits = logits.view(bs, -1, patch_size, patch_size)  # (bs, n_patches, patch_size, patch_size)
        logits = patcher.merge_patches(logits).squeeze()  # (bs, h, w)

        loss = loss_fn(logits, y)

        # save probabilities maps
        if save_preds:
            for i in range(bs):
                Image.fromarray((logits.cpu()[i].sigmoid() * (2**16 - 1)).numpy().astype(np.uint16)).save(f"preds/{idx:04}.tif")
                idx += 1

        pred = torch.where(logits.sigmoid() >= 0.5, 1, 0)

        metric.process_batch(pred, y)
        eval_loss += loss.item()

    eval_loss /= len(eval_dl)
    surface_dice = metric.compute()

    return eval_loss, surface_dice

In [12]:
%%time

epochs = 40
losses = []
dices = []

for epoch in range(1, epochs + 1):
    train_loss = train()
    print(f"EPOCH {epoch}, TLOSS {train_loss:.4f}")
    losses.append(train_loss)

eval_loss, surface_dice = eval(save_preds=False)
print("Evaluation")
print(f"ELOSS {eval_loss:.4f}, SURFACE_DICE {surface_dice:.4f}")

100%|██████████| 72/72 [00:08<00:00,  8.07it/s]


EPOCH 1, TLOSS 0.3802


100%|██████████| 72/72 [00:07<00:00, 10.17it/s]


EPOCH 2, TLOSS 0.1480


100%|██████████| 72/72 [00:06<00:00, 10.48it/s]


EPOCH 3, TLOSS 0.0872


100%|██████████| 72/72 [00:07<00:00, 10.12it/s]


EPOCH 4, TLOSS 0.0596


100%|██████████| 72/72 [00:06<00:00, 10.31it/s]


EPOCH 5, TLOSS 0.0480


100%|██████████| 72/72 [00:06<00:00, 10.37it/s]


EPOCH 6, TLOSS 0.0363


100%|██████████| 72/72 [00:06<00:00, 10.44it/s]


EPOCH 7, TLOSS 0.0313


100%|██████████| 72/72 [00:06<00:00, 10.29it/s]


EPOCH 8, TLOSS 0.0280


100%|██████████| 72/72 [00:07<00:00, 10.19it/s]


EPOCH 9, TLOSS 0.0245


100%|██████████| 72/72 [00:06<00:00, 10.39it/s]


EPOCH 10, TLOSS 0.0230


100%|██████████| 72/72 [00:06<00:00, 10.42it/s]


EPOCH 11, TLOSS 0.0200


100%|██████████| 72/72 [00:07<00:00, 10.17it/s]


EPOCH 12, TLOSS 0.0177


100%|██████████| 72/72 [00:06<00:00, 10.47it/s]


EPOCH 13, TLOSS 0.0173


100%|██████████| 72/72 [00:07<00:00, 10.25it/s]


EPOCH 14, TLOSS 0.0155


100%|██████████| 72/72 [00:07<00:00, 10.21it/s]


EPOCH 15, TLOSS 0.0152


100%|██████████| 72/72 [00:07<00:00, 10.21it/s]


EPOCH 16, TLOSS 0.0143


100%|██████████| 72/72 [00:07<00:00,  9.99it/s]


EPOCH 17, TLOSS 0.0131


100%|██████████| 72/72 [00:07<00:00, 10.04it/s]


EPOCH 18, TLOSS 0.0116


100%|██████████| 72/72 [00:07<00:00, 10.15it/s]


EPOCH 19, TLOSS 0.0107


100%|██████████| 72/72 [00:07<00:00,  9.95it/s]


EPOCH 20, TLOSS 0.0110


100%|██████████| 72/72 [00:07<00:00,  9.90it/s]


EPOCH 21, TLOSS 0.0102


100%|██████████| 72/72 [00:07<00:00, 10.24it/s]


EPOCH 22, TLOSS 0.0101


100%|██████████| 72/72 [00:07<00:00, 10.11it/s]


EPOCH 23, TLOSS 0.0096


100%|██████████| 72/72 [00:07<00:00, 10.19it/s]


EPOCH 24, TLOSS 0.0084


100%|██████████| 72/72 [00:07<00:00, 10.17it/s]


EPOCH 25, TLOSS 0.0095


100%|██████████| 72/72 [00:07<00:00,  9.89it/s]


EPOCH 26, TLOSS 0.0086


100%|██████████| 72/72 [00:07<00:00, 10.11it/s]


EPOCH 27, TLOSS 0.0091


100%|██████████| 72/72 [00:06<00:00, 10.30it/s]


EPOCH 28, TLOSS 0.0083


100%|██████████| 72/72 [00:07<00:00, 10.18it/s]


EPOCH 29, TLOSS 0.0082


100%|██████████| 72/72 [00:07<00:00, 10.17it/s]


EPOCH 30, TLOSS 0.0077


100%|██████████| 72/72 [00:06<00:00, 10.40it/s]


EPOCH 31, TLOSS 0.0076


100%|██████████| 72/72 [00:07<00:00, 10.10it/s]


EPOCH 32, TLOSS 0.0071


100%|██████████| 72/72 [00:07<00:00, 10.15it/s]


EPOCH 33, TLOSS 0.0071


100%|██████████| 72/72 [00:06<00:00, 10.53it/s]


EPOCH 34, TLOSS 0.0073


100%|██████████| 72/72 [00:06<00:00, 10.36it/s]


EPOCH 35, TLOSS 0.0070


100%|██████████| 72/72 [00:06<00:00, 10.66it/s]


EPOCH 36, TLOSS 0.0070


100%|██████████| 72/72 [00:06<00:00, 10.69it/s]


EPOCH 37, TLOSS 0.0063


100%|██████████| 72/72 [00:06<00:00, 10.53it/s]


EPOCH 38, TLOSS 0.0067


100%|██████████| 72/72 [00:07<00:00, 10.20it/s]


EPOCH 39, TLOSS 0.0068


100%|██████████| 72/72 [00:06<00:00, 10.43it/s]


EPOCH 40, TLOSS 0.0066


100%|██████████| 143/143 [01:18<00:00,  1.81it/s]

Evaluation
ELOSS 0.0042, SURFACE_DICE 0.7274
CPU times: user 5min 10s, sys: 38 s, total: 5min 48s
Wall time: 6min 2s





In [13]:
torch.save(net.state_dict(), f"checkpoints/baseline_train_sdc_{surface_dice:.3f}.pth")