## Wprowadzenie


## Preprocessing


In [None]:
from glob import glob
import os
import torch


n_images = 3000
imgs = sorted(glob(os.path.join("data", "imgs", "*.png")))[:n_images]
masks = sorted(glob(os.path.join("data", "masks", "*.png")))[:n_images]

TRAIN_RATIO = 0.7
VAL_RATIO = 0.2

train_size = int(TRAIN_RATIO * n_images)
val_size = int(VAL_RATIO * n_images)
test_size = n_images - train_size - val_size
train_size += n_images - (train_size + val_size + test_size)

files = [{"img": img, "mask": mask} for img, mask in zip(imgs, masks)]
print(len(imgs), len(masks))
train_files, val_files, test_files = torch.utils.data.random_split(
    files,
    [train_size, val_size, test_size],
)

## Augmentacja danych


In [None]:
from monai.transforms import (
    LoadImaged,
    Compose,
    RandRotate90d,
    ScaleIntensityd,
    Resized,
)
from monai.data import PILReader

train_transform = Compose(
    [
        LoadImaged(
            keys=["img", "mask"],
            image_only=True,
            ensure_channel_first=True,
            reader=PILReader(
                converter=lambda image: image.convert("L"), reverse_indexing=False
            ),
            dtype=torch.float,
        ),
        ScaleIntensityd(keys=["img", "mask"], minv=0, maxv=1),
        RandRotate90d(keys=["img", "mask"], prob=0.5, spatial_axes=(0, 1)),
        Resized(keys=["img", "mask"], spatial_size=(256, 256)),
    ]
)

val_transform = Compose(
    [
        LoadImaged(
            keys=["img", "mask"],
            image_only=True,
            ensure_channel_first=True,
            reader=PILReader(
                converter=lambda image: image.convert("L"), reverse_indexing=False
            ),
            dtype=torch.float,
        ),
        ScaleIntensityd(keys=["img", "mask"], minv=0, maxv=1),
        Resized(keys=["img", "mask"], spatial_size=(256, 256)),
    ]
)

## Stworzenie datasetów


In [None]:

import monai

train_dataset = monai.data.Dataset(data=train_files, transform=train_transform)
val_dataset = monai.data.Dataset(data=val_files, transform=val_transform)
test_dataset = monai.data.Dataset(data=test_files, transform=val_transform)

In [None]:
import gc

gc.collect()
with torch.no_grad():
    torch.cuda.empty_cache()

## Parametery


In [None]:
from monai.networks.nets import UNet
from monai.losses import DiceLoss

# from unet import UNet
from resunet import ResUNet

torch.backends.cudnn.benchmark = True

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
lr = 0.001
nepochs = 400
bs = 2

model = UNet(
    spatial_dims=2,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device)


optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = DiceLoss(sigmoid=True)

## Dataloadery


In [None]:
import monai

train_dataloader = monai.data.DataLoader(
    train_dataset, batch_size=bs, shuffle=True, pin_memory=True, num_workers=8
)
val_dataloader = monai.data.DataLoader(
    val_dataset, batch_size=bs, shuffle=False, pin_memory=True, num_workers=8
)
test_dataloader = monai.data.DataLoader(
    test_dataset, batch_size=bs, shuffle=False, pin_memory=True, num_workers=8
)

## Trening


In [None]:
import torch
from tqdm import trange

import matplotlib.pyplot as plt
from pathlib import Path



train_losses = []
val_losses = []
val_interval = 2
save_interval = 10
save_dir = "model"
Path(save_dir).mkdir(exist_ok=True)

for epoch in (pbar := trange(nepochs)):
    model.train()
    train_loss = 0.0
    for i, batch_data in enumerate(train_dataloader):
        img, mask = batch_data["img"].to(device), batch_data["mask"].to(device)
        optimizer.zero_grad(set_to_none=True)
        output = model(img)
        loss = criterion(output, mask)
        loss.backward()
        optimizer.step()
        curr_loss = loss.item()
        train_loss += curr_loss
        pbar.set_description(
            f"Epoch {epoch+1}/{nepochs}, iteration {i+1}/{len(train_dataloader)}, loss: {curr_loss}"
        )
    train_losses.append(train_loss / len(train_dataloader))

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            val_loss = 0.0
            for i, batch_data in enumerate(val_dataloader):
                img, mask = batch_data["img"].to(device), batch_data["mask"].to(device)
                output = model(img)
                loss = criterion(output, mask)
                curr_loss = loss.item()
                val_loss += curr_loss
                pbar.set_description(
                    
                    f"Epoch {epoch+1}/{nepochs}, iteration {i+1}/{len(val_dataloader)}, loss: {curr_loss}"
                )
            val_loss_mean = val_loss / len(val_dataloader)
            val_losses.append(val_loss_mean)

    if (epoch + 1) % save_interval == 0:
        torch.save(model.state_dict(), os.path.join(save_dir, f"model_{epoch+1}.pth"))

## Wykresy strat


In [None]:
import matplotlib.pyplot as plt

plt.plot(train_losses, label="train")
plt.title("Train loss")
plt.ylabel("Loss")
plt.xlabel("Epoch")
plt.savefig("train_loss.png")
plt.show()
plt.plot(val_losses, label="val", color="orange")
plt.title("Validation loss")
plt.ylabel("Loss")
plt.xlabel("Epoch")
plt.savefig("val_loss.png")
plt.show()

In [None]:
filename = "./model-monai-400epoch-3000img-256x256v2.pth"
torch.save(model.state_dict(), filename)

## Testowanie


In [None]:
import torch
from monai.networks.nets import UNet

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# filename = "./models/best-model/model-monai-400epoch-3000img-256x256.pth"
filename = "./model-monai-400epoch-3000img-256x256v2.pth"

model = UNet(
    spatial_dims=2,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device)

model.load_state_dict(torch.load(filename))

In [None]:
from monai.metrics import compute_iou, compute_dice
import numpy as np

from util import recall, precision

dices = []
ious = []
recalls = []
precisions = []

model.eval()
with torch.no_grad():
    for i, batch_data in enumerate(test_dataloader):
        img, mask = batch_data["img"].to(device), batch_data["mask"].to(device)
        output = torch.sigmoid(model(img))
        torch.where(
            output > 0.5,
            torch.tensor([1.0], device=device),
            torch.tensor([0.0], device=device),
            out=output,
        )
        torch.where(
            mask > 0.0,
            torch.tensor([1.0], device=device),
            torch.tensor([0.0], device=device),
            out=mask,
        )
        dice_coeff = compute_dice(output, mask, ignore_empty=False).mean(dim=0)
        iou = compute_iou(output, mask, ignore_empty=False).mean(dim=0)
        r = recall(mask, output).mean()
        p = precision(mask, output).mean()
        dices.append(dice_coeff.item())
        ious.append(iou.item())
        recalls.append(r.item())
        precisions.append(p.item())
print(f"Average iou: {(sum(ious) / len(ious)):.2f}")
print(f"Average dice coefficient: {(sum(dices) / len(dices)):.2f}")
print(f"Average recall: {(sum(recalls) / len(recalls)):.2f}")
print(f"Average precision: {(sum(precisions) / len(precisions)):.2f}")
plt.scatter(np.arange(len(dices)), dices, s=20, c="red")
plt.xlabel("Test batch")
plt.ylabel("Mean dice coefficient")
plt.title("Mean dice coefficient for each test batch")
plt.savefig("dice_metric.png")