In [5]:
import lightning as L
import numpy as np
import torchio as tio

import wandb
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger, WandbLogger
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from pathlib import Path

import wandb
from dataloader import ImageMaskDataset
from model.network import Neuratt

validation_path = Path("/scratch/dataset_patch_128_steps_64/test")
batch_size = 8
num_workers = 8

checkpoint_path = "/data/smartspim_brain_seg_models/whole_brain_seg/whole_brain_seg/cfelpja3/checkpoints/best_model.ckpt"
results_folder = "/results"

In [6]:
val_dataset = ImageMaskDataset(
    validation_path.joinpath('images'),
    validation_path.joinpath('masks'),
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
)

segmentation_model = Neuratt()

if checkpoint_path:
    print(f"Loading path from {checkpoint_path}")
    segmentation_model = Neuratt.load_from_checkpoint(checkpoint_path)

Loading path from /data/smartspim_brain_seg_models/whole_brain_seg/whole_brain_seg/cfelpja3/checkpoints/best_model.ckpt


In [7]:
trainer = L.Trainer(
    default_root_dir=results_folder,
    # callbacks=callbacks,
    # logger=logger,
    max_epochs=100,
    max_time="00:04:00:00",
    devices=1,
    accelerator="gpu",
    # deterministic=True,
    # overfit_batches=1,
    log_every_n_steps=10,
    limit_predict_batches=200,
)

predictions = trainer.predict(segmentation_model, val_dataloader)
print("Number of predictions: ", len(predictions))

metric_list = []
for i, (data, pred, metrics) in enumerate(predictions):
    print(f"[{i}] Metrics: {metrics}")
    metric_list.append(metrics)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting DataLoader 0: 100%|██████████| 139/139 [01:13<00:00,  1.89it/s]
Number of predictions:  139
[0] Metrics: {'loss': 0.11746280640363693, 'dice': 0.9777041673660278, 'jacc': 0.9563809633255005}
[1] Metrics: {'loss': 0.19352303445339203, 'dice': 0.9730879664421082, 'jacc': 0.9475864171981812}
[2] Metrics: {'loss': 0.10105496644973755, 'dice': 0.9881359934806824, 'jacc': 0.9765502214431763}
[3] Metrics: {'loss': 0.22187963128089905, 'dice': 0.9825949668884277, 'jacc': 0.9657854437828064}
[4] Metrics: {'loss': 0.23958858847618103, 'dice': 0.9719198346138, 'jacc': 0.94537353515625}
[5] Metrics: {'loss': 0.5031952261924744, 'dice': 0.9520814418792725, 'jacc': 0.9085453152656555}
[6] Metrics: {'loss': 16.927101135253906, 'dice': 0.9550781846046448, 'jacc': 0.9140188097953796}
[7] Metrics: {'loss': 0.5022412538528442, 'dice': 0.949425220489502, 'jacc': 0.9037197232246399}
[8] Metrics: {'loss': 34.10587692260742, 'dice': 0.9528785347938538, 'jacc': 0.9099980592727661}
[9] Metrics: {'lo

In [8]:
from statistics import mean, median

keys = ["loss", "dice", "jacc"]
stats = {}

for key in keys:
    values = [entry[key] for entry in metric_list]
    stats[key] = {
        "min": min(values),
        "max": max(values),
        "median": median(values),
        "mean": mean(values),
    }

# Print the results
for key, stat in stats.items():
    print(f"{key}:")
    for stat_name, value in stat.items():
        print(f"  {stat_name}: {value:.4f}")

loss:
  min: 0.0796
  max: 147.8349
  median: 0.3755
  mean: 6.9434
dice:
  min: 0.1817
  max: 0.9910
  median: 0.9308
  mean: 0.8774
jacc:
  min: 0.0999
  max: 0.9822
  median: 0.8705
  mean: 0.8044
