In [1]:
import lightning as L
import numpy as np
import torchio as tio

import wandb
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger, WandbLogger
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from pathlib import Path

import wandb
from dataloader import ImageMaskDataset
from model.network import Neuratt

validation_path = Path("/scratch/dataset_patch_128_steps_64/test")
batch_size = 8
num_workers = 8

checkpoint_path = "/results/whole_brain_seg/whole_brain_seg/gl0un99j/checkpoints/best_model.ckpt"
#b1s704bd/checkpoints/best_model.ckpt"
#"/data/smartspim_brain_seg_models/whole_brain_seg/whole_brain_seg/cfelpja3/checkpoints/best_model.ckpt"
results_folder = "/results"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
val_dataset = ImageMaskDataset(
    validation_path.joinpath('images'),
    validation_path.joinpath('masks'),
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
)

segmentation_model = Neuratt()

if checkpoint_path:
    print(f"Loading path from {checkpoint_path}")
    segmentation_model = Neuratt.load_from_checkpoint(checkpoint_path)

Loading path from /results/whole_brain_seg/whole_brain_seg/gl0un99j/checkpoints/best_model.ckpt


In [3]:
trainer = L.Trainer(
    default_root_dir=results_folder,
    # callbacks=callbacks,
    # logger=logger,
    max_epochs=100,
    max_time="00:04:00:00",
    devices=1,
    accelerator="gpu",
    # deterministic=True,
    # overfit_batches=1,
    log_every_n_steps=10,
    limit_predict_batches=200,
)

predictions = trainer.predict(segmentation_model, val_dataloader)
print("Number of predictions: ", len(predictions))

metric_list = []

for i, (data, pred, metrics) in enumerate(predictions):
    print(f"[{i}] Metrics: {metrics}")
    metrics['data'] = data
    metrics['pred'] = pred
    metric_list.append(metrics)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/opt/conda/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting DataLoader 0: 100%|██████████| 139/139 [01:14<00:00,  1.87it/s]
Number of predictions:  139
[0] Metrics: {'loss': 0.05729716271162033, 'dice': 0.9908859729766846, 'jacc': 0.9819364547729492}
[1] Metrics: {'loss': 0.08130361139774323, 'dice': 0.9862983226776123, 'jacc': 0.9729670286178589}
[2] Metrics: {'loss': 0.028762318193912506, 'dice': 0.9945313930511475, 'jacc': 0.9891221523284912}
[3] Metrics: {'loss': 0.11100957542657852, 'dice': 0.9902204275131226, 'jacc': 0.9806302189826965}
[4] Metrics: {'loss': 0.12523505091667175, 'dice': 0.9877841472625732, 'jacc': 0.9758630990982056}
[5] Metrics: {'loss': 0.26786550879478455, 'dice': 0.9731999039649963, 'jacc': 0.9477987289428711}
[6] Metrics: {'loss': 18.501550674438477, 'dice': 0.9607630372047424, 'jacc': 0.9244889616966248}
[7] Metrics: {'loss': 0.3353996276855469, 'dice': 0.9682934880256653, 'jacc': 0.9385356903076172}
[8] Metrics: {'loss': 52.92434310913086, 'dice': 0.9428737759590149, 'jacc': 0.8919216990470886}
[9] Metri

In [4]:
from statistics import mean, median

keys = ["loss", "dice", "jacc"]
stats = {}

for key in keys:
    values = [entry[key] for entry in metric_list]
    stats[key] = {
        "min": min(values),
        "max": max(values),
        "median": median(values),
        "mean": mean(values),
    }

# Print the results
for key, stat in stats.items():
    print(f"{key}:")
    for stat_name, value in stat.items():
        print(f"  {stat_name}: {value:.4f}")

loss:
  min: 0.0259
  max: 228.0285
  median: 0.2238
  mean: 10.2313
dice:
  min: 0.3468
  max: 0.9974
  median: 0.9575
  mean: 0.9189
jacc:
  min: 0.2098
  max: 0.9948
  median: 0.9185
  mean: 0.8647


In [5]:
metric_list = sorted(metric_list, key=lambda d: d['dice'])

In [6]:
for idx, worse_seg in enumerate(metric_list[:5]):
    np.save(
        file=f"/results/worse_seg_data_dice_{worse_seg['dice']}_idx_{idx}.npy",
        arr=worse_seg['data']
    )
    np.save(
        file=f"/results/worse_seg_pred_dice_{worse_seg['dice']}_idx_{idx}.npy",
        arr=worse_seg['pred']
    )

for idx, best_seg in enumerate(metric_list[-5:]):
    np.save(
        file=f"/results/best_seg_data_dice_{best_seg['dice']}_idx_{idx}.npy",
        arr=best_seg['data']
    )
    np.save(
        file=f"/results/best_seg_pred_dice_{best_seg['dice']}_idx_{idx}.npy",
        arr=best_seg['pred']
    )
