In [1]:
# Keep modules up to date every time you hit Shift-Enter 
%load_ext autoreload
%autoreload 2

In [None]:
import logging
import sys

import torch
from torch.utils.tensorboard import SummaryWriter

import monai
from monai.data import list_data_collate, decollate_batch, DataLoader
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.transforms import (
    Activations,
    AsDiscrete,
    Compose,
)
from monai.visualize import plot_2d_or_3d_image
from tumor_dataset import create_tumor_dataset

def main(dataset_dir):
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    monai.config.print_config()
    
    NUM_EPOCHS = 10
    BATCH_SIZE = 2

    # Create datasets
    train_ds, val_ds, = create_tumor_dataset(
        dataset_dir=dataset_dir,
    )

    print(f"Amount of images train: {len(train_ds)} val: {len(val_ds)}")
    train_loader = DataLoader(
        train_ds, batch_size=BATCH_SIZE, shuffle=True,
        num_workers=4, collate_fn=list_data_collate,
        pin_memory=torch.cuda.is_available(),
    )
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)

    # Model, loss, optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = monai.networks.nets.UNet(
        spatial_dims=2, in_channels=4, out_channels=4,
        channels=(16,32,64,128,256), strides=(2,2,2,2),
        num_res_units=2,
    ).to(device)
    loss_function = monai.losses.DiceLoss(sigmoid=True)
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)

    # Train model
    print("Starting training...")
    dice_metric = DiceMetric(include_background=True, reduction="mean")
    post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
    writer = SummaryWriter()
    best_metric = -1.0
    best_epoch = -1

    for epoch in range(NUM_EPOCHS):
        print(f"\nEpoch {epoch+1}/10")
        model.train()
        epoch_loss = 0
        for step, batch in enumerate(train_loader, 1):
            imgs, segs = batch["img"].to(device), batch["seg"].to(device)
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = loss_function(outputs, segs)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            writer.add_scalar("train_loss", loss.item(), epoch * len(train_loader) + step)
        print(f"  Train avg loss: {epoch_loss/step:.4f}")

        if (epoch+1) % 2 == 0:
            model.eval()
            with torch.no_grad():
                for val_batch in val_loader:
                    val_imgs, val_segs = val_batch["img"].to(device), val_batch["seg"].to(device)
                    sw_out = sliding_window_inference(val_imgs, (96,96), 4, model)
                    preds = [post_trans(x) for x in decollate_batch(sw_out)]
                    dice_metric(y_pred=preds, y=val_segs)
                metric = dice_metric.aggregate().item()
                dice_metric.reset()
                writer.add_scalar("val_mean_dice", metric, epoch+1)
                print(f"  Val mean Dice: {metric:.4f}")

                if metric > best_metric:
                    best_metric = metric
                    best_epoch = epoch+1
                    torch.save(model.state_dict(), f"models/best_model_{metric}.pth")
                    print(f"  Best model saved with Dice {best_metric:.4f} at epoch {best_epoch}")
                # Log images
                plot_2d_or_3d_image(val_imgs, epoch+1, writer, index=0, tag="image")
                plot_2d_or_3d_image(val_segs, epoch+1, writer, index=0, tag="label")
                plot_2d_or_3d_image(preds, epoch+1, writer, index=0, tag="output")

    print(f"\nTraining done! Beste Dice {best_metric:.4f} reached on epoch {best_epoch}")
    writer.close()


if __name__ == "__main__":
    main("../data/raw/tumor-segmentation")

MONAI version: 1.4.0
Numpy version: 1.26.4
Pytorch version: 2.7.1+cu126
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 46a5272196a6c2590ca2589029eed8e4d56ff008
MONAI __file__: /home/<username>/Documents/cogito/norwegian-ai-championship-2025/segmentation/.venv/lib/python3.12/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: 0.4.11
ITK version: 5.4.4
Nibabel version: 5.3.2
scikit-image version: 0.25.2
scipy version: 1.16.1
Pillow version: 11.3.0
Tensorboard version: 2.20.0
gdown version: 5.2.0
TorchVision version: 0.22.1+cu126
tqdm version: 4.67.1
lmdb version: 1.7.3
psutil version: 7.0.0
pandas version: 2.3.1
einops version: 0.8.1
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: 3.1.4
pynrrd version: 1.1.3
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-

Num foregrounds 194400, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 195600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 170400, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 220800, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 195600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 170400, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 195600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 170400, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 170800, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio

  Train avg loss: 0.7768

Epoch 2/10


Num foregrounds 195600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 195600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 195600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 391600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 170400, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 170400, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 182800, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 170400, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 346000, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio

  Train avg loss: 0.7521
  Val mean Dice: 0.2800
  Best model saved with Dice 0.2800 at epoch 2

Epoch 3/10


Num foregrounds 170400, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 195600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 195600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 170400, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 170400, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 389200, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 170400, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 195600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 195600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio

  Train avg loss: 0.7302

Epoch 4/10


Num foregrounds 170400, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 170400, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 195600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 187600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 189600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 192000, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 170400, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 195600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 194400, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio

  Train avg loss: 0.7271
  Val mean Dice: 0.2957
  Best model saved with Dice 0.2957 at epoch 4

Epoch 5/10


Num foregrounds 170400, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 195600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 195600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 170400, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 191200, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 195600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 182800, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 195600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 389200, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio

  Train avg loss: 0.7145

Epoch 6/10


Num foregrounds 195600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 195600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 191200, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 346000, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 170400, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 170400, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 195600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 195600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 170400, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio

  Train avg loss: 0.7153
  Val mean Dice: 0.3853
  Best model saved with Dice 0.3853 at epoch 6

Epoch 7/10


Num foregrounds 194400, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 220800, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 195600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 195600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 205200, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 170400, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 195600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 170400, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 170400, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio

  Train avg loss: 0.7119

Epoch 8/10


Num foregrounds 194400, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 195600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 195600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 195600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 170400, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 170800, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 170400, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 170400, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 195600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio

  Train avg loss: 0.6967
  Val mean Dice: 0.4539
  Best model saved with Dice 0.4539 at epoch 8

Epoch 9/10


Num foregrounds 195600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 170400, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 391600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 350400, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 195600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 195600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 170800, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 389200, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 170400, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio

  Train avg loss: 0.6901

Epoch 10/10


Num foregrounds 195600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 195600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 195600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 391600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 170400, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 170400, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 170400, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 220800, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 195600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio

  Train avg loss: 0.6784
  Val mean Dice: 0.4221

Training done! Beste Dice 0.4539 reached on epoch 8
