In [None]:
from monai.transforms import (
    LoadImaged,
    Compose,
    RandRotate90d,
    ScaleIntensityd,
    Resized,
)
from monai.data import PILReader
import numpy as np
from glob import glob
import os
import matplotlib.pyplot as plt
import torch
import monai

# more images is definitely better, but the training time is larger...
# TODO: add saving model states for long training times
imgs = sorted(glob(os.path.join("data", "imgs", "*.png")))[:3000]
masks = sorted(glob(os.path.join("data", "masks", "*.png")))[:3000]

n = len(imgs)
TRAIN_RATIO = 0.7
VAL_RATIO = 0.2

print(f"Number of images: {n}")

train_size = int(TRAIN_RATIO * n)
val_size = int(VAL_RATIO * n)
test_size = n - train_size - val_size
train_size += n - (train_size + val_size + test_size)

files = [{"img": img, "mask": mask} for img, mask in zip(imgs, masks)]

train_files, val_files, test_files = torch.utils.data.random_split(
    files,
    [train_size, val_size, test_size],
)

train_transform = Compose(
    [
        LoadImaged(
            keys=["img", "mask"],
            image_only=True,
            ensure_channel_first=True,
            reader=PILReader(
                converter=lambda image: image.convert("L"), reverse_indexing=False
            ),
            dtype=torch.float,
        ),
        ScaleIntensityd(keys=["img", "mask"], minv=0, maxv=1),
        RandRotate90d(keys=["img", "mask"], prob=0.5, spatial_axes=(0, 1)),
        Resized(keys=["img", "mask"], spatial_size=(256, 256)),
    ]
)

val_transform = Compose(
    [
        LoadImaged(
            keys=["img", "mask"],
            image_only=True,
            ensure_channel_first=True,
            reader=PILReader(
                converter=lambda image: image.convert("L"), reverse_indexing=False
            ),
            dtype=torch.float,
        ),
        ScaleIntensityd(keys=["img", "mask"], minv=0, maxv=1),
        Resized(keys=["img", "mask"], spatial_size=(256, 256)),
    ]
)

train_dataset = monai.data.Dataset(data=train_files, transform=train_transform)
val_dataset = monai.data.Dataset(data=val_files, transform=val_transform)
test_dataset = monai.data.Dataset(data=test_files, transform=val_transform)

In [None]:
import torch
import gc
from tqdm import trange

from monai.networks.nets import UNet
from monai.losses import DiceLoss
import matplotlib.pyplot as plt
import monai

# from resunet2 import UNet
from resunet import ResUNet

gc.collect()
with torch.no_grad():
    torch.cuda.empty_cache()

torch.backends.cudnn.benchmark = True

plt.set_cmap("gray")

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
lr = 0.010
# TODO: 2000 epochs
nepochs = 400
bs = 32

# https://docs.monai.io/en/stable/networks.html#unet
# It is an enhanced version of unet with residual units implemented
# TODO: write from scratch
model = UNet(
    spatial_dims=2,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# TODO: test BCEWithLogitsLoss as loss and dice as evalution metric
criterion = DiceLoss(sigmoid=True)

train_dataloader = monai.data.DataLoader(
    train_dataset, batch_size=bs, shuffle=True, pin_memory=True, num_workers=8
)
val_dataloader = monai.data.DataLoader(
    val_dataset, batch_size=bs, shuffle=False, pin_memory=True, num_workers=8
)
test_dataloader = monai.data.DataLoader(
    test_dataset, batch_size=bs, shuffle=False, pin_memory=True, num_workers=8
)

train_losses = []
val_losses = []
val_interval = 2
for epoch in (pbar := trange(nepochs)):
    model.train()
    train_loss = 0.0
    for i, batch_data in enumerate(train_dataloader):
        img, mask = batch_data["img"].to(device), batch_data["mask"].to(device)
        optimizer.zero_grad(set_to_none=True)
        output = model(img)
        loss = criterion(output, mask)
        loss.backward()
        optimizer.step()
        curr_loss = loss.item()
        train_loss += curr_loss
        pbar.set_description(
            f"Epoch {epoch+1}/{nepochs}, iteration {i+1}/{len(train_dataloader)}, loss: {curr_loss}"
        )
    train_losses.append(train_loss / len(train_dataloader))

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            val_loss = 0.0
            for i, batch_data in enumerate(val_dataloader):
                img, mask = batch_data["img"].to(device), batch_data["mask"].to(device)
                output = model(img)
                loss = criterion(output, mask)
                curr_loss = loss.item()
                val_loss += curr_loss
                pbar.set_description(
                    f"Epoch {epoch+1}/{nepochs}, iteration {i+1}/{len(val_dataloader)}, loss: {curr_loss}"
                )
            val_loss_mean = val_loss / len(val_dataloader)
            val_losses.append(val_loss_mean)

In [None]:
import matplotlib.pyplot as plt

plt.plot(train_losses, label="train")
plt.plot(val_losses, label="val")
plt.legend()
plt.show()

In [None]:
filename = "./model-monai-400epoch-3000img-256x256.pth"
torch.save(model.state_dict(), filename)

In [None]:
import torch
from monai.networks.nets import UNet

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

filename = "./model-monai-400epoch-3000img-256x256.pth"

model = UNet(
    spatial_dims=2,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device)

model.load_state_dict(torch.load(filename))

In [None]:
import torch
import matplotlib.pyplot as plt
from monai.metrics import compute_iou, compute_dice
from util import recall, precision


plt.set_cmap("gray")


dices = []
ious = []
recalls = []
precisions = []


model.eval()
with torch.no_grad():
    for i, batch_data in enumerate(test_dataloader):
        img, mask = batch_data["img"].to(device), batch_data["mask"].to(device)
        output = torch.sigmoid(model(img))
        # TODO: is it needed? difference is 0.01
        torch.where(
            output > 0.5,
            torch.tensor([1.0], device=device),
            torch.tensor([0.0], device=device),
            out=output,
        )
        torch.where(
            mask > 0.0,
            torch.tensor([1.0], device=device),
            torch.tensor([0.0], device=device),
            out=mask,
        )
        dice_coeff = compute_dice(output, mask, ignore_empty=False).mean(dim=0)
        iou = compute_iou(output, mask, ignore_empty=False).mean(dim=0)
        r = recall(mask, output).mean()
        p = precision(mask, output).mean()
        dices.append(dice_coeff.item())
        ious.append(iou.item())
        recalls.append(r.item())
        precisions.append(p.item())

print(f"Average iou: {(sum(ious) / len(ious)):.2f}")
print(f"Average dice coefficient: {(sum(dices) / len(dices)):.2f}")
print(f"Average recall: {(sum(recalls) / len(recalls)):.2f}")
print(f"Average precision: {(sum(precisions) / len(precisions)):.2f}")
plt.scatter(np.arange(len(dices)), dices, s=3, c="red")
plt.xlabel("Test batch")
plt.ylabel("Mean dice coefficient")