# Baseline

- [] I'm currently not evaluating with the full resolution images.
- [] Put the net in eval mode during the evaluation.

In [None]:
import torch
from PIL import Image
from torchvision.transforms import v2
import torchvision.transforms as v1
from torch import nn
import numpy as np
from tqdm import tqdm
import random
import os
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from surface_dice import SurfaceDiceMetric
import wandb
import albumentations as A

device = "cuda" if torch.cuda.is_available() else "cpu"
input_dir = "/kaggle/input/blood-vessel-segmentation/"
train_dir = input_dir + "train/"

# reproducibility
seed = 42
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ["PYTHONHASHSEED"] = str(seed)

## Load data

### Dataset

In [None]:
class KidneyDataset(torch.utils.data.Dataset):
    def __init__(self, imgs_dir, msks_dir, slices_ids, transforms=None):
        self.imgs_dir = imgs_dir
        self.msks_dir = msks_dir
        self.slices_ids = slices_ids
        self.transforms = transforms

    def __len__(self):
        return len(self.slices_ids)

    def __getitem__(self, idx):
        slice_id = self.slices_ids[idx]
        img_path = self.imgs_dir + slice_id
        msk_path = self.msks_dir + slice_id

        img = Image.open(img_path)
        msk = Image.open(msk_path)

        if self.transforms is not None:
            img = np.array(img, dtype=np.float32)
            msk = np.array(msk)

            t = self.transforms(image=img, mask=msk)
            img = t["image"]
            msk = t["mask"]
            
            img = torch.from_numpy(img)[None, :]
            msk = torch.as_tensor(msk, dtype=torch.float32)
            img = img / img.max()
            msk /= 255 

        return img, msk

In [None]:
imgs_dir = f"{train_dir}kidney_1_dense/images/"
msks_dir = f"{train_dir}kidney_1_dense/labels/"
slices_ids = sorted(os.listdir(imgs_dir))

transforms = A.Compose(
    [
        A.RandomCrop(224, 224)
    ]
)

ds = KidneyDataset(
    imgs_dir=imgs_dir,
    msks_dir=msks_dir,
    slices_ids=slices_ids,
    transforms=transforms,
)

print("Dataset length:", len(ds))

In [None]:
bs = 32
num_workers = os.cpu_count()
train_dl = DataLoader(ds, batch_size=bs, num_workers=num_workers, shuffle=False, persistent_workers=True)
print("DataLoader length:", len(train_dl))

## Define model

In [None]:
import segmentation_models_pytorch as smp

net = smp.Unet(
    encoder_name="timm-mobilenetv3_small_075",
    encoder_weights=None,
    in_channels=1,
    classes=1,
)
print(f"Number of params: {sum([p.nelement() for p in net.parameters()]):,}")

In [None]:
net.to(device);

## Train model

### Loss function

#### Test loss function

In [None]:
x = torch.randn((16, 512, 512))
y = torch.randint(2, (16, 512, 512), dtype=torch.float32)

In [None]:
print(torch.nn.functional.binary_cross_entropy_with_logits(x, y))
print(torch.nn.functional.binary_cross_entropy_with_logits(x, y, pos_weight=torch.tensor(3)))
print(torch.nn.functional.binary_cross_entropy_with_logits(x, y, pos_weight=torch.tensor([3])))
print(torch.nn.functional.binary_cross_entropy_with_logits(x, y, pos_weight=torch.tensor([3]).view(1, 1)))

In [None]:
print(-torch.where(y == 1, x.sigmoid().log(), (1 - x.sigmoid()).log()).mean())
print(-torch.where(y == 1, 3 * x.sigmoid().log(), (1 - x.sigmoid()).log()).mean())

#### Determine positive weight

In [None]:
total = 0
pos = 0
for _, y in tqdm(train_dl):
    total += y.nelement()
    pos += y.sum().long().item()

pos_weight = total / pos - 1
pos_weight

In [None]:
# loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))
loss_fn = nn.BCEWithLogitsLoss()

### Optimizer and scheduler

In [None]:
lr = 3e-4
optimizer = torch.optim.Adam(lr=lr, params=net.parameters())

### Train method

In [None]:
def train():
    train_loss = 0.0
    net.train()
    for x, y in train_dl:
        x, y = x.to(device), y.to(device)
        logits = net(x).squeeze()
        loss = loss_fn(logits, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        preds = torch.where(logits.detach().sigmoid() >= 0.5, 1, 0)
        train_loss += loss.item()

    train_loss /= len(train_dl)

    return train_loss

### Eval method

In [None]:
@torch.no_grad()
def eval(dl):
    eval_loss = 0.0
    net.train()  # TODO: Put in eval mode
    metric = SurfaceDiceMetric(n_batches=len(dl), device=device)
    for x, y in dl:
        x, y = x.to(device), y.to(device)
        logits = net(x).squeeze()
        loss = loss_fn(logits, y)

        pred = torch.where(logits.sigmoid() >= 0.5, 1, 0)
        metric.process_batch(pred, y)
        eval_loss += loss.item()

    eval_loss /= len(dl)
    surface_dice = metric.compute()

    return eval_loss, surface_dice

In [None]:
%%time
epochs = 10
losses = []
dices = []

for epoch in range(epochs):
    train_loss = train()
    eval_loss, surface_dice = eval(train_dl)
    print(f"EPOCH {epoch}, DICE {surface_dice:.4f} TLOSS {train_loss:.4f}, ELOSS {eval_loss:.4f}")
    losses.append(train_loss)
    dices.append(surface_dice)

In [None]:
fig, axs = plt.subplots(2, 1, figsize=(15, 7))
axs[0].set_title("Training loss")
axs[1].set_title("Surface Dice")
axs[0].plot(losses)
axs[1].plot(dices, color="orange")
plt.show()

## Inference on train dataset

In [None]:
x, y = next(iter(train_dl))
x = x.cuda()
probs = net(x).squeeze().detach().sigmoid().cpu()
print(
    f"loss with pos_weight: {torch.nn.functional.binary_cross_entropy_with_logits(net(x).squeeze().detach().cpu(), y, pos_weight=torch.tensor(pos_weight)).item():.5f}"
)
print(
    f"loss without pos_weight: {torch.nn.functional.binary_cross_entropy_with_logits(net(x).squeeze().detach().cpu(), y).item():.5f}"
)

In [None]:
idx = torch.randint(bs, (1,)).item()
threshold = 0.5
print(idx)

fig, axs = plt.subplots(1, 4, figsize=(15, 7))
axs[0].set_title("Image")
axs[1].set_title("Groud Truth")
axs[2].set_title("Probabilities")
axs[3].set_title(f"Prediction with p>{threshold}")
axs[0].imshow(x.cpu()[idx].squeeze())
axs[1].imshow(y[idx])
axs[2].imshow(probs[idx])
axs[3].imshow(torch.where(probs >= threshold, 1, 0)[idx])
plt.show()