In [4]:
!pip install opencv-python
!pip install ipywidgets

Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Collecting ipywidgets
  Downloading ipywidgets-8.1.5-py3-none-any.whl (139 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m139.8/139.8 KB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting widgetsnbextension~=4.0.12
  Downloading widgetsnbextension-4.0.13-py3-none-any.whl (2.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m17.6 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hCollecting jupyterlab-widgets~=3.0.12
  Downloading jupyterlab_widgets-3.0.13-py3-none-any.whl (214 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m214.4/214.4 KB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: widgetsnbextension, jupyterlab-widgets, ipywidgets
Successfully installed ipywidgets-8.1.5 jupyterlab-widgets-3.0.13 widgetsnbextension

In [1]:
import logging
import os
import sys
import tempfile
from glob import glob

import torch
from PIL import Image
from torch.utils.tensorboard import SummaryWriter

import monai
from monai.data import ArrayDataset, create_test_image_2d, decollate_batch, DataLoader
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.transforms import (
    Activations,
    AsDiscrete,
    Compose,
    LoadImage,
    RandRotate90,
    RandSpatialCrop,
    ScaleIntensity,
)
from monai.visualize import plot_2d_or_3d_image


In [2]:
def generate_data(tempdir):
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(40):
        im, seg = create_test_image_2d(128, 128, num_seg_classes=1)
        Image.fromarray((im * 255).astype("uint8")).save(os.path.join(tempdir, f"img{i:d}.png"))
        Image.fromarray((seg * 255).astype("uint8")).save(os.path.join(tempdir, f"seg{i:d}.png"))

In [3]:
def setup_data_loaders(images, segs, train_imtrans, train_segtrans, val_imtrans, val_segtrans):
    train_ds = ArrayDataset(images[:20], train_imtrans, segs[:20], train_segtrans)
    train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available())
    val_ds = ArrayDataset(images[-20:], val_imtrans, segs[-20:], val_segtrans)
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, pin_memory=torch.cuda.is_available())
    return train_loader, val_loader

In [4]:
def define_transforms():
    train_imtrans = Compose([
        LoadImage(image_only=True, ensure_channel_first=True),
        ScaleIntensity(),
        RandSpatialCrop((96, 96), random_size=False),
        RandRotate90(prob=0.5, spatial_axes=(0, 1)),
    ])
    train_segtrans = Compose([
        LoadImage(image_only=True, ensure_channel_first=True),
        ScaleIntensity(),
        RandSpatialCrop((96, 96), random_size=False),
        RandRotate90(prob=0.5, spatial_axes=(0, 1)),
    ])
    val_imtrans = Compose([LoadImage(image_only=True, ensure_channel_first=True), ScaleIntensity()])
    val_segtrans = Compose([LoadImage(image_only=True, ensure_channel_first=True), ScaleIntensity()])
    return train_imtrans, train_segtrans, val_imtrans, val_segtrans

In [5]:
def setup_model_and_optimizer():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = monai.networks.nets.UNet(
        spatial_dims=2,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss_function = monai.losses.DiceLoss(sigmoid=True)
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)
    return model, loss_function, optimizer, device

In [6]:
def train_model(train_loader, model, loss_function, optimizer, device):
    model.train()
    for batch_data in train_loader:
        inputs, labels = [x.to(device) for x in batch_data]
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        print(f"train_loss: {loss.item():.4f}")

In [9]:
from monai.metrics import DiceMetric
from monai.transforms import Activations, AsDiscrete

def validate_model(val_loader, model, device):
    model.eval()
    dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
    post_pred = Activations(sigmoid=True)
    post_label = AsDiscrete(threshold=0.5)
    metric_values = []

    with torch.no_grad():
        for val_data in val_loader:
            val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
            val_outputs = sliding_window_inference(val_images, (96, 96), 4, model)
            val_outputs = post_pred(val_outputs)
            val_labels = post_label(val_labels)
            # compute metric for current iteration
            dice_metric(y_pred=val_outputs, y=val_labels)
        # aggregate the final mean dice result
        metric = dice_metric.aggregate().item()
        dice_metric.reset()
        print(f"validation_metric: {metric:.4f}")

In [10]:
if __name__ == "__main__":
    with tempfile.TemporaryDirectory() as tempdir:
        generate_data(tempdir)
        train_imtrans, train_segtrans, val_imtrans, val_segtrans = define_transforms()
        images = sorted(glob(os.path.join(tempdir, "img*.png")))
        segs = sorted(glob(os.path.join(tempdir, "seg*.png")))
        train_loader, val_loader = setup_data_loaders(images, segs, train_imtrans, train_segtrans, val_imtrans, val_segtrans)
        model, loss_function, optimizer, device = setup_model_and_optimizer()
        train_model(train_loader, model, loss_function, optimizer, device)
        validate_model(val_loader, model, device)

generating synthetic data to /tmp/tmph5dbj9pe (this may take a while)
train_loss: 0.3928
train_loss: 0.3661
train_loss: 0.3300
train_loss: 0.3228
train_loss: 0.3289
validation_metric: 0.0000
