In [1]:
import os
import sys
import torch
from pathlib import Path
import numpy as np

import monai

from monai.engines import SupervisedEvaluator, SupervisedTrainer
from monai.metrics.metric import IterationMetric
from monai.metrics.utils import do_metric_reduction

# from ignite.metrics.epoch_metric import EpochMetric
from monai.handlers import (
    ValidationHandler,
    CheckpointSaver,
    LrScheduleHandler,
    MeanDice,
    StatsHandler,
    TensorBoardImageHandler,
    TensorBoardStatsHandler,
    GarbageCollector,
    EarlyStopHandler
)

from monai.inferers import SimpleInferer
from monai.transforms import (
    AddChanneld,
    Compose,
    LoadImaged,
    RandAffined,
    ToTensord,
    RandAdjustContrastd,
    ScaleIntensityd,
    RandScaleIntensityd,
    ScaleIntensityRangePercentilesd,
    AsDiscreted, 
    KeepLargestConnectedComponentd
)

from monai.engines.utils import CommonKeys as Keys
from monai.data import DataLoader
from monai.engines.utils import IterationEvents

from ignite.contrib.handlers import ProgressBar
from ignite.engine import Events

# Local imports
from utils import generate_directory_name, get_list_of_file_names
from loss import DicePlusConstantCatCrossEntropyLoss
from optimizer import RAdam
from transforms import DistanceTransformd, OneHotTransformd
from model import VNet
from metric import LossMetric

# monai.config.print_config()

import logging
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

## Common Functions

In [3]:
def prepare_batch(batch, device, non_blocking=False):
    input_fields = ["mid-systolic-images", "annuli"]
    inputs = [batch[input_field].to(device) for input_field in input_fields]
    inputs = torch.cat(inputs, dim=1)
    batch[Keys.IMAGE] = inputs
    batch[Keys.LABEL] = batch["labels"].to(device)
    return batch[Keys.IMAGE], batch[Keys.LABEL]


def prepare_device(n_gpu_use: int):
    n_gpu = torch.cuda.device_count()
    if n_gpu_use > 0 and n_gpu == 0:
        print("Warning: There\'s no GPU available on this machine, training will be performed on CPU.")
        n_gpu_use = 0
    if n_gpu_use > n_gpu:
        print("Warning: The number of GPU\'s configured to use is {}, but only {} are available on this "
                   "machine.".format(n_gpu_use, n_gpu))
        n_gpu_use = n_gpu
    device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
    list_ids = list(range(n_gpu_use))
    return device, list_ids

## Configuration Parameters

In [None]:
leaflet_names = ["anterior", "posterior", "septal"]
name = "single_phase_with_annulus"


data_dir = "/data/in/DL_DATA_224_224_224_6_vox_min_tricuspid/train"

output_dir = "/data/out/test_monai"
directory = generate_directory_name()
checkpoint_dir = Path(output_dir) / name / directory
train_log_dir = Path(output_dir) / name / directory / "training"
val_log_dir = Path(output_dir) / name / directory / "validation"


validation_split = 0.1
n_gpu = 2
batch_size = 8
num_workers = 8
use_amp = True
max_epochs = 200

all_keys = ["mid-systolic-images", "annuli", "labels"]

## Transforms

In [None]:
rot_rad = 30 * np.pi/180

train_transforms = Compose(
    [
        LoadImaged(keys=all_keys, reader="NibabelReader"),
        AddChanneld(keys=all_keys),
        RandAffined(
            keys=all_keys, 
            prob=0.5, 
            rotate_range=(-rot_rad,rot_rad), # radians! 
            translate_range=(-30,30),
            scale_range=(-0.2,0.2), 
            mode="nearest", 
            padding_mode="zeros", 
            as_tensor_output=False
        ),
        OneHotTransformd(keys=["labels"]),
        DistanceTransformd(keys=["annuli"]),
        RandScaleIntensityd(
            keys=["mid-systolic-images"],
            factors=0.3,
            prob=0.5
        ),
        ScaleIntensityd(
            keys=["mid-systolic-images"], 
            minv=0.0,
            maxv=1.0
        ),
        ToTensord(keys=all_keys)
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=all_keys, reader="NibabelReader"),
        AddChanneld(keys=all_keys),
        OneHotTransformd(keys=["labels"]),
        DistanceTransformd(keys=["annuli"]),
#         ScaleIntensityRangePercentilesd(
#             keys=["mid-systolic-images"], 
#             lower=2, 
#             upper=99, 
#             b_min=0.0, 
#             b_max=1.0, 
#             clip=True, 
#             relative=True
#         ),
        ScaleIntensityd(
            keys=["mid-systolic-images"], 
            minv=0.0,
            maxv=1.0
        ),
        ToTensord(keys=all_keys)
    ]
)


post_transforms = Compose(
    [
        AsDiscreted(keys=Keys.PRED, threshold_values=True),
        KeepLargestConnectedComponentd(keys=Keys.PRED, applied_labels=[1,2,3]),
    ]
)

## Dataloader (train / validation)

In [None]:

from glob import glob

n_samples = len(get_list_of_file_names(os.path.join(data_dir, "mid-systolic-images")))

len_valid = int(n_samples * validation_split)

# get all files for all directories
images = sorted(get_list_of_file_names(os.path.join(data_dir, "mid-systolic-images"), absolute_path=True))
annuli = sorted(get_list_of_file_names(os.path.join(data_dir, "annuli"), absolute_path=True))
labels = sorted(get_list_of_file_names(os.path.join(data_dir, "labels"), absolute_path=True))


train_files = [{"mid-systolic-images":  img, "annuli": ann, "labels": lbl} for img, ann, lbl in zip(images[:n_samples-len_valid],
                                                                                                   annuli[:n_samples-len_valid],
                                                                                                   labels[:n_samples-len_valid])]

val_files = [{"mid-systolic-images": img, "annuli": ann, "labels": lbl} for img, ann, lbl in zip(images[n_samples-len_valid:],
                                                                                                   annuli[n_samples-len_valid:],
                                                                                                   labels[n_samples-len_valid:])]
train_dataset = monai.data.Dataset(train_files, transform=train_transforms)
train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

val_dataset = monai.data.Dataset(val_files, transform=val_transforms)
val_data_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

print("Num training datasets:", len(train_files))
print("Num validation datasets:", len(val_files))

print("first dataset:")
import pprint
pprint.pprint(train_files[0])


## Network, Loss, Optimizer, LR Scheduler

In [None]:
device, device_ids = prepare_device(n_gpu_use=n_gpu)

# Custom VNet
net = VNet(
    n_channels=2,
    n_classes=4,
    n_filters=16,
    normalization="batchnorm"
).to(device)


# from monai.networks.nets.vnet import VNet

# # NET
# net = VNet(
#     in_channels=2,
#     out_channels=4,
#     act="relu"
# ).to(device)


# Multi-GPU Training
if len(device_ids) > 1:
    net = torch.nn.DataParallel(net, device_ids=device_ids).cuda()
    
# LOSS
loss_function = DicePlusConstantCatCrossEntropyLoss(
    boundaries_weight_factor=50,
    boundaries_pool=3,
    sigma=0.02
)

trainable_params = filter(lambda p: p.requires_grad, net.parameters())

# OPTIMIZER
# NB: as of now using local copy of RAdam optimizer but seems to be integrated into pytorch soon (https://github.com/pytorch/pytorch/pull/58968)
optimizer = RAdam(
    params=trainable_params, 
    lr=0.02, 
    weight_decay=1e-05
)

# SCHEDULER
from torch.optim.lr_scheduler import ReduceLROnPlateau
lr_scheduler = ReduceLROnPlateau(
    optimizer=optimizer,
    factor=0.5,
    patience=3,
    verbose=True,
    mode='min'
)

## (optional) Restore previous model state

In [None]:
# checkpoint_file = "/home/herzc/sources/DeepHeartPrivate/notebooks/Visualization/mid_systolic_images_with_annulus_detached.pth"
# checkpoint = torch.load(checkpoint_file)
# net.load_state_dict(checkpoint['state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer'])

In [None]:
class SupervisedValidator(SupervisedEvaluator):

    def __init__(self, **kwargs):
        self.loss_function = kwargs.pop("loss_function")
        super(SupervisedValidator, self).__init__(**kwargs)
        
    def _iteration(self, engine, batchdata):
        if batchdata is None:
            raise ValueError("Must provide batch data for current iteration.")
        batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking)
        if len(batch) == 2:
            inputs, targets = batch
            args: Tuple = ()
            kwargs: Dict = {}
        else:
            inputs, targets, args, kwargs = batch

        # put iteration outputs into engine.state
        engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}
        # execute forward computation
        with self.mode(self.network):
            if self.amp:
                with torch.cuda.amp.autocast():
                    predictions = self.inferer(inputs, self.network, *args, **kwargs)
                    loss = self.loss_function(predictions, targets).mean()
            else:
                predictions = self.inferer(inputs, self.network, *args, **kwargs)
                loss = self.loss_function(predictions, targets).mean()
                
        engine.state.output[Keys.PRED] = predictions
        engine.state.output["val_loss"] = loss.item()
        engine.fire_event(IterationEvents.FORWARD_COMPLETED)
        engine.fire_event(IterationEvents.MODEL_COMPLETED)
        
        torch.cuda.empty_cache()
        
        return engine.state.output


## Training

In [None]:
val_handlers = [
    ProgressBar(),
    GarbageCollector("epoch")
]


validator = SupervisedEvaluator(
    device=device,
    val_data_loader=val_data_loader,
    prepare_batch=prepare_batch,
    network=net,
    inferer=SimpleInferer(),
    key_val_metric={
        "val_mean_dice": MeanDice(include_background=False, output_transform=lambda x: (x[Keys.PRED], x[Keys.LABEL]))
    },
    additional_metrics={
        "val_loss": LossMetric(metric_fn=loss_function, output_transform=lambda x: (x[Keys.PRED], x[Keys.LABEL])),
#         "val_mean_dice_anterior": MeanDice(include_background=False, output_transform=lambda x: (x[Keys.PRED][:,1,:], x[Keys.LABEL][:,1,:])),
#         "val_mean_dice_posterior": MeanDice(include_background=False, output_transform=lambda x: (x[Keys.PRED][:,2,:], x[Keys.LABEL][:,2,:])),
#         "val_mean_dice_septal": MeanDice(include_background=False, output_transform=lambda x: (x[Keys.PRED][:,3,:], x[Keys.LABEL][:,3,:]))
    },
    val_handlers=val_handlers,
    post_transform=post_transforms,
    amp=use_amp
)

    
train_handlers = [
    ValidationHandler(
        validator=validator, 
        interval=1, 
        epoch_level=True
    ),
    StatsHandler(
        tag_name="train_loss", 
        output_transform=lambda x: x[Keys.LOSS]
    ),
    CheckpointSaver(
        save_dir=checkpoint_dir, 
        save_dict={"net": net, "opt": optimizer}, 
        save_interval=5, 
        epoch_level=True
    ),
    GarbageCollector("epoch")
]


trainer = SupervisedTrainer(
    device=device,
    max_epochs=max_epochs,
    train_data_loader=train_data_loader,
    prepare_batch=prepare_batch,
    network=net,
    optimizer=optimizer,
    loss_function=loss_function,
    inferer=SimpleInferer(),
    key_train_metric={
        "train_mean_dice": MeanDice(include_background=False, output_transform=lambda x: (x[Keys.PRED], x[Keys.LABEL])),
    },
    additional_metrics={
#         "train_mean_dice_anterior": MeanDice(include_background=False, output_transform=lambda x: (x[Keys.PRED][:,1,:], x[Keys.LABEL][:,1,:])),
#         "train_mean_dice_posterior": MeanDice(include_background=False, output_transform=lambda x: (x[Keys.PRED][:,2,:], x[Keys.LABEL][:,2,:])),
#         "train_mean_dice_septal": MeanDice(include_background=False, output_transform=lambda x: (x[Keys.PRED][:,3,:], x[Keys.LABEL][:,3,:]))
    },
    train_handlers=train_handlers,
    post_transform=post_transforms,
    amp=use_amp
)


# more handlers
StatsHandler(
    output_transform=lambda x: None,
    global_epoch_transform=lambda x: trainer.state.epoch
).attach(validator)


LrScheduleHandler(
    lr_scheduler=lr_scheduler, 
    print_lr=True, 
    step_transform=lambda x: x.state.metrics["val_loss"]
).attach(validator)


TensorBoardStatsHandler(
    log_dir=val_log_dir,
    output_transform=lambda x: None,
    global_epoch_transform=lambda x: trainer.state.epoch
).attach(validator)


# add handler to draw the first image and the corresponding label and model output in the last batch
# here we draw the 3D output as GIF format along Depth axis, at every validation epoch
val_tensorboard_image_handler = TensorBoardImageHandler(
    log_dir=val_log_dir,
    batch_transform=lambda x: (x[Keys.IMAGE], x[Keys.LABEL]),
    output_transform=lambda x: x[Keys.PRED],
    max_channels=10,
    global_iter_transform=lambda x: trainer.state.epoch,
)
validator.add_event_handler(
    event_name=Events.EPOCH_COMPLETED, handler=val_tensorboard_image_handler
)

TensorBoardStatsHandler(
    log_dir=train_log_dir, 
    tag_name="train_loss", 
    output_transform=lambda x: x[Keys.LOSS],
    global_epoch_transform=lambda x: trainer.state.iteration
).attach(trainer)


EarlyStopHandler(
    patience=30,
    score_function=lambda x: x.state.metrics["val_mean_dice"],
    trainer=trainer,
    epoch_level=True,
).attach(validator)


@trainer.on(IterationEvents.FORWARD_COMPLETED)
def run_post_transform(engine):
    pred = engine.state.output[Keys.PRED]
    engine.state.output[Keys.PRED] = torch.nn.functional.softmax(pred.reshape(pred.size(0), pred.size(1), -1), dim=1).view_as(pred)

    
@validator.on(IterationEvents.FORWARD_COMPLETED)
def run_post_transform(engine):
    pred = engine.state.output[Keys.PRED]
    engine.state.output[Keys.PRED] = torch.nn.functional.softmax(pred.reshape(pred.size(0), pred.size(1), -1), dim=1).view_as(pred)


trainer.run()

In [None]:
## Playground

In [None]:
# NB: asynchronous post transform call. has absolutely no effect on loss function inputs
# from monai.transforms import apply_transform
# 
# post_transforms = Compose(
#     [
#         Activationsd(keys=Keys.PRED,
#                      other=lambda x: torch.nn.functional.softmax(x.reshape(x.size(0), x.size(1), -1), dim=1).view_as(x)),
# #         AsDiscreted(keys=["pred", "label"], argmax=(True, False), to_onehot=True, n_classes=4)
#         AsDiscreted(keys=[Keys.LABEL], to_onehot=True, n_classes=4)
#     ]
# )

# from monai.engines.utils import IterationEvents
# from monai.transforms import apply_transform

# @trainer.on(IterationEvents.FORWARD_COMPLETED)
# def run_post_transform(engine):
#     print("running post transforms")
#     print(engine.state.output.keys())
#     engine.state.output = apply_transform(post_transforms, engine.state.output)
#     print(engine.state.output["label"].shape)
#     print("completed")

# val_loss = "val_loss"

# @validator.on(IterationEvents.FORWARD_COMPLETED)
# def run_post_transform(engine):
#     engine.state.output[Keys.LOSS] = loss_function(engine.state.output[Keys.PRED], engine.state.output[Keys.LABEL])


## Plotting QC

In [None]:
# paths = [
#     os.path.join(os.getcwd(), "..", "..")
# ]

# for path in paths:
#     if not path in sys.path:
#         sys.path.insert(0, path)

# from deeputils.display import showImage

# from monai.utils.misc import first
# from pathlib import Path

# data_dict = first(train_data_loader)

# for idx, data_dict in enumerate(train_data_loader):
#     # print(data_dict["mid-systolic-images"].shape, data_dict["anterior"].shape, data_dict["image_meta_dict"]["filename_or_obj"])
# #     print(f"unique labels: {np.unique(data_dict['labels'])}")
# #     print(data_dict['mid-systolic-images_meta_dict']['filename_or_obj'])

#     showImage(data_dict["mid-systolic-images"], n_disp_images=5)
#     showImage(data_dict["labels"], n_disp_images=5)
#     showImage(data_dict["labels"][:,0,:], n_disp_images=5, case_name=Path(data_dict['mid-systolic-images_meta_dict']['filename_or_obj'][0]).name)
#     showImage(data_dict["annuli"], n_disp_images=5, case_name=Path(data_dict['mid-systolic-images_meta_dict']['filename_or_obj'][0]).name)

#     if idx == 7:
#         break
