<a href="https://colab.research.google.com/github/Meddebma/pyradiomics/blob/master/Spleen_dict.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
%pip install monai==0.3.0
#%pip install git+https://github.com/Project-MONAI/MONAI#egg=MONAI
%pip install 'monai[all]'
%pip install pytorch-ignite
%pip install nibabel==3.2.0

import logging
import os
import sys
import tempfile
from glob import glob

import nibabel as nib
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import monai
from monai.data import create_test_image_3d, list_data_collate
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.transforms import (
    Activations,
    AsChannelFirstd,
    AsDiscrete,
    Compose,
    LoadNiftid,
    Spacingd,
    Orientationd,
    CropForegroundd,
    RandCropByPosNegLabeld,
    ScaleIntensityRanged,
    RandRotate90d,
    ScaleIntensityd,
    ToTensord,
    AddChanneld,
    KeepLargestConnectedComponent, 
    LabelToContour
)
from monai.losses import DiceLoss
from monai.metrics import compute_meandice
from monai.networks.layers import Norm
from monai.utils import first, set_determinism
from monai.visualize import plot_2d_or_3d_image

Collecting monai==0.3.0
[?25l  Downloading https://files.pythonhosted.org/packages/87/47/29473afdd2a90e54b8b11429f8d36b41a74d6d37371ebb50e4de74f42d10/monai-0.3.0-202010042353-py3-none-any.whl (298kB)
[K     |█                               | 10kB 22.6MB/s eta 0:00:01[K     |██▏                             | 20kB 10.0MB/s eta 0:00:01[K     |███▎                            | 30kB 9.1MB/s eta 0:00:01[K     |████▍                           | 40kB 8.7MB/s eta 0:00:01[K     |█████▌                          | 51kB 8.6MB/s eta 0:00:01[K     |██████▋                         | 61kB 9.2MB/s eta 0:00:01[K     |███████▊                        | 71kB 9.7MB/s eta 0:00:01[K     |████████▉                       | 81kB 9.7MB/s eta 0:00:01[K     |█████████▉                      | 92kB 9.9MB/s eta 0:00:01[K     |███████████                     | 102kB 10.1MB/s eta 0:00:01[K     |████████████                    | 112kB 10.1MB/s eta 0:00:01[K     |█████████████▏                  | 1



# Import Data and Transform

In [3]:
root= "/content/drive/My Drive/Task09_Spleen/"
images = sorted(glob(os.path.join(root, "imagesTr","*.nii.gz")))
segs = sorted(glob(os.path.join(root, "labelsTr", "*.nii.gz")))

train_files = [{"img": img, "seg": seg} for img, seg in zip(images[:20], segs[:20])]
val_files = [{"img": img, "seg": seg} for img, seg in zip(images[-20:], segs[-20:])]
print (len(images))
print (len(segs))

train_transforms = Compose(
    [
        LoadNiftid(keys=["img", "seg"]),
        AddChanneld(keys=["img", "seg"]),
        Spacingd(keys=["img", "seg"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
        Orientationd(keys=["img", "seg"], axcodes="RAS"),
        ScaleIntensityRanged(
            keys=["img"], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True,
        ),
        CropForegroundd(keys=["img", "seg"], source_key="img"),
        RandCropByPosNegLabeld(
            keys=["img", "seg"],
            label_key="seg",
            spatial_size=(96, 96, 96),
            pos=1,
            neg=1,
            num_samples=4,
            image_key="img",
            image_threshold=0,
        ),
        # user can also add other random transforms
        # RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=1.0, spatial_size=(96, 96, 96),
        #             rotate_range=(0, 0, np.pi/15), scale_range=(0.1, 0.1, 0.1)),
        ToTensord(keys=["img", "seg"]),
    ]
)
val_transforms = Compose(
    [
        LoadNiftid(keys=["img", "seg"]),
        AddChanneld(keys=["img", "seg"]),
        Spacingd(keys=["img", "seg"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
        Orientationd(keys=["img", "seg"], axcodes="RAS"),
        ScaleIntensityRanged(
            keys=["img"], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True,
        ),
        CropForegroundd(keys=["img", "seg"], source_key="img"),
        ToTensord(keys=["img", "seg"]),
    ]
)

# define dataset, data loader
check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
# use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate)
check_data = monai.utils.misc.first(check_loader)
print(check_data["img"].shape, check_data["seg"].shape)

    # create a training data loader
train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
train_loader = DataLoader(
        train_ds,
        batch_size=2,
        shuffle=True,
        num_workers=4,
        collate_fn=list_data_collate,
        pin_memory=torch.cuda.is_available(),
    )
    # create a validation data loader
val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)
dice_metric = DiceMetric(include_background=True, reduction="mean")
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)])


41
41
torch.Size([8, 1, 96, 96, 96]) torch.Size([8, 1, 96, 96, 96])


# Create Model (UNET)

In [4]:
# create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
loss_function = monai.losses.DiceLoss(sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-3)

# Train the Model

In [5]:
val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = list()
metric_values = list()
writer = SummaryWriter()
for epoch in range(5):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{5}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["img"].to(device), batch_data["seg"].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                metric_sum = 0.0
                metric_count = 0
                val_images = None
                val_labels = None
                val_outputs = None
                for val_data in val_loader:
                    val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device)
                    roi_size = (96, 96, 96)
                    sw_batch_size = 4
                    val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
                    val_outputs = post_trans(val_outputs)
                    value, _ = dice_metric(y_pred=val_outputs, y=val_labels)
                    metric_count += len(value)
                    metric_sum += value.item() * len(value)
                metric = metric_sum / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), "best_metric_model_segmentation3d_dict.pth")
                    print("saved new best metric model")
                print(
                    "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(
                        epoch + 1, metric, best_metric, best_metric_epoch
                    )
                )
                writer.add_scalar("val_mean_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image")
                plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
                plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output")

print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
writer.close()

----------
epoch 1/5
1/10, train_loss: 0.9326
2/10, train_loss: 0.8809
3/10, train_loss: 0.9259
4/10, train_loss: 0.9467
5/10, train_loss: 0.8414
6/10, train_loss: 0.9327
7/10, train_loss: 0.9758
8/10, train_loss: 0.9417
9/10, train_loss: 0.9362
10/10, train_loss: 0.8649
epoch 1 average loss: 0.9179
----------
epoch 2/5
1/10, train_loss: 0.8246
2/10, train_loss: 0.8822
3/10, train_loss: 0.8926
4/10, train_loss: 0.9259
5/10, train_loss: 0.9587
6/10, train_loss: 0.9139
7/10, train_loss: 0.9315
8/10, train_loss: 0.8317
9/10, train_loss: 0.9348
10/10, train_loss: 0.9483
epoch 2 average loss: 0.9044


ValueError: ignored