# An Evolutionary Chameleon Swarm Algorithm based Network for 3D Medical Image Segmentation 

In [1]:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device count: {torch.cuda.device_count()}")
if torch.cuda.is_available():
    print(f"CUDA device name: {torch.cuda.get_device_name(0)}")
    # Try a simple GPU operation
    x = torch.rand(5, 3).cuda()
    print(f"Tensor device: {x.device}")

PyTorch version: 2.5.1+cu121
CUDA available: False
CUDA device count: 0


In [2]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [2]:
!pip3 install git+https://github.com/Project-MONAI/MONAI#egg=monai

Collecting monai
  Cloning https://github.com/Project-MONAI/MONAI to /tmp/pip-install-jv6oygzi/monai_38ae9478a2504f138dcaaeb5da690b7a
  Running command git clone --filter=blob:none --quiet https://github.com/Project-MONAI/MONAI /tmp/pip-install-jv6oygzi/monai_38ae9478a2504f138dcaaeb5da690b7a
  Resolved https://github.com/Project-MONAI/MONAI to commit b58e883c887e0f99d382807550654c44d94f47bd
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: monai
  Building wheel for monai (pyproject.toml) ... [?25l[?25hdone
  Created wheel for monai: filename=monai-1.4.1rc1+46.gb58e883c-py3-none-any.whl size=2658851 sha256=306a58a4e8b35f8fde6e0bdce7251822e742341e9c6cbb3318d5e8df4e3fdad2
  Stored in directory: /tmp/pip-ephem-wheel-cache-c0wu8bty/wheels/ae/df/85/e1529c65c7b6d24f94fb29018f2e6a19809d416ee64044d71f
Successfully built monai
Install

In [3]:
!pip3 install monai[nibabel,skimage]



In [4]:
!pip install thop

Collecting thop
  Downloading thop-0.1.1.post2209072238-py3-none-any.whl.metadata (2.7 kB)
Downloading thop-0.1.1.post2209072238-py3-none-any.whl (15 kB)
Installing collected packages: thop
Successfully installed thop-0.1.1.post2209072238


In [5]:
"""importing modules"""
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd,
    RandAffined,
    Invertd,
)
import os
import csv
import glob
import time
import torch
import pandas
import shutil
import tempfile
import numpy as np
from thop import profile
from datetime import datetime
import matplotlib.pyplot as plt
from torchsummary import summary
from monai.losses import DiceLoss
from monai.config import print_config
from monai.networks.layers import Norm
from monai.apps import download_and_extract
from monai.handlers.utils import from_engine
from monai.utils import first, set_determinism
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric, MeanIoU, compute_roc_auc
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch

In [None]:
from monai.networks import blocks

def CONV(s, i, o, k, n, a, d):
    return blocks.Convolution(spatial_dims=s, in_channels=i, out_channels=o, kernel_size=k, norm=n, act=a, dropout=d)


#Block1
def UBB(s, i, o, k, n, a, d):
	return blocks.UnetBasicBlock(spatial_dims=s, in_channels=i, out_channels=o, kernel_size=k, stride=1, norm_name=n, act_name=a, dropout=d)

#Block2
def URB(s, i, o, k, n, a, d):
	return blocks.UnetResBlock(spatial_dims=s, in_channels=i, out_channels=o, kernel_size=k, stride=1, norm_name=n, act_name=a, dropout=d)

#Block3
def RB(s, i, o, k, n, a, d):
	return blocks.ResBlock(spatial_dims=s, in_channels=i, norm=n, kernel_size=k, act=a)

#Block4
def SASPP(s, i, o, k, n, a, d):
    return blocks.SimpleASPP(spatial_dims=s, in_channels=i, conv_out_channels=int(o/2), kernel_sizes=[k,k], dilations=[1,1], norm_type=n, acti_type=a, bias=False)

#Block5
def SERB(s, i, o, k, n, a, d):
	return blocks.SEResNetBottleneck(spatial_dims=s, inplanes=i, planes=4, groups=4, reduction=2, stride=1, downsample=None)
#outchannels=16

#Block6
def P3AB(s, i, o, k, n, a, d):
	return blocks.P3DActiConvNormBlock(in_channel=i, out_channel=o, kernel_size=k, padding=1, act_name=a, norm_name=n)

#Block7
def RRCB(s, i, o, k, n, a, d):
    return blocks.RegistrationResidualConvBlock(spatial_dims=s, in_channels=i, out_channels=i, num_layers=2, kernel_size=k)

#Block8
def UrPUB(s, i, o, k, n, a, d):
	return blocks.UnetrPrUpBlock(spatial_dims=s, in_channels=i, out_channels=o, num_layer=1, kernel_size=k, stride=1, upsample_kernel_size=1, norm_name=n, conv_block=True, res_block=True)



def  MaxAvg(k):
	return blocks.MaxAvgPool(spatial_dims=3, kernel_size=k, stride=2, padding=0, ceil_mode=True)



def  Ups(s, i, o, k, n, a, d):   #Upsample
	return blocks.FactorizedIncreaseBlock(in_channel=i, out_channel=o, spatial_dims=s, act_name=a, norm_name=n) 

def  Transp(s, i, o, k, n, a, d):  #conv3dtranspose
	return blocks.UpSample(spatial_dims=s, in_channels=i, out_channels=o, scale_factor=2, kernel_size=None, size=None, mode='deconv', pre_conv='default', interp_mode='LINEAR', align_corners=True, bias=True, apply_pad_pool=True)  


In [None]:

from torch import nn

class Block(nn.Module):
    def __init__(self, bn, inc, outc, kernel, norm, act, dr):
        super().__init__()
        self.bn = bn
        self.inc = inc
        self.outc = outc
        self.kernels = kernel
        self.norms = norm
        self.acts = act
        self.do = dr
        
        self.ubb = UBB(s=3, i=self.inc, o=self.outc, k=self.kernels, n=self.norms, a=self.acts, d=self.do)
        self.urb = URB(s=3, i=self.inc, o=self.outc, k=self.kernels, n=self.norms, a=self.acts, d=self.do)
        self.rb = RB(s=3, i=self.inc, o=self.outc, k=self.kernels, n=self.norms, a=self.acts, d=self.do)
        self.conv = CONV(s=3, i=self.inc, o=self.outc, k=self.kernels, n=self.norms, a=self.acts, d=self.do)
        self.saspp = SASPP(s=3, i=self.inc, o=self.outc, k=self.kernels, n=self.norms, a=self.acts, d=self.do) 
        self.serb = SERB(s=3, i=self.inc, o=self.outc, k=self.kernels, n=self.norms, a=self.acts, d=self.do)
        self.conv1 = CONV(s=3, i=16, o=self.outc, k=self.kernels, n=self.norms, a=self.acts, d=self.do)
        self.p3ab = P3AB(s=3, i=self.inc, o=self.outc, k=self.kernels, n=self.norms, a=self.acts, d=self.do)
        self.rrcb = RRCB(s=3, i=self.inc, o=self.outc, k=self.kernels, n=self.norms, a=self.acts, d=self.do)
        self.urpub = UrPUB(s=3, i=self.inc, o=self.outc, k=self.kernels, n=self.norms, a=self.acts, d=self.do)
        
        
    def forward(self, x):
        if self.bn == 0:
            return self.ubb(x)
        elif self.bn == 1:
            return self.urb(x)
        elif self.bn == 2:
            x = self.rb(x)            
            if self.inc == self.outc:
                return x
            return self.conv(x)
        elif self.bn == 3:
            return self.saspp(x)
        elif self.bn == 4:
            x = self.serb(x)
            return self.conv1(x)
        elif self.bn == 5:
            return self.p3ab(x)
        elif self.bn == 6:
            x = self.rrcb(x)
            if self.inc == self.outc:
                return x
            return self.conv(x)
        elif self.bn == 7:
            return self.urpub(x)
        else:
            raise "Invalid Block number"

In [None]:

from torch import nn

class CS3DEANet(nn.Module):
    def __init__(self, b, acts, norms, upsn, input_channels = 1, num_classes=2, initial_kernel=16):
        super().__init__()
        self.initial_kernel = initial_kernel
        self.input_channels = input_channels
        self.num_classes = num_classes
        self.b = b
        self.acts = acts
        self.norms = norms
        self.upsn = upsn
        
        self.r1 = Block(bn=self.b[0], inc=self.input_channels,   outc=self.initial_kernel,   kernel=3, norm=self.norms[0], act=self.acts[0], dr=None)
        self.r2 = Block(bn=self.b[1], inc=self.initial_kernel*2, outc=self.initial_kernel*2, kernel=3, norm=self.norms[1], act=self.acts[1], dr=None)
        self.r3 = Block(bn=self.b[2], inc=self.initial_kernel*4, outc=self.initial_kernel*4, kernel=3, norm=self.norms[2], act=self.acts[2], dr=None)
        self.r4 = Block(bn=self.b[3], inc=self.initial_kernel*8, outc=self.initial_kernel*8, kernel=3, norm=self.norms[3], act=self.acts[3], dr=None)
        self.r5 = Block(bn=self.b[4], inc=self.initial_kernel*4, outc=self.initial_kernel*4, kernel=3, norm=self.norms[4], act=self.acts[4], dr=None)
        self.r6 = Block(bn=self.b[5], inc=self.initial_kernel*2, outc=self.initial_kernel*2, kernel=3, norm=self.norms[5], act=self.acts[5], dr=None)
        self.r7 = Block(bn=self.b[6], inc=self.initial_kernel,   outc=self.num_classes,       kernel=3, norm=self.norms[6], act=self.acts[6], dr=None)
        
        self.m = MaxAvg(k=3)
        self.u1 = Ups(s=3, i=self.initial_kernel*8, o=self.initial_kernel*4, k=3, n=self.norms[4], a=self.acts[4], d=None)
        self.u2 = Ups(s=3, i=self.initial_kernel*4, o=self.initial_kernel*2, k=3, n=self.norms[5], a=self.acts[5], d=None)
        self.u3 = Ups(s=3, i=self.initial_kernel*2, o=self.initial_kernel,   k=3, n=self.norms[6], a=self.acts[6], d=None)
        
        self.t1 = Transp(s=3, i=self.initial_kernel*8, o=self.initial_kernel*4, k=3, n=self.norms[4], a=self.acts[4], d=None)
        self.t2 = Transp(s=3, i=self.initial_kernel*4, o=self.initial_kernel*2, k=3, n=self.norms[5], a=self.acts[5], d=None)
        self.t3 = Transp(s=3, i=self.initial_kernel*2, o=self.initial_kernel,   k=3, n=self.norms[6], a=self.acts[6], d=None)
        

    def forward(self, x):
        x1 = self.r1(x)
        x = self.m(x1)
        
        x2 = self.r2(x)
        x = self.m(x2)
        
        x3 = self.r3(x)
        x = self.m(x3)
        
        x4 = self.r4(x)
        if self.upsn[0] == 0:
            xu1 = self.t1(x4)
        elif self.upsn[0] == 1:
            xu1 = self.u1(x4)
        else:
            raise "Invalid upsampling "
        
        xu1x3 = xu1+x3
        x5 = self.r5(xu1x3)       
        if self.upsn[1] == 0:
            xu2 = self.t2(x5)
        elif self.upsn[1] == 1:
            xu2 = self.u2(x5)
        else:
            raise "Invalid upsampling "
            
        xu2x2 = xu2+x2
        x6 = self.r6(xu2x2)         
        if self.upsn[2] == 0:
            xu3 = self.t3(x6)
        elif self.upsn[2] == 1:
            xu3 = self.u3(x6)
        else:
            raise "Invalid upsampling "
        xu3x1 = xu3+x1
        x7 = self.r7(xu3x1)  
        
        return x7
    
    

In [None]:
from monai.losses import DiceLoss, DiceFocalLoss, FocalLoss, TverskyLoss, DiceCELoss
from torch.optim import Adam, Adadelta, Adamax, SGD

acts = ('memswish', 'relu', 'prelu', 'leakyrelu')
norms = ('BATCH', 'INSTANCE', ('GROUP', {'num_groups': 1}), '')



def loss_functions(l):
    if l==0:
        return DiceLoss(to_onehot_y=True, softmax=True)
    elif l == 1:
        return DiceFocalLoss(to_onehot_y=True, softmax=True)
    elif l == 2:
        return FocalLoss(to_onehot_y=True)
    elif l == 3:
        return TverskyLoss(to_onehot_y=True, softmax=True)
    else:
        print("Invalid loss function")


def optimizers(o, params,lr=1e-4):
    if o == 0:
        return Adam(params, lr=lr)
    elif o == 1:
        return Adadelta(params, lr=lr)
    elif o == 2:
        return Adamax(params, lr=lr)
    else:
        return SGD(params, lr=lr, momentum=0.9)

    
def todec(b):
    return int(''.join(map(lambda x: str(int(x)), b)), 2)


def encoding(ch):
    b, a, n, j = [], [], [], 0
    
    for i in range(7):  #0-21
        b.append(todec(ch[j:j+3]))
        j += 3        
    for i in range(7):  #21-35
        a.append(acts[todec(ch[j:j+2])])
        j += 2        
    for i in range(7):  #35-49
        n.append(norms[todec(ch[j:j+2])])
        j += 2        
    upsn = ch[j:j+3]  #49-52
    j += 3
    ol = list(map(todec, (ch[j:j+2], ch[j+2:j+4])))  #52-56  optimizer-52,53  loss function-54,55
    return b, a, n, upsn, ol[0], ol[1]


# Data loading and Preprocessing

In [None]:
""" Select dataset name here """
task_name =  "Spleen"  

In [None]:
!ls /kaggle/input/spleen/data/


In [None]:
"""Setup dataset path"""

root_dir = "/kaggle/working/"
print(root_dir)
dataset_name = task_name
img_size = 96
img_shape= (img_size, img_size, img_size)
# data_dir = "/kaggle/input/spleen/data/"+ task_name
data_dir = os.path.join("/kaggle/input/spleen/data", task_name)
# data_dir = os.path.join("/kaggle/input/spleen/data", task_name)

# train_images = sorted(glob.glob(os.path.join(data_dir, "train", "images", "*.nii")))
# train_labels = sorted(glob.glob(os.path.join(data_dir, "train", "labels", "*.nii")))


train_images = sorted(glob.glob(os.path.join(data_dir, "train", "images", "*.nii")))
train_labels = sorted(glob.glob(os.path.join(data_dir, "train", "labels", "*.nii")))

val_images = sorted(glob.glob(os.path.join(data_dir, "val", "images", "*.nii")))
val_label = sorted(glob.glob(os.path.join(data_dir, "val", "labels", "*.nii")))

test_images = sorted(glob.glob(os.path.join(data_dir, "test", "images", "*.nii")))
test_labels = sorted(glob.glob(os.path.join(data_dir, "test", "labels", "*.nii")))


train_files = [{"image": image_name, "label": label_name}
    for image_name, label_name in zip(train_images, train_labels)]

val_files = [{"image": image_name, "label": label_name}
    for image_name, label_name in zip(val_images, val_label)]

test_files = [{"image": image_name, "label": label_name}
    for image_name, label_name in zip(test_images, test_labels)]

print(len(train_files), len(val_files), len(test_files))



"""Set deterministic training for reproducibility"""

set_determinism(seed=0)
if dataset_name == "Spleen":
    max_intens = 164
    min_intens = -57
    train_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys=["image", "label"]),
            ScaleIntensityRanged(
                keys=["image"], a_min=min_intens, a_max=max_intens,
                b_min=0.0, b_max=1.0, clip=True,
            ),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            Orientationd(keys=["image", "label"], axcodes="RAS"),
            Spacingd(keys=["image", "label"], pixdim=(
                1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
            RandCropByPosNegLabeld(
                keys=["image", "label"],
                label_key="label",
                spatial_size=img_shape,
                pos=1,
                neg=1,
                num_samples=4,
                image_key="image",
                image_threshold=0,
            ),
        ]
    )
elif dataset_name == "Heart":    
    max_intens = 2033
    min_intens = 0
    train_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys=["image", "label"]),
            ScaleIntensityRanged(
                keys=["image"], a_min=min_intens, a_max=max_intens,
                b_min=0.0, b_max=1.0, clip=True,
            ),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            Orientationd(keys=["image", "label"], axcodes="RAS"),
            Spacingd(keys=["image", "label"], pixdim=(
                1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
            RandAffined(
                keys=['image', 'label'],
                mode=('bilinear', 'nearest'),
                prob=1.0, spatial_size=img_shape,
                rotate_range=(0, 0, np.pi/15),
                scale_range=(0.1, 0.1, 0.1)),
        ]
    )
    
    
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        ScaleIntensityRanged(
            keys=["image"], a_min=min_intens, a_max=max_intens,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(keys=["image", "label"], pixdim=(
            1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
    ]
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


now = 2
cr = 1.0
train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=cr, num_workers=now)
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=now)

val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=cr, num_workers=now)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=now)

test_ds = CacheDataset(data=test_files, transform=val_transforms, cache_rate=cr, num_workers=now)
test_loader = DataLoader(test_ds, batch_size=1, num_workers=now)

test_org_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Orientationd(keys=["image"], axcodes="RAS"),
        Spacingd(keys=["image"], pixdim=(
            1.5, 1.5, 2.0), mode="bilinear"),
        ScaleIntensityRanged(
            keys=["image"], a_min=min_intens, a_max=max_intens,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        CropForegroundd(keys=["image"], source_key="image"),
    ]
)

test_org_ds = Dataset(
    data=test_files, transform=test_org_transforms)
test_org_loader = DataLoader(test_org_ds, batch_size=1, num_workers=4)

post_transforms = Compose([
    Invertd(
        keys="pred",
        transform=test_org_transforms,
        orig_keys="image",
        meta_keys="pred_meta_dict",
        orig_meta_keys="image_meta_dict",
        meta_key_postfix="meta_dict",
        nearest_interp=False,
        to_tensor=True,
        device="cpu",
    ),
    AsDiscreted(keys="pred", argmax=True, to_onehot=2),
    AsDiscreted(keys="label", to_onehot=2),
])

file_name = '/kaggle/working/'+dataset_name+'.csv'
print(file_name)

if not os.path.exists(file_name):
    print("writing new")
    with open(file_name,'a') as fp:
        wr = csv.writer(fp, dialect='excel')
        wr.writerow(['Generation', 'Index', 'epoch', 'Val_Dice', 'Test_Dice',\
                     'Test_IOU', 'Ch', 'Params', 'Time', 'Start', 'End', 'Task'])
        
def create_dir(path):
    if not os.path.exists(path):
        print(f"{path} created")
        os.makedirs(path)
        

def calc_ff(model):
    input_tensor = torch.randn(1,1,img_size,img_size,img_size)

    # Measure FLOPs and Parameters
    input_flops = input_tensor.to(device)
    macs, params = profile(model, inputs=(input_flops,))
    print(f"FLOPs: {macs / 1e9} Gmacs")
    print(f"Parameters: {params / 1e6} M")
    return macs / 1e9, params / 1e6

In [None]:
for batch_data in test_loader:
    inputs, labels = (
        batch_data["image"].to(device),
        batch_data["label"].to(device),
    )
    break
print(inputs.shape, labels.shape)

ip = np.squeeze(inputs)
op = np.squeeze(labels)

# check_ds = Dataset(data=val_files, transform=val_transforms)
# check_loader = DataLoader(check_ds, batch_size=1)
# check_data = first(check_loader)
# image, label = (check_data["image"][0][0], check_data["label"][0][0])
print(f"image shape: {ip.shape}, label shape: {op.shape}")
# plot the slice [:, :, 80]
plt.figure("check", (12, 6))
plt.subplot(1, 2, 1)
plt.title("image")
plt.imshow(ip[:, :, 80], cmap="gray")
plt.subplot(1, 2, 2)
plt.title("label")
plt.imshow(op[:, :, 80])
plt.show()

# Fitness evaluation

In [None]:
""" Execute a typical PyTorch training process"""
max_epochs = 150
sw_batch_size = 1
create_dir(root_dir)

def cal_params(model):
    return sum(p.numel() for p in model.parameters())

def runModel(g, ind, ch, file_name='results_'+dataset_name+'.csv'):    
    print("\n\n",g,ind)
    
    ptcl = ' '.join(map(str, ch))
    f = pandas.read_csv(file_name)
    if ptcl in f['Ch'].values:
        print("already found")
        f1s = float(np.max(f[f['Ch'] == ptcl]['Test_Dice']))
        return f1s

    b, a, n, upsn, o, l = encoding(ch)
    print(b, a, n, upsn, o, l)
    model = CS3DEANet(b, a, n, upsn, input_channels=1, num_classes=2, initial_kernel=16)
    Total_params = cal_params(model)
    
    # model
    model = model.to(device)
    optimizer = optimizers(o, model.parameters(), lr = 0.00035)
    loss_function = loss_functions(l)
    
    # Setup checkpoint paths
    checkpoint_dir = os.path.join(root_dir, "checkpoints")
    os.makedirs(checkpoint_dir, exist_ok=True)
    checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_{dataset_name}_{g}_{ind}.pth")
    best_model_path = os.path.join(root_dir, f"best_metric_model_{dataset_name}_{g}_{ind}.pth")
    
    # Initialize training state
    start_epoch = 0
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = []
    metric_values = []
    
    # Try to load checkpoint if exists
    if os.path.exists(checkpoint_path):
        print(f"Loading checkpoint from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        best_metric = checkpoint['best_metric']
        best_metric_epoch = checkpoint['best_metric_epoch']
        epoch_loss_values = checkpoint['epoch_loss_values']
        metric_values = checkpoint['metric_values']
        print(f"Resuming from epoch {start_epoch}")
    
    dice_metric = DiceMetric(include_background=False, reduction="mean")
    mean_iou = MeanIoU(include_background=False, reduction="mean")

    val_interval = 2
    post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)])
    post_label = Compose([AsDiscrete(to_onehot=2)])
    start = time.time()
    
    for epoch in range(start_epoch, max_epochs):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{max_epochs}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = (
                batch_data["image"].to(device),
                batch_data["label"].to(device),
            )       
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            print(
                f"{step}/{len(train_ds) // train_loader.batch_size}, "
                f"train_loss: {loss.item():.4f}")
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                for val_data in val_loader:
                    val_inputs, val_labels = (
                        val_data["image"].to(device),
                        val_data["label"].to(device),
                    )
                    roi_size = (160, 160, 160)

                    val_outputs = sliding_window_inference(
                        val_inputs, roi_size, sw_batch_size, model)
                    val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
                    val_labels = [post_label(i) for i in decollate_batch(val_labels)]
                    dice_metric(y_pred=val_outputs, y=val_labels)

                metric = dice_metric.aggregate().item()
                dice_metric.reset()

                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), best_model_path)
                    print(f"saved new best metric model at {best_model_path}")
                print(
                    f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                    f"\nbest mean dice: {best_metric:.4f} "
                    f"at epoch: {best_metric_epoch}"
                )
        
        # Save checkpoint every 10 epochs
        if (epoch + 1) % 10 == 0:
            checkpoint = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_metric': best_metric,
                'best_metric_epoch': best_metric_epoch,
                'epoch_loss_values': epoch_loss_values,
                'metric_values': metric_values,
            }
            torch.save(checkpoint, checkpoint_path)
            print(f"Saved checkpoint at epoch {epoch + 1}")
                
        if epoch > 60 and best_metric < 0.15 and metric < 0.15:
            print("Stopping training after 20th epoch as metric not increasing")
            break

    print(
        f"train completed, best_metric: {best_metric:.4f} "
        f"at epoch: {best_metric_epoch}")

    end = time.time()
    
    """Evaluation on test dataset"""
    dice_metric = DiceMetric(include_background=False, reduction="mean")
    mean_iou = MeanIoU(include_background=False, reduction="mean")

    model.load_state_dict(torch.load(best_model_path))
    model.eval()

    with torch.no_grad():
        for test_data in test_org_loader:
            test_inputs = test_data["image"].to(device)
            roi_size = (160, 160, 160)
            test_data["pred"] = sliding_window_inference(
                test_inputs, roi_size, sw_batch_size, model)
            test_data = [post_transforms(i) for i in decollate_batch(test_data)]
            test_outputs, test_labels = from_engine(["pred", "label"])(test_data)
            dice_metric(y_pred=test_outputs, y=test_labels)
            mean_iou(y_pred=test_outputs, y=test_labels)

        dice_metric_org = dice_metric.aggregate().item()
        iou_metric_org = mean_iou.aggregate().item()
        dice_metric.reset()
        mean_iou.reset()

    print("Dice Metric on Test dataset: ", dice_metric_org)
    print("IOU Metric on Test dataset: ", iou_metric_org)

    l=[]
    l.extend([g, ind, best_metric_epoch, best_metric, dice_metric_org,\
              iou_metric_org, ptcl, Total_params, int((end-start)/60), \
              datetime.fromtimestamp(start).strftime('%Y-%m-%d %H:%M:%S'),\
              datetime.fromtimestamp(end).strftime('%Y-%m-%d %H:%M:%S'), dataset_name])
    with open(file_name,'a') as fp:
        wr = csv.writer(fp, dialect='excel')
        wr.writerow(l)

    return dice_metric_org

# CS3DEA-Net

In [None]:

# Main parameters of CSA
searchAgents = 20
dim = 56
ub = [1] * dim
lb = [0] * dim
rho = 1.0
p1 = 2.0
p2 = 2.0
c1 = 2.0
c2 = 1.8
gamma = 2.0
alpha = 4.0
beta = 3.0
iteMax = 30

from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
from enum import Enum
from math import atan, sqrt, tanh, erf, e, pi
from copy import copy

def initialization(Particles_no, dim):
    C = np.random.randint(0,2,(Particles_no, dim))               
    return C


# transfer functions V-Shape and S-Shape
class TransferFuncion(Enum):
    """enumeration value of transfer function
    """
    V1 = 1
    V2 = 2
    V3 = 3
    V4 = 4
    S1 = 5
    S2 = 6
    S3 = 7
    S4 = 8
    
    
def transfer_function(transfer_function_type: TransferFuncion, a: float) -> float:
    """8 transfer functions (4 V-Shapes and 4 S-Shapes) that map values to the [0, 1] interval

     Args:
         transfer_function_type (TransferFuncion): transfer function type
         a (float): any real value

     Returns:
         float: the value mapped to the interval [0, 1] after the transfer function
     """
    if transfer_function_type == TransferFuncion.V1:
        return abs((2/pi)*atan((pi/2)*a))
    
    elif transfer_function_type == TransferFuncion.V2:
        return abs(tanh(a))
        
    elif transfer_function_type == TransferFuncion.V3:
        return abs(a/(sqrt(1+a**2)))
    
    elif transfer_function_type == TransferFuncion.V4:
        return abs(erf((sqrt(pi)/2)*a))
    
    elif transfer_function_type == TransferFuncion.S1:
        return 1/(1+e**(-a))
    
    elif transfer_function_type == TransferFuncion.S2:
        return 1/(1+e**(-2*a))
    
    elif transfer_function_type == TransferFuncion.S3:
        return 1/(1+(e**(-a/2)))
    
    elif transfer_function_type == TransferFuncion.S4:
        return 1/(1+(e**(-a/3)))
    
    else:
        print('[ERROR] Unknow transfer function type, Please use V1~V4 or S1~S4')
        exit
        
        
def arr2bin(arr, tf):
    for i in range(arr.shape[0]):
        if transfer_function(tf, arr[i]) >= np.random.rand():
            arr[i] = 0
        else:
            arr[i] = 1
    return arr


def  fobj(g, ind, particle):
    particle = list(map(int, particle)) 
    # try:
    fitness_score = runModel(g, ind, particle, file_name)
    # except Exception as e:
    #     print(e)
    #     fitness_score = 0
    print(fitness_score)
    return fitness_score


tf=TransferFuncion.V2
chameleonPositions = initialization(searchAgents,dim)
fit = np.zeros((searchAgents,1))

for i in range(searchAgents):
    fit[i] = fobj(0, i, chameleonPositions[i])

    

print(fit, chameleonPositions)

## Convergence curve
cg_curve = np.zeros(iteMax)

## Initalize the parameters of CSA
fitness = fit.copy()
# print(fit)
fmin0,index = np.amax(fit), np.argmax(fit)
chameleonBestPosition = chameleonPositions.copy()
cg_curve[0] = fmin0

# chameleonPositions, fit
gPosition = chameleonPositions[index]
v = 0.1 * chameleonBestPosition
v0 = 0.0 * v

print(fmin0,index)



# CSA loop starts from here
for t in range(1,iteMax):
    a = 0.5 #2590 * (1 - np.exp(- np.log(t)))
    omega = (1 - (t / iteMax)) ** (rho * np.sqrt(t / iteMax))
    p1 = 2 * np.exp(- 2 * (t / iteMax) ** 2)
    p2 = 2 / (1 + np.exp((- t + iteMax / 2) / 100))
    mu = gamma * np.exp(- (alpha * t / iteMax) ** beta)
    ch = np.ceil(searchAgents * np.random.rand(searchAgents)-1)
    
    ## Update the position of CSA (Exploration)
    for i in range(searchAgents):
        if np.random.rand() >= 0.1:
            chameleonPositions[i] = chameleonPositions[i] + p1 * (chameleonBestPosition[int(ch[i])] - chameleonPositions[i]) * np.random.rand() + p2 * (gPosition - chameleonPositions[i]) * np.random.rand()
        else:
            for j in range(dim):
                chameleonPositions[i][j] = gPosition[j] + mu * ((ub[j] - lb[j]) * np.random.rand() + lb[j]) * np.sign(np.random.rand() - 0.5)

    for i in range(searchAgents):
        v[i] = omega * v[i] + p1 * (chameleonBestPosition[i] - chameleonPositions[i]) * np.random.rand() + p2 * (gPosition - chameleonPositions[i]) * np.random.rand()
        chameleonPositions[i] = chameleonPositions[i] + (v[i] ** 2 - v0[i] ** 2) / (2 * a)
    v0 = v
    
    for i in range(searchAgents):
        chameleonPositions[i] = arr2bin(arr=chameleonPositions[i], tf=tf)    
        fit[i] = fobj(t, i, chameleonPositions[i])
        print(fit[i], fitness[i])
        if fit[i] > fitness[i]:
            print("updated")
            chameleonBestPosition[i] = chameleonPositions[i]
            fitness[i] = fit[i]
            
    ## Evaluate the new positions
    fmin,index = np.amax(fitness), np.argmax(fitness)
    # Updating gPosition and best fitness
    if fmin > fmin0:
        gPosition = chameleonBestPosition[index]
        fmin0 = fmin
    cg_curve[t] = fmin0
    
    print("\n\nAt iteration  ", t)
    print(chameleonBestPosition, fitness)

# Top discovered models

In [None]:
if dataset_name == "Spleen":
    best_ch = [0 ,0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0]
    save_path = "/kaggle/working/best_metric_model_Spleen_1_15.pth"
elif dataset_name == "Heart":
    best_ch = [0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1]
    save_path = "CS3DEA_Net_best_Heart.pth" 
print(dataset_name)


b, a, n, upsn, o, l = encoding(best_ch)
cs3dea = CS3DEANet(b, a, n, upsn, input_channels=1, num_classes=2, initial_kernel=16)
cs3dea = cs3dea.to(device)
cs3dea.load_state_dict(torch.load(save_path))
# cs3dea.load_state_dict(torch.load(save_path, map_location=torch.device('cpu')))
# cs3dea.load_state_dict(torch.load(save_path, map_location=torch.device('cpu')))


cs3dea.eval()

calc_ff(cs3dea)
summary(cs3dea, (1,img_size,img_size,img_size))
# optimizer = optimizers(o, cs3dea.parameters(), lr = 0.00035)
# loss_function = loss_functions(l)
# runModel("CS3DEA_Net", cs3dea, optimizer, loss_function)

# Evaluation on test dataset

In [None]:

dice_metric = DiceMetric(include_background=False, reduction="mean")
mean_iou = MeanIoU(include_background=False, reduction="mean")

# cs3dea.load_state_dict(torch.load(save_path))
# cs3dea.eval()

with torch.no_grad():
    for test_data in test_org_loader:
        test_inputs = test_data["image"].to(device)
        roi_size = (160, 160, 160)
        test_data["pred"] = sliding_window_inference(
            test_inputs, roi_size, sw_batch_size, cs3dea)
        test_data = [post_transforms(i) for i in decollate_batch(test_data)]
        test_outputs, test_labels = from_engine(["pred", "label"])(test_data)
        dice_metric(y_pred=test_outputs, y=test_labels)
        mean_iou(y_pred=test_outputs, y=test_labels)

    # aggregate the final mean dice result
    dice_metric_org = dice_metric.aggregate().item()
    iou_metric_org = mean_iou.aggregate().item()
    # reset the status for next validation round
    dice_metric.reset()
    mean_iou.reset()

print("Dice Metric on Test dataset: ", dice_metric_org)
print("IOU Metric on Test dataset: ", iou_metric_org)


# Test image results

In [None]:

with torch.no_grad():
    for i, test_data in enumerate(test_loader):
        roi_size = (160, 160, 160)
        sw_batch_size = 4
        test_outputs = sliding_window_inference(
            test_data["image"].to(device), roi_size, sw_batch_size, cs3dea
        )
        # plot the slice [:, :, 80]
        plt.figure("check", (18, 6))
        plt.subplot(1, 3, 1)
        plt.title(f"image {i}")
        plt.imshow(test_data["image"][0, 0, :, :, 80], cmap="gray")
        plt.subplot(1, 3, 2)
        plt.title(f"label {i}")
        plt.imshow(test_data["label"][0, 0, :, :, 80])
        plt.subplot(1, 3, 3)
        plt.title(f"output {i}")
        plt.imshow(torch.argmax(
            test_outputs, dim=1).detach().cpu()[0, :, :, 80])
        plt.show()

In [None]:
## saving results
model_name="CS3DEA_Net"
res_save_path = '/kaggle/working/'+dataset_name+'_top_model'
create_dir(res_save_path)

with torch.no_grad():
    for i, test_data in enumerate(test_loader):
        roi_size = (160, 160, 160)
        sw_batch_size = 4
        test_outputs = sliding_window_inference(
            test_data["image"].to(device), roi_size, sw_batch_size, cs3dea
        )
        plt.imsave(res_save_path+"/"+str(i)+"_image.png", test_data["image"][0, 0, :, :, 80], cmap="gray")
        plt.imsave(res_save_path+"/"+str(i)+"_label.png", test_data["label"][0, 0, :, :, 80], cmap="gray")
        plt.imsave(res_save_path+"/"+str(i)+"_output.png", 
                   torch.argmax(test_outputs, dim=1).detach().cpu()[0, :, :, 80], cmap="gray")
        print(f"{i} saved in res_save_path")

In [None]:

import os
import glob
import random
import numpy as np
import torch
from torch.optim import Adam
from torch.amp import GradScaler, autocast
from torch.optim.lr_scheduler import ReduceLROnPlateau

from monai.data import CacheDataset, Dataset, DataLoader
from monai.inferers import sliding_window_inference
from monai.transforms import (
    LoadImaged,
    EnsureChannelFirstd,
    Spacingd,
    ScaleIntensityd,
    CropForegroundd,
    CenterSpatialCropd,
    RandFlipd,
    RandRotate90d,
    RandZoomd,
    RandAdjustContrastd,
    Rand3DElasticd,
    RandAffined,
    ToTensord,
)
from monai.networks.nets import UNet
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric


In [None]:
data_root = "/kaggle/input/spleen/data/Spleen"
checkpoint_dir = "./checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

max_epochs = 100
train_batch_size = 2
val_batch_size = 1
learning_rate = 1e-4
patience = 10
roi_size = (96, 96, 96)

# ----------------------------
# Build dataset paths
# ----------------------------
def make_dataset(phase):
    img_dir = os.path.join(data_root, phase, "images")
    lbl_dir = os.path.join(data_root, phase, "label")
    img_paths = sorted(glob.glob(os.path.join(img_dir, "*.nii*")))
    lbl_paths = sorted(glob.glob(os.path.join(lbl_dir, "*.nii*")))
    assert img_paths and lbl_paths, f"No files for {phase}"
    return [{"image": i, "label": l} for i, l in zip(img_paths, lbl_paths)]



In [None]:
train_files = make_dataset("train")
val_files = make_dataset("val")
test_files = make_dataset("test")

# ----------------------------
# Transforms
# ----------------------------
train_transforms = [
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    Spacingd(keys=["image", "label"], pixdim=(1.5,1.5,2.0), mode=("bilinear","nearest")),
    ScaleIntensityd(keys=["image"]),
    CropForegroundd(keys=["image", "label"], source_key="image"),
    CenterSpatialCropd(keys=["image", "label"], roi_size=roi_size),
    RandFlipd(keys=["image", "label"], spatial_axis=0, prob=0.5),
    RandRotate90d(keys=["image", "label"], prob=0.5, max_k=3),
    RandZoomd(keys=["image", "label"], prob=0.3, min_zoom=0.8, max_zoom=1.2),
    RandAdjustContrastd(keys=["image"], prob=0.3, gamma=(0.7,1.5)),
    ToTensord(keys=["image", "label"]),
]

val_transforms = [
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    Spacingd(keys=["image", "label"], pixdim=(1.5,1.5,2.0), mode=("bilinear","nearest")),
    ScaleIntensityd(keys=["image"]),
    CropForegroundd(keys=["image", "label"], source_key="image"),
    CenterSpatialCropd(keys=["image", "label"], roi_size=roi_size),
    ToTensord(keys=["image", "label"]),
]

# ----------------------------
# Loaders
# ----------------------------
train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0)
val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0)
train_loader = DataLoader(train_ds, batch_size=train_batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=val_batch_size, shuffle=False, num_workers=2)

# ----------------------------
# Model setup
# ----------------------------
model = UNet(spatial_dims=3, in_channels=1, out_channels=2,
             channels=(32,64,128,256,512), strides=(2,2,2,2), num_res_units=2, norm='batch').to(device)
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = Adam(model.parameters(), lr=learning_rate)
scheduler = ReduceLROnPlateau(optimizer, 'max', patience=5, factor=0.5)
scaler = GradScaler()
dice_metric = DiceMetric(include_background=True, reduction="mean")

# ----------------------------
# Training & Validation
# ----------------------------
best_metric = 0.0
no_improve = 0

for epoch in range(1, max_epochs+1):
    model.train()
    epoch_loss = 0.0
    for batch in train_loader:
        imgs, segs = batch['image'].to(device), batch['label'].to(device)
        optimizer.zero_grad()
        with autocast(device_type='cuda'):
            outputs = model(imgs)
            loss = loss_function(outputs, segs)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        epoch_loss += loss.item()
    epoch_loss /= len(train_loader)
    print(f"Epoch {epoch}, Avg Loss: {epoch_loss:.4f}")

    # Validation phase
    model.eval()
    dice_metric.reset()
    with torch.no_grad():
        for val_batch in val_loader:
            val_imgs, val_segs = val_batch['image'].to(device), val_batch['label'].to(device)
            val_preds = sliding_window_inference(val_imgs, roi_size, 4, model)
            dice_metric(y_pred=val_preds, y=val_segs)
    mean_dice = dice_metric.aggregate().item()
    print(f"Validation Dice: {mean_dice:.4f}")
    scheduler.step(mean_dice)

    if mean_dice > best_metric:
        best_metric = mean_dice
        torch.save(model.state_dict(), os.path.join(checkpoint_dir, f"best_model_epoch{epoch}_dice{best_metric:.4f}.pth"))
        no_improve = 0
        print("Saved best model")
    else:
        no_improve += 1
        if no_improve >= patience:
            print("Early stopping")
            break

print(f"Training finished. Best Dice = {best_metric:.4f}")

# ----------------------------
# Test Phase
# ----------------------------
model.load_state_dict(torch.load(os.path.join(checkpoint_dir, f"best_model_epoch{epoch-no_improve}_dice{best_metric:.4f}.pth")))
model.eval()
test_ds = Dataset(data=test_files, transform=val_transforms)
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False)

dice_metric.reset()
with torch.no_grad():
    for test_batch in test_loader:
        test_imgs, test_segs = test_batch['image'].to(device), test_batch['label'].to(device)
        test_preds = sliding_window_inference(test_imgs, roi_size, 4, model)
        dice_metric(y_pred=test_preds, y=test_segs)

final_test_dice = dice_metric.aggregate().item()
print(f"Test Mean Dice: {final_test_dice:.4f}")


In [None]:

import os
import glob
import random
import numpy as np
import torch
from torch.optim import Adam
from torch.amp import GradScaler, autocast
from torch.optim.lr_scheduler import ReduceLROnPlateau

from monai.data import CacheDataset, Dataset, DataLoader
from monai.inferers import sliding_window_inference
from monai.transforms import (
    LoadImaged,
    EnsureChannelFirstd,
    Spacingd,
    ScaleIntensityd,
    CropForegroundd,
    SpatialPadd,
    CenterSpatialCropd,
    RandFlipd,
    RandRotate90d,
    RandZoomd,
    RandAdjustContrastd,
    Rand3DElasticd,
    RandAffined,
    ToTensord,
)
from monai.networks.nets import UNet
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric

data_root = "/kaggle/input/spleen/data/Spleen"
checkpoint_dir = "/kaggle/working/checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

max_epochs = 300
train_batch_size = 2
val_batch_size = 1
learning_rate = 5e-5
patience = 20
roi_size = (128, 128, 128)

# ----------------------------
# Build dataset paths
# ----------------------------
def make_dataset(phase):
    img_dir = os.path.join(data_root, phase, "images")
    lbl_dir = os.path.join(data_root, phase, "labels")
    img_paths = sorted(glob.glob(os.path.join(img_dir, "*.nii*")))
    lbl_paths = sorted(glob.glob(os.path.join(lbl_dir, "*.nii*")))
    assert img_paths and lbl_paths, f"No files for {phase}"
    return [{"image": i, "label": l} for i, l in zip(img_paths, lbl_paths)]

train_files = make_dataset("train")
val_files = make_dataset("val")
test_files = make_dataset("test")

# ----------------------------
# Transforms
# ----------------------------
train_transforms = [
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    Spacingd(keys=["image", "label"], pixdim=(1.5,1.5,2.0), mode=("bilinear","nearest")),
    ScaleIntensityd(keys=["image"]),
    CropForegroundd(keys=["image", "label"], source_key="image"),
    SpatialPadd(keys=["image", "label"], spatial_size=(128,128,128)),
    CenterSpatialCropd(keys=["image", "label"], roi_size=(128,128,128)),
    RandFlipd(keys=["image", "label"], spatial_axis=[0,1,2], prob=0.5),
    RandRotate90d(keys=["image", "label"], prob=0.5, max_k=3),
    RandZoomd(keys=["image", "label"], prob=0.3, min_zoom=0.8, max_zoom=1.2),
    RandAdjustContrastd(keys=["image"], prob=0.3, gamma=(0.7,1.5)),
    Rand3DElasticd(keys=["image", "label"], sigma_range=(5,8), magnitude_range=(100,200), prob=0.3),
    RandAffined(keys=["image", "label"], rotate_range=(0.1,0.1,0.1), prob=0.5),
    ToTensord(keys=["image", "label"]),
]

val_transforms = [
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    Spacingd(keys=["image", "label"], pixdim=(1.5,1.5,2.0), mode=("bilinear","nearest")),
    ScaleIntensityd(keys=["image"]),
    CropForegroundd(keys=["image", "label"], source_key="image"),
    SpatialPadd(keys=["image", "label"], spatial_size=(128,128,128)),
    CenterSpatialCropd(keys=["image", "label"], roi_size=(128,128,128)),
    ToTensord(keys=["image", "label"]),
]


# ----------------------------
# Loaders
# ----------------------------
train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0)
val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0)
train_loader = DataLoader(train_ds, batch_size=train_batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=val_batch_size, shuffle=False, num_workers=2)

# ----------------------------
# Model setup
# ----------------------------
model = UNet(spatial_dims=3, in_channels=1, out_channels=2,
             channels=(32,64,128,256,512,1024), strides=(2,2,2,2,2), num_res_units=2, norm='batch').to(device)
loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
optimizer = Adam(model.parameters(), lr=learning_rate)
scheduler = ReduceLROnPlateau(optimizer, 'max', patience=10, factor=0.5)
scaler = GradScaler()
dice_metric = DiceMetric(include_background=True, reduction="mean")

# ----------------------------
# Training & Validation
# ----------------------------
best_metric = 0.0
no_improve = 0

for epoch in range(1, max_epochs+1):
    model.train()
    epoch_loss = 0.0
    for batch in train_loader:
        imgs, segs = batch['image'].to(device), batch['label'].to(device)
        optimizer.zero_grad()
        with autocast(device_type='cuda'):
            outputs = model(imgs)
            loss = loss_function(outputs, segs)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        epoch_loss += loss.item()
    epoch_loss /= len(train_loader)
    print(f"Epoch {epoch}, Avg Loss: {epoch_loss:.4f}")

    model.eval()
    dice_metric.reset()
    with torch.no_grad():
        for val_batch in val_loader:
            val_imgs, val_segs = val_batch['image'].to(device), val_batch['label'].to(device)
            val_preds = sliding_window_inference(val_imgs, roi_size, 4, model)
            dice_metric(y_pred=val_preds, y=val_segs)
    mean_dice = dice_metric.aggregate().item()
    print(f"Validation Dice: {mean_dice:.4f}")
    scheduler.step(mean_dice)

    if mean_dice > best_metric:
        best_metric = mean_dice
        torch.save(model.state_dict(), os.path.join(checkpoint_dir, f"best_model_epoch{epoch}_dice{best_metric:.4f}.pth"))
        no_improve = 0
        print("Saved best model")
    else:
        no_improve += 1
        if no_improve >= patience:
            print("Early stopping")
            break

print(f"Training finished. Best Dice = {best_metric:.4f}")

# ----------------------------
# Test Phase
# ----------------------------
model.load_state_dict(torch.load(os.path.join(checkpoint_dir, f"best_model_epoch{epoch-no_improve}_dice{best_metric:.4f}.pth")))
model.eval()
test_ds = Dataset(data=test_files, transform=val_transforms)
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False)

dice_metric.reset()
with torch.no_grad():
    for test_batch in test_loader:
        test_imgs, test_segs = test_batch['image'].to(device), test_batch['label'].to(device)
        test_preds = sliding_window_inference(test_imgs, roi_size, 4, model)
        dice_metric(y_pred=test_preds, y=test_segs)

final_test_dice = dice_metric.aggregate().item()
print(f"Test Mean Dice: {final_test_dice:.4f}")


In [None]:
import os
import glob
import random
import numpy as np
import torch
import torch.nn.functional as F
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.cuda.amp import autocast, GradScaler
import matplotlib.pyplot as plt

# --- MONAI Imports ---
from monai.data import CacheDataset, Dataset, DataLoader
from monai.inferers import sliding_window_inference
from monai.transforms import (
    LoadImaged, EnsureChannelFirstd, Spacingd, Orientationd,
    ScaleIntensityRanged, CropForegroundd, RandCropByPosNegLabeld,
    RandFlipd, RandRotate90d, RandShiftIntensityd,
    EnsureTyped, Compose
)
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric
from monai.config import print_config
from monai.utils import set_determinism

# --- Configuration ---
base_data_dir = "/kaggle/input/spleen"
data_root = os.path.join(base_data_dir, "data", "Spleen")
checkpoint_dir = "/kaggle/working/checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

# reproducibility
seed = 42
set_determinism(seed=seed)
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print_config()

# hyperparameters
max_epochs = 100
train_batch_size = 2
val_batch_size = 1
test_batch_size = 1
learning_rate = 5e-5
patience = 20
train_roi_size = (96, 96, 96)
infer_roi_size = (160, 160, 160)
num_samples_per_volume = 4
save_interval = 15  # epochs between periodic checkpoints

# --- Dataset Listing ---
def make_dataset(phase):
    phase_dir = os.path.join(data_root, phase)
    img_paths = sorted(glob.glob(os.path.join(phase_dir, "images", "*.nii*")))
    lbl_paths = sorted(glob.glob(os.path.join(phase_dir, "labels", "*.nii*")))
    if not img_paths or not lbl_paths or len(img_paths) != len(lbl_paths):
        raise RuntimeError(f"Data error in phase '{phase}'.")
    return [{"image": i, "label": l} for i, l in zip(img_paths, lbl_paths)]

train_files = make_dataset("train")
val_files = make_dataset("val")
test_files = make_dataset("test")
print(f"Train/Val/Test: {len(train_files)}/{len(val_files)}/{len(test_files)}")

# --- Transforms ---
train_transforms = Compose([
    LoadImaged(keys=["image","label"]), EnsureChannelFirstd(keys=["image","label"]),
    Orientationd(keys=["image","label"], axcodes="RAS"),
    Spacingd(keys=["image","label"], pixdim=(1.5,1.5,2.0), mode=("bilinear","nearest")),
    ScaleIntensityRanged(keys=["image"], a_min=-57,a_max=164,b_min=0.0,b_max=1.0,clip=True),
    CropForegroundd(keys=["image","label"], source_key="image"),
    RandCropByPosNegLabeld(keys=["image","label"], label_key="label",
                             spatial_size=train_roi_size, pos=1, neg=1, num_samples=num_samples_per_volume),
    RandFlipd(keys=["image","label"], spatial_axis=[0], prob=0.1),
    RandFlipd(keys=["image","label"], spatial_axis=[1], prob=0.1),
    RandFlipd(keys=["image","label"], spatial_axis=[2], prob=0.1),
    RandRotate90d(keys=["image","label"], prob=0.1, max_k=3),
    RandShiftIntensityd(keys=["image"], offsets=0.1, prob=0.5),
    EnsureTyped(keys=["image","label"], track_meta=False),
])
val_transforms = Compose([
    LoadImaged(keys=["image","label"]), EnsureChannelFirstd(keys=["image","label"]),
    Orientationd(keys=["image","label"], axcodes="RAS"),
    Spacingd(keys=["image","label"], pixdim=(1.5,1.5,2.0), mode=("bilinear","nearest")),
    ScaleIntensityRanged(keys=["image"], a_min=-57,a_max=164,b_min=0.0,b_max=1.0,clip=True),
    CropForegroundd(keys=["image","label"], source_key="image"),
    EnsureTyped(keys=["image","label"], track_meta=True),
])

# --- DataLoaders ---
train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=4)
val_ds   = CacheDataset(data=val_files,   transform=val_transforms,   cache_rate=1.0, num_workers=4)
test_ds  = Dataset(data=test_files,      transform=val_transforms)

train_loader = DataLoader(train_ds, batch_size=train_batch_size, shuffle=True,  num_workers=4, pin_memory=torch.cuda.is_available())
val_loader   = DataLoader(val_ds,   batch_size=val_batch_size,   shuffle=False, num_workers=4, pin_memory=torch.cuda.is_available())
test_loader  = DataLoader(test_ds,  batch_size=test_batch_size,  shuffle=False, num_workers=4, pin_memory=torch.cuda.is_available())

# --- Model & Optimizer ---
model = UNet(spatial_dims=3, in_channels=1, out_channels=2,
             channels=(32,64,128,256,512), strides=(2,2,2,2), num_res_units=2, norm=Norm.BATCH).to(device)
loss_fn   = DiceCELoss(to_onehot_y=True, softmax=True)
optimizer = Adam(model.parameters(), lr=learning_rate)
scheduler = ReduceLROnPlateau(optimizer, mode='max', patience=10, factor=0.5, verbose=True)
scaler    = GradScaler()
dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)

# --- Training & Validation ---
best_metric, best_epoch, epochs_no_improve = -1.0, -1, 0
for epoch in range(1, max_epochs+1):
    # Training
    model.train(); train_loss = 0.0
    for i, batch in enumerate(train_loader, 1):
        imgs, lbls = batch['image'].to(device), batch['label'].to(device)
        optimizer.zero_grad()
        with autocast(enabled=torch.cuda.is_available()):
            preds = model(imgs)
            loss = loss_fn(preds, lbls)
        scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update()
        train_loss += loss.item()
    print(f"Epoch {epoch}/{max_epochs} - Train loss: {train_loss/i:.4f}")

    # Validation
    model.eval(); dice_metric.reset()
    with torch.no_grad():
        for batch in val_loader:
            imgs, lbls = batch['image'].to(device), batch['label'].to(device)
            outputs = sliding_window_inference(
                imgs, roi_size=infer_roi_size, sw_batch_size=4,
                predictor=model, overlap=0.5, mode="gaussian"
            )
            probs = torch.softmax(outputs, dim=1)
            seg = torch.argmax(probs, dim=1)
            onehot_pred = F.one_hot(seg, num_classes=2).permute(0,4,1,2,3).float()
            gt = lbls.squeeze(1).long()
            onehot_gt = F.one_hot(gt, num_classes=2).permute(0,4,1,2,3).float()
            for b in range(onehot_pred.shape[0]):
                dice_metric(y_pred=[onehot_pred[b]], y=[onehot_gt[b]])
    val_dice = dice_metric.aggregate().item()
    print(f"  Val Dice: {val_dice:.4f}")
    scheduler.step(val_dice)

    # Save best model
    if val_dice > best_metric:
        best_metric, best_epoch, epochs_no_improve = val_dice, epoch, 0
        if 'best_ckpt' in globals() and os.path.exists(best_ckpt): os.remove(best_ckpt)
        best_ckpt = os.path.join(checkpoint_dir, f"best_ep{epoch}_dice{val_dice:.4f}.pth")
        torch.save(model.state_dict(), best_ckpt)
        print(f"  Saved best model: {best_ckpt}")
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print("Early stopping")
            break

    # Periodic checkpoint
    if epoch % save_interval == 0:
        periodic_ckpt = os.path.join(checkpoint_dir, f"ckpt_epoch{epoch}.pth")
        torch.save(model.state_dict(), periodic_ckpt)
        print(f"  Saved periodic checkpoint: {periodic_ckpt}")

print(f"Training complete. Best val dice: {best_metric:.4f} at epoch {best_epoch}")

# --- Test Phase ---
print("\n--- Test Evaluation ---")
if not os.path.exists(best_ckpt): raise FileNotFoundError("Checkpoint not found.")
model.load_state_dict(torch.load(best_ckpt, map_location=device))
model.eval(); dice_metric.reset()
with torch.no_grad():
    for batch in test_loader:
        imgs, lbls = batch['image'].to(device), batch['label'].to(device)
        outputs = sliding_window_inference(
            imgs, roi_size=infer_roi_size, sw_batch_size=4,
            predictor=model, overlap=0.5, mode="gaussian"
        )
        probs = torch.softmax(outputs, dim=1)
        seg = torch.argmax(probs, dim=1)
        onehot_pred = F.one_hot(seg, num_classes=2).permute(0,4,1,2,3).float()
        gt = lbls.squeeze(1).long()
        onehot_gt = F.one_hot(gt, num_classes=2).permute(0,4,1,2,3).float()
        for b in range(onehot_pred.shape[0]):
            dice_metric(y_pred=[onehot_pred[b]], y=[onehot_gt[b]])
final_test_dice = dice_metric.aggregate().item()
print(f"Test Mean Dice: {final_test_dice:.4f}")

# --- Plot Examples ---
print("\n--- Plotting Examples ---")
plots = 0
for batch in val_loader:
    if plots >= 3: break
    img_np = batch['image'][0,0].cpu().numpy()
    lbl_np = batch['label'][0,0].cpu().numpy()
    with torch.no_grad():
        out = sliding_window_inference(
            batch['image'].to(device), roi_size=infer_roi_size, sw_batch_size=4,
            predictor=model, overlap=0.5, mode="gaussian"
        )
        pred_np = torch.argmax(torch.softmax(out, dim=1), dim=1)[0].cpu().numpy()
    sl = img_np.shape[2]//2
    plt.figure(figsize=(12,4))
    plt.subplot(1,3,1); plt.imshow(img_np[:,:,sl], cmap="gray"); plt.axis('off'); plt.title(f"Img {plots}")
    plt.subplot(1,3,2); plt.imshow(lbl_np[:,:,sl]); plt.axis('off'); plt.title(f"Lbl {plots}")
    plt.subplot(1,3,3); plt.imshow(pred_np[:,:,sl]); plt.axis('off'); plt.title(f"Pred {plots}")
    plt.tight_layout(); plt.show()
    plots += 1

print("Done.")


In [None]:
import os
import glob
import random
import numpy as np
import torch
import torch.nn.functional as F
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.cuda.amp import autocast, GradScaler
import matplotlib.pyplot as plt

# MONAI imports
from monai.data import CacheDataset, Dataset, DataLoader
from monai.inferers import sliding_window_inference
from monai.transforms import (
    LoadImaged, EnsureChannelFirstd, Spacingd, Orientationd,
    ScaleIntensityRanged, CropForegroundd, RandCropByPosNegLabeld,
    RandFlipd, RandRotate90d, RandShiftIntensityd,
    EnsureTyped, Compose
)
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric
from monai.config import print_config
from monai.utils import set_determinism

# Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print_config()
print(f"Using device: {device}")

# Reproducibility
seed = 42
set_determinism(seed=seed)
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

# Hyperparameters
max_epochs = 300
train_batch_size = 1
val_batch_size = 1
test_batch_size = 1
learning_rate = 5e-5
patience = 20
train_roi_size = (64, 64, 64)
infer_roi_size = (128, 128, 128)
save_interval = 15

# Paths
base_data_dir = "/kaggle/input/spleen"
data_root = os.path.join(base_data_dir, "data", "Spleen")
checkpoint_dir = "/kaggle/working/checkpoints_unet++"
os.makedirs(checkpoint_dir, exist_ok=True)

# Dataset
def make_dataset(phase):
    img_dir = os.path.join(data_root, phase, "images")
    lbl_dir = os.path.join(data_root, phase, "labels")
    imgs = sorted(glob.glob(os.path.join(img_dir, "*.nii*")))
    lbls = sorted(glob.glob(os.path.join(lbl_dir, "*.nii*")))
    assert len(imgs) == len(lbls) and len(imgs) > 0, f"Data error in {phase}"
    return [{"image": i, "label": l} for i, l in zip(imgs, lbls)]

train_files = make_dataset("train")
val_files = make_dataset("val")
test_files = make_dataset("test")
print(f"Train/Val/Test: {len(train_files)}/{len(val_files)}/{len(test_files)}")

# Transforms
train_transforms = Compose([
    LoadImaged(keys=["image","label"]),
    EnsureChannelFirstd(keys=["image","label"]),
    Orientationd(keys=["image","label"], axcodes="RAS"),
    Spacingd(keys=["image","label"], pixdim=(1.5,1.5,2.0), mode=("bilinear","nearest")),
    ScaleIntensityRanged(keys=["image"], a_min=-57, a_max=164, b_min=0, b_max=1, clip=True),
    CropForegroundd(keys=["image","label"], source_key="image"),
    RandCropByPosNegLabeld(keys=["image","label"], label_key="label", spatial_size=train_roi_size,
                             pos=1, neg=1, num_samples=4),
    RandFlipd(keys=["image","label"], spatial_axis=[0], prob=0.1),
    RandFlipd(keys=["image","label"], spatial_axis=[1], prob=0.1),
    RandFlipd(keys=["image","label"], spatial_axis=[2], prob=0.1),
    RandRotate90d(keys=["image","label"], prob=0.1, max_k=3),
    RandShiftIntensityd(keys=["image"], offsets=0.1, prob=0.5),
    EnsureTyped(keys=["image","label"], track_meta=False),
])
val_transforms = Compose([
    LoadImaged(keys=["image","label"]),
    EnsureChannelFirstd(keys=["image","label"]),
    Orientationd(keys=["image","label"], axcodes="RAS"),
    Spacingd(keys=["image","label"], pixdim=(1.5,1.5,2.0), mode=("bilinear","nearest")),
    ScaleIntensityRanged(keys=["image"], a_min=-57, a_max=164, b_min=0, b_max=1, clip=True),
    CropForegroundd(keys=["image","label"], source_key="image"),
    EnsureTyped(keys=["image","label"], track_meta=True),
])

# DataLoaders
data_train = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=4)
data_val = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=4)
train_loader = DataLoader(data_train, batch_size=train_batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(data_val, batch_size=val_batch_size, shuffle=False, num_workers=4)

test_ds = Dataset(data=test_files, transform=val_transforms)
test_loader = DataLoader(test_ds, batch_size=test_batch_size, shuffle=False, num_workers=4)

# Define UNet++ with gradient checkpointing
import torch.nn as nn
import torch.utils.checkpoint as cp

def conv_block(in_ch, out_ch):
    return nn.Sequential(
        nn.Conv3d(in_ch, out_ch, 3, padding=1, bias=False),
        nn.BatchNorm3d(out_ch),
        nn.ReLU(inplace=True),
        nn.Conv3d(out_ch, out_ch, 3, padding=1, bias=False),
        nn.BatchNorm3d(out_ch),
        nn.ReLU(inplace=True)
    )

class UNetPlusPlus3D(nn.Module):
    def __init__(self, filters=(8,16,32,64)):
        super().__init__()
        f = filters
        self.conv0_0 = conv_block(1, f[0]); self.conv1_0 = conv_block(f[0], f[1])
        self.conv2_0 = conv_block(f[1], f[2]); self.conv3_0 = conv_block(f[2], f[3])
        self.conv0_1 = conv_block(f[0]+f[1], f[0]); self.conv1_1 = conv_block(f[1]+f[2], f[1]); self.conv2_1 = conv_block(f[2]+f[3], f[2])
        self.conv0_2 = conv_block(f[0]*2+f[1], f[0]); self.conv1_2 = conv_block(f[1]*2+f[2], f[1])
        self.conv0_3 = conv_block(f[0]*3+f[1], f[0])
        self.pool = nn.MaxPool3d(2)
        self.up = lambda x, ref: nn.functional.interpolate(x, size=ref.shape[2:], mode='trilinear', align_corners=True)
        self.final = nn.Conv3d(f[0], 2, 1)

    def forward(self, x):
        x0_0 = cp.checkpoint(self.conv0_0, x)
        x1_0 = cp.checkpoint(self.conv1_0, self.pool(x0_0))
        x0_1 = cp.checkpoint(self.conv0_1, torch.cat([x0_0, self.up(x1_0, x0_0)], dim=1))
        x2_0 = cp.checkpoint(self.conv2_0, self.pool(x1_0))
        x1_1 = cp.checkpoint(self.conv1_1, torch.cat([x1_0, self.up(x2_0, x1_0)], dim=1))
        x0_2 = cp.checkpoint(self.conv0_2, torch.cat([x0_0, x0_1, self.up(x1_1, x0_0)], dim=1))
        x3_0 = cp.checkpoint(self.conv3_0, self.pool(x2_0))
        x2_1 = cp.checkpoint(self.conv2_1, torch.cat([x2_0, self.up(x3_0, x2_0)], dim=1))
        x1_2 = cp.checkpoint(self.conv1_2, torch.cat([x1_0, x1_1, self.up(x2_1, x1_0)], dim=1))
        x0_3 = cp.checkpoint(self.conv0_3, torch.cat([x0_0, x0_1, x0_2, self.up(x1_2, x0_0)], dim=1))
        return self.final(x0_3)

model = UNetPlusPlus3D().to(device)

# Optimizer & Loss
optimizer = Adam(model.parameters(), lr=learning_rate)
scheduler = ReduceLROnPlateau(optimizer, 'max', patience=patience, factor=0.5, verbose=True)
loss_fn = DiceCELoss(to_onehot_y=True, softmax=True)
dice_metric = DiceMetric(include_background=False, reduction="mean")
scaler = GradScaler()

# Training & Validation
best_metric, best_epoch, no_improve = -1.0, -1, 0
for epoch in range(1, max_epochs+1):
    model.train(); train_loss=0.0
    for i, batch in enumerate(train_loader, 1):
        imgs, lbls = batch['image'].to(device), batch['label'].to(device)
        optimizer.zero_grad()
        with autocast():
            preds = model(imgs)
            loss = loss_fn(preds, lbls)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        train_loss += loss.item()
    print(f"Epoch {epoch} | Train Loss: {train_loss/i:.4f}")

    model.eval(); dice_metric.reset()
    with torch.no_grad(), autocast():
        for batch in val_loader:
            imgs, lbls = batch['image'].to(device), batch['label'].to(device)
            out = sliding_window_inference(
                imgs, roi_size=infer_roi_size, sw_batch_size=1, predictor=model, overlap=0.5
            )
            seg = torch.argmax(out.softmax(1), 1)
            onehot_pred = F.one_hot(seg, 2).permute(0,4,1,2,3).float()
            gt = lbls.squeeze(1).long()
            onehot_gt = F.one_hot(gt, 2).permute(0,4,1,2,3).float()
            for b in range(onehot_pred.shape[0]):
                dice_metric(y_pred=[onehot_pred[b]], y=[onehot_gt[b]])
    val_dice = dice_metric.aggregate().item()
    print(f"Epoch {epoch} | Val Dice: {val_dice:.4f}")
    scheduler.step(val_dice)

    if val_dice > best_metric:
        best_metric, best_epoch, no_improve = val_dice, epoch, 0
        best_ckpt = os.path.join(checkpoint_dir, f"best_epoch{epoch}.pth")
        torch.save(model.state_dict(), best_ckpt)
        print(f"Saved best model: {best_ckpt}")
    else:
        no_improve += 1
        if no_improve >= patience:
            print("Early stopping triggered.")
            break

    if epoch % save_interval == 0:
        ckpt = os.path.join(checkpoint_dir, f"ckpt_epoch{epoch}.pth")
        torch.save(model.state_dict(), ckpt)
        print(f"Saved periodic checkpoint: {ckpt}")

print(f"Training done | Best Val Dice: {best_metric:.4f} at epoch {best_epoch}")

# Test Phase
print("\n--- Testing on held-out set ---")
model.load_state_dict(torch.load(best_ckpt, map_location=device))
model.eval(); dice_metric.reset()
with torch.no_grad(), autocast():
    for batch in test_loader:
        imgs, lbls = batch['image'].to(device), batch['label'].to(device)
        out = sliding_window_inference(
            imgs, roi_size=infer_roi_size, sw_batch_size=1, predictor=model, overlap=0.5
        )
        seg = torch.argmax(out.softmax(1), 1)
        onehot_pred = F.one_hot(seg, 2).permute(0,4,1,2,3).float()
        gt = lbls.squeeze(1).long()
        onehot_gt = F.one_hot(gt, 2).permute(0,4,1,2,3).float()
        for b in range(onehot_pred.shape[0]):
            dice_metric(y_pred=[onehot_pred[b]], y=[onehot_gt[b]])
final_test_dice = dice_metric.aggregate().item()
print(f"Test Mean Dice: {final_test_dice:.4f}")

# Plot
print("\n--- Sample Predictions ---")
plot_count = 0
for batch in val_loader:
    if plot_count >= 3: break
    img_np = batch['image'][0,0].cpu().numpy()
    lbl_np = batch['label'][0,0].cpu().numpy()
    with torch.no_grad(), autocast():
        out = sliding_window_inference(
            batch['image'].to(device), roi_size=infer_roi_size, sw_batch_size=1, predictor=model, overlap=0.5
        )
        pred_np = torch.argmax(out.softmax(1), 1)[0].cpu().numpy()
    slice_idx = img_np.shape[2] // 2
    plt.figure(figsize=(12,4))
    plt.subplot(1,3,1); plt.imshow(img_np[:,:,slice_idx], cmap="gray"); plt.title("Image"); plt.axis('off')
    plt.subplot(1,3,2); plt.imshow(lbl_np[:,:,slice_idx]); plt.title("Label"); plt.axis('off')
    plt.subplot(1,3,3); plt.imshow(pred_np[:,:,slice_idx]); plt.title("Prediction"); plt.axis('off')
    plt.tight_layout(); plt.show()
    plot_count += 1
print("Done.")


In [6]:
import os
import glob
import random
import numpy as np
import torch
import torch.nn.functional as F
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.cuda.amp import autocast, GradScaler
import matplotlib.pyplot as plt

# MONAI imports
from monai.data import CacheDataset, Dataset, DataLoader
from monai.inferers import sliding_window_inference
from monai.transforms import (
    LoadImaged, EnsureChannelFirstd, Spacingd, Orientationd,
    ScaleIntensityRanged, CropForegroundd, RandCropByPosNegLabeld,
    RandFlipd, RandRotate90d, RandShiftIntensityd,
    EnsureTyped, Compose
)
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric
from monai.config import print_config
from monai.utils import set_determinism

# Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print_config(); print(f"Using device: {device}")

# Reproducibility
seed = 42
set_determinism(seed=seed)
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

# Hyperparameters
max_epochs = 300
train_batch_size = 1
val_batch_size = 1
test_batch_size = 1
learning_rate = 5e-5
patience = 20
train_roi_size = (64, 64, 64)
infer_roi_size = (128, 128, 128)
save_interval = 20  # periodic model save
resume_interval = 20  # periodic resume checkpoint

# Paths
base_data_dir = "/kaggle/input/spleen"
data_root = os.path.join(base_data_dir, "data", "Spleen")
checkpoint_dir = "/kaggle/working/checkpoints_attunet++"
os.makedirs(checkpoint_dir, exist_ok=True)

# Dataset helper
def make_dataset(phase):
    img_dir = os.path.join(data_root, phase, "images")
    lbl_dir = os.path.join(data_root, phase, "labels")
    imgs = sorted(glob.glob(os.path.join(img_dir, "*.nii*")))
    lbls = sorted(glob.glob(os.path.join(lbl_dir, "*.nii*")))
    assert len(imgs) == len(lbls) and len(imgs) > 0, f"Data error in {phase}"
    return [{"image": i, "label": l} for i, l in zip(imgs, lbls)]

train_files = make_dataset("train")
val_files = make_dataset("val")
test_files = make_dataset("test")
print(f"Train/Val/Test: {len(train_files)}/{len(val_files)}/{len(test_files)}")

# Transforms
train_transforms = Compose([
    LoadImaged(keys=["image","label"]), EnsureChannelFirstd(keys=["image","label"]),
    Orientationd(keys=["image","label"], axcodes="RAS"),
    Spacingd(keys=["image","label"], pixdim=(1.5,1.5,2.0), mode=("bilinear","nearest")),
    ScaleIntensityRanged(keys=["image"], a_min=-57,a_max=164,b_min=0,b_max=1,clip=True),
    CropForegroundd(keys=["image","label"], source_key="image"),
    RandCropByPosNegLabeld(keys=["image","label"],label_key="label",spatial_size=train_roi_size,pos=1,neg=1,num_samples=4),
    RandFlipd(keys=["image","label"], spatial_axis=[0], prob=0.1),
    RandFlipd(keys=["image","label"], spatial_axis=[1], prob=0.1),
    RandFlipd(keys=["image","label"], spatial_axis=[2], prob=0.1),
    RandRotate90d(keys=["image","label"], prob=0.1, max_k=3),
    RandShiftIntensityd(keys=["image"], offsets=0.1, prob=0.5),
    EnsureTyped(keys=["image","label"], track_meta=False),
])
val_transforms = Compose([
    LoadImaged(keys=["image","label"]), EnsureChannelFirstd(keys=["image","label"]),
    Orientationd(keys=["image","label"], axcodes="RAS"),
    Spacingd(keys=["image","label"], pixdim=(1.5,1.5,2.0), mode=("bilinear","nearest")),
    ScaleIntensityRanged(keys=["image"], a_min=-57,a_max=164,b_min=0,b_max=1,clip=True),
    CropForegroundd(keys=["image","label"], source_key="image"),
    EnsureTyped(keys=["image","label"], track_meta=True),
])

# DataLoaders
data_train = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=4)
train_loader = DataLoader(data_train, batch_size=train_batch_size, shuffle=True, num_workers=4)
data_val = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=4)
val_loader = DataLoader(data_val, batch_size=val_batch_size, shuffle=False, num_workers=4)
test_ds = Dataset(data=test_files, transform=val_transforms)
test_loader = DataLoader(test_ds, batch_size=test_batch_size, shuffle=False, num_workers=4)

# Attention U-Net definition
from monai.networks.nets import AttentionUnet

model = AttentionUnet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(8, 16, 32, 64),   # adjust for memory vs performance
    strides=(2, 2, 2),
    kernel_size=3,
    up_kernel_size=3,
    dropout=0.0,
).to(device)

# Optimizer & utilities
optimizer = Adam(model.parameters(), lr=learning_rate)
scheduler = ReduceLROnPlateau(optimizer,'max',patience=patience,factor=0.5,verbose=True)
loss_fn = DiceCELoss(to_onehot_y=True,softmax=True)
dice_metric = DiceMetric(include_background=False,reduction='mean')
scaler = GradScaler()

# Resume logic
start_epoch = 1
best_metric = -1.0
best_epoch = 0
no_improve = 0
latest_ckpt = os.path.join(checkpoint_dir,'latest.pth')
if os.path.exists(latest_ckpt):
    ckpt = torch.load(latest_ckpt,map_location=device)
    model.load_state_dict(ckpt['model'])
    optimizer.load_state_dict(ckpt['optimizer'])
    scheduler.load_state_dict(ckpt['scheduler'])
    scaler.load_state_dict(ckpt['scaler'])
    best_metric = ckpt['best_metric']; best_epoch = ckpt['best_epoch']; no_improve = ckpt['no_improve']
    start_epoch = ckpt['epoch'] + 1
    print(f"Resumed from epoch {start_epoch-1}")

# Training loop
for epoch in range(start_epoch, max_epochs+1):
    model.train(); sum_loss=0.0
    for i,batch in enumerate(train_loader,1):
        imgs, lbls = batch['image'].to(device), batch['label'].to(device)
        optimizer.zero_grad()
        with autocast():
            preds = model(imgs)
            loss = loss_fn(preds,lbls)
        scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update()
        sum_loss += loss.item()
    print(f"Epoch {epoch} | Train Loss: {sum_loss/i:.4f}")

    model.eval(); dice_metric.reset()
    with torch.no_grad(), autocast():
        for batch in val_loader:
            imgs, lbls = batch['image'].to(device), batch['label'].to(device)
            out = sliding_window_inference(imgs, roi_size=infer_roi_size, sw_batch_size=1, predictor=model, overlap=0.5)
            seg = torch.argmax(out.softmax(1),1)
            oh = F.one_hot(seg,2).permute(0,4,1,2,3).float()
            gt = lbls.squeeze(1).long()
            ohg = F.one_hot(gt,2).permute(0,4,1,2,3).float()
            for b in range(oh.shape[0]): dice_metric(y_pred=[oh[b]],y=[ohg[b]])
    val_dice = dice_metric.aggregate().item()
    print(f"Epoch {epoch} | Val Dice: {val_dice:.4f}")
    scheduler.step(val_dice)

    # best model
    if val_dice>best_metric:
        best_metric,best_epoch,no_improve=val_dice,epoch,0
        best_path=os.path.join(checkpoint_dir,f"best_epoch{epoch}.pth")
        torch.save(model.state_dict(),best_path)
        print(f"Saved best model: {best_path}")
    else:
        no_improve+=1
        if no_improve>=patience:
            print("Early stopping.")
            break

    # periodic checkpoint
    if epoch%save_interval==0:
        path=os.path.join(checkpoint_dir,f"ckpt_epoch{epoch}.pth")
        torch.save(model.state_dict(),path)
        print(f"Saved periodic: {path}")

    # resume checkpoint
    if epoch%resume_interval==0:
        torch.save({
            'epoch':epoch,
            'model':model.state_dict(),
            'optimizer':optimizer.state_dict(),
            'scheduler':scheduler.state_dict(),
            'scaler':scaler.state_dict(),
            'best_metric':best_metric,
            'best_epoch':best_epoch,
            'no_improve':no_improve
        }, latest_ckpt)
        print(f"Saved resume checkpoint: {latest_ckpt}")

print(f"Training done | Best Val Dice: {best_metric:.4f} at epoch {best_epoch}")

# Test phase
print("\n--- Testing ---")
model.load_state_dict(torch.load(best_path,map_location=device))
model.eval(); dice_metric.reset()
with torch.no_grad(), autocast():
    for batch in test_loader:
        imgs,lbls = batch['image'].to(device), batch['label'].to(device)
        out = sliding_window_inference(imgs,roi_size=infer_roi_size,sw_batch_size=1,predictor=model,overlap=0.5)
        seg=torch.argmax(out.softmax(1),1)
        oh=F.one_hot(seg,2).permute(0,4,1,2,3).float()
        gt=lbls.squeeze(1).long()
        ohg=F.one_hot(gt,2).permute(0,4,1,2,3).float()
        for b in range(oh.shape[0]): dice_metric(y_pred=[oh[b]],y=[ohg[b]])
final= dice_metric.aggregate().item()
print(f"Test Mean Dice: {final:.4f}")


MONAI version: 1.4.1rc1+46.gb58e883c
Numpy version: 1.26.4
Pytorch version: 2.5.1+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: b58e883c887e0f99d382807550654c44d94f47bd
MONAI __file__: /usr/local/lib/python3.10/dist-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: 0.5.1
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.3.2
scikit-image version: 0.25.0
scipy version: 1.13.1
Pillow version: 11.0.0
Tensorboard version: 2.17.1
gdown version: 5.2.0
TorchVision version: 0.20.1+cu121
tqdm version: 4.67.1
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.5
pandas version: 2.2.3
einops version: 0.8.0
transformers version: 4.47.0
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/instal

Loading dataset: 100%|██████████| 29/29 [00:48<00:00,  1.68s/it]
Loading dataset: 100%|██████████| 6/6 [00:09<00:00,  1.65s/it]

Resumed from epoch 120



  scaler = GradScaler()
  ckpt = torch.load(latest_ckpt,map_location=device)
  with autocast():


Epoch 121 | Train Loss: 0.7981


  with torch.no_grad(), autocast():


Epoch 121 | Val Dice: 0.3760
Epoch 122 | Train Loss: 0.8224
Epoch 122 | Val Dice: 0.4063
Epoch 123 | Train Loss: 0.8217
Epoch 123 | Val Dice: 0.4605
Epoch 124 | Train Loss: 0.8171
Epoch 124 | Val Dice: 0.4526
Epoch 125 | Train Loss: 0.8034
Epoch 125 | Val Dice: 0.4158
Epoch 126 | Train Loss: 0.8210
Epoch 126 | Val Dice: 0.4510
Early stopping.
Training done | Best Val Dice: 0.5279 at epoch 106

--- Testing ---


  model.load_state_dict(torch.load(best_path,map_location=device))
  with torch.no_grad(), autocast():


Test Mean Dice: 0.3852


In [6]:
import os
import glob
import random
import numpy as np
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.cuda.amp import autocast, GradScaler
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
from monai.data import CacheDataset, Dataset, DataLoader
from monai.inferers import sliding_window_inference
from monai.transforms import (
    LoadImaged, EnsureChannelFirstd, Orientationd, Spacingd,
    ScaleIntensityRanged, CropForegroundd, RandCropByPosNegLabeld,
    RandFlipd, RandRotate90d, RandShiftIntensityd,
    RandGaussianNoised, Rand3DElasticd,
    EnsureTyped, Compose
)
from monai.losses import TverskyLoss
from monai.metrics import DiceMetric
from monai.networks.nets import UNet
from monai.config import print_config
from monai.utils import set_determinism
import time

# -------------------- Configuration --------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print_config()
print(f"Using device: {device}")

seed = 42
set_determinism(seed=seed)
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

data_root = "/kaggle/input/spleen/data/Spleen"
checkpoint_dir = "/kaggle/working/checkpoints_unet_swa"
os.makedirs(checkpoint_dir, exist_ok=True)

max_epochs = 200
train_batch_size = 2
val_batch_size = 1
test_batch_size = 1
learning_rate = 1e-3
train_roi_size = (128, 128, 128)
infer_roi_size = (192, 192, 192)
patience = 30
accumulation_steps = 2
swa_start_epoch = int(max_epochs * 0.75)
save_every = 15

# -------------------- Dataset --------------------
def make_dataset(phase):
    img_dir = os.path.join(data_root, phase, "images")
    lbl_dir = os.path.join(data_root, phase, "labels")
    imgs = sorted(glob.glob(os.path.join(img_dir, "*.nii*")))
    lbls = sorted(glob.glob(os.path.join(lbl_dir, "*.nii*")))
    assert len(imgs) == len(lbls) and len(imgs) > 0, f"Data error in {phase}"
    return [{"image": i, "label": l} for i, l in zip(imgs, lbls)]

train_files = make_dataset("train")
val_files = make_dataset("val")
test_files = make_dataset("test")
print(f"Train/Val/Test: {len(train_files)}/{len(val_files)}/{len(test_files)}")

# -------------------- Transforms --------------------
train_transforms = Compose([
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
    ScaleIntensityRanged(keys=["image"], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),
    CropForegroundd(keys=["image", "label"], source_key="image"),
    RandCropByPosNegLabeld(
        keys=["image", "label"], label_key="label",
        spatial_size=train_roi_size, pos=1, neg=1, num_samples=8,
        allow_smaller=True
    ),
    RandFlipd(keys=["image", "label"], spatial_axis=[0, 1, 2], prob=0.5),
    RandRotate90d(keys=["image", "label"], prob=0.5, max_k=3),
    RandShiftIntensityd(keys=["image"], offsets=0.15, prob=0.5),
    RandGaussianNoised(keys=["image"], prob=0.3, mean=0.0, std=0.1),
    Rand3DElasticd(keys=["image", "label"], sigma_range=(5, 8), magnitude_range=(100, 200),
                   spatial_size=train_roi_size, prob=0.3, mode=("bilinear", "nearest")),
    EnsureTyped(keys=["image", "label"], track_meta=False),
])
val_transforms = Compose([
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
    ScaleIntensityRanged(keys=["image"], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),
    CropForegroundd(keys=["image", "label"], source_key="image"),
    EnsureTyped(keys=["image", "label"], track_meta=True),
])

# -------------------- DataLoaders --------------------
train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=4)
train_loader = DataLoader(train_ds, batch_size=train_batch_size, shuffle=True, num_workers=4)
val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=val_batch_size, shuffle=False, num_workers=2)
test_ds = Dataset(data=test_files, transform=val_transforms)
test_loader = DataLoader(test_ds, batch_size=test_batch_size, shuffle=False, num_workers=2)

# -------------------- Model & Optimization --------------------
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(32, 64, 128, 256, 512),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    dropout=0.1
).to(device)
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5)
scheduler = CosineAnnealingLR(optimizer, T_max=max_epochs - swa_start_epoch, eta_min=1e-6)
loss_fn = TverskyLoss(alpha=0.3, beta=0.7, to_onehot_y=True, smooth_nr=1e-5, smooth_dr=1e-5)
dice_metric = DiceMetric(include_background=False, reduction="mean")
scaler = GradScaler()

swa_model = AveragedModel(model)
swa_scheduler = SWALR(optimizer, swa_lr=learning_rate * 0.1)

# -------------------- Resume from latest checkpoint --------------------
start_epoch = 1
best_metric = -1.0
checkpoint_files = sorted(glob.glob(os.path.join(checkpoint_dir, "checkpoint_epoch*.pth")))
if checkpoint_files:
    latest_checkpoint = checkpoint_files[-1]
    print(f"Resuming training from {latest_checkpoint}")
    checkpoint = torch.load(latest_checkpoint, map_location=device)
    model.load_state_dict(checkpoint["model_state"])
    optimizer.load_state_dict(checkpoint["optimizer_state"])
    scaler.load_state_dict(checkpoint["scaler_state"])
    scheduler.load_state_dict(checkpoint["scheduler_state"])
    start_epoch = checkpoint["epoch"] + 1
    best_metric = checkpoint.get("best_metric", -1.0)
else:
    print("No checkpoint found. Starting training from scratch.")

# -------------------- Training & Validation --------------------
best_epoch, no_improve = 0, 0
for epoch in range(start_epoch, max_epochs + 1):
    epoch_start = time.time()
    model.train()
    epoch_loss = 0.0
    optimizer.zero_grad()
    for i, batch in enumerate(train_loader, 1):
        imgs = batch['image'].to(device)
        lbls = batch['label'].to(device)
        with autocast():
            outputs = model(imgs)
            loss = loss_fn(outputs, lbls) / accumulation_steps
        scaler.scale(loss).backward()
        if i % accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        epoch_loss += float(loss) * accumulation_steps
        if epoch <= swa_start_epoch:
            scheduler.step()
        else:
            swa_scheduler.step()
    epoch_loss /= i
    print(f"Epoch {epoch} | Train Loss: {epoch_loss:.4f}")

    if epoch > swa_start_epoch:
        swa_model.update_parameters(model)

    # Validation
    eval_model = swa_model.module if epoch > swa_start_epoch else model
    eval_model.eval()
    dice_metric.reset()
    with torch.no_grad():
        for batch in val_loader:
            imgs = batch['image'].to(device)
            lbls = batch['label'].to(device)
            out = sliding_window_inference(
                imgs, roi_size=infer_roi_size, sw_batch_size=1,
                predictor=eval_model, overlap=0.5
            )
            seg = torch.argmax(torch.softmax(out, dim=1), dim=1)
            oh = F.one_hot(seg, num_classes=2).permute(0, 4, 1, 2, 3).float()
            gt = lbls.squeeze(1).long()
            ohg = F.one_hot(gt, num_classes=2).permute(0, 4, 1, 2, 3).float()
            for b in range(oh.shape[0]):
                dice_metric(y_pred=[oh[b]], y=[ohg[b]])
    val_dice = float(dice_metric.aggregate())
    print(f"Epoch {epoch} | Val Dice: {val_dice:.4f}")

    if val_dice > best_metric:
        best_metric, best_epoch, no_improve = val_dice, epoch, 0
        torch.save({
            "model_state": eval_model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "scaler_state": scaler.state_dict(),
            "scheduler_state": scheduler.state_dict(),
            "epoch": epoch,
            "best_metric": best_metric
        }, os.path.join(checkpoint_dir, f"best_epoch{epoch}.pth"))
        print(f"Saved best model at epoch {epoch}")
    else:
        no_improve += 1
        if no_improve >= patience:
            print("Early stopping triggered.")
            break

    if epoch % save_every == 0:
        torch.save({
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "scaler_state": scaler.state_dict(),
            "scheduler_state": scheduler.state_dict(),
            "epoch": epoch,
            "best_metric": best_metric
        }, os.path.join(checkpoint_dir, f"checkpoint_epoch{epoch}.pth"))
        print(f"Saved checkpoint at epoch {epoch}")

print(f"Training complete. Best Val Dice: {best_metric:.4f} at epoch {best_epoch}")

update_bn(train_loader, swa_model, device=device)

# -------------------- Testing --------------------
print("--- Testing ---")
checkpoint = torch.load(os.path.join(checkpoint_dir, f"best_epoch{best_epoch}.pth"))
model.load_state_dict(checkpoint["model_state"])
model.eval()
dice_metric.reset()
with torch.no_grad():
    for batch in test_loader:
        imgs = batch['image'].to(device)
        lbls = batch['label'].to(device)
        out = sliding_window_inference(
            imgs, roi_size=infer_roi_size, sw_batch_size=1,
            predictor=swa_model.module, overlap=0.5
        )
        seg = torch.argmax(torch.softmax(out, dim=1), dim=1)
        oh = F.one_hot(seg, num_classes=2).permute(0, 4, 1, 2, 3).float()
        gt = lbls.squeeze(1).long()
        ohg = F.one_hot(gt, num_classes=2).permute(0, 4, 1, 2, 3).float()
        for b in range(oh.shape[0]):
            dice_metric(y_pred=[oh[b]], y=[ohg[b]])
final_dice = float(dice_metric.aggregate())
print(f"Test Mean Dice: {final_dice:.4f}")


MONAI version: 1.4.1rc1+46.gb58e883c
Numpy version: 1.26.4
Pytorch version: 2.5.1+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: b58e883c887e0f99d382807550654c44d94f47bd
MONAI __file__: /usr/local/lib/python3.10/dist-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: 0.5.1
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.3.2
scikit-image version: 0.25.0
scipy version: 1.13.1
Pillow version: 11.0.0
Tensorboard version: 2.17.1
gdown version: 5.2.0
TorchVision version: 0.20.1+cu121
tqdm version: 4.67.1
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.5
pandas version: 2.2.3
einops version: 0.8.0
transformers version: 4.47.0
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/instal



Using device: cuda
Train/Val/Test: 29/6/6


Loading dataset: 100%|██████████| 29/29 [00:51<00:00,  1.79s/it]
Loading dataset: 100%|██████████| 6/6 [00:12<00:00,  2.11s/it]


No checkpoint found. Starting training from scratch.


  scaler = GradScaler()
  with autocast():


Epoch 1 | Train Loss: 0.6560
Epoch 1 | Val Dice: 0.0326
Saved best model at epoch 1
Epoch 2 | Train Loss: 0.4045
Epoch 2 | Val Dice: 0.0352
Saved best model at epoch 2
Epoch 3 | Train Loss: 0.2488
Epoch 3 | Val Dice: 0.0314
Epoch 4 | Train Loss: 0.1887
Epoch 4 | Val Dice: 0.0350
Epoch 5 | Train Loss: 0.1438
Epoch 5 | Val Dice: 0.0495
Saved best model at epoch 5
Epoch 6 | Train Loss: 0.6568
Epoch 6 | Val Dice: 0.0972
Saved best model at epoch 6
Epoch 7 | Train Loss: -0.1423
Epoch 7 | Val Dice: 0.0379
Epoch 8 | Train Loss: 0.5313
Epoch 8 | Val Dice: 0.0909
Epoch 9 | Train Loss: 0.6881
Epoch 9 | Val Dice: 0.0853
Epoch 10 | Train Loss: 1.0396
Epoch 10 | Val Dice: 0.0881
Epoch 11 | Train Loss: 0.9814
Epoch 11 | Val Dice: 0.0872
Epoch 12 | Train Loss: 0.9076
Epoch 12 | Val Dice: 0.0708
Epoch 13 | Train Loss: 0.7427
Epoch 13 | Val Dice: 0.0342
Epoch 14 | Train Loss: 0.6254
Epoch 14 | Val Dice: 0.0058
Epoch 15 | Train Loss: 0.5468
Epoch 15 | Val Dice: 0.0003
Saved checkpoint at epoch 15
Epoch 

  checkpoint = torch.load(os.path.join(checkpoint_dir, f"best_epoch{best_epoch}.pth"))


Test Mean Dice: 0.0106


GAN

In [1]:
import os, re
import nibabel as nib
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.functional import pad
from torch.amp import autocast, GradScaler
from tqdm import tqdm

# ----------------------------
# 0) CONFIGURATION
# ----------------------------
DATA_DIR         = "/kaggle/input/spleen/data/Spleen"
CKPT_DIR         = "/kaggle/working/gan_checkpoints"
os.makedirs(CKPT_DIR, exist_ok=True)

BATCH_SIZE       = 1
EPOCHS           = 50
LR               = 2e-4
CHECKPOINT_EVERY = 10
DEVICE           = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SCALER           = GradScaler()

# Desired volume shape (D, H, W)
TARGET_SHAPE     = (50, 512, 512)


# ----------------------------
# 1) DATASET
# ----------------------------
class NiiDataset(Dataset):
    def __init__(self, base_dir, split="train"):
        self.img_dir = os.path.join(base_dir, split, "images")
        self.lbl_dir = os.path.join(base_dir, split, "labels")
        self.files   = sorted(f for f in os.listdir(self.img_dir) if f.endswith(".nii"))

    def __len__(self):
        return len(self.files)

    def _process(self, vol_np):
        vol = torch.from_numpy(vol_np.transpose(2,0,1)).float()
        # center‐crop if too large
        for ax, tgt in enumerate(TARGET_SHAPE):
            if vol.size(ax) > tgt:
                start = (vol.size(ax) - tgt)//2
                vol = vol.narrow(ax, start, tgt)
        # pad if too small
        pads = []
        for sz, tgt in zip(vol.shape, TARGET_SHAPE):
            diff = tgt - sz
            b = diff//2 if diff>0 else 0
            a = diff - b if diff>0 else 0
            pads.extend([b, a])
        vol = pad(vol, pads[::-1], mode="constant", value=0.0)
        return vol.unsqueeze(0)

    def __getitem__(self, idx):
        fn   = self.files[idx]
        img  = nib.load(os.path.join(self.img_dir, fn)).get_fdata()
        mask = nib.load(os.path.join(self.lbl_dir, fn)).get_fdata()
        img  = (img - img.min()) / (img.max() - img.min() + 1e-8)
        return self._process(img), self._process(mask)


# ----------------------------
# 2) MODELS (no final Sigmoid)
# ----------------------------
class UNet3D(nn.Module):
    def __init__(self, in_ch=1, out_ch=1):
        super().__init__()
        def C(ic, oc):
            return nn.Sequential(
                nn.Conv3d(ic,  oc, 3, padding=1), nn.BatchNorm3d(oc), nn.ReLU(inplace=True),
                nn.Conv3d(oc, oc, 3, padding=1),   nn.BatchNorm3d(oc), nn.ReLU(inplace=True),
            )
        self.enc1 = C(in_ch,   32)
        self.enc2 = C(32,      64)
        self.enc3 = C(64,     128)
        self.pool = nn.MaxPool3d(2)
        self.up2  = nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2, output_padding=(1,0,0))
        self.dec2 = C(64+64,   64)
        self.up1  = nn.ConvTranspose3d(64,  32, kernel_size=2, stride=2)
        self.dec1 = C(32+32,   32)
        self.final= nn.Conv3d(32, out_ch, 1)  # raw logits

    def forward(self, x):
        c1 = self.enc1(x); p1 = self.pool(c1)
        c2 = self.enc2(p1); p2 = self.pool(c2)
        c3 = self.enc3(p2)

        u2 = self.up2(c3)
        diff = [c2.size(i+2) - u2.size(i+2) for i in range(3)]
        slices = []
        for i, d in enumerate(diff):
            if d>0:
                s = d//2
                slices.append(slice(s, s+u2.size(i+2)))
            else:
                slices.append(slice(None))
        c2c = c2[:,:, slices[0], slices[1], slices[2]]
        d2  = self.dec2(torch.cat([u2, c2c], dim=1))

        u1 = self.up1(d2)
        diff = [c1.size(i+2) - u1.size(i+2) for i in range(3)]
        slices = []
        for i, d in enumerate(diff):
            if d>0:
                s = d//2
                slices.append(slice(s, s+u1.size(i+2)))
            else:
                slices.append(slice(None))
        c1c = c1[:,:, slices[0], slices[1], slices[2]]
        d1  = self.dec1(torch.cat([u1, c1c], dim=1))

        return self.final(d1)  # logits


class Disc3D(nn.Module):
    def __init__(self, in_ch=2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv3d(in_ch,  32, 4, stride=2, padding=1), nn.LeakyReLU(0.2, True),
            nn.Conv3d(32,    64, 4, stride=2, padding=1), nn.BatchNorm3d(64), nn.LeakyReLU(0.2, True),
            nn.Conv3d(64,   128, 4, stride=2, padding=1), nn.BatchNorm3d(128), nn.LeakyReLU(0.2, True),
            nn.Conv3d(128,    1, 4, stride=1, padding=1)   # logits
        )

    def forward(self, x, y):
        return self.net(torch.cat([x, y], dim=1))


# ----------------------------
# 3) SETUP & AUTO-RESUME
# ----------------------------
gen  = UNet3D().to(DEVICE)
disc = Disc3D().to(DEVICE)
adv  = nn.BCEWithLogitsLoss()
seg  = nn.BCEWithLogitsLoss()
optG = optim.Adam(gen.parameters(),  lr=LR, betas=(0.5,0.999))
optD = optim.Adam(disc.parameters(), lr=LR, betas=(0.5,0.999))

start_epoch = 0
ckpts = [int(m.group(1)) for fn in os.listdir(CKPT_DIR)
         if (m := re.match(r"ckpt_(\d+)\.pth", fn))]
if ckpts:
    last = max(ckpts)
    data = torch.load(os.path.join(CKPT_DIR, f"ckpt_{last}.pth"), map_location=DEVICE)
    gen.load_state_dict(data["gen"])
    disc.load_state_dict(data["disc"])
    optG.load_state_dict(data["optG"])
    optD.load_state_dict(data["optD"])
    start_epoch = last
    print(f"→ Resumed from epoch {start_epoch}")

train_ds = NiiDataset(DATA_DIR, "train")
train_ld = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                      num_workers=2, pin_memory=True)


# ----------------------------
# 4) TRAINING LOOP
# ----------------------------
for ep in range(start_epoch, EPOCHS):
    gen.train(); disc.train()
    d_sum = g_sum = 0.0
    bar = tqdm(train_ld, desc=f"Epoch [{ep+1}/{EPOCHS}]")
    for imgs, real in bar:
        imgs, real = imgs.to(DEVICE, non_blocking=True), real.to(DEVICE, non_blocking=True)

        # -- Discriminator step --
        optD.zero_grad()
        with autocast("cuda"):
            pr    = disc(imgs, real)
            fake  = gen(imgs)
            pf    = disc(imgs, fake.detach())
            d_loss= 0.5*(adv(pr, torch.ones_like(pr)) + adv(pf, torch.zeros_like(pf)))
        SCALER.scale(d_loss).backward()
        SCALER.step(optD)

        # -- Generator step --
        optG.zero_grad()
        with autocast("cuda"):
            pf2    = disc(imgs, fake)
            g_loss = adv(pf2, torch.ones_like(pf2)) + seg(fake, real)
        SCALER.scale(g_loss).backward()
        SCALER.step(optG)
        SCALER.update()

        d_sum += d_loss.item()
        g_sum += g_loss.item()
        bar.set_postfix(D=d_sum/len(train_ld), G=g_sum/len(train_ld))

    # checkpoint every N epochs
    if (ep+1) % CHECKPOINT_EVERY == 0 or (ep+1) == EPOCHS:
        path = os.path.join(CKPT_DIR, f"ckpt_{ep+1}.pth")
        torch.save({
            "gen":  gen.state_dict(),
            "disc": disc.state_dict(),
            "optG": optG.state_dict(),
            "optD": optD.state_dict(),
        }, path)
        print(f"→ Saved checkpoint at epoch {ep+1}")


# ----------------------------
# 5) FINAL EVALUATION
# ----------------------------
def evaluate():
    gen.eval()
    ds   = NiiDataset(DATA_DIR, "test")
    ld   = DataLoader(ds, batch_size=1, shuffle=False)
    dices, ious = [], []
    with torch.no_grad():
        for imgs, real in tqdm(ld, desc="Eval"):
            imgs, real = imgs.to(DEVICE), real.to(DEVICE)
            logits = gen(imgs)
            probs  = torch.sigmoid(logits)
            pred   = (probs > 0.5).float()
            inter  = (pred * real).sum()
            union  = pred.sum() + real.sum()
            dices.append(((2*inter)/(union+1e-7)).item())
            ious.append((inter/(union-inter+1e-7)).item())
    print(f"\nFinal → Dice: {np.mean(dices):.4f}±{np.std(dices):.4f},  IoU: {np.mean(ious):.4f}±{np.std(ious):.4f}")

if __name__ == "__main__":
    evaluate()


Epoch [1/50]: 100%|██████████| 29/29 [02:13<00:00,  4.60s/it, D=0.486, G=2.01] 
Epoch [2/50]: 100%|██████████| 29/29 [02:09<00:00,  4.48s/it, D=0.183, G=3.23]  
Epoch [3/50]: 100%|██████████| 29/29 [02:09<00:00,  4.48s/it, D=0.102, G=3.65]  
Epoch [4/50]: 100%|██████████| 29/29 [02:09<00:00,  4.47s/it, D=0.0371, G=4.49] 
Epoch [5/50]: 100%|██████████| 29/29 [02:09<00:00,  4.48s/it, D=0.0638, G=4.91] 
Epoch [6/50]: 100%|██████████| 29/29 [02:09<00:00,  4.47s/it, D=0.233, G=4.74]   
Epoch [7/50]: 100%|██████████| 29/29 [02:09<00:00,  4.47s/it, D=0.116, G=4.32] 
Epoch [8/50]: 100%|██████████| 29/29 [02:10<00:00,  4.49s/it, D=0.0422, G=5.47]  
Epoch [9/50]: 100%|██████████| 29/29 [02:09<00:00,  4.47s/it, D=0.0664, G=5.41] 
Epoch [10/50]: 100%|██████████| 29/29 [02:09<00:00,  4.48s/it, D=0.099, G=4.56]  


→ Saved checkpoint at epoch 10


Epoch [11/50]: 100%|██████████| 29/29 [02:09<00:00,  4.48s/it, D=0.376, G=4.75]   
Epoch [12/50]: 100%|██████████| 29/29 [02:09<00:00,  4.47s/it, D=0.202, G=4.54] 
Epoch [13/50]: 100%|██████████| 29/29 [02:09<00:00,  4.48s/it, D=0.108, G=5.36]   
Epoch [14/50]: 100%|██████████| 29/29 [02:09<00:00,  4.47s/it, D=0.332, G=3.93]  
Epoch [15/50]: 100%|██████████| 29/29 [02:09<00:00,  4.47s/it, D=0.321, G=3.48]  
Epoch [16/50]: 100%|██████████| 29/29 [02:09<00:00,  4.47s/it, D=0.319, G=3.7]  
Epoch [17/50]: 100%|██████████| 29/29 [02:09<00:00,  4.48s/it, D=0.294, G=3.47]  
Epoch [18/50]: 100%|██████████| 29/29 [02:09<00:00,  4.47s/it, D=0.302, G=3.4]  
Epoch [19/50]: 100%|██████████| 29/29 [02:10<00:00,  4.49s/it, D=0.287, G=3.05] 
Epoch [20/50]: 100%|██████████| 29/29 [02:09<00:00,  4.47s/it, D=0.399, G=2.63]  


→ Saved checkpoint at epoch 20


Epoch [21/50]: 100%|██████████| 29/29 [02:09<00:00,  4.47s/it, D=0.408, G=2.61]   
Epoch [22/50]: 100%|██████████| 29/29 [02:09<00:00,  4.46s/it, D=0.419, G=2.57]  
Epoch [23/50]: 100%|██████████| 29/29 [02:09<00:00,  4.47s/it, D=0.297, G=2.67] 
Epoch [24/50]: 100%|██████████| 29/29 [02:09<00:00,  4.47s/it, D=0.263, G=2.68]  
Epoch [25/50]: 100%|██████████| 29/29 [02:10<00:00,  4.50s/it, D=0.247, G=3.53]  
Epoch [26/50]: 100%|██████████| 29/29 [02:09<00:00,  4.46s/it, D=0.386, G=2.86]  
Epoch [27/50]: 100%|██████████| 29/29 [02:09<00:00,  4.47s/it, D=0.263, G=2.96]  
Epoch [28/50]: 100%|██████████| 29/29 [02:09<00:00,  4.47s/it, D=0.346, G=2.93] 
Epoch [29/50]: 100%|██████████| 29/29 [02:09<00:00,  4.47s/it, D=0.318, G=2.95]  
Epoch [30/50]: 100%|██████████| 29/29 [02:09<00:00,  4.47s/it, D=0.296, G=3.03]  


→ Saved checkpoint at epoch 30


Epoch [31/50]: 100%|██████████| 29/29 [02:09<00:00,  4.47s/it, D=0.274, G=3.04] 
Epoch [32/50]: 100%|██████████| 29/29 [02:09<00:00,  4.46s/it, D=0.273, G=3.22]  
Epoch [33/50]: 100%|██████████| 29/29 [02:09<00:00,  4.48s/it, D=0.112, G=3.64]  
Epoch [34/50]: 100%|██████████| 29/29 [02:09<00:00,  4.48s/it, D=0.134, G=3.68]  
Epoch [35/50]: 100%|██████████| 29/29 [02:09<00:00,  4.48s/it, D=0.268, G=3.3]   
Epoch [36/50]: 100%|██████████| 29/29 [02:09<00:00,  4.47s/it, D=0.127, G=3.66]  
Epoch [37/50]: 100%|██████████| 29/29 [02:09<00:00,  4.48s/it, D=0.376, G=3.87]  
Epoch [38/50]: 100%|██████████| 29/29 [02:09<00:00,  4.46s/it, D=0.425, G=3.09]  
Epoch [39/50]: 100%|██████████| 29/29 [02:09<00:00,  4.47s/it, D=0.143, G=3.63]  
Epoch [40/50]: 100%|██████████| 29/29 [02:09<00:00,  4.48s/it, D=0.0979, G=4.07] 


→ Saved checkpoint at epoch 40


Epoch [41/50]: 100%|██████████| 29/29 [02:09<00:00,  4.47s/it, D=0.397, G=3.66]  
Epoch [42/50]: 100%|██████████| 29/29 [02:09<00:00,  4.48s/it, D=0.329, G=2.52]  
Epoch [43/50]: 100%|██████████| 29/29 [02:09<00:00,  4.47s/it, D=0.337, G=2.72]  
Epoch [44/50]: 100%|██████████| 29/29 [02:09<00:00,  4.48s/it, D=0.221, G=3]     
Epoch [45/50]: 100%|██████████| 29/29 [02:09<00:00,  4.47s/it, D=0.218, G=3.37]  
Epoch [46/50]: 100%|██████████| 29/29 [02:09<00:00,  4.46s/it, D=0.209, G=3.4]   
Epoch [47/50]: 100%|██████████| 29/29 [02:09<00:00,  4.47s/it, D=0.249, G=3.18]  
Epoch [48/50]: 100%|██████████| 29/29 [02:09<00:00,  4.47s/it, D=0.27, G=3.17]   
Epoch [49/50]: 100%|██████████| 29/29 [02:09<00:00,  4.46s/it, D=0.267, G=3.86] 
Epoch [50/50]: 100%|██████████| 29/29 [02:09<00:00,  4.47s/it, D=0.239, G=3.34]  


→ Saved checkpoint at epoch 50


Eval: 100%|██████████| 6/6 [00:18<00:00,  3.12s/it]


Final → Dice: 0.0255±0.0260,  IoU: 0.0131±0.0135





In [3]:
!python -c "import monai" || pip install -q "monai-weekly[gdown, nibabel, tqdm, ignite]"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline

2025-04-29 15:00:09.591744: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-04-29 15:00:09.780743: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-04-29 15:00:09.838993: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [4]:
from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd,
    Invertd,
)
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob

print_config()

MONAI version: 1.4.1rc1+46.gb58e883c
Numpy version: 1.26.4
Pytorch version: 2.5.1+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: b58e883c887e0f99d382807550654c44d94f47bd
MONAI __file__: /usr/local/lib/python3.10/dist-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: 0.4.11
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.3.2
scikit-image version: 0.25.0
scipy version: 1.13.1
Pillow version: 11.0.0
Tensorboard version: 2.17.1
gdown version: 5.2.0
TorchVision version: 0.20.1+cu121
tqdm version: 4.67.1
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.5
pandas version: 2.2.3
einops version: 0.8.0
transformers version: 4.47.0
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/insta