In [1]:
!python -c "import monai" || pip install -q "monai-weekly[gdown, nibabel, tqdm, ignite]"
!python -c "import aim" || pip install -q aim
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline

Traceback (most recent call last):
  File "<string>", line 1, in <module>
ModuleNotFoundError: No module named 'monai'
[K     |████████████████████████████████| 821 kB 5.3 MB/s 
[K     |████████████████████████████████| 251 kB 50.5 MB/s 
[?25hTraceback (most recent call last):
  File "<string>", line 1, in <module>
ModuleNotFoundError: No module named 'aim'
[K     |████████████████████████████████| 3.5 MB 5.2 MB/s 
[K     |████████████████████████████████| 53 kB 2.2 MB/s 
[K     |████████████████████████████████| 4.0 MB 42.4 MB/s 
[K     |████████████████████████████████| 4.3 MB 33.0 MB/s 
[K     |████████████████████████████████| 55 kB 3.7 MB/s 
[K     |████████████████████████████████| 5.1 MB 35.7 MB/s 
[K     |████████████████████████████████| 280 kB 40.1 MB/s 
[K     |████████████████████████████████| 51 kB 320 kB/s 
[K     |████████████████████████████████| 17.2 MB 1.4 MB/s 
[K     |████████████████████████████████| 4.0 MB 26.4 MB/s 
[K     |█████████████████████████

In [2]:
from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd, 
    EnsureTyped,
    EnsureType,
    Invertd
)

from monai.handlers.utils import from_engine
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import aim
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob

In [3]:
import nibabel as nib
import numpy as np
from tqdm.notebook import tqdm

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
# os.makedirs("./data")
root_dir = "/content/drive/MyDrive/Parse2022/train"

In [6]:
train_images = sorted(glob.glob(os.path.join(root_dir, "*", 'image', "*.nii.gz")))
train_labels = sorted(glob.glob(os.path.join(root_dir, "*", 'label', "*.nii.gz")))

In [7]:
data_dicts = [{"images": images_name, "labels": label_name} for images_name, label_name in zip(train_images, train_labels)]
train_files, val_files = data_dicts[:-9], data_dicts[-9:]

set_determinism(seed = 0)

In [8]:
train_transforms = Compose(
    [
     LoadImaged(keys=['images', 'labels']),
     EnsureChannelFirstd(keys = ["images", "labels"]),
     Orientationd(keys=['images', 'labels'], axcodes = 'LPS'),
     Spacingd(keys=['images', 'labels'], pixdim = (1.5,1.5,2), mode = ("bilinear", 'nearest')),
     ScaleIntensityRanged(
            keys=["images"], a_min=-700, a_max=300,
            b_min=0.0, b_max=1.0, clip=True,
        ),
     CropForegroundd(keys=['images', 'labels'], source_key="images"),
     RandCropByPosNegLabeld(
            keys=['images', 'labels'],
            label_key="labels",
            spatial_size=(96, 96, 96),
            pos=1,
            neg=1,
            num_samples=4,
            image_key="images",
            image_threshold=0,
        ),
     EnsureTyped(keys=['images', 'labels']),     
          
    ]
)

val_transforms = Compose(
    [
        LoadImaged(keys=["images", "labels"]),
        EnsureChannelFirstd(keys=["images", "labels"]),
        Orientationd(keys=["images", "labels"], axcodes="LPS"),
        Spacingd(keys=["images", "labels"], pixdim=(
            1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
        ScaleIntensityRanged(
            keys=["images"], a_min=-700, a_max=300,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        CropForegroundd(keys=["images", "labels"], source_key="images"),
        EnsureTyped(keys=["images", "labels"]),
    ]
)


In [9]:
check_ds = Dataset(data = val_files, transform = val_transforms)
check_loader = DataLoader(check_ds, batch_size = 1)
check_data = first(check_loader)
img, label = check_data['images'], check_data['labels']

In [None]:
img.shape

In [None]:
check_ds = Dataset(data = train_files, transform = train_transforms)
check_loader = DataLoader(check_ds, batch_size = 1)
check_data = first(check_loader)
img, label = check_data['images'], check_data['labels']

In [None]:
img.shape

In [None]:
np.unique(label)

In [None]:
plt.imshow(img[2,0,:,:,56])

In [None]:
plt.imshow(label[2,0,:,:,56])

In [9]:
train_ds = CacheDataset(
    data = train_files, transform = train_transforms,
    cache_rate = 1.0, num_workers = 2
)

train_loader = DataLoader(train_ds, batch_size = 2, shuffle = True, num_workers=2)
val_ds = CacheDataset(
    data = val_files, transform = val_transforms,
    cache_rate = 1.0, num_workers = 2
)
val_loader = DataLoader(val_ds, batch_size = 1, shuffle = False, num_workers=2)

Loading dataset: 100%|██████████| 91/91 [10:16<00:00,  6.78s/it]
Loading dataset: 100%|██████████| 9/9 [01:21<00:00,  9.05s/it]


In [11]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 
# model = UNet(spatial_dims=3,
#              in_channels=1, 
#              out_channels=2,
#              channels = (16,32,64,128,256),
#              strides = (2,2,2,2),
#              num_res_units = 2,
#              norm = Norm.BATCH
#              ).to(device)        

UNet_meatdata = dict(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH
)

model = UNet(**UNet_meatdata).to(device)


In [13]:
model.load_state_dict(torch.load("/content/drive/MyDrive/best_metric_model_1.pth", map_location=torch.device('cpu')))

<All keys matched successfully>

In [14]:
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
loss_type = "DiceLoss"
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
dice_metric = DiceMetric(include_background=False, reduction="mean")


In [15]:
Optimizer_metadata = {}
for ind, param_group in enumerate(optimizer.param_groups):
    optim_meta_keys = list(param_group.keys())
    Optimizer_metadata[f'param_group_{ind}'] = {key: value for (key, value) in param_group.items() if 'params' not in key}

In [None]:
max_epochs = 500
val_interval = 10
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=2)])
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)])

# initialize a new Aim Run
aim_run = aim.Run()
# log model metadata
aim_run['UNet_meatdata'] = UNet_meatdata
# log optimizer metadata
aim_run['Optimizer_metadata'] = Optimizer_metadata


slice_to_track = 80

for epoch in tqdm(range(max_epochs)):
  model.train()
  epoch_loss = 0
  step = 0
  for batch_data in train_loader:
    step += 1
    inputs, labels = (
        batch_data['images'].to(device),
        batch_data['labels'].to(device)
    )
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = loss_function(outputs, labels)
    loss.backward()
    optimizer.step()
    epoch_loss +=loss.item()
    print(f"{step}/{len(train_ds) // train_loader.batch_size}, "
            f"train_loss: {loss.item():.4f}")
    aim_run.track(loss.item(), name="batch_loss", context={'type':loss_type})

  epoch_loss /= step
  epoch_loss_values.append(epoch_loss)
  aim_run.track(epoch_loss, name="epoch_loss", context={'type':loss_type})

  print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

  if (epoch + 1) % val_interval == 0:

    model.eval()
    with torch.no_grad():
      for index, val_data in enumerate(val_loader):
        val_inputs, val_labels = val_data['images'].to(device), val_data['labels'].to(device)
        roi_size = (160, 160, 160)
        sw_batch_size = 4
        val_outputs = sliding_window_inference(
                    val_inputs, roi_size, sw_batch_size, model)
        
        output = torch.argmax(val_outputs, dim=1)[0, :, :, slice_to_track].float()

        # aim_run.track(aim.Image(val_inputs[0, 0, :, :, slice_to_track], \
        #                                 caption=f'Input Image: {index}'), \
        #                        name='validation', context={'type':'input'})
        # aim_run.track(aim.Image(val_labels[0, 0, :, :, slice_to_track], \
        #                         caption=f'Label Image: {index}'), \
        #                 name='validation', context={'type':'label'})
        # aim_run.track(aim.Image(output, caption=f'Predicted Label: {index}'), \
        #                 name = 'predictions', context={'type':'labels'})
                      
        val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
        val_labels = [post_label(i) for i in decollate_batch(val_labels)]
        dice_metric(y_pred=val_outputs, y=val_labels)

      metric = dice_metric.aggregate().item()
      aim_run.track(metric, name="val_metric", context={'type':loss_type})
      dice_metric.reset()

      metric_values.append(metric)
      if metric > best_metric:
        best_metric = metric
        best_metric_epoch = epoch + 1
        torch.save(model.state_dict(), os.path.join(
            root_dir, "/content/drive/MyDrive/best_metric_model.pth"))
        
        best_model_log_message = f"saved new best metric model at the {epoch+1}th epoch"
        aim_run.track(aim.Text(best_model_log_message), name='best_model_log_message', epoch=epoch+1)
        print(best_model_log_message)
              
        message1 = f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
        message2 = f"\nbest mean dice: {best_metric:.4f} "
        message3 = f"at epoch: {best_metric_epoch}"
  
        aim_run.track(aim.Text(message1 +"\n" + message2 + message3), name='epoch_summary', epoch=epoch+1)
        print(message1, message2, message3)










        





      

  0%|          | 0/500 [00:00<?, ?it/s]

1/45, train_loss: 0.0926
2/45, train_loss: 0.0943
3/45, train_loss: 0.1468
4/45, train_loss: 0.1265
5/45, train_loss: 0.0899
6/45, train_loss: 0.1259
7/45, train_loss: 0.1065
8/45, train_loss: 0.1093
9/45, train_loss: 0.0941
10/45, train_loss: 0.1030
11/45, train_loss: 0.0917
12/45, train_loss: 0.1238
13/45, train_loss: 0.0889
14/45, train_loss: 0.0817
15/45, train_loss: 0.1088
16/45, train_loss: 0.1326
17/45, train_loss: 0.1091
18/45, train_loss: 0.1049
19/45, train_loss: 0.1063
20/45, train_loss: 0.1420
21/45, train_loss: 0.1117
22/45, train_loss: 0.0855
23/45, train_loss: 0.1267
24/45, train_loss: 0.1627
25/45, train_loss: 0.1537
26/45, train_loss: 0.1012
27/45, train_loss: 0.0876
28/45, train_loss: 0.0971
29/45, train_loss: 0.1146
30/45, train_loss: 0.1388
31/45, train_loss: 0.0789
32/45, train_loss: 0.1670
33/45, train_loss: 0.1271
34/45, train_loss: 0.1179
35/45, train_loss: 0.1096
36/45, train_loss: 0.1385
37/45, train_loss: 0.1264
38/45, train_loss: 0.1180
39/45, train_loss: 0.

In [None]:
aim_run.close()

In [6]:


print(
    f"train completed, best_metric: {best_metric:.4f} "
    f"at epoch: {best_metric_epoch}")



NameError: ignored

In [None]:
%load_ext aim
%aim up

torch.Size([4, 2, 96, 96, 96])