In [None]:
!pip install monai[nibabel]

In [3]:
from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd,
    Invertd,
    SpatialPadd
)
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob

In [35]:
root_dir = tempfile.mkdtemp()
print(root_dir)
data_root_dir = os.path.join(root_dir, "data")
os.makedirs(data_root_dir)

/var/tmp/tmp6sljes8c


In [27]:
# Download data
!gsutil -m cp -r "gs://marketplace-2xim6sjc/Medical Decathlon Spleen" {data_root_dir}

Copying gs://marketplace-2xim6sjc/Medical Decathlon Spleen/148/artifactFiles/dataset.json...
Copying gs://marketplace-2xim6sjc/Medical Decathlon Spleen/148/artifactFiles/imagesTr/spleen_10.nii.gz...
Copying gs://marketplace-2xim6sjc/Medical Decathlon Spleen/148/artifactFiles/imagesTr/spleen_12.nii.gz...
Copying gs://marketplace-2xim6sjc/Medical Decathlon Spleen/148/artifactFiles/imagesTr/spleen_13.nii.gz...
Copying gs://marketplace-2xim6sjc/Medical Decathlon Spleen/148/artifactFiles/imagesTr/spleen_14.nii.gz...
Copying gs://marketplace-2xim6sjc/Medical Decathlon Spleen/148/artifactFiles/imagesTr/spleen_16.nii.gz...
Copying gs://marketplace-2xim6sjc/Medical Decathlon Spleen/148/artifactFiles/imagesTr/spleen_17.nii.gz...
Copying gs://marketplace-2xim6sjc/Medical Decathlon Spleen/148/artifactFiles/imagesTr/spleen_18.nii.gz...
Copying gs://marketplace-2xim6sjc/Medical Decathlon Spleen/148/artifactFiles/imagesTr/spleen_19.nii.gz...
Copying gs://marketplace-2xim6sjc/Medical Decathlon Spleen/

In [28]:
data_dir = os.path.join(data_root_dir, "Medical Decathlon Spleen/148/artifactFiles")

In [29]:
train_images = sorted(glob.glob(os.path.join(data_dir, "imagesTr", "*.nii.gz")))
train_labels = sorted(glob.glob(os.path.join(data_dir, "labelsTr", "*.nii.gz")))
data_dicts = [{"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels)]
train_files, val_files = data_dicts[:-9], data_dicts[-9:]

In [30]:
set_determinism(seed=0)


In [31]:
train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-57,
            a_max=164,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
        SpatialPadd(keys=["image", "label"], spatial_size=(96, 96, 96)),
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=(96, 96, 96),
            pos=1,
            neg=1,
            num_samples=4,
            image_key="image",
            image_threshold=0,
        ),
        # user can also add other random transforms
        # RandAffined(
        #     keys=['image', 'label'],
        #     mode=('bilinear', 'nearest'),
        #     prob=1.0, spatial_size=(96, 96, 96),
        #     rotate_range=(0, 0, np.pi/15),
        #     scale_range=(0.1, 0.1, 0.1)),
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-57,
            a_max=164,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        # SpatialPadd(keys=["image", "label"], spatial_size=(96, 96, 96)),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
    ]
)

In [32]:
train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=4)
# train_ds = Dataset(data=train_files, transform=train_transforms)

# use batch_size=2 to load images and use RandCropByPosNegLabeld
# to generate 2 x 4 images for network training
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)

val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=4)
# val_ds = Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4)

Loading dataset: 100%|██████████| 32/32 [00:34<00:00,  1.08s/it]
Loading dataset: 100%|██████████| 9/9 [00:06<00:00,  1.47it/s]


In [33]:
# standard PyTorch program style: create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda:0")
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
dice_metric = DiceMetric(include_background=False, reduction="mean")

In [34]:
max_epochs = 2
val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)])
post_label = Compose([AsDiscrete(to_onehot=2)])

for epoch in range(max_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = (
            batch_data["image"].to(device),
            batch_data["label"].to(device),
        )
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        print(f"{step}/{len(train_ds) // train_loader.batch_size}, " f"train_loss: {loss.item():.4f}")
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
    
    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            for val_data in val_loader:
                val_inputs, val_labels = (
                    val_data["image"].to(device),
                    val_data["label"].to(device),
                )
                roi_size = (160, 160, 160)
                sw_batch_size = 4
                val_outputs = sliding_window_inference(val_inputs, roi_size, sw_batch_size, model)
                val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
                val_labels = [post_label(i) for i in decollate_batch(val_labels)]
                # compute metric for current iteration
                dice_metric(y_pred=val_outputs, y=val_labels)

            # aggregate the final mean dice result
            metric = dice_metric.aggregate().item()
            # reset the status for next validation round
            dice_metric.reset()

            metric_values.append(metric)
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), os.path.join(root_dir, "best_metric_model.pth"))
                print("saved new best metric model")
            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f"\nbest mean dice: {best_metric:.4f} "
                f"at epoch: {best_metric_epoch}"
            )

----------
epoch 1/2
1/16, train_loss: 0.6680
2/16, train_loss: 0.6792
3/16, train_loss: 0.6834
4/16, train_loss: 0.6849
5/16, train_loss: 0.6857
6/16, train_loss: 0.6920
7/16, train_loss: 0.6590
8/16, train_loss: 0.6759
9/16, train_loss: 0.6681
10/16, train_loss: 0.6563
11/16, train_loss: 0.6539
12/16, train_loss: 0.6792
13/16, train_loss: 0.6526
14/16, train_loss: 0.6629
15/16, train_loss: 0.6517
16/16, train_loss: 0.6292
epoch 1 average loss: 0.6676
----------
epoch 2/2
1/16, train_loss: 0.6618
2/16, train_loss: 0.6614
3/16, train_loss: 0.6438
4/16, train_loss: 0.6484
5/16, train_loss: 0.6577
6/16, train_loss: 0.6419
7/16, train_loss: 0.6516
8/16, train_loss: 0.6140
9/16, train_loss: 0.6245
10/16, train_loss: 0.6492
11/16, train_loss: 0.6294
12/16, train_loss: 0.6113
13/16, train_loss: 0.6226
14/16, train_loss: 0.6431
15/16, train_loss: 0.6437
16/16, train_loss: 0.6450
epoch 2 average loss: 0.6406
saved new best metric model
current epoch: 2 current mean dice: 0.0320
best mean dice:

In [36]:
model.load_state_dict(torch.load(os.path.join(root_dir, "best_metric_model.pth")))

  model.load_state_dict(torch.load(os.path.join(root_dir, "best_metric_model.pth")))


FileNotFoundError: [Errno 2] No such file or directory: '/var/tmp/tmp6sljes8c/best_metric_model.pth'