### Min-Max Normalize and Window Transforms:

In [1]:
import monai
from monai.transforms import MapTransform
from monai.config import KeysCollection
class MinMax(MapTransform):
	def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False):
		super().__init__(keys=keys, allow_missing_keys=allow_missing_keys)

	def normalize(self, data: monai.data.MetaTensor):
		min = data.min()
		max = data.max()
		if max > min:
			data = (data - min) / (max - min)
		elif max == min:
			data = (data - min) / max
		else:
			raise ValueError('MinMax failed: Minimum seems to be greater than maximum')
		return data

	def __call__(self, data):
		for key in self.keys:
			if key in data:
				data[key] = self.normalize(data[key])
		return data

class Window(MapTransform):
	def __init__(self, window_center: float, window_width: float, keys: KeysCollection, allow_missing_keys: bool = False):
		super().__init__(keys=keys, allow_missing_keys=allow_missing_keys)
		self.window_center = window_center
		self.window_width = window_width

	def window(self, data: monai.data.MetaTensor):
		img_min = self.window_center - self.window_width // 2
		img_max = self.window_center + self.window_width // 2
		data[data < img_min] = img_min
		data[data > img_max] = img_max
		return data

	def __call__(self, data):
		for key in self.keys:
			if key in data:
				data[key] = self.window(data[key])
		return data

### Setup transforms for training and validation

In [2]:
from monai.transforms import (
    Compose,
    LoadImaged,
    EnsureChannelFirstd,
    SpatialPadD,
    Resized,
    RandFlipd,
    RandRotate90d,
    RandShiftIntensityd,
)

root_dir = '/home/samartha-mig/Projects/monai_prac_3d/Seg3'

volume_size = (224, 224, 224)

train_transforms = Compose(
    [
        LoadImaged(keys=['image', 'label']),
        EnsureChannelFirstd(keys=['image', 'label']),
        # Window(keys=["image"], window_center=150, window_width=500),
        SpatialPadD(
        keys=['image', 'label'],
        spatial_size=(512, 512, 256)
        ),
        Resized(
        keys=['image', 'label'],
        spatial_size=volume_size,
        mode=('bilinear', 'nearest')
        ),
        # Orientationd(keys=["image", "label"], axcodes="RAS"),
        # Spacingd(
        #     keys=["image", "label"],
        #     pixdim=(1.5, 1.5, 2.0),
        #     mode=("bilinear", "nearest"),
        # ),
        # ScaleIntensityRanged(
        #     keys=["image"],
        #     a_min=-175,
        #     a_max=250,
        #     b_min=0.0,
        #     b_max=1.0,
        #     clip=True,
        # ),
        # CropForegroundd(keys=["image", "label"], source_key="image"),
        # RandCropByPosNegLabeld(
        #     keys=["image", "label"],
        #     label_key="label",
        #     spatial_size=(96, 96, 96),
        #     pos=1,
        #     neg=1,
        #     num_samples=4,
        #     image_key="image",
        #     image_threshold=0,
        # ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[0],
            prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[1],
            prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[2],
            prob=0.10,
        ),
        RandRotate90d(
            keys=["image", "label"],
            prob=0.10,
            max_k=3,
        ),
        RandShiftIntensityd(
            keys=["image"],
            offsets=0.10,
            prob=0.50,
        ),
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        SpatialPadD(
        keys=['image', 'label'],
        spatial_size=(512, 512, 256)
        ),
        Resized(
        keys=['image', 'label'],
        spatial_size=volume_size,
        mode=('bilinear', 'nearest')
        ),
        # Orientationd(keys=["image", "label"], axcodes="RAS"),
        # Spacingd(
        #     keys=["image", "label"],
        #     pixdim=(1.5, 1.5, 2.0),
        #     mode=("bilinear", "nearest"),
        # ),
        # ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
        # CropForegroundd(keys=["image", "label"], source_key="image"),
    ]
)

### Setting up Data

In [3]:
from monai.data import DataLoader
from monai.data.dataset import Dataset
from sklearn.model_selection import train_test_split
import os
import json
data_dir = "/home/samartha-mig/Projects/data/Task08_HepaticVessel"
files = json.load(open(os.path.join(data_dir, 'dataset.json')))['training']
for file in files:
  for key in file.keys():
    file[key] = os.path.join(data_dir, file[key].replace('./', ''))
train_list, val_list = train_test_split(files, train_size=0.7)
train_ds = Dataset(data=train_list, transform=train_transforms)
val_ds = Dataset(data=val_list, transform=val_transforms)

train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=5, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=5, pin_memory=True)

### Create Model, Loss, Optimizer


In [4]:
import torch
from monai.networks.nets.unetr import UNETR
from monai.losses.dice import DiceCELoss
device = "cuda:1"

model = UNETR(
    in_channels=1,
    out_channels=3,
    img_size=volume_size,
    feature_size=16,
    hidden_size=768,
    mlp_dim=3072,
    num_heads=12,
    proj_type="conv",
    norm_name="instance",
    res_block=True,
    dropout_rate=0.0,
).to(device)

loss_function = DiceCELoss(
    include_background = False,
    to_onehot_y=True,
    softmax=True
)
torch.backends.cudnn.benchmark = True
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

### Execute a typical PyTorch training process

In [None]:
from monai.inferers import sliding_window_inference
from monai.transforms import AsDiscrete
from monai.metrics import DiceMetric
from monai.data import decollate_batch
from tqdm import tqdm
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

def validation(epoch_iterator_val):
    model.eval()
    with torch.no_grad():
        for batch in epoch_iterator_val:
            val_inputs, val_labels = (batch["image"].to(device), batch["label"].to(device))
            val_outputs = sliding_window_inference(val_inputs, volume_size, 4, model)
            val_labels_list = decollate_batch(val_labels)
            val_labels_convert = [post_label(val_label_tensor) for val_label_tensor in val_labels_list]
            val_outputs_list = decollate_batch(val_outputs)
            val_output_convert = [post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list]
            dice_metric(y_pred=val_output_convert, y=val_labels_convert)
            epoch_iterator_val.set_description("Validate (%d / %d Steps)" % (global_step, 10.0))  # noqa: B038
        mean_dice_val = dice_metric.aggregate().item()
        dice_metric.reset()
    return mean_dice_val

def train(global_step, train_loader, dice_val_best, global_step_best):
    model.train()
    epoch_loss = 0
    step = 0
    epoch_iterator = tqdm(train_loader) # , desc="Training (X / X Steps) (loss=X.X)", dynamic_ncols=True
    for step, batch in enumerate(epoch_iterator):
        step += 1
        x, y = (batch["image"].to(device), batch["label"].to(device))
        logit_map = model(x)
        loss = loss_function(logit_map, y)
        loss.backward()
        epoch_loss += loss.item()
        optimizer.step()
        optimizer.zero_grad()
        epoch_iterator.set_description(  # noqa: B038
            "Training (%d / %d Steps) (loss=%2.5f)" % (global_step, max_iterations, loss)
        )
        if (global_step % eval_num == 0 and global_step != 0) or global_step == max_iterations:
            epoch_iterator_val = tqdm(val_loader) # , desc="Validate (X / X Steps) (dice=X.X)", dynamic_ncols=True
            dice_val = validation(epoch_iterator_val)
            epoch_loss /= step
            epoch_loss_values.append(epoch_loss)
            metric_values.append(dice_val)
            if dice_val > dice_val_best:
                dice_val_best = dice_val
                global_step_best = global_step
                torch.save(model.state_dict(), os.path.join(root_dir, "best_metric_model.pth"))
                print(
                    "Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(dice_val_best, dice_val)
                )
            else:
                print(
                    "Model Was Not Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(
                        dice_val_best, dice_val
                    )
                )
        global_step += 1
    return global_step, dice_val_best, global_step_best


max_iterations = 21200
eval_num = 1060
post_label = AsDiscrete(to_onehot=3)
post_pred = AsDiscrete(argmax=True, to_onehot=3)
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
global_step = 0
dice_val_best = 0.0
global_step_best = 0
epoch_loss_values = []
metric_values = []
while global_step < max_iterations:
    global_step, dice_val_best, global_step_best = train(global_step, train_loader, dice_val_best, global_step_best)
model.load_state_dict(torch.load(os.path.join(root_dir, "best_metric_model.pth")))

Training (113 / 21200 Steps) (loss=1.54195):  54%|█████▍    | 114/212 [05:33<03:11,  1.95s/it]

In [None]:
savefile = open(os.path.join(root_dir, 'savefile.json'), 'w')
json.dump(
    {
        'dice_val_best':dice_val_best,
        'global_step_best':global_step_best,
        'epoch_loss_values':epoch_loss_values,
        'metric_values':metric_values
    },
    savefile
)
savefile.close()

In [None]:
print(f"train completed, best_metric: {dice_val_best:.4f} " f"at iteration: {global_step_best}")