In [None]:
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
import monai
import nibabel as nib
import pandas as pd
import csv
import torch
from datetime import datetime

from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
from monai.transforms import (
    AsDiscrete,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    Activations,
    RandRotated,
    RandScaleIntensityd,
    RandShiftIntensityd,
)

from monai.networks.nets import UNETR
from monai.metrics import DiceMetric, DiceHelper, HausdorffDistanceMetric
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch

# Variabel yang perlu diubah

In [None]:
# Ukuran input model
model_input_size = (128,128,128)

# Jumlah patch yang diambil
patch_num = 4

# Ratio dice loss untuk perhitungan DiceCELoss
dice_ratio = 0.6
ce_ratio = 1.0 - dice_ratio

# Tipe optimizer
optim_type = "SGD"

# CSV train-val 
train_csv = "CSV/data_train.csv"
val_csv = "CSV/data_val.csv"

# Apakah merupakan data tanpa reorientasi?
is_resized = True
resized_dir = "resized_data"
reorient_dir = "reoriented_data"

# Dataset cache rate
cache_rate = 1.0

# Batch size 
batch_size = 1

# Model dir & checkpoint
model_dir = "results/Exp1/Reorient-50-50"
model_ckpt = ""

# Simpan checkpoint model setiap ... epoch
save_every = 25

# Epoch mulai dan selesai
start_epoch = 0
end_epoch = 2500

# Hasil terbaik saat validasi (ubah hanya jika melanjutkan training)
dice_val_best = 0
dice_val_best_epoch = 0

# Code

In [None]:
train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        RandRotated(keys=["image", "label"], range_x=30, range_y=30, range_z=30, prob=0.4, mode=["bilinear", "nearest"]),
        RandScaleIntensityd(keys="image", factors=0.1, prob=0.3),
        RandShiftIntensityd(keys="image", offsets=0.1, prob=0.3),
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=model_input_size,
            pos=1,
            neg=1,
            num_samples=patch_num,
            image_key="image",
            image_threshold=0,
        ),
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),      
        CropForegroundd(keys=["image", "label"], source_key="image"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
    ]
)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = UNETR(
    in_channels=1,
    out_channels=1,
    img_size=model_input_size,
    feature_size=16,
    hidden_size=768,
    mlp_dim=3072,
    num_heads=12,
    proj_type="perceptron",
    norm_name="instance",
    res_block=True,
    dropout_rate=0.0,
).to(device)

loss_function = DiceCELoss(to_onehot_y=True, sigmoid=True, lambda_dice=0.6, lambda_ce=0.4)
torch.backends.cudnn.benchmark = True
post_label = AsDiscrete(threshold=0.5)
post_pred = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False, ignore_empty=True)
hausdorff_metric = HausdorffDistanceMetric(include_background=True, reduction="mean", get_not_nans=False, percentile=95)

In [None]:
if(optim_type == "Adam"):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
elif(optim_type == "Adagrad"):
    optimizer = torch.optim.Adagrad(model.parameters(), lr=1e-2, lr_decay=0, weight_decay=1e-5)
elif(optim_type == "Adadelta"):
    optimizer = torch.optim.Adadelta(model.parameters(), lr=1.0, rho=0.9, eps=1e-6, weight_decay=1e-5)
elif(optim_type == "SGD"):
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9, weight_decay=1e-5, nesterov=True)
elif(optim_type == "RMSProp"):
    optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-4, alpha=0.99, eps=1e-8, weight_decay=1e-5, momentum=0.9)
else:
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

In [None]:
df_train = pd.read_csv(train_csv)
df_val = pd.read_csv(val_csv)

if (is_resized):
    df_train['TOF_pre'] = df_train['TOF_pre'].str.replace(reorient_dir, resized_dir, regex=False)
    df_train['labels'] = df_train['labels'].str.replace(reorient_dir, resized_dir, regex=False)
    df_val['TOF_pre'] = df_val['TOF_pre'].str.replace(reorient_dir, resized_dir, regex=False)
    df_val['labels'] = df_val['labels'].str.replace(reorient_dir, resized_dir, regex=False)

TOF_pre_train_path = df_train['TOF_pre']
TOF_pre_val_path = df_val['TOF_pre']
label_train_path = df_train['labels']
label_val_path = df_val['labels']

train_dict = [{"image": image_name, "label": label_name} for image_name, label_name in zip(TOF_pre_train_path, label_train_path)]
val_dict = [{"image": image_names, "label": label_names} for image_names, label_names in zip(TOF_pre_val_path, label_val_path)]

train_ds = CacheDataset(data=train_dict, transform=train_transforms, cache_rate=cache_rate, num_workers=4)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)

val_ds = CacheDataset(data=val_dict, transform=val_transforms, cache_rate=cache_rate, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=4)

In [None]:
csv_file = os.path.join(model_dir, f"loss_log.csv")

os.makedirs(model_dir, exist_ok=True)
if model_ckpt != "":
    if os.path.exists(model_ckpt):
        print(f"Weights loaded from {model_ckpt}")
        model.load_state_dict(torch.load(model_ckpt))
    else:
        print(f"Model checkpoint ({model_ckpt}) not found")

if (not(os.path.exists(csv_file))):
    with open(csv_file, mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(["epoch", "train loss", "val mean dice", "val hausdorff"])

In [None]:
dice_val_best = 0
dice_val_best_epoch = 0
for epoch in range(start_epoch, end_epoch):
    time_start = datetime.now()
    print("-" * 10)
    epoch_loss = 0
    step = 0
    model.train()
    for batch in train_loader:
        step += 1
        x, y = batch["image"].to(device), batch["label"].to(device)
        
        optimizer.zero_grad()
        logit_map = model(x)
        loss = loss_function(logit_map, y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    epoch_loss /= step

    model.eval()
    with torch.no_grad():
        for batch in val_loader:
            val_inputs, val_labels = (batch["image"].cuda(), batch["label"].cuda())
            val_outputs = sliding_window_inference(val_inputs, (128, 128, 128), 4, model)
            val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
            val_labels = [post_label(i) for i in decollate_batch(val_labels)]
            dice_metric(y_pred=val_outputs, y=val_labels)
            hausdorff_metric(y_pred=val_outputs, y=val_labels)
        mean_dice_val = dice_metric.aggregate().item()
        mean_hausdorff_score = hausdorff_metric.aggregate().item()
        dice_metric.reset()
        hausdorff_metric.reset()

    if mean_dice_val > dice_val_best:
        dice_val_best = mean_dice_val
        dice_val_best_epoch = epoch+1
        torch.save(model.state_dict(), os.path.join(model_dir, f"best_metric_model.pth"))
        print("saved new best metric model")
    print(
        f"best mean dice: {dice_val_best:.4f} "
        f"at epoch: {dice_val_best_epoch}"
    )

    with open(csv_file, mode='a', newline='') as file:
        writer = csv.writer(file)
        writer.writerow([epoch + 1, epoch_loss, mean_dice_val, mean_hausdorff_score])

    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f} - Validation Dice: {mean_dice_val:.4f} Hausdorff: {mean_hausdorff_score:.4f} - Time taken: {datetime.now() - time_start}")
    if ((epoch+1) % save_every == 0):
        torch.save(model.state_dict(), os.path.join(model_dir, f"model_e{epoch+1}.pth"))


Struktur akhir setelah training :  

results/Exp1/  
|-- Reorient-50-50/  
|&nbsp;&nbsp;&nbsp;&nbsp;|-- best_metric_model.pth  
|&nbsp;&nbsp;&nbsp;&nbsp;|-- loss_log.csv  
|&nbsp;&nbsp;&nbsp;&nbsp;|-- model_e25.pth  
|&nbsp;&nbsp;&nbsp;&nbsp;|-- model_e50.pth  
|&nbsp;&nbsp;&nbsp;&nbsp;|...  
|-- Resize-50-50/  
|&nbsp;&nbsp;&nbsp;&nbsp;|-- best_metric_model.pth  
|&nbsp;&nbsp;&nbsp;&nbsp;|-- loss_log.csv  
|&nbsp;&nbsp;&nbsp;&nbsp;|-- model_e25.pth  
|&nbsp;&nbsp;&nbsp;&nbsp;|-- model_e50.pth  
|&nbsp;&nbsp;&nbsp;&nbsp;|...  
...