# __Import & config__

In [1]:
%load_ext autoreload
%autoreload 2
import os
os.chdir('C:\\Users\\Usuario\\TFG\\digipanca\\')

In [30]:
import os
import numpy as np
from tqdm.auto import tqdm
import torch
from monai.data import DataLoader
import monai.transforms as mt
from monai.data import CacheDataset, load_decathlon_datalist
from monai.networks.nets import UNet
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric
from monai.utils import set_determinism
from scipy import ndimage

# __Configuration__

In [5]:
set_determinism(seed=42)
data_dir = "data/prepared/"
json_list = "data/splits/datalist.json"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 5  # background, pancreas, tumor, artery, vein

In [18]:
def make_transforms():
    deterministic = [
        mt.LoadImaged(keys=["image", "label"]),
        mt.EnsureChannelFirstd(keys=["image", "label"]),
        mt.Orientationd(keys=["image", "label"], axcodes="RAS"),
        mt.Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest"), align_corners=True),
        mt.CastToTyped(keys=["image"], dtype=torch.float32),
        mt.ScaleIntensityRanged(keys=["image"], a_min=-87, a_max=199, b_min=0, b_max=1, clip=True),
        mt.CastToTyped(keys=["image", "label"], dtype=[np.float16, np.uint8]),
        mt.CopyItemsd(keys=["label"], times=1, names=["label4crop"]),
        mt.Lambdad(
            keys="label4crop",
            func=lambda x: np.concatenate(tuple([
                ndimage.binary_dilation((x == c).astype(x.dtype), iterations=48).astype(float)
                for c in range(num_classes)
            ]), axis=0),
            overwrite=True
        ),
        mt.EnsureTyped(keys=["image", "label", "label4crop"]),
        mt.CastToTyped(keys=["image"], dtype=torch.float32),
        mt.SpatialPadd(keys=["image", "label", "label4crop"], spatial_size=(96, 96, 96), mode=["reflect", "constant", "constant"])
    ]

    random = [
        mt.RandCropByLabelClassesd(
            keys=["image", "label"],
            label_key="label4crop",
            spatial_size=(96, 96, 96),
            num_classes=num_classes,
            ratios=[1.0] * num_classes,
            num_samples=1
        ),
        mt.Lambdad(keys="label4crop", func=lambda x: 0),  # clean up the label4crop
        mt.RandFlipd(keys=["image", "label"], spatial_axis=0, prob=0.5),
        mt.RandFlipd(keys=["image", "label"], spatial_axis=1, prob=0.5),
        mt.RandFlipd(keys=["image", "label"], spatial_axis=2, prob=0.5),
        mt.RandRotated(keys=["image", "label"], range_x=0.2, range_y=0.2, range_z=0.2, prob=0.2),
        mt.RandZoomd(keys=["image", "label"], min_zoom=0.9, max_zoom=1.1, prob=0.2),
        mt.RandGaussianSmoothd(keys="image", prob=0.1),
        mt.RandGaussianNoised(keys="image", prob=0.1),
        mt.RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
        mt.RandScaleIntensityd(keys="image", factors=0.1, prob=0.5),
        mt.ToTensord(keys=["image", "label"])
    ]

    return mt.Compose(deterministic + random)

## Datalist

In [7]:
# Load list
datalist = load_decathlon_datalist(json_list, True, "training", base_dir=data_dir)
print(f"Number of training images: {len(datalist)}")
print(datalist[0])

Number of training images: 70
{'label': 'data\\prepared\\labelsTr\\rtum001.nii.gz', 'image': 'data\\prepared\\imagesTr\\rtum001.nii.gz'}


## Dataset

In [20]:
train_ds = CacheDataset(
    data=datalist,
    transform=make_transforms(),
    cache_rate=0.1,
    num_workers=4
)

Loading dataset: 100%|██████████| 7/7 [00:26<00:00,  3.82s/it]


## Dataloader

In [34]:
train_loader = DataLoader(
    train_ds,
    batch_size=2,
    shuffle=True,
    num_workers=0
)

## Model

In [23]:
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=num_classes,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm="batch"
).to(device)

## Optimization

In [24]:
loss_fn = DiceCELoss(to_onehot_y=True, softmax=True, include_background=False)
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)

## Loop

In [None]:
max_epochs = 100
for epoch in range(max_epochs):
    model.train()
    epoch_loss = 0
    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{max_epochs}", leave=False)
    
    for batch in loop:
        images, labels = (
            batch["image"].to(device),
            batch["label"].to(device)
        )

        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        loop.set_postfix(loss=loss.item())

    print(f"✅ Epoch {epoch+1} - Loss promedio: {epoch_loss / len(train_loader):.4f}")

Epoch 1/100:   0%|          | 0/35 [00:00<?, ?it/s]

{'label': metatensor([[[[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]],

          [[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]],

          [[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]],

          ...,

          [[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ..

Epoch 2/100:   0%|          | 0/35 [00:00<?, ?it/s]

{'label': metatensor([[[[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]],

          [[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]],

          [[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]],

          ...,

          [[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ..

KeyboardInterrupt: 