In [1]:
import logging
import os
import sys
import tempfile
from glob import glob

import nibabel as nib
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import torch
import glob
from ignite.metrics import Accuracy

import monai
from monai.apps import get_logger
from monai.data import create_test_image_3d
from monai.networks.layers import Norm
from monai.engines import SupervisedEvaluator, SupervisedTrainer
from torch.utils.tensorboard import SummaryWriter
from monai.handlers import (
    CheckpointSaver,
    EarlyStopHandler,
    LrScheduleHandler,
    MeanDice,
    StatsHandler,
    TensorBoardImageHandler,
    TensorBoardStatsHandler,
    ValidationHandler,
    from_engine,
)
from monai.inferers import SimpleInferer, SlidingWindowInferer
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.networks.nets import UNet
from monai.transforms import (
    Activationsd,
    AsChannelFirstd,
    Orientationd,
    EnsureChannelFirstd,
    AsDiscreted,
    Spacingd,
    Compose,
    KeepLargestConnectedComponentd,
    LoadImaged,
    ScaleIntensityRanged,
    CropForegroundd,
    RandCropByPosNegLabeld,
    RandRotate90d,
    ScaleIntensityd,
    EnsureTyped,
)

train_dir = 'C:/Users/Hripsime/OneDrive - ABYS MEDICAL/projects/CTPelvic1K_data/train_dir/'

def main():
    #monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    get_logger("train_log")
    
    train_images = sorted(glob.glob(os.path.join(train_dir, "*data.nii.gz")))
    train_labels = sorted(glob.glob(os.path.join(train_dir, "*mask_4label.nii.gz")))
    data_dicts = [{"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels)]
    train_files, val_files = data_dicts[:-10], data_dicts[-10:]
    

    
    
    # define transforms for image and segmentation
    train_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys=["image", "label"]),
            Orientationd(keys=["image", "label"], axcodes="RAS"),
            Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")),
            ScaleIntensityRanged(keys=["image"], a_min=-120, a_max=360,b_min=0.0, b_max=1.0, clip=True,),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            RandCropByPosNegLabeld(
                keys=["image", "label"],
                label_key="label",
                spatial_size=(128, 128, 128),
                pos=1,
                neg=1,
                num_samples=4,
                image_key="image",
                image_threshold=0,
        ),
            EnsureTyped(keys=["image", "label"]),
        ]
    )
    
    val_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys=["image", "label"]),
            Orientationd(keys=["image", "label"], axcodes="RAS"),
            Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")),
            ScaleIntensityRanged(keys=["image"], a_min=-120, a_max=360,b_min=0.0, b_max=1.0, clip=True,),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            EnsureTyped(keys=["image", "label"]),
        ]
    )

    # create a training data loader
    train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=0)
    train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=0)

    # create a validation data loader
    val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=0)
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=0)
    
    
    # create UNet, DiceLoss and Adam optimizer
    device=torch.device("cuda:0")  
    net = UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=5,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
        norm=Norm.BATCH,
    ).to(device)
    loss = monai.losses.DiceCELoss(to_onehot_y=True, softmax=True)
    opt = torch.optim.Adam(net.parameters(), 2e-4)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=1, gamma=0.1)

   
    train_post_transforms = Compose(
        [
            EnsureTyped(keys="pred"),
            Activationsd(keys="pred", softmax=True),
            AsDiscreted(keys="pred", argmax=True, to_onehot=5),
            AsDiscreted(keys="label", to_onehot=5),
            #KeepLargestConnectedComponentd(keys="pred", is_onehot=5),
        ]
    )

    val_post_transforms = Compose(
        [
            EnsureTyped(keys="pred"),
            Activationsd(keys="pred", softmax=True),
            AsDiscreted(keys="pred", argmax=True, to_onehot=5),
            AsDiscreted(keys="label", to_onehot=5),
            #KeepLargestConnectedComponentd(keys="pred", is_onehot=5),
        ]
    )
    
    
    val_handlers = [
        # apply “EarlyStop” logic based on the validation metrics
        EarlyStopHandler(trainer=None, patience=5, score_function=lambda x: x.state.metrics["val_mean_dice"]),
        StatsHandler(name="train_log", output_transform=lambda x: None),
        TensorBoardStatsHandler(log_dir="./runs/", output_transform=lambda x: None),
        TensorBoardImageHandler(
            log_dir="./runs/",
            batch_transform=from_engine(["image", "label"]),
            output_transform=from_engine(["pred"]),
        ),
        CheckpointSaver(save_dir="./runs/", save_dict={"net": net}, save_key_metric=True),
    ]   
    
    
    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(160, 160, 160), sw_batch_size=4, overlap=0.5),
        postprocessing=val_post_transforms,
        key_val_metric={"val_mean_dice": MeanDice(include_background=False, output_transform=from_engine(["pred", "label"]))},
        additional_metrics={"val_acc": Accuracy(output_transform=from_engine(["pred", "label"]))},
        val_handlers=val_handlers,
        amp=True if monai.utils.get_torch_version_tuple() >= (1, 6) else False,
    )
        
        
        
    
    train_handlers = [
        # apply “EarlyStop” logic based on the loss value, use “-” negative value because smaller loss is better
        EarlyStopHandler(trainer=None, patience=20, score_function=lambda x: -x.state.output[0]["loss"], epoch_level=True),
        LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
        ValidationHandler(validator=evaluator, interval=1, epoch_level=True),
        StatsHandler(name="train_log", tag_name="train_loss", output_transform=from_engine(["loss"], first=True)),
        TensorBoardStatsHandler(log_dir="./runs/", tag_name="train_loss", output_transform=from_engine(["loss"], first=True)),
        CheckpointSaver(save_dir="./runs/", save_dict={"net": net, "opt": opt}, save_interval=1, epoch_level=True),
    ]
    
 

    trainer = SupervisedTrainer(
        device=device,
        max_epochs=50,
        train_data_loader=train_loader,
        network=net,
        optimizer=opt,
        loss_function=loss,
        inferer=SimpleInferer(),
        postprocessing=train_post_transforms,
        key_train_metric={"train_acc": Accuracy(output_transform=from_engine(["pred", "label"]))},
        train_handlers=train_handlers,
        amp=True if monai.utils.get_torch_version_tuple() >= (1, 6) else False,
    )
    
    # set initialized trainer for "early stop" handlers
    train_handlers[0].set_trainer(trainer=trainer)
    val_handlers[0].set_trainer(trainer=trainer)
    trainer.run()
    

    %load_ext tensorboard
    %tensorboard --logdir=$log_dir
    


if __name__ == "__main__":
    main()

Loading dataset: 100%|█████████████████████████████████████████████████████████████████| 31/31 [04:59<00:00,  9.67s/it]
Loading dataset: 100%|█████████████████████████████████████████████████████████████████| 10/10 [01:28<00:00,  8.88s/it]

INFO:ignite.engine.engine.SupervisedTrainer:Engine run resuming from iteration 0, epoch 0 until 50 epochs





2022-04-16 16:55:15,844 - INFO - Epoch: 1/50, Iter: 1/16 -- train_loss: 2.6128 
2022-04-16 16:55:16,757 - INFO - Epoch: 1/50, Iter: 2/16 -- train_loss: 2.5742 
2022-04-16 16:55:17,671 - INFO - Epoch: 1/50, Iter: 3/16 -- train_loss: 2.5298 
2022-04-16 16:55:18,523 - INFO - Epoch: 1/50, Iter: 4/16 -- train_loss: 2.4920 
2022-04-16 16:55:19,368 - INFO - Epoch: 1/50, Iter: 5/16 -- train_loss: 2.4586 
2022-04-16 16:55:20,236 - INFO - Epoch: 1/50, Iter: 6/16 -- train_loss: 2.4391 
2022-04-16 16:55:20,996 - INFO - Epoch: 1/50, Iter: 7/16 -- train_loss: 2.4111 
2022-04-16 16:55:21,815 - INFO - Epoch: 1/50, Iter: 8/16 -- train_loss: 2.3533 
2022-04-16 16:55:22,649 - INFO - Epoch: 1/50, Iter: 9/16 -- train_loss: 2.3280 
2022-04-16 16:55:23,488 - INFO - Epoch: 1/50, Iter: 10/16 -- train_loss: 2.3107 
2022-04-16 16:55:24,299 - INFO - Epoch: 1/50, Iter: 11/16 -- train_loss: 2.2997 
2022-04-16 16:55:25,048 - INFO - Epoch: 1/50, Iter: 12/16 -- train_loss: 2.2754 
2022-04-16 16:55:25,847 - INFO - Epoc

2022-04-16 16:56:54,682 - INFO - Epoch: 4/50, Iter: 15/16 -- train_loss: 2.1673 
2022-04-16 16:56:55,084 - INFO - Epoch: 4/50, Iter: 16/16 -- train_loss: 2.1702 
INFO:ignite.engine.engine.SupervisedTrainer:Got new best metric of train_acc: 0.840355674682125
INFO:ignite.engine.engine.SupervisedTrainer:Current learning rate: 2.0000000000000007e-08
INFO:ignite.engine.engine.SupervisedEvaluator:Engine run resuming from iteration 0, epoch 3 until 4 epochs
2022-04-16 16:57:05,223 - INFO - Epoch[4] Metrics -- val_acc: 0.9283 val_mean_dice: 0.0706 
2022-04-16 16:57:05,223 - INFO - Key metric: val_mean_dice best value: 0.07519263029098511 at epoch: 3
INFO:ignite.engine.engine.SupervisedEvaluator:Epoch[4] Complete. Time taken: 00:00:16
INFO:ignite.engine.engine.SupervisedEvaluator:Engine run complete. Time taken: 00:00:16
2022-04-16 16:57:10,974 - INFO - Epoch[4] Metrics -- train_acc: 0.8404 
2022-04-16 16:57:10,974 - INFO - Key metric: train_acc best value: 0.840355674682125 at epoch: 4
INFO:ig

2022-04-16 16:58:47,222 - INFO - Epoch: 8/50, Iter: 7/16 -- train_loss: 2.1608 
2022-04-16 16:58:48,129 - INFO - Epoch: 8/50, Iter: 8/16 -- train_loss: 2.1633 
2022-04-16 16:58:48,911 - INFO - Epoch: 8/50, Iter: 9/16 -- train_loss: 2.1748 
2022-04-16 16:58:49,738 - INFO - Epoch: 8/50, Iter: 10/16 -- train_loss: 2.1717 
2022-04-16 16:58:50,534 - INFO - Epoch: 8/50, Iter: 11/16 -- train_loss: 2.1693 
2022-04-16 16:58:51,365 - INFO - Epoch: 8/50, Iter: 12/16 -- train_loss: 2.1842 
2022-04-16 16:58:52,207 - INFO - Epoch: 8/50, Iter: 13/16 -- train_loss: 2.1462 
2022-04-16 16:58:53,046 - INFO - Epoch: 8/50, Iter: 14/16 -- train_loss: 2.1418 
2022-04-16 16:58:53,932 - INFO - Epoch: 8/50, Iter: 15/16 -- train_loss: 2.1673 
2022-04-16 16:58:54,433 - INFO - Epoch: 8/50, Iter: 16/16 -- train_loss: 2.1714 
INFO:ignite.engine.engine.SupervisedTrainer:Current learning rate: 2.000000000000001e-12
INFO:ignite.engine.engine.SupervisedEvaluator:Engine run resuming from iteration 0, epoch 7 until 8 epoc

2022-04-16 16:59:05,107 ignite.handlers.early_stopping.EarlyStopping INFO: EarlyStopping: Stop training


INFO:ignite.engine.engine.SupervisedTrainer:Terminate signaled. Engine will stop after current iteration is finished.
2022-04-16 16:59:05,108 - INFO - Epoch[8] Metrics -- val_acc: 0.9232 val_mean_dice: 0.0675 
2022-04-16 16:59:05,108 - INFO - Key metric: val_mean_dice best value: 0.07519263029098511 at epoch: 3
INFO:ignite.engine.engine.SupervisedEvaluator:Epoch[8] Complete. Time taken: 00:00:17
INFO:ignite.engine.engine.SupervisedEvaluator:Engine run complete. Time taken: 00:00:17
2022-04-16 16:59:11,270 - INFO - Epoch[8] Metrics -- train_acc: 0.8397 
2022-04-16 16:59:11,271 - INFO - Key metric: train_acc best value: 0.8418822411567934 at epoch: 5
INFO:ignite.engine.engine.SupervisedTrainer:Saved checkpoint at epoch: 8
INFO:ignite.engine.engine.SupervisedTrainer:Epoch[8] Complete. Time taken: 00:00:30
INFO:ignite.engine.engine.SupervisedTrainer:Engine run complete. Time taken: 00:03:59
