<a href="https://colab.research.google.com/github/akansh12/monai-tutorial/blob/main/MONAI/MONAI_tutorial(3D_seg).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!python -c "import monai" || pip install -q "monai-weekly[pillow, tqdm]"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline

Traceback (most recent call last):
  File "<string>", line 1, in <module>
ModuleNotFoundError: No module named 'monai'
[K     |████████████████████████████████| 816 kB 5.3 MB/s 
[?25h

In [None]:
import logging
import os
import sys
import tempfile
from glob import glob
import nibabel as nib
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import monai
from monai.data import create_test_image_3d, list_data_collate, decollate_batch
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric

from monai.transforms import (
    Activations,
    AsChannelFirstd,
    AsDiscrete,
    Compose,
    LoadImaged,
    RandCropByPosNegLabeld,
    RandRotate90d,
    ScaleIntensityd,
    EnsureTyped,
    EnsureType
)

from monai.visualize import plot_2d_or_3d_image


In [None]:
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

In [None]:
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
os.makedirs("./temp_file")
tempdir = "./temp_file"
print(f"generating synthetic data to {tempdir} (this may take a while)")

generating synthetic data to ./temp_file (this may take a while)


In [None]:
for i in range(40):
    im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)

    n = nib.Nifti1Image(im, np.eye(4))
    nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz"))

    n = nib.Nifti1Image(seg, np.eye(4))
    nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

In [None]:
eg = nib.load("/content/temp_file/img13.nii.gz")
eg_data = eg.get_fdata()
print(eg_data.shape)
print(np.unique(eg_data))

(128, 128, 128, 1)
[0.         0.56419837 0.58732367 0.75521755 0.77602738 0.85623723
 0.88002145 0.90849024 0.9440394  0.96790183 1.        ]


In [None]:
eg = nib.load("/content/temp_file/seg13.nii.gz")
eg_data = eg.get_fdata()
print(eg_data.shape)
print(np.unique(eg_data))

(128, 128, 128, 1)
[0. 1.]


In [None]:
images = sorted(glob(os.path.join(tempdir, "img*.nii.gz")))
segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
train_files = [{"img": img, "seg": seg} for img, seg in zip(images[:20], segs[:20])]
val_files = [{"img": img, "seg": seg} for img, seg in zip(images[-20:], segs[-20:])]

In [None]:
train_transforms = Compose(
    [
        LoadImaged(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        ScaleIntensityd(keys="img"),
        # RandCropByPosNegLabeld(
        #     keys=["img", "seg"], label_key="seg", spatial_size=[128, 128, 376], pos=1, neg=1, num_samples=4
        # ),
        RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 2]),
        EnsureTyped(keys=["img", "seg"]),
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        ScaleIntensityd(keys="img"),
        EnsureTyped(keys=["img", "seg"]),
    ]
)

In [None]:
check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
# use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
check_loader = DataLoader(check_ds, batch_size=2, num_workers=2, collate_fn=list_data_collate)
check_data = monai.utils.misc.first(check_loader)
print(check_data["img"].shape, check_data["seg"].shape)

torch.Size([2, 1, 128, 128, 128]) torch.Size([2, 1, 128, 128, 128])


In [None]:
train_ds = monai.data.Dataset(data = train_files, transform = train_transforms)
train_loader = DataLoader(train_ds, 
                          batch_size = 2,
                          shuffle = True,
                          num_workers = 2,
                          collate_fn = list_data_collate,
                          pin_memory=torch.cuda.is_available())

val_ds = monai.data.Dataset(data = val_files, transform = val_transforms)
val_loader = DataLoader(val_ds, 
                          batch_size = 1,
                          shuffle = False,
                          num_workers = 2,
                          collate_fn = list_data_collate,
                          )

In [None]:
dice_metric = DiceMetric(include_background=True, reduction = 'mean', get_not_nans=False)
post_trans = Compose([EnsureType(), Activations(sigmoid = True), AsDiscrete(threshold=0.5)])


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = monai.networks.nets.UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    channels=(32, 64, 128, 256, 512),
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device)

In [None]:
loss_function = monai.losses.DiceLoss(sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-3)

In [None]:
val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = list()
metric_values = list()
writer = SummaryWriter()

for epoch in tqdm(range(40)):
  model.train()
  epoch_loss = 0
  step = 0
  for batch_data in tqdm(train_loader):
    step += 1
    inputs, labels = batch_data['img'].to(device),batch_data['seg'].to(device)
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = loss_function(outputs,inputs)
    loss.backward()
    optimizer.step()
    epoch_loss += loss.item()
    epoch_len = len(train_ds) // train_loader.batch_size
    writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
  epoch_loss /= step
  epoch_loss_values.append(epoch_loss)
  print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

  if (epoch + 1) % val_interval == 0:
    model.eval()
    with torch.no_grad():
      val_images = None
      val_labels = None
      val_outputs = None

      for val_data in tqdm(val_loader):
        val_images, val_labels = val_data['img'].to(device),val_data['seg'].to(device)
        # roi_size = (128,128,128)
        # sw_batch_size = 4
        # val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
        # val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
        val_outputs = post_trans(model(val_images))


        dice_metric(y_pred = val_outputs, y =val_labels)

      metric = dice_metric.aggregate().item()
      dice_metric.reset()

      metric_values.append(metric)

      if metric > best_metric:
        best_metric = metric
        best_metric_epoch = epoch +1
        torch.save(model.state_dict(), "best_metric_model_segmentation3d_dict.pth")
        print("saved new best metric model")
      writer.add_scalar("val_mean_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
      plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image")
      plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
      plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output")
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
writer.close()


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

epoch 1 average loss: 0.7468


  0%|          | 0/10 [00:00<?, ?it/s]

epoch 2 average loss: 0.7012


  0%|          | 0/20 [00:00<?, ?it/s]

saved new best metric model


  0%|          | 0/10 [00:00<?, ?it/s]

epoch 3 average loss: 0.6802


  0%|          | 0/10 [00:00<?, ?it/s]

epoch 4 average loss: 0.6679


  0%|          | 0/20 [00:00<?, ?it/s]

saved new best metric model


  0%|          | 0/10 [00:00<?, ?it/s]

epoch 5 average loss: 0.6596
train completed, best_metric: 0.9258 at epoch: 4
