From 0413df297577f606301746806bb04d494b585db0 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Tue, 11 Mar 2025 12:21:58 -0600 Subject: [PATCH 01/10] init Signed-off-by: dongyang0122 --- generation/maisi/scripts/compute_fid2p5d.py | 682 ++++++++++++++++++++ 1 file changed, 682 insertions(+) create mode 100644 generation/maisi/scripts/compute_fid2p5d.py diff --git a/generation/maisi/scripts/compute_fid2p5d.py b/generation/maisi/scripts/compute_fid2p5d.py new file mode 100644 index 000000000..d5fb11499 --- /dev/null +++ b/generation/maisi/scripts/compute_fid2p5d.py @@ -0,0 +1,682 @@ +#!/usr/bin/env python +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, +# either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. + +""" +Compute 2.5D FID using distributed GPU processing, **without** external fid_utils dependencies. + +SHELL Usage Example: +------------------- + #!/bin/bash + + export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 + NUM_GPUS=7 + + torchrun --nproc_per_node=${NUM_GPUS} compute_fid2d_mgpu.py \ + --model_name "radimagenet_resnet50" \ + --data0_dataroot "path/to/datasetA" \ + --data0_filelist "path/to/filelistA.txt" \ + --data0_folder "datasetA" \ + --data1_dataroot "path/to/datasetB" \ + --data1_filelist "path/to/filelistB.txt" \ + --data1_folder "datasetB" \ + --enable_center_slices False \ + --enable_padding True \ + --enable_center_cropping True \ + --enable_resampling True \ + --ignore_existing True \ + --num_images 100 \ + --output_root "./features/features-512x512x512" \ + --target_shape "512x512x512" + +This script loads two datasets (real vs. synthetic) in 3D medical format (NIfTI) and extracts +feature maps via a 2.5D approach. It then computes the Frechet Inception Distance (FID) +across three orthogonal planes. Data parallelism is implemented using +torch.distributed with an NCCL backend. + +Function Arguments (main): +-------------------------- + data0_dataroot (str): Root folder for dataset 0 (real images). + data0_filelist (str): Text file listing 3D images for dataset 0. + data0_folder (str): Name for dataset 0 output folder under `output_root`. + data1_dataroot (str): Root folder for dataset 1 (synthetic images). + data1_filelist (str): Text file listing 3D images for dataset 1. + data1_folder (str): Name for dataset 1 output folder under `output_root`. + enable_center_slices (bool): Whether to slice around the center region of each axis. + enable_center_slices_ratio (float): Ratio of slices to take from the center if `enable_center_slices` is True. + enable_padding (bool): If True, pad images to `target_shape` before feature extraction. + enable_center_cropping (bool): If True, center-crop images to `target_shape`. + enable_resampling (bool): If True, resample images to `enable_resampling_spacing`. + enable_resampling_spacing (str): Target voxel spacing, e.g. "1.0x1.0x1.0". + ignore_existing (bool): If True, ignore existing .pt feature files and recompute. + model_name (str): Model identifier. Uses "radimagenet_resnet50" or a default squeezenet1_1. + num_images (int): Max number of images from each dataset to process (truncate if larger). + output_root (str): Folder where extracted feature .pt files and logs are saved. + target_shape (str): "XxYxZ" shape to which images are padded/cropped/resampled. + +Example: +-------- + python compute_fid2p5d.py --model_name=radimagenet_resnet50 \\ + --data0_dataroot=/data/real_images \\ + --data0_filelist=/data/real_list.txt \\ + --data0_folder=real \\ + --data1_dataroot=/data/synth_images \\ + --data1_filelist=/data/synth_list.txt \\ + --data1_folder=synth \\ + --enable_center_slices=True \\ + --enable_padding=True \\ + --target_shape=512x512x512 + +""" + +from __future__ import annotations + +import os +import sys +import torch +import fire +import monai +import re +import torch.distributed as dist +import torch.nn.functional as F + +from datetime import timedelta +from pathlib import Path +from generative.metrics import FIDMetric +from monai.transforms import Compose + +# ------------------------------------------------------------------------------ +# Below are the core utilities originally in fid_utils.py, now inlined here +# to remove external dependency. +# ------------------------------------------------------------------------------ + +def drop_empty_slice(slices, empty_threshold: float): + """ + Decide which 2D slices to keep by checking if their maximum intensity + is below a certain threshold. + + Args: + slices (tuple or list of Tensors): Each element is (B, C, H, W). + empty_threshold (float): If the slice's maximum value is below this threshold, + it is considered "empty". + + Returns: + list[bool]: A list of booleans indicating for each slice whether to keep it. + """ + outputs = [] + n_drop = 0 + for s in slices: + largest_unique = torch.max(torch.unique(s)) + if largest_unique < empty_threshold: + outputs.append(False) + n_drop += 1 + else: + outputs.append(True) + + print(f"Empty slice drop rate {round((n_drop/len(slices))*100,1)}%") + return outputs + + +def subtract_mean(x: torch.Tensor) -> torch.Tensor: + """ + Subtract per-channel means (ImageNet-like: [0.406, 0.456, 0.485]) + from the input 4D or 5D tensor. Expects channels in the first dimension + after the batch dimension: (B, C, H, W) or (B, C, H, W, D). + """ + mean = [0.406, 0.456, 0.485] + x[:, 0, ...] -= mean[0] + x[:, 1, ...] -= mean[1] + x[:, 2, ...] -= mean[2] + return x + + +def spatial_average(x: torch.Tensor, keepdim: bool = True) -> torch.Tensor: + """ + Average out the spatial dimensions of a tensor, preserving or removing them + according to `keepdim`. This is used to produce a 1D feature vector + out of a feature map. + + Args: + x (torch.Tensor): Input tensor (B, C, H, W, ...) or (B, C, H, W). + keepdim (bool): Whether to keep dimension or not after averaging. + + Returns: + torch.Tensor: Tensor with reduced spatial dimensions. + """ + dim = len(x.shape) + # 2D -> no average + if dim == 2: + return x + # 3D -> average over last dim + if dim == 3: + return x.mean([2], keepdim=keepdim) + # 4D -> average over H,W + if dim == 4: + return x.mean([2, 3], keepdim=keepdim) + # 5D -> average over H,W,D + if dim == 5: + return x.mean([2, 3, 4], keepdim=keepdim) + return x + + +def medicalnet_intensity_normalisation(volume: torch.Tensor) -> torch.Tensor: + """ + Intensity normalization approach from MedicalNet: + (volume - mean) / (std + 1e-5) across spatial dims. + Expects (B, C, H, W) or (B, C, H, W, D). + """ + dim = len(volume.shape) + if dim == 4: + mean = volume.mean([2, 3], keepdim=True) + std = volume.std([2, 3], keepdim=True) + elif dim == 5: + mean = volume.mean([2, 3, 4], keepdim=True) + std = volume.std([2, 3, 4], keepdim=True) + else: + return volume + return (volume - mean) / (std + 1e-5) + + +def radimagenet_intensity_normalisation(volume: torch.Tensor, norm2d: bool = False) -> torch.Tensor: + """ + Intensity normalization for radimagenet_resnet. Optionally normalizes each 2D slice individually. + + Args: + volume (torch.Tensor): Input (B, C, H, W) or (B, C, H, W, D). + norm2d (bool): If True, normalizes each (H,W) slice to [0,1], then subtracts the ImageNet mean. + """ + print(f"norm2d: {norm2d}") + dim = len(volume.shape) + # If norm2d is True, only meaningful for 4D data (B, C, H, W): + if dim == 4 and norm2d: + max2d, _ = torch.max(volume, dim=2, keepdim=True) + max2d, _ = torch.max(max2d, dim=3, keepdim=True) + min2d, _ = torch.min(volume, dim=2, keepdim=True) + min2d, _ = torch.min(min2d, dim=3, keepdim=True) + # Scale each slice to 0..1 + volume = (volume - min2d) / (max2d - min2d + 1e-10) + # Subtract channel mean + return subtract_mean(volume) + elif dim == 4: + # 4D but no per-slice normalization + max3d = torch.max(volume) + min3d = torch.min(volume) + volume = (volume - min3d) / (max3d - min3d + 1e-10) + return subtract_mean(volume) + # Fallback for e.g. 5D data is simply a min-max over entire volume + if dim == 5: + maxval = torch.max(volume) + minval = torch.min(volume) + volume = (volume - minval) / (maxval - minval + 1e-10) + return subtract_mean(volume) + return volume + + +def get_features_2p5d( + image: torch.Tensor, + feature_network: torch.nn.Module, + center_slices: bool = False, + center_slices_ratio: float = 1.0, + sample_every_k: int = 1, + xy_only: bool = True, + drop_empty: bool = False, + empty_threshold: float = -700, +) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: + """ + Extract 2.5D features from a 3D image by slicing it along XY, YZ, ZX planes. + + Args: + image (torch.Tensor): Input 5D tensor in shape (B, C, H, W, D). + feature_network (torch.nn.Module): Model that processes 2D slices (C,H,W). + center_slices (bool): Whether to slice only the center portion of each axis. + center_slices_ratio (float): Ratio of slices to keep in the center if `center_slices` is True. + sample_every_k (int): Downsampling factor along each axis when slicing. + xy_only (bool): If True, return only the XY-plane features. + drop_empty (bool): Drop slices that are deemed "empty" below `empty_threshold`. + empty_threshold (float): Threshold to decide emptiness of slices. + + Returns: + tuple of torch.Tensor or None: (XY_features, YZ_features, ZX_features). + """ + print(f"center_slices: {center_slices}, ratio: {center_slices_ratio}") + + # If there's only 1 channel, replicate to 3 channels + if image.shape[1] == 1: + image = image.repeat(1, 3, 1, 1, 1) + + # Convert from 'RGB'→(R,G,B) to (B, G, R) ordering + image = image[:, [2, 1, 0], ...] + + B, C, H, W, D = image.size() + with torch.no_grad(): + # --------------------------------------------------------------------- + # 1) XY-plane slicing along D + # --------------------------------------------------------------------- + if center_slices: + start_d = int((1.0 - center_slices_ratio) / 2.0 * D) + end_d = int((1.0 + center_slices_ratio) / 2.0 * D) + slices = torch.unbind(image[:, :, :, :, start_d:end_d:sample_every_k], dim=-1) + else: + slices = torch.unbind(image, dim=-1) + + if drop_empty: + mapping_index = drop_empty_slice(slices, empty_threshold) + else: + mapping_index = [True for _ in range(len(slices))] + + images_2d = torch.cat(slices, dim=0) + images_2d = radimagenet_intensity_normalisation(images_2d) + images_2d = images_2d[mapping_index] + + feature_image_xy = feature_network.forward(images_2d) + feature_image_xy = spatial_average(feature_image_xy, keepdim=False) + + if xy_only: + return feature_image_xy, None, None + + # --------------------------------------------------------------------- + # 2) YZ-plane slicing along H + # --------------------------------------------------------------------- + if center_slices: + start_h = int((1.0 - center_slices_ratio) / 2.0 * H) + end_h = int((1.0 + center_slices_ratio) / 2.0 * H) + slices = torch.unbind(image[:, :, start_h:end_h:sample_every_k, :, :], dim=2) + else: + slices = torch.unbind(image, dim=2) + + if drop_empty: + mapping_index = drop_empty_slice(slices, empty_threshold) + else: + mapping_index = [True for _ in range(len(slices))] + + images_2d = torch.cat(slices, dim=0) + images_2d = radimagenet_intensity_normalisation(images_2d) + images_2d = images_2d[mapping_index] + + feature_image_yz = feature_network.forward(images_2d) + feature_image_yz = spatial_average(feature_image_yz, keepdim=False) + + # --------------------------------------------------------------------- + # 3) ZX-plane slicing along W + # --------------------------------------------------------------------- + if center_slices: + start_w = int((1.0 - center_slices_ratio) / 2.0 * W) + end_w = int((1.0 + center_slices_ratio) / 2.0 * W) + slices = torch.unbind(image[:, :, :, start_w:end_w:sample_every_k, :], dim=3) + else: + slices = torch.unbind(image, dim=3) + + if drop_empty: + mapping_index = drop_empty_slice(slices, empty_threshold) + else: + mapping_index = [True for _ in range(len(slices))] + + images_2d = torch.cat(slices, dim=0) + images_2d = radimagenet_intensity_normalisation(images_2d) + images_2d = images_2d[mapping_index] + + feature_image_zx = feature_network.forward(images_2d) + feature_image_zx = spatial_average(feature_image_zx, keepdim=False) + + return feature_image_xy, feature_image_yz, feature_image_zx + +# ------------------------------------------------------------------------------ +# End inline fid_utils code +# ------------------------------------------------------------------------------ + + +def pad_to_max_size(tensor: torch.Tensor, max_size: int, padding_value: float = 0.0) -> torch.Tensor: + """ + Zero-pad a 2D feature map or other tensor along the first dimension to match a specified size. + + Args: + tensor (torch.Tensor): The feature tensor to pad. + max_size (int): Desired size along the first dimension. + padding_value (float): Value to fill during padding. + + Returns: + torch.Tensor: Padded tensor matching `max_size` along dim=0. + """ + # For a shape (B, C, ...), we only pad the B dimension up to `max_size`. + pad_size = [0, 0] * (len(tensor.shape) - 1) + [0, max_size - tensor.shape[0]] + return F.pad(tensor, pad_size, "constant", padding_value) + + +def main( + data0_dataroot: str = "path/to/datasetA", + data0_filelist: str = "path/to/filelistA.txt", + data0_folder: str = "datasetA", + data1_dataroot: str = "path/to/datasetB", + data1_filelist: str = "path/to/filelistB.txt", + data1_folder: str = "datasetB", + enable_center_slices: bool = False, + enable_center_slices_ratio: float = 0.4, + enable_padding: bool = True, + enable_center_cropping: bool = True, + enable_resampling: bool = False, + enable_resampling_spacing: str = "1.0x1.0x1.0", + ignore_existing: bool = False, + model_name: str = "radimagenet_resnet50", + num_images: int = 100, + output_root: str = "./features/features-512x512x512", + target_shape: str = "512x512x512", +): + """ + Main function to compute 2.5D features for two datasets (e.g. real vs. synthetic) + and calculate FID across XY, YZ, ZX planes. + + Args: + data0_dataroot (str): Root path of dataset 0 (real images). + data0_filelist (str): Text file listing the 3D images in dataset 0. + data0_folder (str): Name (subdir) for dataset 0 outputs under `output_root`. + data1_dataroot (str): Root path of dataset 1 (synthetic images). + data1_filelist (str): Text file listing the 3D images in dataset 1. + data1_folder (str): Name (subdir) for dataset 1 outputs under `output_root`. + enable_center_slices (bool): If True, only slices around the center ratio are used. + enable_center_slices_ratio (float): Ratio of center slices to keep if `enable_center_slices` is True. + enable_padding (bool): Whether to pad images to `target_shape`. + enable_center_cropping (bool): Whether to center-crop images to `target_shape`. + enable_resampling (bool): Whether to resample images to `enable_resampling_spacing`. + enable_resampling_spacing (str): Target voxel spacing as "XxYxZ", e.g. "1.0x1.0x1.0". + ignore_existing (bool): If True, re-extract features even if .pt files exist. + model_name (str): Model to use for feature extraction (radimagenet_resnet50 or squeezenet1_1). + num_images (int): Number of images to use from each dataset (truncate if exceeding). + output_root (str): Output directory to store extracted features and final results. + target_shape (str): Desired shape, as "XxYxZ", for padding/cropping/resampling. + + Returns: + None + """ + # ------------------------------------------------------------------------- + # Initialize Process Group (Distributed) + # ------------------------------------------------------------------------- + dist.init_process_group(backend="nccl", init_method="env://", timeout=timedelta(seconds=7200)) + + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(dist.get_world_size()) + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + print(f"[INFO] Running process on {device} of total {world_size} ranks.") + + # ------------------------------------------------------------------------- + # Convert potential string bools to actual bools (Fire sometimes passes strings) + # ------------------------------------------------------------------------- + if not isinstance(enable_center_slices, bool): + enable_center_slices = enable_center_slices.lower() == "true" + if not isinstance(enable_padding, bool): + enable_padding = enable_padding.lower() == "true" + if not isinstance(enable_center_cropping, bool): + enable_center_cropping = enable_center_cropping.lower() == "true" + if not isinstance(enable_resampling, bool): + enable_resampling = enable_resampling.lower() == "true" + if not isinstance(ignore_existing, bool): + ignore_existing = ignore_existing.lower() == "true" + + # Print out some flags on rank 0 + if local_rank == 0: + print(f"[INFO] enable_center_slices: {enable_center_slices}") + print(f"[INFO] enable_padding: {enable_padding}") + print(f"[INFO] enable_center_cropping: {enable_center_cropping}") + print(f"[INFO] enable_resampling: {enable_resampling}") + print(f"[INFO] ignore_existing: {ignore_existing}") + + # ------------------------------------------------------------------------- + # Load feature extraction model + # ------------------------------------------------------------------------- + if model_name == "radimagenet_resnet50": + # Using a model from Warvito/radimagenet-models on Torch Hub + feature_network = torch.hub.load( + "Warvito/radimagenet-models", model="radimagenet_resnet50", verbose=True, trust_repo=True + ) + suffix = "radimagenet_resnet50" + else: + import torchvision + + feature_network = torchvision.models.squeezenet1_1(pretrained=True) + suffix = "squeezenet1_1" + + feature_network.to(device) + feature_network.eval() + + # ------------------------------------------------------------------------- + # Parse shape/spacings from string + # ------------------------------------------------------------------------- + t_shape = [int(x) for x in target_shape.split("x")] + target_shape_tuple = tuple(t_shape) + if enable_resampling: + rs_spacing = [float(x) for x in enable_resampling_spacing.split("x")] + rs_spacing_tuple = tuple(rs_spacing) + if local_rank == 0: + print(f"[INFO] resampling spacing: {rs_spacing_tuple}") + else: + rs_spacing_tuple = (1.0, 1.0, 1.0) + + center_slices_ratio_final = enable_center_slices_ratio if enable_center_slices else 1.0 + if local_rank == 0: + print(f"[INFO] center_slices_ratio: {center_slices_ratio_final}") + + # ------------------------------------------------------------------------- + # Prepare dataset 0 + # ------------------------------------------------------------------------- + output_root0 = os.path.join(output_root, data0_folder) + with open(data0_filelist, "r") as f0: + lines0 = [l.strip() for l in f0.readlines()] + lines0.sort() + lines0 = lines0[:num_images] + + filenames0 = [{"image": os.path.join(data0_dataroot, f)} for f in lines0] + filenames0 = monai.data.partition_dataset( + data=filenames0, shuffle=False, num_partitions=world_size, even_divisible=False + )[local_rank] + + # ------------------------------------------------------------------------- + # Prepare dataset 1 + # ------------------------------------------------------------------------- + output_root1 = os.path.join(output_root, data1_folder) + with open(data1_filelist, "r") as f1: + lines1 = [l.strip() for l in f1.readlines()] + lines1.sort() + lines1 = lines1[:num_images] + + filenames1 = [{"image": os.path.join(data1_dataroot, f)} for f in lines1] + filenames1 = monai.data.partition_dataset( + data=filenames1, shuffle=False, num_partitions=world_size, even_divisible=False + )[local_rank] + + # ------------------------------------------------------------------------- + # Build MONAI transforms + # ------------------------------------------------------------------------- + transform_list = [ + monai.transforms.LoadImaged(keys=["image"]), + monai.transforms.EnsureChannelFirstd(keys=["image"]), + monai.transforms.Orientationd(keys=["image"], axcodes="RAS"), + ] + + if enable_resampling: + transform_list.append( + monai.transforms.Spacingd(keys=["image"], pixdim=rs_spacing_tuple, mode=["bilinear"]) + ) + if enable_padding: + transform_list.append( + monai.transforms.SpatialPadd(keys=["image"], spatial_size=target_shape_tuple, mode=["constant"], value=-1000) + ) + if enable_center_cropping: + transform_list.append( + monai.transforms.CenterSpatialCropd(keys=["image"], roi_size=target_shape_tuple) + ) + + # Intensity scaling to clamp between [-1000, 1000] + transform_list.append( + monai.transforms.ScaleIntensityRanged( + keys=["image"], a_min=-1000, a_max=1000, b_min=-1000, b_max=1000, clip=True + ) + ) + + transforms = Compose(transform_list) + + # ------------------------------------------------------------------------- + # Create DataLoaders + # ------------------------------------------------------------------------- + real_ds = monai.data.Dataset(data=filenames0, transform=transforms) + real_loader = monai.data.DataLoader(real_ds, num_workers=6, batch_size=1, shuffle=False) + + synt_ds = monai.data.Dataset(data=filenames1, transform=transforms) + synt_loader = monai.data.DataLoader(synt_ds, num_workers=6, batch_size=1, shuffle=False) + + # ------------------------------------------------------------------------- + # Extract features for dataset 0 + # ------------------------------------------------------------------------- + real_features_xy, real_features_yz, real_features_zx = [], [], [] + for idx, batch_data in enumerate(real_loader, start=1): + img = batch_data["image"].to(device) + fn = img.meta["filename_or_obj"][0] + print(f"[Rank {local_rank}] Real data {idx}/{len(filenames0)}: {fn}") + + out_fp = fn.replace(data0_dataroot, output_root0).replace(".nii.gz", ".pt") + out_fp = Path(out_fp) + out_fp.parent.mkdir(parents=True, exist_ok=True) + + if (not ignore_existing) and os.path.isfile(out_fp): + feats = torch.load(out_fp) + else: + img_t = img.as_tensor() + print(f"[INFO] image shape: {tuple(img_t.shape)}") + + # Inline get_features_2p5d + feats = get_features_2p5d( + img_t, + feature_network, + center_slices=enable_center_slices, + center_slices_ratio=center_slices_ratio_final, + xy_only=False, + ) + print(f"[INFO] feats shapes: {feats[0].shape}, {feats[1].shape}, {feats[2].shape}") + torch.save(feats, out_fp) + + real_features_xy.append(feats[0]) + real_features_yz.append(feats[1]) + real_features_zx.append(feats[2]) + + real_features_xy = torch.vstack(real_features_xy) + real_features_yz = torch.vstack(real_features_yz) + real_features_zx = torch.vstack(real_features_zx) + print(f"[INFO] Real feature shapes: {real_features_xy.shape}, {real_features_yz.shape}, {real_features_zx.shape}") + + # ------------------------------------------------------------------------- + # Extract features for dataset 1 + # ------------------------------------------------------------------------- + synth_features_xy, synth_features_yz, synth_features_zx = [], [], [] + for idx, batch_data in enumerate(synt_loader, start=1): + img = batch_data["image"].to(device) + fn = img.meta["filename_or_obj"][0] + print(f"[Rank {local_rank}] Synthetic data {idx}/{len(filenames1)}: {fn}") + + out_fp = fn.replace(data1_dataroot, output_root1).replace(".nii.gz", ".pt") + out_fp = Path(out_fp) + out_fp.parent.mkdir(parents=True, exist_ok=True) + + if (not ignore_existing) and os.path.isfile(out_fp): + feats = torch.load(out_fp) + else: + img_t = img.as_tensor() + print(f"[INFO] image shape: {tuple(img_t.shape)}") + + feats = get_features_2p5d( + img_t, + feature_network, + center_slices=enable_center_slices, + center_slices_ratio=center_slices_ratio_final, + xy_only=False, + ) + print(f"[INFO] feats shapes: {feats[0].shape}, {feats[1].shape}, {feats[2].shape}") + torch.save(feats, out_fp) + + synth_features_xy.append(feats[0]) + synth_features_yz.append(feats[1]) + synth_features_zx.append(feats[2]) + + synth_features_xy = torch.vstack(synth_features_xy) + synth_features_yz = torch.vstack(synth_features_yz) + synth_features_zx = torch.vstack(synth_features_zx) + print( + f"[INFO] Synthetic feature shapes: {synth_features_xy.shape}, " + f"{synth_features_yz.shape}, {synth_features_zx.shape}" + ) + + # ------------------------------------------------------------------------- + # All-reduce / gather features across ranks + # ------------------------------------------------------------------------- + features = [ + real_features_xy, real_features_yz, real_features_zx, + synth_features_xy, synth_features_yz, synth_features_zx + ] + + # 1) Gather local feature sizes across ranks + local_sizes = [] + for ft_idx in range(len(features)): + local_size = torch.tensor([features[ft_idx].shape[0]], dtype=torch.int64, device=device) + local_sizes.append(local_size) + + all_sizes = [] + for ft_idx in range(len(features)): + rank_sizes = [torch.tensor([0], dtype=torch.int64, device=device) for _ in range(world_size)] + dist.all_gather(rank_sizes, local_sizes[ft_idx]) + all_sizes.append(rank_sizes) + + # 2) Pad and gather all features + all_tensors_list = [] + for ft_idx, ft in enumerate(features): + max_size = max(all_sizes[ft_idx]).item() + ft_padded = pad_to_max_size(ft, max_size) + + gather_list = [torch.empty_like(ft_padded) for _ in range(world_size)] + dist.all_gather(gather_list, ft_padded) + + # Trim each gather back to the real size + for rk in range(world_size): + gather_list[rk] = gather_list[rk][: all_sizes[ft_idx][rk], :] + + all_tensors_list.append(gather_list) + + # On rank 0, compute FID + if local_rank == 0: + real_xy = torch.vstack(all_tensors_list[0]) + real_yz = torch.vstack(all_tensors_list[1]) + real_zx = torch.vstack(all_tensors_list[2]) + + synth_xy = torch.vstack(all_tensors_list[3]) + synth_yz = torch.vstack(all_tensors_list[4]) + synth_zx = torch.vstack(all_tensors_list[5]) + + print(f"[INFO] Final Real shapes: {real_xy.shape}, {real_yz.shape}, {real_zx.shape}") + print(f"[INFO] Final Synth shapes: {synth_xy.shape}, {synth_yz.shape}, {synth_zx.shape}") + + fid = FIDMetric() + print(f"\n[INFO] Computing FID for: {output_root0} | {output_root1}") + fid_res_xy = fid(synth_xy, real_xy) + fid_res_yz = fid(synth_yz, real_yz) + fid_res_zx = fid(synth_zx, real_zx) + + print(f" FID XY: {fid_res_xy}") + print(f" FID YZ: {fid_res_yz}") + print(f" FID ZX: {fid_res_zx}") + fid_avg = (fid_res_xy + fid_res_yz + fid_res_zx) / 3.0 + print(f" FID Avg: {fid_avg}") + + dist.destroy_process_group() + + +if __name__ == "__main__": + # Using python-fire for command-line interface. + # e.g., python compute_fid2d_mgpu.py --model_name=radimagenet_resnet50 --num_images=100 ... + fire.Fire(main) From bbae5161649de5a11aaedc2197806a780d1557cc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 11 Mar 2025 18:26:11 +0000 Subject: [PATCH 02/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- generation/maisi/scripts/compute_fid2p5d.py | 22 ++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/generation/maisi/scripts/compute_fid2p5d.py b/generation/maisi/scripts/compute_fid2p5d.py index d5fb11499..7ba5cffa6 100644 --- a/generation/maisi/scripts/compute_fid2p5d.py +++ b/generation/maisi/scripts/compute_fid2p5d.py @@ -100,6 +100,7 @@ # to remove external dependency. # ------------------------------------------------------------------------------ + def drop_empty_slice(slices, empty_threshold: float): """ Decide which 2D slices to keep by checking if their maximum intensity @@ -330,6 +331,7 @@ def get_features_2p5d( return feature_image_xy, feature_image_yz, feature_image_zx + # ------------------------------------------------------------------------------ # End inline fid_utils code # ------------------------------------------------------------------------------ @@ -503,17 +505,15 @@ def main( ] if enable_resampling: - transform_list.append( - monai.transforms.Spacingd(keys=["image"], pixdim=rs_spacing_tuple, mode=["bilinear"]) - ) + transform_list.append(monai.transforms.Spacingd(keys=["image"], pixdim=rs_spacing_tuple, mode=["bilinear"])) if enable_padding: transform_list.append( - monai.transforms.SpatialPadd(keys=["image"], spatial_size=target_shape_tuple, mode=["constant"], value=-1000) + monai.transforms.SpatialPadd( + keys=["image"], spatial_size=target_shape_tuple, mode=["constant"], value=-1000 + ) ) if enable_center_cropping: - transform_list.append( - monai.transforms.CenterSpatialCropd(keys=["image"], roi_size=target_shape_tuple) - ) + transform_list.append(monai.transforms.CenterSpatialCropd(keys=["image"], roi_size=target_shape_tuple)) # Intensity scaling to clamp between [-1000, 1000] transform_list.append( @@ -617,8 +617,12 @@ def main( # All-reduce / gather features across ranks # ------------------------------------------------------------------------- features = [ - real_features_xy, real_features_yz, real_features_zx, - synth_features_xy, synth_features_yz, synth_features_zx + real_features_xy, + real_features_yz, + real_features_zx, + synth_features_xy, + synth_features_yz, + synth_features_zx, ] # 1) Gather local feature sizes across ranks From ed3d194389c5e410cb4e718a00fa471528a2b714 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Tue, 11 Mar 2025 12:31:04 -0600 Subject: [PATCH 03/10] update Signed-off-by: dongyang0122 --- .../{compute_fid2p5d.py => compute_fid2p5d_ct.py} | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) rename generation/maisi/scripts/{compute_fid2p5d.py => compute_fid2p5d_ct.py} (97%) diff --git a/generation/maisi/scripts/compute_fid2p5d.py b/generation/maisi/scripts/compute_fid2p5d_ct.py similarity index 97% rename from generation/maisi/scripts/compute_fid2p5d.py rename to generation/maisi/scripts/compute_fid2p5d_ct.py index d5fb11499..709134ca0 100644 --- a/generation/maisi/scripts/compute_fid2p5d.py +++ b/generation/maisi/scripts/compute_fid2p5d_ct.py @@ -22,7 +22,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 NUM_GPUS=7 - torchrun --nproc_per_node=${NUM_GPUS} compute_fid2d_mgpu.py \ + torchrun --nproc_per_node=${NUM_GPUS} compute_fid2p5d_ct.py \ --model_name "radimagenet_resnet50" \ --data0_dataroot "path/to/datasetA" \ --data0_filelist "path/to/filelistA.txt" \ @@ -64,19 +64,6 @@ output_root (str): Folder where extracted feature .pt files and logs are saved. target_shape (str): "XxYxZ" shape to which images are padded/cropped/resampled. -Example: --------- - python compute_fid2p5d.py --model_name=radimagenet_resnet50 \\ - --data0_dataroot=/data/real_images \\ - --data0_filelist=/data/real_list.txt \\ - --data0_folder=real \\ - --data1_dataroot=/data/synth_images \\ - --data1_filelist=/data/synth_list.txt \\ - --data1_folder=synth \\ - --enable_center_slices=True \\ - --enable_padding=True \\ - --target_shape=512x512x512 - """ from __future__ import annotations From 466e9c62e40903c827256c3b1aa1898f4c669bdb Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Tue, 11 Mar 2025 12:41:18 -0600 Subject: [PATCH 04/10] udpate FIDMetric source Signed-off-by: dongyang0122 --- generation/maisi/scripts/compute_fid2p5d_ct.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/generation/maisi/scripts/compute_fid2p5d_ct.py b/generation/maisi/scripts/compute_fid2p5d_ct.py index d4ba6b8f5..7cd8c768f 100644 --- a/generation/maisi/scripts/compute_fid2p5d_ct.py +++ b/generation/maisi/scripts/compute_fid2p5d_ct.py @@ -79,7 +79,7 @@ from datetime import timedelta from pathlib import Path -from generative.metrics import FIDMetric +from monai.metrics.fid import FIDMetric from monai.transforms import Compose # ------------------------------------------------------------------------------ From 52a04fcfd22b5bd42ef8f332513c8aec063c90e2 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Wed, 12 Mar 2025 07:21:59 -0600 Subject: [PATCH 05/10] clean Signed-off-by: dongyang0122 --- ...te_fid2p5d_ct.py => compute_fid2-5d_ct.py} | 119 +++++++++--------- 1 file changed, 56 insertions(+), 63 deletions(-) rename generation/maisi/scripts/{compute_fid2p5d_ct.py => compute_fid2-5d_ct.py} (87%) diff --git a/generation/maisi/scripts/compute_fid2p5d_ct.py b/generation/maisi/scripts/compute_fid2-5d_ct.py similarity index 87% rename from generation/maisi/scripts/compute_fid2p5d_ct.py rename to generation/maisi/scripts/compute_fid2-5d_ct.py index 7cd8c768f..1d9b30823 100644 --- a/generation/maisi/scripts/compute_fid2p5d_ct.py +++ b/generation/maisi/scripts/compute_fid2-5d_ct.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +12,7 @@ # and limitations under the License. """ -Compute 2.5D FID using distributed GPU processing, **without** external fid_utils dependencies. +Compute 2.5D FID using distributed GPU processing. SHELL Usage Example: ------------------- @@ -22,7 +21,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 NUM_GPUS=7 - torchrun --nproc_per_node=${NUM_GPUS} compute_fid2p5d_ct.py \ + torchrun --nproc_per_node=${NUM_GPUS} compute_fid_2-5d_ct.py \ --model_name "radimagenet_resnet50" \ --data0_dataroot "path/to/datasetA" \ --data0_filelist "path/to/filelistA.txt" \ @@ -82,10 +81,16 @@ from monai.metrics.fid import FIDMetric from monai.transforms import Compose +import logging + # ------------------------------------------------------------------------------ -# Below are the core utilities originally in fid_utils.py, now inlined here -# to remove external dependency. +# Create logger # ------------------------------------------------------------------------------ +logger = logging.getLogger("fid_2-5d_ct") +if not logger.handlers: + # Configure logger only if it has no handlers (avoid reconfiguring in multi-rank scenarios) + logging.basicConfig(stream=sys.stdout, level=logging.INFO) +logger.setLevel(logging.INFO) def drop_empty_slice(slices, empty_threshold: float): @@ -111,7 +116,7 @@ def drop_empty_slice(slices, empty_threshold: float): else: outputs.append(True) - print(f"Empty slice drop rate {round((n_drop/len(slices))*100,1)}%") + logger.info(f"Empty slice drop rate {round((n_drop/len(slices))*100,1)}%") return outputs @@ -183,7 +188,7 @@ def radimagenet_intensity_normalisation(volume: torch.Tensor, norm2d: bool = Fal volume (torch.Tensor): Input (B, C, H, W) or (B, C, H, W, D). norm2d (bool): If True, normalizes each (H,W) slice to [0,1], then subtracts the ImageNet mean. """ - print(f"norm2d: {norm2d}") + logger.info(f"norm2d: {norm2d}") dim = len(volume.shape) # If norm2d is True, only meaningful for 4D data (B, C, H, W): if dim == 4 and norm2d: @@ -236,20 +241,18 @@ def get_features_2p5d( Returns: tuple of torch.Tensor or None: (XY_features, YZ_features, ZX_features). """ - print(f"center_slices: {center_slices}, ratio: {center_slices_ratio}") + logger.info(f"center_slices: {center_slices}, ratio: {center_slices_ratio}") # If there's only 1 channel, replicate to 3 channels if image.shape[1] == 1: image = image.repeat(1, 3, 1, 1, 1) - # Convert from 'RGB'→(R,G,B) to (B, G, R) ordering + # Convert from 'RGB'→(R,G,B) to (B,G,R) image = image[:, [2, 1, 0], ...] B, C, H, W, D = image.size() with torch.no_grad(): - # --------------------------------------------------------------------- - # 1) XY-plane slicing along D - # --------------------------------------------------------------------- + # ---------------------- XY-plane slicing along D ---------------------- if center_slices: start_d = int((1.0 - center_slices_ratio) / 2.0 * D) end_d = int((1.0 + center_slices_ratio) / 2.0 * D) @@ -268,13 +271,10 @@ def get_features_2p5d( feature_image_xy = feature_network.forward(images_2d) feature_image_xy = spatial_average(feature_image_xy, keepdim=False) - if xy_only: return feature_image_xy, None, None - # --------------------------------------------------------------------- - # 2) YZ-plane slicing along H - # --------------------------------------------------------------------- + # ---------------------- YZ-plane slicing along H ---------------------- if center_slices: start_h = int((1.0 - center_slices_ratio) / 2.0 * H) end_h = int((1.0 + center_slices_ratio) / 2.0 * H) @@ -294,9 +294,7 @@ def get_features_2p5d( feature_image_yz = feature_network.forward(images_2d) feature_image_yz = spatial_average(feature_image_yz, keepdim=False) - # --------------------------------------------------------------------- - # 3) ZX-plane slicing along W - # --------------------------------------------------------------------- + # ---------------------- ZX-plane slicing along W ---------------------- if center_slices: start_w = int((1.0 - center_slices_ratio) / 2.0 * W) end_w = int((1.0 + center_slices_ratio) / 2.0 * W) @@ -319,11 +317,6 @@ def get_features_2p5d( return feature_image_xy, feature_image_yz, feature_image_zx -# ------------------------------------------------------------------------------ -# End inline fid_utils code -# ------------------------------------------------------------------------------ - - def pad_to_max_size(tensor: torch.Tensor, max_size: int, padding_value: float = 0.0) -> torch.Tensor: """ Zero-pad a 2D feature map or other tensor along the first dimension to match a specified size. @@ -336,7 +329,6 @@ def pad_to_max_size(tensor: torch.Tensor, max_size: int, padding_value: float = Returns: torch.Tensor: Padded tensor matching `max_size` along dim=0. """ - # For a shape (B, C, ...), we only pad the B dimension up to `max_size`. pad_size = [0, 0] * (len(tensor.shape) - 1) + [0, max_size - tensor.shape[0]] return F.pad(tensor, pad_size, "constant", padding_value) @@ -395,11 +387,9 @@ def main( world_size = int(dist.get_world_size()) device = torch.device("cuda", local_rank) torch.cuda.set_device(device) - print(f"[INFO] Running process on {device} of total {world_size} ranks.") + logger.info(f"[INFO] Running process on {device} of total {world_size} ranks.") - # ------------------------------------------------------------------------- # Convert potential string bools to actual bools (Fire sometimes passes strings) - # ------------------------------------------------------------------------- if not isinstance(enable_center_slices, bool): enable_center_slices = enable_center_slices.lower() == "true" if not isinstance(enable_padding, bool): @@ -413,24 +403,22 @@ def main( # Print out some flags on rank 0 if local_rank == 0: - print(f"[INFO] enable_center_slices: {enable_center_slices}") - print(f"[INFO] enable_padding: {enable_padding}") - print(f"[INFO] enable_center_cropping: {enable_center_cropping}") - print(f"[INFO] enable_resampling: {enable_resampling}") - print(f"[INFO] ignore_existing: {ignore_existing}") + logger.info(f"enable_center_slices: {enable_center_slices}") + logger.info(f"enable_padding: {enable_padding}") + logger.info(f"enable_center_cropping: {enable_center_cropping}") + logger.info(f"enable_resampling: {enable_resampling}") + logger.info(f"ignore_existing: {ignore_existing}") # ------------------------------------------------------------------------- # Load feature extraction model # ------------------------------------------------------------------------- if model_name == "radimagenet_resnet50": - # Using a model from Warvito/radimagenet-models on Torch Hub feature_network = torch.hub.load( "Warvito/radimagenet-models", model="radimagenet_resnet50", verbose=True, trust_repo=True ) suffix = "radimagenet_resnet50" else: import torchvision - feature_network = torchvision.models.squeezenet1_1(pretrained=True) suffix = "squeezenet1_1" @@ -438,7 +426,7 @@ def main( feature_network.eval() # ------------------------------------------------------------------------- - # Parse shape/spacings from string + # Parse shape/spacings # ------------------------------------------------------------------------- t_shape = [int(x) for x in target_shape.split("x")] target_shape_tuple = tuple(t_shape) @@ -446,13 +434,13 @@ def main( rs_spacing = [float(x) for x in enable_resampling_spacing.split("x")] rs_spacing_tuple = tuple(rs_spacing) if local_rank == 0: - print(f"[INFO] resampling spacing: {rs_spacing_tuple}") + logger.info(f"resampling spacing: {rs_spacing_tuple}") else: rs_spacing_tuple = (1.0, 1.0, 1.0) center_slices_ratio_final = enable_center_slices_ratio if enable_center_slices else 1.0 if local_rank == 0: - print(f"[INFO] center_slices_ratio: {center_slices_ratio_final}") + logger.info(f"center_slices_ratio: {center_slices_ratio_final}") # ------------------------------------------------------------------------- # Prepare dataset 0 @@ -490,25 +478,20 @@ def main( monai.transforms.EnsureChannelFirstd(keys=["image"]), monai.transforms.Orientationd(keys=["image"], axcodes="RAS"), ] - if enable_resampling: transform_list.append(monai.transforms.Spacingd(keys=["image"], pixdim=rs_spacing_tuple, mode=["bilinear"])) if enable_padding: transform_list.append( - monai.transforms.SpatialPadd( - keys=["image"], spatial_size=target_shape_tuple, mode=["constant"], value=-1000 - ) + monai.transforms.SpatialPadd(keys=["image"], spatial_size=target_shape_tuple, mode="constant", value=-1000) ) if enable_center_cropping: transform_list.append(monai.transforms.CenterSpatialCropd(keys=["image"], roi_size=target_shape_tuple)) - # Intensity scaling to clamp between [-1000, 1000] transform_list.append( monai.transforms.ScaleIntensityRanged( keys=["image"], a_min=-1000, a_max=1000, b_min=-1000, b_max=1000, clip=True ) ) - transforms = Compose(transform_list) # ------------------------------------------------------------------------- @@ -527,7 +510,7 @@ def main( for idx, batch_data in enumerate(real_loader, start=1): img = batch_data["image"].to(device) fn = img.meta["filename_or_obj"][0] - print(f"[Rank {local_rank}] Real data {idx}/{len(filenames0)}: {fn}") + logger.info(f"[Rank {local_rank}] Real data {idx}/{len(filenames0)}: {fn}") out_fp = fn.replace(data0_dataroot, output_root0).replace(".nii.gz", ".pt") out_fp = Path(out_fp) @@ -537,9 +520,8 @@ def main( feats = torch.load(out_fp) else: img_t = img.as_tensor() - print(f"[INFO] image shape: {tuple(img_t.shape)}") + logger.info(f"image shape: {tuple(img_t.shape)}") - # Inline get_features_2p5d feats = get_features_2p5d( img_t, feature_network, @@ -547,7 +529,10 @@ def main( center_slices_ratio=center_slices_ratio_final, xy_only=False, ) - print(f"[INFO] feats shapes: {feats[0].shape}, {feats[1].shape}, {feats[2].shape}") + logger.info( + f"feats shapes: {feats[0].shape}, " + f"{feats[1].shape}, {feats[2].shape}" + ) torch.save(feats, out_fp) real_features_xy.append(feats[0]) @@ -557,7 +542,10 @@ def main( real_features_xy = torch.vstack(real_features_xy) real_features_yz = torch.vstack(real_features_yz) real_features_zx = torch.vstack(real_features_zx) - print(f"[INFO] Real feature shapes: {real_features_xy.shape}, {real_features_yz.shape}, {real_features_zx.shape}") + logger.info( + f"Real feature shapes: {real_features_xy.shape}, " + f"{real_features_yz.shape}, {real_features_zx.shape}" + ) # ------------------------------------------------------------------------- # Extract features for dataset 1 @@ -566,7 +554,7 @@ def main( for idx, batch_data in enumerate(synt_loader, start=1): img = batch_data["image"].to(device) fn = img.meta["filename_or_obj"][0] - print(f"[Rank {local_rank}] Synthetic data {idx}/{len(filenames1)}: {fn}") + logger.info(f"[Rank {local_rank}] Synthetic data {idx}/{len(filenames1)}: {fn}") out_fp = fn.replace(data1_dataroot, output_root1).replace(".nii.gz", ".pt") out_fp = Path(out_fp) @@ -576,7 +564,7 @@ def main( feats = torch.load(out_fp) else: img_t = img.as_tensor() - print(f"[INFO] image shape: {tuple(img_t.shape)}") + logger.info(f"image shape: {tuple(img_t.shape)}") feats = get_features_2p5d( img_t, @@ -585,7 +573,10 @@ def main( center_slices_ratio=center_slices_ratio_final, xy_only=False, ) - print(f"[INFO] feats shapes: {feats[0].shape}, {feats[1].shape}, {feats[2].shape}") + logger.info( + f"feats shapes: {feats[0].shape}, " + f"{feats[1].shape}, {feats[2].shape}" + ) torch.save(feats, out_fp) synth_features_xy.append(feats[0]) @@ -595,8 +586,8 @@ def main( synth_features_xy = torch.vstack(synth_features_xy) synth_features_yz = torch.vstack(synth_features_yz) synth_features_zx = torch.vstack(synth_features_zx) - print( - f"[INFO] Synthetic feature shapes: {synth_features_xy.shape}, " + logger.info( + f"Synthetic feature shapes: {synth_features_xy.shape}, " f"{synth_features_yz.shape}, {synth_features_zx.shape}" ) @@ -649,25 +640,27 @@ def main( synth_yz = torch.vstack(all_tensors_list[4]) synth_zx = torch.vstack(all_tensors_list[5]) - print(f"[INFO] Final Real shapes: {real_xy.shape}, {real_yz.shape}, {real_zx.shape}") - print(f"[INFO] Final Synth shapes: {synth_xy.shape}, {synth_yz.shape}, {synth_zx.shape}") + logger.info( + f"Final Real shapes: {real_xy.shape}, {real_yz.shape}, {real_zx.shape}" + ) + logger.info( + f"Final Synth shapes: {synth_xy.shape}, {synth_yz.shape}, {synth_zx.shape}" + ) fid = FIDMetric() - print(f"\n[INFO] Computing FID for: {output_root0} | {output_root1}") + logger.info(f"Computing FID for: {output_root0} | {output_root1}") fid_res_xy = fid(synth_xy, real_xy) fid_res_yz = fid(synth_yz, real_yz) fid_res_zx = fid(synth_zx, real_zx) - print(f" FID XY: {fid_res_xy}") - print(f" FID YZ: {fid_res_yz}") - print(f" FID ZX: {fid_res_zx}") + logger.info(f"FID XY: {fid_res_xy}") + logger.info(f"FID YZ: {fid_res_yz}") + logger.info(f"FID ZX: {fid_res_zx}") fid_avg = (fid_res_xy + fid_res_yz + fid_res_zx) / 3.0 - print(f" FID Avg: {fid_avg}") + logger.info(f"FID Avg: {fid_avg}") dist.destroy_process_group() if __name__ == "__main__": - # Using python-fire for command-line interface. - # e.g., python compute_fid2d_mgpu.py --model_name=radimagenet_resnet50 --num_images=100 ... fire.Fire(main) From ebdce1f309f7d1219c889a494a6833b16de61620 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 12 Mar 2025 13:23:59 +0000 Subject: [PATCH 06/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../maisi/scripts/compute_fid2-5d_ct.py | 25 ++++++------------- 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/generation/maisi/scripts/compute_fid2-5d_ct.py b/generation/maisi/scripts/compute_fid2-5d_ct.py index 1d9b30823..f684cabe5 100644 --- a/generation/maisi/scripts/compute_fid2-5d_ct.py +++ b/generation/maisi/scripts/compute_fid2-5d_ct.py @@ -419,6 +419,7 @@ def main( suffix = "radimagenet_resnet50" else: import torchvision + feature_network = torchvision.models.squeezenet1_1(pretrained=True) suffix = "squeezenet1_1" @@ -529,10 +530,7 @@ def main( center_slices_ratio=center_slices_ratio_final, xy_only=False, ) - logger.info( - f"feats shapes: {feats[0].shape}, " - f"{feats[1].shape}, {feats[2].shape}" - ) + logger.info(f"feats shapes: {feats[0].shape}, " f"{feats[1].shape}, {feats[2].shape}") torch.save(feats, out_fp) real_features_xy.append(feats[0]) @@ -543,8 +541,7 @@ def main( real_features_yz = torch.vstack(real_features_yz) real_features_zx = torch.vstack(real_features_zx) logger.info( - f"Real feature shapes: {real_features_xy.shape}, " - f"{real_features_yz.shape}, {real_features_zx.shape}" + f"Real feature shapes: {real_features_xy.shape}, " f"{real_features_yz.shape}, {real_features_zx.shape}" ) # ------------------------------------------------------------------------- @@ -573,10 +570,7 @@ def main( center_slices_ratio=center_slices_ratio_final, xy_only=False, ) - logger.info( - f"feats shapes: {feats[0].shape}, " - f"{feats[1].shape}, {feats[2].shape}" - ) + logger.info(f"feats shapes: {feats[0].shape}, " f"{feats[1].shape}, {feats[2].shape}") torch.save(feats, out_fp) synth_features_xy.append(feats[0]) @@ -587,8 +581,7 @@ def main( synth_features_yz = torch.vstack(synth_features_yz) synth_features_zx = torch.vstack(synth_features_zx) logger.info( - f"Synthetic feature shapes: {synth_features_xy.shape}, " - f"{synth_features_yz.shape}, {synth_features_zx.shape}" + f"Synthetic feature shapes: {synth_features_xy.shape}, " f"{synth_features_yz.shape}, {synth_features_zx.shape}" ) # ------------------------------------------------------------------------- @@ -640,12 +633,8 @@ def main( synth_yz = torch.vstack(all_tensors_list[4]) synth_zx = torch.vstack(all_tensors_list[5]) - logger.info( - f"Final Real shapes: {real_xy.shape}, {real_yz.shape}, {real_zx.shape}" - ) - logger.info( - f"Final Synth shapes: {synth_xy.shape}, {synth_yz.shape}, {synth_zx.shape}" - ) + logger.info(f"Final Real shapes: {real_xy.shape}, {real_yz.shape}, {real_zx.shape}") + logger.info(f"Final Synth shapes: {synth_xy.shape}, {synth_yz.shape}, {synth_zx.shape}") fid = FIDMetric() logger.info(f"Computing FID for: {output_root0} | {output_root1}") From 3764b4b5eedec2c96cfcc99aa26556e5f11c8a76 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Wed, 12 Mar 2025 07:26:56 -0600 Subject: [PATCH 07/10] rename Signed-off-by: dongyang0122 --- .../scripts/{compute_fid2-5d_ct.py => compute_fid_2-5d_ct.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename generation/maisi/scripts/{compute_fid2-5d_ct.py => compute_fid_2-5d_ct.py} (100%) diff --git a/generation/maisi/scripts/compute_fid2-5d_ct.py b/generation/maisi/scripts/compute_fid_2-5d_ct.py similarity index 100% rename from generation/maisi/scripts/compute_fid2-5d_ct.py rename to generation/maisi/scripts/compute_fid_2-5d_ct.py From be96180acbdeadcdd699d65f918214c867c29450 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Thu, 13 Mar 2025 10:01:05 -0600 Subject: [PATCH 08/10] update README and python script Signed-off-by: dongyang0122 --- generation/maisi/README.md | 56 ++- .../maisi/scripts/compute_fid_2-5d_ct.py | 322 ++++++++++++------ 2 files changed, 272 insertions(+), 106 deletions(-) diff --git a/generation/maisi/README.md b/generation/maisi/README.md index f51cdf61f..038378ff0 100644 --- a/generation/maisi/README.md +++ b/generation/maisi/README.md @@ -247,13 +247,65 @@ torchrun \ ``` Please also check [maisi_train_controlnet_tutorial.ipynb](./maisi_train_controlnet_tutorial.ipynb) for more details about data preparation and training parameters. -### 4. License +### 4. FID Score Computation + +We provide the `compute_fid_2-5d_ct.py` script that calculates the Frechet Inception Distance (FID) between two 3D medical datasets (e.g., **real** vs. **synthetic** images). It uses a **2.5D** feature-extraction approach across three orthogonal planes (XY, YZ, ZX) and leverages **distributed GPU processing** (via PyTorch’s `torch.distributed` and NCCL) for efficient, large-scale computations. + +#### Key Features + +- **Distributed Processing** + Scales to multiple GPUs and larger datasets by splitting the workload across devices. + +- **2.5D Feature Extraction** + Uses a slice-based technique, applying a 2D model across all slices in each dimension. + +- **Flexible Preprocessing** + Supports optional center-cropping, padding, and resampling to target shapes or voxel spacings. + +#### Usage Example + +Suppose your **real** dataset root is `path/to/real_images`, and you have a `real_filelist.txt` that lists filenames line by line, such as: + +case001.nii.gz +case002.nii.gz +case003.nii.gz + +You also have a **synthetic** dataset in `path/to/synth_images` with a corresponding `synth_filelist.txt`. You can run the script as follows: + +```bash +torchrun --nproc_per_node=2 compute_fid_2-5d_ct.py \ + --model_name "radimagenet_resnet50" \ + --real_dataset_root "path/to/real_images" \ + --real_filelist "path/to/real_filelist.txt" \ + --real_features_dir "datasetA" \ + --synth_dataset_root "path/to/synth_images" \ + --synth_filelist "path/to/synth_filelist.txt" \ + --synth_features_dir "datasetB" \ + --enable_center_slices_ratio 0.4 \ + --enable_padding True \ + --enable_center_cropping True \ + --enable_resampling_spacing "1.0x1.0x1.0" \ + --ignore_existing True \ + --num_images 100 \ + --output_root "./features/features-512x512x512" \ + --target_shape "512x512x512" +``` + +This command will: +1. Launch a distributed run with 2 GPUs. +2. Load each `.nii.gz` file from your specified `real_filelist` and `synth_filelist`. +3. Apply 2.5D feature extraction across the XY, YZ, and ZX planes. +4. Compute FID to compare **real** vs. **synthetic** feature distributions. + +For more details, see the in-code docstring in [`compute_fid_2-5d_ct.py`](./scripts/compute_fid_2-5d_ct.py) or consult our documentation for a deeper dive into function arguments and the underlying implementation. + +### 5. License The code is released under Apache 2.0 License. The model weight is released under [NSCLv1 License](./LICENSE.weights). -### 5. Questions and Bugs +### 6. Questions and Bugs - For questions relating to the use of MONAI, please use our [Discussions tab](https://github.com/Project-MONAI/MONAI/discussions) on the main repository of MONAI. - For bugs relating to MONAI functionality, please create an issue on the [main repository](https://github.com/Project-MONAI/MONAI/issues). diff --git a/generation/maisi/scripts/compute_fid_2-5d_ct.py b/generation/maisi/scripts/compute_fid_2-5d_ct.py index f684cabe5..ababf63d1 100644 --- a/generation/maisi/scripts/compute_fid_2-5d_ct.py +++ b/generation/maisi/scripts/compute_fid_2-5d_ct.py @@ -23,48 +23,83 @@ torchrun --nproc_per_node=${NUM_GPUS} compute_fid_2-5d_ct.py \ --model_name "radimagenet_resnet50" \ - --data0_dataroot "path/to/datasetA" \ - --data0_filelist "path/to/filelistA.txt" \ - --data0_folder "datasetA" \ - --data1_dataroot "path/to/datasetB" \ - --data1_filelist "path/to/filelistB.txt" \ - --data1_folder "datasetB" \ - --enable_center_slices False \ + --real_dataset_root "path/to/datasetA" \ + --real_filelist "path/to/filelistA.txt" \ + --real_features_dir "datasetA" \ + --synth_dataset_root "path/to/datasetB" \ + --synth_filelist "path/to/filelistB.txt" \ + --synth_features_dir "datasetB" \ + --enable_center_slices_ratio 0.4 \ --enable_padding True \ --enable_center_cropping True \ - --enable_resampling True \ + --enable_resampling_spacing "1.0x1.0x1.0" \ --ignore_existing True \ --num_images 100 \ --output_root "./features/features-512x512x512" \ --target_shape "512x512x512" -This script loads two datasets (real vs. synthetic) in 3D medical format (NIfTI) and extracts -feature maps via a 2.5D approach. It then computes the Frechet Inception Distance (FID) -across three orthogonal planes. Data parallelism is implemented using -torch.distributed with an NCCL backend. +This script loads two datasets (real vs. synthetic) in 3D medical format (NIfTI) +and extracts feature maps via a 2.5D approach. It then computes the Frechet +Inception Distance (FID) across three orthogonal planes. Data parallelism +is implemented using torch.distributed with an NCCL backend. Function Arguments (main): -------------------------- - data0_dataroot (str): Root folder for dataset 0 (real images). - data0_filelist (str): Text file listing 3D images for dataset 0. - data0_folder (str): Name for dataset 0 output folder under `output_root`. - data1_dataroot (str): Root folder for dataset 1 (synthetic images). - data1_filelist (str): Text file listing 3D images for dataset 1. - data1_folder (str): Name for dataset 1 output folder under `output_root`. - enable_center_slices (bool): Whether to slice around the center region of each axis. - enable_center_slices_ratio (float): Ratio of slices to take from the center if `enable_center_slices` is True. - enable_padding (bool): If True, pad images to `target_shape` before feature extraction. - enable_center_cropping (bool): If True, center-crop images to `target_shape`. - enable_resampling (bool): If True, resample images to `enable_resampling_spacing`. - enable_resampling_spacing (str): Target voxel spacing, e.g. "1.0x1.0x1.0". - ignore_existing (bool): If True, ignore existing .pt feature files and recompute. - model_name (str): Model identifier. Uses "radimagenet_resnet50" or a default squeezenet1_1. - num_images (int): Max number of images from each dataset to process (truncate if larger). - output_root (str): Folder where extracted feature .pt files and logs are saved. - target_shape (str): "XxYxZ" shape to which images are padded/cropped/resampled. + real_dataset_root (str): + Root folder for the real dataset. + real_filelist (str): + Text file listing 3D images for the real dataset. + + real_features_dir (str): + Subdirectory (under `output_root`) in which to store feature files + extracted from the real dataset. + + synth_dataset_root (str): + Root folder for the synthetic dataset. + + synth_filelist (str): + Text file listing 3D images for the synthetic dataset. + + synth_features_dir (str): + Subdirectory (under `output_root`) in which to store feature files + extracted from the synthetic dataset. + + enable_center_slices_ratio (float or None): + - If not None, only slices around the specified center ratio will be used + (analogous to "enable_center_slices=True" with that ratio). + - If None, no center-slice selection is performed + (analogous to "enable_center_slices=False"). + + enable_padding (bool): + Whether to pad images to `target_shape`. + + enable_center_cropping (bool): + Whether to center-crop images to `target_shape`. + + enable_resampling_spacing (str or None): + - If not None, resample images to the specified voxel spacing (e.g. "1.0x1.0x1.0") + (analogous to "enable_resampling=True" with that spacing). + - If None, resampling is skipped + (analogous to "enable_resampling=False"). + + ignore_existing (bool): + If True, ignore any existing .pt feature files and force re-extraction. + + model_name (str): + Model identifier. Typically "radimagenet_resnet50" or "squeezenet1_1". + + num_images (int): + Max number of images to process from each dataset (truncate if more are present). + + output_root (str): + Folder where extracted .pt feature files, logs, and results are saved. + + target_shape (str): + Target shape as "XxYxZ" for padding, cropping, or resampling operations. """ + from __future__ import annotations import os @@ -334,18 +369,16 @@ def pad_to_max_size(tensor: torch.Tensor, max_size: int, padding_value: float = def main( - data0_dataroot: str = "path/to/datasetA", - data0_filelist: str = "path/to/filelistA.txt", - data0_folder: str = "datasetA", - data1_dataroot: str = "path/to/datasetB", - data1_filelist: str = "path/to/filelistB.txt", - data1_folder: str = "datasetB", - enable_center_slices: bool = False, - enable_center_slices_ratio: float = 0.4, + real_dataset_root: str = "path/to/datasetA", + real_filelist: str = "path/to/filelistA.txt", + real_features_dir: str = "datasetA", + synth_dataset_root: str = "path/to/datasetB", + synth_filelist: str = "path/to/filelistB.txt", + synth_features_dir: str = "datasetB", + enable_center_slices_ratio: float = None, enable_padding: bool = True, enable_center_cropping: bool = True, - enable_resampling: bool = False, - enable_resampling_spacing: str = "1.0x1.0x1.0", + enable_resampling_spacing: str = None, ignore_existing: bool = False, model_name: str = "radimagenet_resnet50", num_images: int = 100, @@ -353,27 +386,74 @@ def main( target_shape: str = "512x512x512", ): """ - Main function to compute 2.5D features for two datasets (e.g. real vs. synthetic) - and calculate FID across XY, YZ, ZX planes. + Compute 2.5D FID using distributed GPU processing. + + This function loads two datasets (real vs. synthetic) in 3D medical format (NIfTI) + and extracts feature maps via a 2.5D approach, then computes the Frechet Inception + Distance (FID) across three orthogonal planes. Data parallelism is implemented + using torch.distributed with an NCCL backend. Args: - data0_dataroot (str): Root path of dataset 0 (real images). - data0_filelist (str): Text file listing the 3D images in dataset 0. - data0_folder (str): Name (subdir) for dataset 0 outputs under `output_root`. - data1_dataroot (str): Root path of dataset 1 (synthetic images). - data1_filelist (str): Text file listing the 3D images in dataset 1. - data1_folder (str): Name (subdir) for dataset 1 outputs under `output_root`. - enable_center_slices (bool): If True, only slices around the center ratio are used. - enable_center_slices_ratio (float): Ratio of center slices to keep if `enable_center_slices` is True. - enable_padding (bool): Whether to pad images to `target_shape`. - enable_center_cropping (bool): Whether to center-crop images to `target_shape`. - enable_resampling (bool): Whether to resample images to `enable_resampling_spacing`. - enable_resampling_spacing (str): Target voxel spacing as "XxYxZ", e.g. "1.0x1.0x1.0". - ignore_existing (bool): If True, re-extract features even if .pt files exist. - model_name (str): Model to use for feature extraction (radimagenet_resnet50 or squeezenet1_1). - num_images (int): Number of images to use from each dataset (truncate if exceeding). - output_root (str): Output directory to store extracted features and final results. - target_shape (str): Desired shape, as "XxYxZ", for padding/cropping/resampling. + real_dataset_root (str): + Root folder for the real dataset. + real_filelist (str): + Path to a text file listing 3D images (e.g., NIfTI files) for the real dataset. + Each line in this file should contain a relative path (or filename) to a NIfTI file. + For example, your "real_filelist.txt" could look like: + case001.nii.gz + case002.nii.gz + case003.nii.gz + ... + These entries will be appended to `real_dataset_root`. + real_features_dir (str): + Name of the directory under `output_root` in which to store + extracted features for the real dataset. + + synth_dataset_root (str): + Root folder for the synthetic dataset. + synth_filelist (str): + Path to a text file listing 3D images (e.g., NIfTI files) for the synthetic dataset. + The format is the same as the real dataset file list, for example: + synth_case001.nii.gz + synth_case002.nii.gz + synth_case003.nii.gz + ... + These entries will be appended to `synth_dataset_root`. + synth_features_dir (str): + Name of the directory under `output_root` in which to store + extracted features for the synthetic dataset. + + enable_center_slices_ratio (float or None): + - If not None, only slices around the specified center ratio are used. + (similar to "enable_center_slices=True" with that ratio in an earlier script). + - If None, no center-slice selection is performed + (similar to "enable_center_slices=False"). + + enable_padding (bool): + Whether to pad images to `target_shape`. + + enable_center_cropping (bool): + Whether to center-crop images to `target_shape`. + + enable_resampling_spacing (str or None): + - If not None, resample images to this voxel spacing (e.g. "1.0x1.0x1.0") + (similar to "enable_resampling=True" with that spacing). + - If None, skip resampling (similar to "enable_resampling=False"). + + ignore_existing (bool): + If True, ignore any existing .pt feature files and force re-computation. + + model_name (str): + Model identifier. Typically "radimagenet_resnet50" or "squeezenet1_1". + + num_images (int): + Maximum number of images to load from each dataset (truncate if more are present). + + output_root (str): + Parent folder where extracted .pt files and logs will be saved. + + target_shape (str): + Target shape, e.g. "512x512x512", for padding, cropping, or resampling operations. Returns: None @@ -389,23 +469,29 @@ def main( torch.cuda.set_device(device) logger.info(f"[INFO] Running process on {device} of total {world_size} ranks.") - # Convert potential string bools to actual bools (Fire sometimes passes strings) - if not isinstance(enable_center_slices, bool): - enable_center_slices = enable_center_slices.lower() == "true" + # Convert potential string bools to actual bools (if using Fire or similar) if not isinstance(enable_padding, bool): enable_padding = enable_padding.lower() == "true" if not isinstance(enable_center_cropping, bool): enable_center_cropping = enable_center_cropping.lower() == "true" - if not isinstance(enable_resampling, bool): - enable_resampling = enable_resampling.lower() == "true" if not isinstance(ignore_existing, bool): ignore_existing = ignore_existing.lower() == "true" + # Merge logic for center slices + enable_center_slices = enable_center_slices_ratio is not None + + # Merge logic for resampling + enable_resampling = enable_resampling_spacing is not None + # Print out some flags on rank 0 if local_rank == 0: + logger.info(f"Real dataset root: {real_dataset_root}") + logger.info(f"Synth dataset root: {synth_dataset_root}") + logger.info(f"enable_center_slices_ratio: {enable_center_slices_ratio}") logger.info(f"enable_center_slices: {enable_center_slices}") logger.info(f"enable_padding: {enable_padding}") logger.info(f"enable_center_cropping: {enable_center_cropping}") + logger.info(f"enable_resampling_spacing: {enable_resampling_spacing}") logger.info(f"enable_resampling: {enable_resampling}") logger.info(f"ignore_existing: {ignore_existing}") @@ -414,12 +500,14 @@ def main( # ------------------------------------------------------------------------- if model_name == "radimagenet_resnet50": feature_network = torch.hub.load( - "Warvito/radimagenet-models", model="radimagenet_resnet50", verbose=True, trust_repo=True + "Warvito/radimagenet-models", + model="radimagenet_resnet50", + verbose=True, + trust_repo=True ) suffix = "radimagenet_resnet50" else: import torchvision - feature_network = torchvision.models.squeezenet1_1(pretrained=True) suffix = "squeezenet1_1" @@ -431,44 +519,53 @@ def main( # ------------------------------------------------------------------------- t_shape = [int(x) for x in target_shape.split("x")] target_shape_tuple = tuple(t_shape) + + # If not None, parse the resampling spacing if enable_resampling: rs_spacing = [float(x) for x in enable_resampling_spacing.split("x")] rs_spacing_tuple = tuple(rs_spacing) if local_rank == 0: - logger.info(f"resampling spacing: {rs_spacing_tuple}") + logger.info(f"Resampling spacing: {rs_spacing_tuple}") else: rs_spacing_tuple = (1.0, 1.0, 1.0) + # Use the ratio if provided, otherwise 1.0 center_slices_ratio_final = enable_center_slices_ratio if enable_center_slices else 1.0 if local_rank == 0: logger.info(f"center_slices_ratio: {center_slices_ratio_final}") # ------------------------------------------------------------------------- - # Prepare dataset 0 + # Prepare Real Dataset # ------------------------------------------------------------------------- - output_root0 = os.path.join(output_root, data0_folder) - with open(data0_filelist, "r") as f0: - lines0 = [l.strip() for l in f0.readlines()] - lines0.sort() - lines0 = lines0[:num_images] - - filenames0 = [{"image": os.path.join(data0_dataroot, f)} for f in lines0] - filenames0 = monai.data.partition_dataset( - data=filenames0, shuffle=False, num_partitions=world_size, even_divisible=False + output_root_real = os.path.join(output_root, real_features_dir) + with open(real_filelist, "r") as rf: + real_lines = [l.strip() for l in rf.readlines()] + real_lines.sort() + real_lines = real_lines[:num_images] + + real_filenames = [{"image": os.path.join(real_dataset_root, f)} for f in real_lines] + real_filenames = monai.data.partition_dataset( + data=real_filenames, + shuffle=False, + num_partitions=world_size, + even_divisible=False )[local_rank] # ------------------------------------------------------------------------- - # Prepare dataset 1 + # Prepare Synthetic Dataset # ------------------------------------------------------------------------- - output_root1 = os.path.join(output_root, data1_folder) - with open(data1_filelist, "r") as f1: - lines1 = [l.strip() for l in f1.readlines()] - lines1.sort() - lines1 = lines1[:num_images] - - filenames1 = [{"image": os.path.join(data1_dataroot, f)} for f in lines1] - filenames1 = monai.data.partition_dataset( - data=filenames1, shuffle=False, num_partitions=world_size, even_divisible=False + output_root_synth = os.path.join(output_root, synth_features_dir) + with open(synth_filelist, "r") as sf: + synth_lines = [l.strip() for l in sf.readlines()] + synth_lines.sort() + synth_lines = synth_lines[:num_images] + + synth_filenames = [{"image": os.path.join(synth_dataset_root, f)} for f in synth_lines] + synth_filenames = monai.data.partition_dataset( + data=synth_filenames, + shuffle=False, + num_partitions=world_size, + even_divisible=False )[local_rank] # ------------------------------------------------------------------------- @@ -479,14 +576,25 @@ def main( monai.transforms.EnsureChannelFirstd(keys=["image"]), monai.transforms.Orientationd(keys=["image"], axcodes="RAS"), ] + if enable_resampling: - transform_list.append(monai.transforms.Spacingd(keys=["image"], pixdim=rs_spacing_tuple, mode=["bilinear"])) + transform_list.append( + monai.transforms.Spacingd( + keys=["image"], pixdim=rs_spacing_tuple, mode=["bilinear"] + ) + ) + if enable_padding: transform_list.append( - monai.transforms.SpatialPadd(keys=["image"], spatial_size=target_shape_tuple, mode="constant", value=-1000) + monai.transforms.SpatialPadd( + keys=["image"], spatial_size=target_shape_tuple, mode="constant", value=-1000 + ) ) + if enable_center_cropping: - transform_list.append(monai.transforms.CenterSpatialCropd(keys=["image"], roi_size=target_shape_tuple)) + transform_list.append( + monai.transforms.CenterSpatialCropd(keys=["image"], roi_size=target_shape_tuple) + ) transform_list.append( monai.transforms.ScaleIntensityRanged( @@ -498,22 +606,22 @@ def main( # ------------------------------------------------------------------------- # Create DataLoaders # ------------------------------------------------------------------------- - real_ds = monai.data.Dataset(data=filenames0, transform=transforms) + real_ds = monai.data.Dataset(data=real_filenames, transform=transforms) real_loader = monai.data.DataLoader(real_ds, num_workers=6, batch_size=1, shuffle=False) - synt_ds = monai.data.Dataset(data=filenames1, transform=transforms) - synt_loader = monai.data.DataLoader(synt_ds, num_workers=6, batch_size=1, shuffle=False) + synth_ds = monai.data.Dataset(data=synth_filenames, transform=transforms) + synth_loader = monai.data.DataLoader(synth_ds, num_workers=6, batch_size=1, shuffle=False) # ------------------------------------------------------------------------- - # Extract features for dataset 0 + # Extract features for Real Dataset # ------------------------------------------------------------------------- real_features_xy, real_features_yz, real_features_zx = [], [], [] for idx, batch_data in enumerate(real_loader, start=1): img = batch_data["image"].to(device) fn = img.meta["filename_or_obj"][0] - logger.info(f"[Rank {local_rank}] Real data {idx}/{len(filenames0)}: {fn}") + logger.info(f"[Rank {local_rank}] Real data {idx}/{len(real_filenames)}: {fn}") - out_fp = fn.replace(data0_dataroot, output_root0).replace(".nii.gz", ".pt") + out_fp = fn.replace(real_dataset_root, output_root_real).replace(".nii.gz", ".pt") out_fp = Path(out_fp) out_fp.parent.mkdir(parents=True, exist_ok=True) @@ -530,7 +638,9 @@ def main( center_slices_ratio=center_slices_ratio_final, xy_only=False, ) - logger.info(f"feats shapes: {feats[0].shape}, " f"{feats[1].shape}, {feats[2].shape}") + logger.info( + f"feats shapes: {feats[0].shape}, {feats[1].shape}, {feats[2].shape}" + ) torch.save(feats, out_fp) real_features_xy.append(feats[0]) @@ -541,19 +651,20 @@ def main( real_features_yz = torch.vstack(real_features_yz) real_features_zx = torch.vstack(real_features_zx) logger.info( - f"Real feature shapes: {real_features_xy.shape}, " f"{real_features_yz.shape}, {real_features_zx.shape}" + f"Real feature shapes: {real_features_xy.shape}, " + f"{real_features_yz.shape}, {real_features_zx.shape}" ) # ------------------------------------------------------------------------- - # Extract features for dataset 1 + # Extract features for Synthetic Dataset # ------------------------------------------------------------------------- synth_features_xy, synth_features_yz, synth_features_zx = [], [], [] - for idx, batch_data in enumerate(synt_loader, start=1): + for idx, batch_data in enumerate(synth_loader, start=1): img = batch_data["image"].to(device) fn = img.meta["filename_or_obj"][0] - logger.info(f"[Rank {local_rank}] Synthetic data {idx}/{len(filenames1)}: {fn}") + logger.info(f"[Rank {local_rank}] Synth data {idx}/{len(synth_filenames)}: {fn}") - out_fp = fn.replace(data1_dataroot, output_root1).replace(".nii.gz", ".pt") + out_fp = fn.replace(synth_dataset_root, output_root_synth).replace(".nii.gz", ".pt") out_fp = Path(out_fp) out_fp.parent.mkdir(parents=True, exist_ok=True) @@ -570,7 +681,9 @@ def main( center_slices_ratio=center_slices_ratio_final, xy_only=False, ) - logger.info(f"feats shapes: {feats[0].shape}, " f"{feats[1].shape}, {feats[2].shape}") + logger.info( + f"feats shapes: {feats[0].shape}, {feats[1].shape}, {feats[2].shape}" + ) torch.save(feats, out_fp) synth_features_xy.append(feats[0]) @@ -581,7 +694,8 @@ def main( synth_features_yz = torch.vstack(synth_features_yz) synth_features_zx = torch.vstack(synth_features_zx) logger.info( - f"Synthetic feature shapes: {synth_features_xy.shape}, " f"{synth_features_yz.shape}, {synth_features_zx.shape}" + f"Synth feature shapes: {synth_features_xy.shape}, " + f"{synth_features_yz.shape}, {synth_features_zx.shape}" ) # ------------------------------------------------------------------------- @@ -637,7 +751,7 @@ def main( logger.info(f"Final Synth shapes: {synth_xy.shape}, {synth_yz.shape}, {synth_zx.shape}") fid = FIDMetric() - logger.info(f"Computing FID for: {output_root0} | {output_root1}") + logger.info(f"Computing FID for: {output_root_real} | {output_root_synth}") fid_res_xy = fid(synth_xy, real_xy) fid_res_yz = fid(synth_yz, real_yz) fid_res_zx = fid(synth_zx, real_zx) From 3ae25bae2dbfa9c1f5ad22e9824ae8784b2e300b Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Thu, 13 Mar 2025 10:04:14 -0600 Subject: [PATCH 09/10] update readme Signed-off-by: dongyang0122 --- generation/maisi/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/generation/maisi/README.md b/generation/maisi/README.md index 038378ff0..398c44b20 100644 --- a/generation/maisi/README.md +++ b/generation/maisi/README.md @@ -265,11 +265,11 @@ We provide the `compute_fid_2-5d_ct.py` script that calculates the Frechet Incep #### Usage Example Suppose your **real** dataset root is `path/to/real_images`, and you have a `real_filelist.txt` that lists filenames line by line, such as: - +``` case001.nii.gz case002.nii.gz case003.nii.gz - +``` You also have a **synthetic** dataset in `path/to/synth_images` with a corresponding `synth_filelist.txt`. You can run the script as follows: ```bash From 2bff1d994f3423e2a577b13d30958e008b28f002 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 13 Mar 2025 16:04:25 +0000 Subject: [PATCH 10/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- generation/maisi/README.md | 6 +-- .../maisi/scripts/compute_fid_2-5d_ct.py | 50 ++++++------------- 2 files changed, 17 insertions(+), 39 deletions(-) diff --git a/generation/maisi/README.md b/generation/maisi/README.md index 398c44b20..727b1a266 100644 --- a/generation/maisi/README.md +++ b/generation/maisi/README.md @@ -253,13 +253,13 @@ We provide the `compute_fid_2-5d_ct.py` script that calculates the Frechet Incep #### Key Features -- **Distributed Processing** +- **Distributed Processing** Scales to multiple GPUs and larger datasets by splitting the workload across devices. -- **2.5D Feature Extraction** +- **2.5D Feature Extraction** Uses a slice-based technique, applying a 2D model across all slices in each dimension. -- **Flexible Preprocessing** +- **Flexible Preprocessing** Supports optional center-cropping, padding, and resampling to target shapes or voxel spacings. #### Usage Example diff --git a/generation/maisi/scripts/compute_fid_2-5d_ct.py b/generation/maisi/scripts/compute_fid_2-5d_ct.py index ababf63d1..123af7c10 100644 --- a/generation/maisi/scripts/compute_fid_2-5d_ct.py +++ b/generation/maisi/scripts/compute_fid_2-5d_ct.py @@ -390,7 +390,7 @@ def main( This function loads two datasets (real vs. synthetic) in 3D medical format (NIfTI) and extracts feature maps via a 2.5D approach, then computes the Frechet Inception - Distance (FID) across three orthogonal planes. Data parallelism is implemented + Distance (FID) across three orthogonal planes. Data parallelism is implemented using torch.distributed with an NCCL backend. Args: @@ -406,7 +406,7 @@ def main( ... These entries will be appended to `real_dataset_root`. real_features_dir (str): - Name of the directory under `output_root` in which to store + Name of the directory under `output_root` in which to store extracted features for the real dataset. synth_dataset_root (str): @@ -420,7 +420,7 @@ def main( ... These entries will be appended to `synth_dataset_root`. synth_features_dir (str): - Name of the directory under `output_root` in which to store + Name of the directory under `output_root` in which to store extracted features for the synthetic dataset. enable_center_slices_ratio (float or None): @@ -500,14 +500,12 @@ def main( # ------------------------------------------------------------------------- if model_name == "radimagenet_resnet50": feature_network = torch.hub.load( - "Warvito/radimagenet-models", - model="radimagenet_resnet50", - verbose=True, - trust_repo=True + "Warvito/radimagenet-models", model="radimagenet_resnet50", verbose=True, trust_repo=True ) suffix = "radimagenet_resnet50" else: import torchvision + feature_network = torchvision.models.squeezenet1_1(pretrained=True) suffix = "squeezenet1_1" @@ -545,10 +543,7 @@ def main( real_filenames = [{"image": os.path.join(real_dataset_root, f)} for f in real_lines] real_filenames = monai.data.partition_dataset( - data=real_filenames, - shuffle=False, - num_partitions=world_size, - even_divisible=False + data=real_filenames, shuffle=False, num_partitions=world_size, even_divisible=False )[local_rank] # ------------------------------------------------------------------------- @@ -562,10 +557,7 @@ def main( synth_filenames = [{"image": os.path.join(synth_dataset_root, f)} for f in synth_lines] synth_filenames = monai.data.partition_dataset( - data=synth_filenames, - shuffle=False, - num_partitions=world_size, - even_divisible=False + data=synth_filenames, shuffle=False, num_partitions=world_size, even_divisible=False )[local_rank] # ------------------------------------------------------------------------- @@ -578,23 +570,15 @@ def main( ] if enable_resampling: - transform_list.append( - monai.transforms.Spacingd( - keys=["image"], pixdim=rs_spacing_tuple, mode=["bilinear"] - ) - ) + transform_list.append(monai.transforms.Spacingd(keys=["image"], pixdim=rs_spacing_tuple, mode=["bilinear"])) if enable_padding: transform_list.append( - monai.transforms.SpatialPadd( - keys=["image"], spatial_size=target_shape_tuple, mode="constant", value=-1000 - ) + monai.transforms.SpatialPadd(keys=["image"], spatial_size=target_shape_tuple, mode="constant", value=-1000) ) if enable_center_cropping: - transform_list.append( - monai.transforms.CenterSpatialCropd(keys=["image"], roi_size=target_shape_tuple) - ) + transform_list.append(monai.transforms.CenterSpatialCropd(keys=["image"], roi_size=target_shape_tuple)) transform_list.append( monai.transforms.ScaleIntensityRanged( @@ -638,9 +622,7 @@ def main( center_slices_ratio=center_slices_ratio_final, xy_only=False, ) - logger.info( - f"feats shapes: {feats[0].shape}, {feats[1].shape}, {feats[2].shape}" - ) + logger.info(f"feats shapes: {feats[0].shape}, {feats[1].shape}, {feats[2].shape}") torch.save(feats, out_fp) real_features_xy.append(feats[0]) @@ -651,8 +633,7 @@ def main( real_features_yz = torch.vstack(real_features_yz) real_features_zx = torch.vstack(real_features_zx) logger.info( - f"Real feature shapes: {real_features_xy.shape}, " - f"{real_features_yz.shape}, {real_features_zx.shape}" + f"Real feature shapes: {real_features_xy.shape}, " f"{real_features_yz.shape}, {real_features_zx.shape}" ) # ------------------------------------------------------------------------- @@ -681,9 +662,7 @@ def main( center_slices_ratio=center_slices_ratio_final, xy_only=False, ) - logger.info( - f"feats shapes: {feats[0].shape}, {feats[1].shape}, {feats[2].shape}" - ) + logger.info(f"feats shapes: {feats[0].shape}, {feats[1].shape}, {feats[2].shape}") torch.save(feats, out_fp) synth_features_xy.append(feats[0]) @@ -694,8 +673,7 @@ def main( synth_features_yz = torch.vstack(synth_features_yz) synth_features_zx = torch.vstack(synth_features_zx) logger.info( - f"Synth feature shapes: {synth_features_xy.shape}, " - f"{synth_features_yz.shape}, {synth_features_zx.shape}" + f"Synth feature shapes: {synth_features_xy.shape}, " f"{synth_features_yz.shape}, {synth_features_zx.shape}" ) # -------------------------------------------------------------------------