In [None]:
from monai.transforms import (
    Compose,
    LoadImage,
    EnsureChannelFirst,
    ScaleIntensity,
    Resized,
    RandFlip,
    RandRotate90,
    RandZoom,
    RandGaussianNoise,
    ToTensor,
)
from monai.data import Dataset, DataLoader, CacheDataset
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric
from monai.inferers import sliding_window_inference
from torch.optim import Adam
from sklearn.model_selection import train_test_split
import torch
import glob
from utils.CIS_UNet import CIS_UNet

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
mha_files = sorted(glob.glob("data/masks/*.mha"))  
image_files = sorted(glob.glob("data/images/*.png"))  
data_dicts = [{"mha": mha, "image": img} for mha, img in zip(mha_files, image_files)]

In [None]:
train_files, val_files = train_test_split(data_dicts, test_size=0.2, random_state=42)

train_transforms = Compose([
    LoadImage(keys=["mha", "image"]),
    EnsureChannelFirst(keys=["mha", "image"]),
    ScaleIntensity(keys=["mha"]),
    Resized(keys=["image"], spatial_size=(256, 256)),
    RandFlip(keys=["mha", "image"], prob=0.5, spatial_axis=0),
    RandRotate90(keys=["mha", "image"], prob=0.5),
    RandZoom(keys=["mha", "image"], prob=0.2, min_zoom=0.9, max_zoom=1.1),
    RandGaussianNoise(keys=["mha"], prob=0.1),
])

val_transforms = Compose([
    LoadImage(keys=["mha", "image"]),
    EnsureChannelFirst(keys=["mha", "image"]),
    ScaleIntensity(keys=["mha"]),
    Resized(keys=["image"], spatial_size=(256, 256)),
])

In [None]:
train_dataset = CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.8)
val_dataset = CacheDataset(data=val_files, transform=val_transforms, cache_rate=0.8)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=4)

In [None]:
model = CIS_UNet(spatial_dims=3, 
                 in_channels=1, 
                 num_classes=24, 
                 encoder_channels=[64, 64, 128, 256],
                 feature_size=48).to(device)

loss_function = DiceCELoss(include_background=True, to_onehot_y=True, softmax=True)
optimizer = Adam(model.parameters(), lr=1e-4)
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)

In [None]:
max_epochs = 50
val_interval = 2

In [None]:
for epoch in range(max_epochs):
    print(f"Epoch {epoch + 1}/{max_epochs}")
    
    model.train()
    epoch_loss = 0
    for batch_data in train_loader:
        mha_batch = batch_data["mha"].to(device)
        image_batch = batch_data["image"].to(device)
        optimizer.zero_grad()
        outputs = model(mha_batch)
        loss = loss_function(outputs, image_batch)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    
    print(f"Average Training Loss: {epoch_loss / len(train_loader):.4f}")
    
    if (epoch + 1) % val_interval == 0:
        model.eval()
        dice_scores = []
        with torch.no_grad():
            for val_data in val_loader:
                mha_batch = val_data["mha"].to(device)
                image_batch = val_data["image"].to(device)
                outputs = sliding_window_inference(mha_batch, (128, 128, 128), 4, model)
                dice_scores.append(dice_metric(outputs, image_batch).item())
        print(f"Validation Dice Score: {sum(dice_scores) / len(dice_scores):.4f}")

torch.save(model.state_dict(), "cis_unet.pth")
print("Training Complete. Model saved as cis_unet.pth.")