In [15]:
!python -c "import monai" || pip install -q "monai-weekly[nibabel, tqdm]"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline

In [16]:
"""
0: background,
1: sacrum,
2: right_hip,
3: left_hip,
4: lumbar_vertebra.        
"""        

'\n0: background,\n1: sacrum,\n2: right_hip,\n3: left_hip,\n4: lumbar_vertebra.        \n'

In [17]:
# import EarlyStopping
from pytorchtools import EarlyStopping

In [18]:
from monai.utils import first, set_determinism
import time
from ignite.handlers import EarlyStopping, ModelCheckpoint
from monai.transforms import (
    AddChanneld,
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    Resized,
    CropForegroundd,
    RandScaleIntensityd,
    DataStatsd,
    LoadImaged,
    Orientationd,
    Activationsd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    NormalizeIntensityd,
    RandSpatialCropd,
    ScaleIntensityRanged,
    Spacingd,
    ScaleIntensityd,
    SaveImaged,
    EnsureTyped,
    EnsureType,
    Invertd,
)
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
from monai.optimizers import Novograd
from monai.utils import get_torch_version_tuple, set_determinism
from ignite.handlers import EarlyStopping, ModelCheckpoint

print_config()


MONAI version: 0.9.dev2209
Numpy version: 1.21.5
Pytorch version: 1.10.2
MONAI flags: HAS_EXT = False, USE_COMPILED = False
MONAI rev id: 9db6e543d6090a3256f20695c1d3224df8cbbc0e
MONAI __file__: C:\Users\Hripsime\anaconda3\envs\abys\lib\site-packages\monai\__init__.py

Optional dependencies:
Pytorch Ignite version: 0.4.8
Nibabel version: 3.2.2
scikit-image version: 0.19.2
Pillow version: 9.0.1
Tensorboard version: 2.6.0
gdown version: 4.2.1
TorchVision version: 0.11.3
tqdm version: 4.63.0
lmdb version: 1.3.0
psutil version: 5.9.0
pandas version: 1.4.1
einops version: 0.4.0
transformers version: 4.16.2
mlflow version: 1.23.1

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies



## Setup data directory

## Set dataset path

In [19]:
train_dir = "C:/Users/Hripsime/OneDrive - ABYS MEDICAL/projects/CTPelvic1K_data/train_dir/"

train_images = sorted(glob.glob(os.path.join(train_dir, "*data.nii.gz")))
train_labels = sorted(glob.glob(os.path.join(train_dir, "*mask_4label.nii.gz")))
data_dicts = [{"image": image_name, "mask": label_name} for image_name, label_name in zip(train_images, train_labels)]
train_files, val_files = data_dicts[:-10], data_dicts[-10:]

In [20]:
set_determinism(seed=0)

In [21]:
train_transforms = Compose(
    [
        LoadImaged(keys=["image", "mask"]),
        EnsureChannelFirstd(keys=["image", "mask"]),
        Orientationd(keys=["image", "mask"], axcodes="RAS"),
        Spacingd(keys=["image", "mask"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")),
        CropForegroundd(keys=["image", "mask"], source_key="image"),
        ScaleIntensityRanged(keys=["image"], a_min=-120, a_max=360, b_min=0.0, b_max=1.0, clip=True,),
        #DataStatsd(keys=["image", "mask"], prefix='Data', data_type=True, data_shape=True, value_range=True, data_value=False),
        
        RandCropByPosNegLabeld(
            keys=["image", "mask"],
            label_key="mask",
            spatial_size=(128, 128, 128),
            pos=1,
            neg=1,
            num_samples=4,
            image_key="image",
            image_threshold=0,),
            
        EnsureTyped(keys=["image", "mask"]),
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "mask"]),
        EnsureChannelFirstd(keys=["image", "mask"]),
        Orientationd(keys=["image", "mask"], axcodes="RAS"),
        Spacingd(keys=["image", "mask"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")),
        ScaleIntensityRanged(keys=["image"], a_min=-120, a_max=360, b_min=0.0, b_max=1.0, clip=True,),
        CropForegroundd(keys=["image", "mask"], source_key="image"),
        
        EnsureTyped(keys=["image", "mask"]),
    ]
)


In [22]:
train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=0)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=0)

val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=0)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=0)


Loading dataset: 100%|█████████████████████████████████████████████████████████████████| 31/31 [04:51<00:00,  9.41s/it]
Loading dataset: 100%|█████████████████████████████████████████████████████████████████| 10/10 [01:28<00:00,  8.84s/it]


In [23]:
import gc #garbage collector
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()

## Create Model, Loss, Optimizer

In [24]:
from monai.losses import DiceCELoss

# standard PyTorch program style: create UNet, DiceLoss and Adam optimizer

VAL_AMP = True #Automatic mixed precision to accelerate training speed and reduce a memory usage
learning_rate = 2e-4

device=torch.device("cuda:0")     
#device=torch.device("cpu")      
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=5,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)

#loss_function = DiceLoss(to_onehot_y=True, softmax=True)
loss_function = DiceCELoss(to_onehot_y=True, softmax=True)

optimizer = torch.optim.Adam(model.parameters())
#optimizer = Novograd(model.parameters(), learning_rate * 10)
dice_metric = DiceMetric(include_background=False, reduction="mean")


# define inference method
def inference(input):

    def _compute(input):
        return sliding_window_inference(
            inputs=input,
            roi_size=(160, 160, 160),
            sw_batch_size=4,
            predictor=model,
            overlap=0.5,
        )

    if VAL_AMP:
        with torch.cuda.amp.autocast():
            return _compute(input)
    else:
        return _compute(input)
    
# use amp to accelerate training
scaler = torch.cuda.amp.GradScaler()
# enable cuDNN benchmark
torch.backends.cudnn.benchmark = True    

In [25]:
# import EarlyStopping
from pytorchtools import EarlyStopping

## Execute a PyTorch training process with Early Stopping

In [26]:
n_epochs = 100    
train_losses = []
valid_losses = []
avg_train_losses = []
avg_valid_losses = [] 

post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=5)])
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=5)])

early_stopping = EarlyStopping(patience=20, verbose=True)
    
for epoch in range(1, n_epochs + 1):
    model.train() 
    for batch_data in train_loader:
        inputs, labels = (
            batch_data["image"].to(device),
            batch_data["mask"].to(device),
        )
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())

    model.eval()
    for val_data in val_loader:
        val_inputs, val_labels = (
            val_data["image"].to(device),
            val_data["mask"].to(device),
        )
            
       
        val_outputs = inference(val_inputs)
        val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
        val_labels = [post_label(i) for i in decollate_batch(val_labels)]
        loss_val = loss_function(val_outputs, val_labels)
        valid_losses.append(loss_val.item())

             
        # calculate average loss over an epoch
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)
        
        epoch_len = len(str(n_epochs))
        
        # print training/validation statistics
        print_msg = (f'[{epoch:>{epoch_len}}/{n_epochs:>{epoch_len}}] ' +
                     f'train_loss: {train_loss:.5f} ' +
                     f'valid_loss: {valid_loss:.5f}')
        
        print(print_msg)
        
        # clear lists to track next epoch
        train_losses = []
        valid_losses = []
        
        # early_stopping needs the validation loss to check if it has decresed, 
        # and if it has, it will make a checkpoint of the current model
        early_stopping(valid_loss, model)
        
        if early_stopping.early_stop:
            print("Early stopping")
            break
        
    # load the last checkpoint with the best model
    torch.save(model.state_dict(), os.path.join(train_dir, "best_metric_model.pth"))


RuntimeError: CUDA out of memory. Tried to allocate 128.00 MiB (GPU 0; 8.00 GiB total capacity; 5.29 GiB already allocated; 0 bytes free; 5.68 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

## Execute a PyTorch training process without Early Stopping

In [13]:
n_epochs = 15
val_interval = 1   
best_metric = -1
best_metric_epoch = -1

epoch_loss_values = []
metric_values = []
post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=5)])
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=5)])

start_time = time.time()

for epoch in range(max_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    
    
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = (
            batch_data["image"].to(device),
            batch_data["mask"].to(device),
        )
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        epoch_loss += loss.item()
        print(
            f"{step}/{len(train_ds) // train_loader.batch_size}, "
            f"train_loss: {loss.item():.4f}")
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        
        model.eval()
        with torch.no_grad():
            for val_data in val_loader:
                val_inputs, val_labels = (
                    val_data["image"].to(device),
                    val_data["mask"].to(device),
                )
                
                val_outputs = inference(val_inputs)
                val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
                val_labels = [post_label(i) for i in decollate_batch(val_labels)]
                
                # compute metric for current iteration
                dice_metric(y_pred=val_outputs, y=val_labels)

            # aggregate the final mean dice result
            metric = dice_metric.aggregate().item()
            
            # reset the status for next validation round
            dice_metric.reset()

            metric_values.append(metric)
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), os.path.join(root_dir, "best_metric_model.pth"))
                print("saved new best metric model")
            
            print(
                "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {} Total time seconds: {:.2f}".format(epoch + 1, metric, best_metric, best_metric_epoch, (time.time()- start_time))
                    )       

----------
epoch 1/15
1/14, train_loss: 2.7198
2/14, train_loss: 2.6451
3/14, train_loss: 2.6131
4/14, train_loss: 2.5705
5/14, train_loss: 2.5150
6/14, train_loss: 2.5007
7/14, train_loss: 2.4746
8/14, train_loss: 2.4421
9/14, train_loss: 2.4101
10/14, train_loss: 2.3967
11/14, train_loss: 2.3370
12/14, train_loss: 2.3954
13/14, train_loss: 2.2957
14/14, train_loss: 2.3679
15/14, train_loss: 2.3078
epoch 1 average loss: 2.4661
saved new best metric model
current epoch: 1 current mean dice: 0.0612 best mean dice: 0.0612 at epoch 1 Total time seconds: 31.35
----------
epoch 2/15
1/14, train_loss: 2.2905
2/14, train_loss: 2.2591
3/14, train_loss: 2.2795
4/14, train_loss: 2.2333
5/14, train_loss: 2.2339
6/14, train_loss: 2.2632
7/14, train_loss: 2.2602
8/14, train_loss: 2.1670
9/14, train_loss: 2.2260
10/14, train_loss: 2.1862
11/14, train_loss: 2.2460
12/14, train_loss: 2.1600
13/14, train_loss: 2.1570
14/14, train_loss: 2.1603
15/14, train_loss: 2.1448
epoch 2 average loss: 2.2178
saved

14/14, train_loss: 0.9005
15/14, train_loss: 0.9016
epoch 15 average loss: 0.8800
saved new best metric model
current epoch: 15 current mean dice: 0.5305 best mean dice: 0.5305 at epoch 15 Total time seconds: 269.68


In [14]:
print(
    f"train completed, best_metric: {best_metric:.4f} "
    f"at epoch: {best_metric_epoch}")

train completed, best_metric: 0.5305 at epoch: 15
