In [None]:
from monai.transforms import LoadImaged, Compose, RandRotate90d, ScaleIntensityd
from monai.data import PILReader
import numpy as np
from glob import glob
import os
import matplotlib.pyplot as plt
import torch
import monai

# more images is definitely better, but the training time is larger...
# TODO: add saving model states for long training times
imgs = sorted(glob(os.path.join("data", "imgs", "*.png")))[:1500]
masks = sorted(glob(os.path.join("data", "masks", "*.png")))[:1500]

n = len(imgs)
TRAIN_RATIO = 0.7
VAL_RATIO = 0.2

print(f"Number of images: {n}")

train_size = int(TRAIN_RATIO * n)
val_size = int(VAL_RATIO * n)
test_size = n - train_size - val_size
train_size += n - (train_size + val_size + test_size)

files = [{"img": img, "mask": mask} for img, mask in zip(imgs, masks)]

train_files, val_files, test_files = torch.utils.data.random_split(
    files,
    [train_size, val_size, test_size],
)

train_transform = Compose(
    [
        LoadImaged(
            keys=["img", "mask"],
            image_only=True,
            ensure_channel_first=True,
            reader=PILReader(
                converter=lambda image: image.convert("L"), reverse_indexing=False
            ),
            dtype=torch.float,
        ),
        ScaleIntensityd(keys=["mask"], minv=0, maxv=1),
        RandRotate90d(keys=["img", "mask"], prob=0.5, spatial_axes=(0, 1)),
    ]
)

val_transform = Compose(
    [
        LoadImaged(
            keys=["img", "mask"],
            image_only=True,
            ensure_channel_first=True,
            reader=PILReader(
                converter=lambda image: image.convert("L"), reverse_indexing=False
            ),
            dtype=torch.float,
        ),
        ScaleIntensityd(keys=["mask"], minv=0, maxv=1),
    ]
)

train_dataset = monai.data.Dataset(data=train_files, transform=train_transform)
val_dataset = monai.data.Dataset(data=val_files, transform=val_transform)
test_dataset = monai.data.Dataset(data=test_files, transform=val_transform)

In [None]:
import torch
import gc
from tqdm import trange
from monai.networks.nets import UNet
from monai.losses import DiceLoss
import monai

gc.collect()
with torch.no_grad():
    torch.cuda.empty_cache()

plt.set_cmap("gray")

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
lr = 0.001
# TODO: 2000 epochs
nepochs = 10
bs = 12

# https://docs.monai.io/en/stable/networks.html#unet
# It is an enhanced version of unet with residual units implemented
model = UNet(
    spatial_dims=2,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = DiceLoss(sigmoid=True)

train_dataloader = monai.data.DataLoader(train_dataset, batch_size=bs, shuffle=True)
val_dataloader = monai.data.DataLoader(val_dataset, batch_size=bs, shuffle=False)
test_dataloader = monai.data.DataLoader(test_dataset, batch_size=bs, shuffle=False)

train_losses = []
val_losses = []
for epoch in (pbar := trange(nepochs)):
    model.train()
    train_loss = 0.0
    for i, batch_data in enumerate(train_dataloader):
        img, mask = batch_data["img"].to(device), batch_data["mask"].to(device)
        optimizer.zero_grad()
        output = model(img)
        loss = criterion(output, mask)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        pbar.set_description(
            f"Epoch {epoch+1}/{nepochs}, iteration {i+1}/{len(train_dataloader)}, loss: {loss.item()}"
        )
    train_losses.append(train_loss / len(train_dataloader))

    model.eval()
    with torch.no_grad():
        val_loss = 0.0
        for i, batch_data in enumerate(val_dataloader):
            img, mask = batch_data["img"].to(device), batch_data["mask"].to(device)
            output = model(img)
            loss = criterion(output, mask)
            val_loss += loss.item()
            pbar.set_description(
                f"Epoch {epoch+1}/{nepochs}, iteration {i+1}/{len(val_dataloader)}, loss: {loss.item()}"
            )
        val_losses.append(val_loss / len(val_dataloader))

In [None]:
import matplotlib.pyplot as plt

plt.plot(train_losses, label="train")
plt.plot(val_losses, label="val")
plt.legend()
plt.show()

In [None]:
torch.save(model.state_dict(), "./model-monai-10epoch-11603.pth")

In [None]:
import torch
from monai.networks.nets import UNet

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = UNet(
    spatial_dims=2,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device)

model.load_state_dict(torch.load("./model-monai-200epoch-1500img.pth"))

In [None]:
import torch
import matplotlib.pyplot as plt
from monai.metrics import DiceMetric, compute_iou


plt.set_cmap("gray")

model.eval()

dices = []
ious = []
rocs = []

with torch.no_grad():
    for i, batch_data in enumerate(test_dataloader):
        img, mask = batch_data["img"].to(device), batch_data["mask"].to(device)
        output = model(img)
        output[output >= 0.5] = 1
        output[output < 0.5] = 0
        dice_coeff = DiceMetric(ignore_empty=False)(output, mask).mean(dim=0)
        iou = compute_iou(output, mask, ignore_empty=False).mean(dim=0)
        dices.append(dice_coeff.item())
        ious.append(iou.item())


print(np.argmax(dices), np.argmin(dices))
print(f"Average iou: {(sum(ious) / len(ious)):.2f}")
print(f"Average dice coefficient: {(sum(dices) / len(dices)):.2f}")
plt.scatter(np.arange(len(dices)), dices, s=3, c="red")
plt.xlabel("Test batch")
plt.ylabel("Mean dice coefficient")