In [1]:
import torch
import monai
from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd, 
    EnsureTyped,
    EnsureType,
    Invertd,
    RandGridDistortiond,
    Rand3DElasticd,
    RandRotate90d,
    RandFlipd,
    RandRotated,
    RandZoomd,
    CropForegroundd,
    RandGaussianNoised,
    ShiftIntensityd,
    RandShiftIntensityd,
    RandKSpaceSpikeNoised,
    KeepLargestConnectedComponentd,
    OneOf,
    AddChanneld,
    ToTensord
)

from monai.data import DataLoader, Dataset
from monai.config import print_config
from monai.apps import download_and_extract
from monai.visualize import blend_images, matshow3d, plot_2d_or_3d_image
import tempfile
import shutil
import os
import glob
import matplotlib.pyplot as plt
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
import nibabel as nib
import numpy as np
from tqdm.notebook import tqdm

In [2]:
root_dir = "/scratch/scratch6/akansh12/Parse_data/train/train/"
train_images = sorted(glob.glob(os.path.join(root_dir, "*", 'image', "*.nii.gz")))
train_labels = sorted(glob.glob(os.path.join(root_dir, "*", 'label', "*.nii.gz")))

data_dicts = [{"images": images_name, "labels": label_name} for images_name, label_name in zip(train_images, train_labels)]
train_files, val_files = data_dicts[:-9], data_dicts[-9:]

In [3]:
# train_transforms = Compose(
#     [
#         LoadImaged(keys=["images", "labels"]),
#         AddChanneld(keys=["images", "labels"]),
#         Orientationd(keys=["images", "labels"], axcodes="LPS"),
#         ScaleIntensityRanged(
#             keys=["images"],
#             a_min=-1000,
#             a_max=1000,
#             b_min=0.0,
#             b_max=1.0,
#             clip=True,
#         ),
#         CropForegroundd(keys=["images", "labels"], source_key="images"),
#         RandCropByPosNegLabeld(
#             keys=["images", "labels"],
#             label_key="labels",
#             spatial_size=(128, 128, 128),
#             pos=1,
#             neg=1,
#             num_samples=4,
#             image_key="images",
#             image_threshold=0,
#         ),
#         RandFlipd(
#             keys=["images", "labels"],
#             spatial_axis=[0],
#             prob=0.20,
#         ),
#         RandFlipd(
#             keys=["images", "labels"],
#             spatial_axis=[1],
#             prob=0.20,
#         ),
#         RandFlipd(
#             keys=["images", "labels"],
#             spatial_axis=[2],
#             prob=0.20,
#         ),
#         RandRotate90d(
#             keys=["images", "labels"],
#             prob=0.10,
#             max_k=3,
#         ),
#         RandShiftIntensityd(
#             keys=["images"],
#             offsets=0.10,
#             prob=0.50,
#         ),
#         ToTensord(keys=["images", "labels"]),
#     ]
# )


# val_transforms = Compose(
#     [
#         LoadImaged(keys=["images", "labels"]),
#         EnsureChannelFirstd(keys=["images", "labels"]),
#         Orientationd(keys=["images", "labels"], axcodes="LPS"),
#         ScaleIntensityRanged(
#             keys=["images"], a_min=-1000, a_max=1000,
#             b_min=0.0, b_max=1.0, clip=True,
#         ),
#         CropForegroundd(keys=["images", "labels"], source_key="images"),
#         EnsureTyped(keys=["images", "labels"]),
#     ]
# )

In [3]:
train_transforms = Compose(
    [
     LoadImaged(keys=['images', 'labels']),
     EnsureChannelFirstd(keys = ["images", "labels"]),
     Orientationd(keys=['images', 'labels'], axcodes = 'LPS'),
#      Spacingd(keys=['images', 'labels'], pixdim = (1.5,1.5,2), mode = ("bilinear", 'nearest')),
     ScaleIntensityRanged(
            keys=["images"], a_min=-1000, a_max=1000,
            b_min=0.0, b_max=1.0, clip=True,
        ),
     CropForegroundd(keys=['images', 'labels'], source_key="images"),
        RandCropByPosNegLabeld(
            keys=['images', 'labels'],
            label_key="labels",
            spatial_size=(128, 128, 128),
            pos=1,
            neg=1,
            num_samples=4,
            image_key="images",
            image_threshold=0,
        ),
        EnsureTyped(keys=['images', 'labels']),
        
    ]
)

val_transforms = Compose(
    [
        LoadImaged(keys=["images", "labels"]),
        EnsureChannelFirstd(keys=["images", "labels"]),
        Orientationd(keys=["images", "labels"], axcodes="LPS"),
#         Spacingd(keys=["images", "labels"], pixdim=(
#             1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
        ScaleIntensityRanged(
            keys=["images"], a_min=-1000, a_max=1000,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        CropForegroundd(keys=["images", "labels"], source_key="images"),
        EnsureTyped(keys=["images", "labels"]),
    ]
)

In [None]:
train_ds = CacheDataset(
    data = train_files, transform = train_transforms,
    cache_rate = 1.0, num_workers = 4
)

train_loader = DataLoader(train_ds, batch_size = 2, shuffle = True, num_workers=4)
val_ds = CacheDataset(
    data = val_files, transform = val_transforms,
    cache_rate = 1.0, num_workers = 4
)
val_loader = DataLoader(val_ds, batch_size = 1, shuffle = False, num_workers=4)

Loading dataset:  81%|███████████████████████████████████████████████████████████████████████████████████████████▉                     | 74/91 [02:58<00:39,  2.32s/it]

In [3]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 
UNet_meatdata = dict(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH
)
model = UNet(**UNet_meatdata).to(device)

In [None]:
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
loss_type = "DiceLoss"
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
dice_metric = DiceMetric(include_background=False, reduction="mean")

In [None]:
Optimizer_metadata = {}
for ind, param_group in enumerate(optimizer.param_groups):
    optim_meta_keys = list(param_group.keys())
    Optimizer_metadata[f'param_group_{ind}'] = {key: value for (key, value) in param_group.items() if 'params' not in key}

In [4]:
model.load_state_dict(torch.load("/scratch/scratch6/akansh12/challenges/parse2022/temp/best_metric_model_Unet_1000_hu.pth"))

<All keys matched successfully>

In [20]:
max_epochs = 600
val_interval = 10
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=2)])
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)])

for epoch in tqdm(range(max_epochs)):
    model.train()
    epoch_loss = 0
    step = 0
    
    for batch_data in train_loader:
        step += 1
        inputs, labels = (
            batch_data['images'].to(device),
            batch_data['labels'].to(device)
        )
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss +=loss.item()
        print(f"{step}/{len(train_ds) // train_loader.batch_size}, "
                f"train_loss: {loss.item():.4f}")

    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)

    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:

        model.eval()
        with torch.no_grad():
            for index, val_data in enumerate(val_loader):

                val_inputs, val_labels = val_data['images'].to(device), val_data['labels'].to(device)
                roi_size = (256, 256, 256)
                sw_batch_size = 4
                val_outputs = sliding_window_inference(
                            val_inputs, roi_size, sw_batch_size, model)

                val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
                val_labels = [post_label(i) for i in decollate_batch(val_labels)]
                dice_metric(y_pred=val_outputs, y=val_labels)

            metric = dice_metric.aggregate().item()
            dice_metric.reset()

            metric_values.append(metric)
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), os.path.join(
                    root_dir, "/scratch/scratch6/akansh12/challenges/parse2022/temp/Unet_3d_no_spacing_128.pth"))

                best_model_log_message = f"saved new best metric model at the {epoch+1}th epoch"
                print(best_model_log_message)

                message1 = f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                message2 = f"\nbest mean dice: {best_metric:.4f} "
                message3 = f"at epoch: {best_metric_epoch}"
                print(message1, message2, message3)

np.save("/scratch/scratch6/akansh12/challenges/parse2022/temp/epoch_loss_Unet_3d_no_spacing_128.npy", epoch_loss_values)
np.save("/scratch/scratch6/akansh12/challenges/parse2022/temp/metric_values_Unet_3d_no_spacing_128.npy", metric_values)

  0%|          | 0/600 [00:00<?, ?it/s]

1/4, train_loss: 0.6544
2/4, train_loss: 0.6469
3/4, train_loss: 0.6669
4/4, train_loss: 0.6427
5/4, train_loss: 0.6520
epoch 1 average loss: 0.6526
1/4, train_loss: 0.6335
2/4, train_loss: 0.6316
3/4, train_loss: 0.6388
4/4, train_loss: 0.6478
5/4, train_loss: 0.6447
epoch 2 average loss: 0.6393
1/4, train_loss: 0.6421
2/4, train_loss: 0.6539
3/4, train_loss: 0.6479
4/4, train_loss: 0.6307
5/4, train_loss: 0.6319
epoch 3 average loss: 0.6413
1/4, train_loss: 0.6427
2/4, train_loss: 0.6473
3/4, train_loss: 0.6366
4/4, train_loss: 0.6395
5/4, train_loss: 0.6336
epoch 4 average loss: 0.6399
1/4, train_loss: 0.6235
2/4, train_loss: 0.6577
3/4, train_loss: 0.6241
4/4, train_loss: 0.6335
5/4, train_loss: 0.6106
epoch 5 average loss: 0.6299
1/4, train_loss: 0.6323
2/4, train_loss: 0.6183
3/4, train_loss: 0.6394
4/4, train_loss: 0.6129
5/4, train_loss: 0.6407
epoch 6 average loss: 0.6287
1/4, train_loss: 0.6293
2/4, train_loss: 0.6111


KeyboardInterrupt: 