In [3]:
# Keep modules up to date every time you hit Shift-Enter 
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
from datetime import datetime
import logging
import sys
import os
import numpy as np
from dotenv import load_dotenv
import torch
from torch.utils.tensorboard import SummaryWriter

import monai
from monai.data import list_data_collate, decollate_batch, DataLoader
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.transforms import (
    Activations,
    RandAffined,
    AsDiscrete,
    Compose,
    Rand2DElasticd,
    RandShiftIntensityd,
)
from monai.visualize import plot_2d_or_3d_image
import wandb
from tumor_dataset import create_tumor_dataset



def main(dataset_dir):
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    monai.config.print_config()
    
    NUM_EPOCHS = 50
    BATCH_SIZE = 2
    LR = 1e-4
    PATCH_SIZE = (128, 128)

    # W&B run initialization
    run = wandb.init(
        project="tumor-segmentation",
        entity="nm-i-ki",
        name=f"unet_{datetime.now():%Y%m%d_%H%M%S}",
        config=dict(
            num_epochs=NUM_EPOCHS,
            batch_size=BATCH_SIZE,
            learning_rate=LR,
            patch_size=PATCH_SIZE,
            architecture="UNet",
            loss="Dice(sigmoid=True)",
            optimizer="Adam",
        ),
        save_code=True,
        sync_tensorboard=True,
        tags=["monai", "segmentation"],
    )

    # Train transforms
    train_transforms = [
        RandAffined(
            keys=["img", "seg"],
            mode=["bilinear", "nearest"],
            spatial_size=PATCH_SIZE,

            rotate_range=(np.pi / 20, np.pi / 20),
            scale_range=(0.1, 0.1),
            translate_range=(-10, 10),
            padding_mode="border",
        ),
        Rand2DElasticd(
            keys=["img", "seg"],
            spacing=(24, 24),
            magnitude_range=(1, 10),
            mode=["bilinear", "nearest"],
            spatial_size=PATCH_SIZE,

            rotate_range=(0.1, 0.1),
            scale_range=(0.1, 0.1),
            translate_range=(-10, 10),
        ),
        RandShiftIntensityd(
            keys=["img"],
            offsets=0.2,
            safe=True,
            channel_wise=False
        )
    ]

    # Create datasets
    train_ds, val_ds = create_tumor_dataset(dataset_dir=dataset_dir, train_data_augmentation=train_transforms)

    print(f"Amount of images train: {len(train_ds)} val: {len(val_ds)}")
    train_loader = DataLoader(
        train_ds,
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=list_data_collate,
        pin_memory=torch.cuda.is_available(),
    )
    val_loader = DataLoader(val_ds, batch_size=1, collate_fn=list_data_collate)

    # Model, loss, optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = monai.networks.nets.UNet(
        spatial_dims=2,
        in_channels=4,
        out_channels=4,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss_function = monai.losses.DiceLoss(sigmoid=True)
    optimizer = torch.optim.Adam(model.parameters(), LR)

    # Attach gradients & parameters to W&B
    wandb.watch(model, log="all", log_freq=10)

    # Train model
    print("Starting training...")
    dice_metric = DiceMetric(include_background=True, reduction="mean")
    post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
    tensorboard_writer = SummaryWriter()
    best_metric = -1.0
    best_epoch = -1

    os.makedirs("models", exist_ok=True)
    global_step = 0

    for epoch in range(NUM_EPOCHS):
        print(f"\nEpoch {epoch + 1}/{NUM_EPOCHS}")
        model.train()
        epoch_loss = 0
        for step, batch in enumerate(train_loader, 1):
            imgs, segs = batch["img"].to(device), batch["seg"].to(device)
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = loss_function(outputs, segs)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

            # Log per step
            global_step += 1
            wandb.log({"train_loss": loss.item(), "epoch": epoch + 1, "step": global_step})
            tensorboard_writer.add_scalar(
                "train_loss", loss.item(), epoch * len(train_loader) + step
            )
        avg_epoch_loss = epoch_loss / step
        print(f"  Train avg loss: {avg_epoch_loss:.4f}")
        wandb.log({"train_avg_loss": avg_epoch_loss, "epoch": epoch + 1})

        if (epoch + 1) % 2 == 0:
            model.eval()
            with torch.no_grad():
                for val_batch in val_loader:
                    val_imgs, val_segs = val_batch["img"].to(device), val_batch["seg"].to(device)
                    sw_out = sliding_window_inference(val_imgs, PATCH_SIZE, 4, model)
                    preds = [post_trans(x) for x in decollate_batch(sw_out)]
                    dice_metric(y_pred=preds, y=val_segs)
                metric = dice_metric.aggregate().item()
                dice_metric.reset()
                tensorboard_writer.add_scalar("val_mean_dice", metric, epoch + 1)
                print(f"  Val mean Dice: {metric:.4f}")

                # Log validation metric & sample images
                wandb.log({"val_mean_dice": metric, "epoch": epoch + 1})

                # Log first channel of first image / pred / label as examples
                img_np = val_imgs[0, 0].cpu().float().numpy()
                pred_np = preds[0][0].cpu().float().numpy()
                label_np = val_segs[0, 0].cpu().float().numpy()
                wandb.log(
                    {
                        "example_input": wandb.Image(img_np, caption="input"),
                        "example_pred": wandb.Image(pred_np, caption="prediction"),
                        "example_label": wandb.Image(label_np, caption="ground truth"),
                        "epoch": epoch + 1,
                    }
                )

                if metric > best_metric:
                    best_metric = metric
                    best_epoch = epoch + 1
                    best_model_path = f"models/best_model_{metric:.4f}.pth"
                    torch.save(model.state_dict(), best_model_path)
                    print(
                        f"  Best model saved with Dice {best_metric:.4f} at epoch {best_epoch}"
                    )

                    # Save model to W&B as an artifact
                    artifact = wandb.Artifact("best_model", type="model")
                    artifact.add_file(best_model_path)
                    run.log_artifact(artifact)

                # Continue logging to TensorBoard if desired
                plot_2d_or_3d_image(val_imgs, epoch + 1, tensorboard_writer, index=0, tag="image")
                plot_2d_or_3d_image(val_segs, epoch + 1, tensorboard_writer, index=0, tag="label")
                plot_2d_or_3d_image(preds, epoch + 1, tensorboard_writer, index=0, tag="output")

    print(
        f"\nTraining done! Best Dice {best_metric:.4f} reached on epoch {best_epoch}"
    )
    tensorboard_writer.close()
    run.finish()


if __name__ == "__main__":
    load_dotenv()  # Load environment variables from .env file
    WANDB_API_KEY = os.getenv("WANDB_API_KEY")
    if not WANDB_API_KEY:
        raise ValueError("WANDB_API_KEY not found in environment variables. Please set it in .env file.")
    print(f"WANDB_API_KEY: {WANDB_API_KEY[:4]}...")  # Print only the first 4 characters for security
    main("../data/raw/tumor-segmentation")

WANDB_API_KEY: ac37...
MONAI version: 1.4.0
Numpy version: 1.26.4
Pytorch version: 2.7.1+cu126
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 46a5272196a6c2590ca2589029eed8e4d56ff008
MONAI __file__: /media/<username>/AI-Mesterskap/norwegian-ai-championship-2025/segmentation/.venv/lib/python3.11/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: 0.4.11
ITK version: 5.4.4
Nibabel version: 5.3.2
scikit-image version: 0.25.2
scipy version: 1.16.1
Pillow version: 11.3.0
Tensorboard version: 2.20.0
gdown version: 5.2.0
TorchVision version: 0.22.1+cu126
tqdm version: 4.67.1
lmdb version: 1.7.3
psutil version: 7.0.0
pandas version: 2.3.1
einops version: 0.8.1
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: 3.1.4
pynrrd version: 1.1.3
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation



Found 426 control images.
Found 182 patient images.
Randomly selected 182 control samples from 426 available.
Final dataset: 182 patients + 182 controls = 364 samples (50/50 split)
Amount of images train: 328 val: 36
Starting training...

Epoch 1/50


Num foregrounds 195600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 176800, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 353200, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 135600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 185200, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 371200, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 170400, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 189600, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio` to 1.
Num foregrounds 168400, Num backgrounds 0, unable to generate class balanced samples, setting `pos_ratio

  Train avg loss: 0.7934

Epoch 2/50
  Train avg loss: 0.7793
  Val mean Dice: 0.2958
  Best model saved with Dice 0.2958 at epoch 2

Epoch 3/50
  Train avg loss: 0.7744

Epoch 4/50
  Train avg loss: 0.7668
  Val mean Dice: 0.2931

Epoch 5/50
  Train avg loss: 0.7614

Epoch 6/50
  Train avg loss: 0.7565
  Val mean Dice: 0.2964
  Best model saved with Dice 0.2964 at epoch 6

Epoch 7/50
  Train avg loss: 0.7537

Epoch 8/50
  Train avg loss: 0.7509
  Val mean Dice: 0.3031
  Best model saved with Dice 0.3031 at epoch 8

Epoch 9/50
  Train avg loss: 0.7487

Epoch 10/50
  Train avg loss: 0.7467
  Val mean Dice: 0.3071
  Best model saved with Dice 0.3071 at epoch 10

Epoch 11/50
  Train avg loss: 0.7440

Epoch 12/50
  Train avg loss: 0.7424
  Val mean Dice: 0.3226
  Best model saved with Dice 0.3226 at epoch 12

Epoch 13/50
  Train avg loss: 0.7410

Epoch 14/50
  Train avg loss: 0.7378
  Val mean Dice: 0.3404
  Best model saved with Dice 0.3404 at epoch 14

Epoch 15/50
  Train avg loss: 0.740

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


  Val mean Dice: 0.5321

Training done! Best Dice 0.5516 reached on epoch 42


0,1
epoch,▁▁▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇██████
global_step,▁▁▁▂▂▂▂▂▂▁▂▃▃▃▃▄▄▄▁▄▅▅▅▅▅▅▁▆▆▆▆▇▇▇▇▇████
step,▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▄▅▅▆▆▆▆▇▇▇▇▇▇████
train_avg_loss,█▇▇▆▆▅▅▅▄▄▄▄▄▃▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁
train_loss,██▇▇▆▇▅▇▇▇▇▅▇▇▇▆▆▆▇▆▁▆▆▆▆▆▆▅▆▆▂▁▄▆▆▆▄▄▆▆
val_mean_dice,▁▁▁▁▁▁▁▁▁▂▂▂▃▃▄▄▄▄▆▆▆▅▅▆▆▇▇▆▆▇▇▇██▇▇▇▆▆▇

0,1
epoch,50.0
global_step,50.0
step,8200.0
train_avg_loss,0.70217
train_loss,0.7501
val_mean_dice,0.53207
