# 🧠 U-Net Inference Script (Single Dataset)
This notebook runs inference using a selected dataset with nnU-Net.

In [None]:
import json
import os
import warnings

import torch
from asmunet.evalute_utils import compute_metrics_on_folder, labels_to_list_of_regions
from batchgenerators.utilities.file_and_folder_operations import join
from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
from nnunetv2.paths import nnUNet_raw, nnUNet_results

### 🔧 Inference Configuration
Set key parameters for nnU-Net inference, including dataset ID, trainer, checkpoint, and GPU usage.

In [None]:
Dataset_ID = 1  # ID of the dataset
fold = "all"  # Use "all" or a specific fold (e.g., "1")
use_gpu = "0"  # GPU ID to be used
tr = "nnUNetTrainer_asmunet_enc"  # Trainer name
checkpoint_name = "checkpoint_final.pth"  # Or "checkpoint_latest.pth"
predicited_set = "imagesTs"  # Dataset to predict (typically "imagesTs")

### 🎯 Dataset Selection and GPU Configuration
Set the GPU device and identify the target dataset based on the given ID.

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = use_gpu
dataset_name = [i for i in os.listdir(nnUNet_raw) if str(Dataset_ID).zfill(3) in i]
assert len(dataset_name) != 0, f"No dataset found for ID {Dataset_ID:03d}"
dataset_name = dataset_name[0]

### 🚀 Initialize nnU-Net Predictor
Set up the nnU-Net predictor and load the trained model checkpoint.

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")

    predictor = nnUNetPredictor(
        tile_step_size=0.5,
        use_gaussian=True,
        use_mirroring=True,
        perform_everything_on_device=True,
        device=torch.device("cuda", 0),
        verbose=False,
        verbose_preprocessing=False,
        allow_tqdm=True,
    )
    predictor.initialize_from_trained_model_folder(
        join(nnUNet_results, dataset_name + "/" + tr + "__nnUNetPlans__3d_fullres"),
        use_folds=(fold),
        checkpoint_name=checkpoint_name,
    )

### 🧠 Run Inference on Test Set
Perform prediction using the initialized nnU-Net model on the specified dataset.

In [None]:
p = predictor.predict_from_files(
    join(nnUNet_raw, dataset_name, predicited_set),
    join(
        nnUNet_results,
        dataset_name,
        f"{tr}__nnUNetPlans__3d_fullres",
        f"{predicited_set}_predlowres",
        f"fold_{fold}",
    ),
    save_probabilities=False,
    overwrite=False,
    num_processes_preprocessing=2,
    num_processes_segmentation_export=2,
    folder_with_segs_from_prev_stage=None,
    num_parts=1,
    part_id=0,
)

### 📊 Evaluate Segmentation Results
Compute evaluation metrics by comparing predictions with reference labels and save them to a JSON summary.

In [None]:
folder_ref = join(nnUNet_raw, dataset_name, "labels" + predicited_set[-2:])
folder_pred = join(
    nnUNet_results,
    dataset_name,
    f"{tr}__nnUNetPlans__3d_fullres",
    f"{predicited_set}_predlowres",
    f"fold_{fold}",
)
output_file = join(
    nnUNet_results,
    dataset_name,
    f"{tr}__nnUNetPlans__3d_fullres",
    f"{predicited_set}_predlowres",
    f"fold_{fold}",
    "summary.json",
)
image_reader_writer = SimpleITKIO()
file_ending = ".nii.gz"
label_len = len(
    json.load(open(join(nnUNet_raw, dataset_name, "dataset.json")))["labels"]
)
regions = labels_to_list_of_regions(list(range(1, label_len)))
ignore_label = None
num_processes = 12
results = compute_metrics_on_folder(
    folder_ref,
    folder_pred,
    output_file,
    image_reader_writer,
    file_ending,
    regions,
    ignore_label,
    num_processes,
)

In [None]:
r = json.load(open(output_file))
for label, metrics in r["mean"].items():
    print(f"Label: {label}", f"Dice: {metrics['Dice']}")
print(r["foreground_mean"]["Dice"])