diff --git a/generation/maisi/README.md b/generation/maisi/README.md index f51cdf61f..727b1a266 100644 --- a/generation/maisi/README.md +++ b/generation/maisi/README.md @@ -247,13 +247,65 @@ torchrun \ ``` Please also check [maisi_train_controlnet_tutorial.ipynb](./maisi_train_controlnet_tutorial.ipynb) for more details about data preparation and training parameters. -### 4. License +### 4. FID Score Computation + +We provide the `compute_fid_2-5d_ct.py` script that calculates the Frechet Inception Distance (FID) between two 3D medical datasets (e.g., **real** vs. **synthetic** images). It uses a **2.5D** feature-extraction approach across three orthogonal planes (XY, YZ, ZX) and leverages **distributed GPU processing** (via PyTorch’s `torch.distributed` and NCCL) for efficient, large-scale computations. + +#### Key Features + +- **Distributed Processing** + Scales to multiple GPUs and larger datasets by splitting the workload across devices. + +- **2.5D Feature Extraction** + Uses a slice-based technique, applying a 2D model across all slices in each dimension. + +- **Flexible Preprocessing** + Supports optional center-cropping, padding, and resampling to target shapes or voxel spacings. + +#### Usage Example + +Suppose your **real** dataset root is `path/to/real_images`, and you have a `real_filelist.txt` that lists filenames line by line, such as: +``` +case001.nii.gz +case002.nii.gz +case003.nii.gz +``` +You also have a **synthetic** dataset in `path/to/synth_images` with a corresponding `synth_filelist.txt`. You can run the script as follows: + +```bash +torchrun --nproc_per_node=2 compute_fid_2-5d_ct.py \ + --model_name "radimagenet_resnet50" \ + --real_dataset_root "path/to/real_images" \ + --real_filelist "path/to/real_filelist.txt" \ + --real_features_dir "datasetA" \ + --synth_dataset_root "path/to/synth_images" \ + --synth_filelist "path/to/synth_filelist.txt" \ + --synth_features_dir "datasetB" \ + --enable_center_slices_ratio 0.4 \ + --enable_padding True \ + --enable_center_cropping True \ + --enable_resampling_spacing "1.0x1.0x1.0" \ + --ignore_existing True \ + --num_images 100 \ + --output_root "./features/features-512x512x512" \ + --target_shape "512x512x512" +``` + +This command will: +1. Launch a distributed run with 2 GPUs. +2. Load each `.nii.gz` file from your specified `real_filelist` and `synth_filelist`. +3. Apply 2.5D feature extraction across the XY, YZ, and ZX planes. +4. Compute FID to compare **real** vs. **synthetic** feature distributions. + +For more details, see the in-code docstring in [`compute_fid_2-5d_ct.py`](./scripts/compute_fid_2-5d_ct.py) or consult our documentation for a deeper dive into function arguments and the underlying implementation. + +### 5. License The code is released under Apache 2.0 License. The model weight is released under [NSCLv1 License](./LICENSE.weights). -### 5. Questions and Bugs +### 6. Questions and Bugs - For questions relating to the use of MONAI, please use our [Discussions tab](https://github.com/Project-MONAI/MONAI/discussions) on the main repository of MONAI. - For bugs relating to MONAI functionality, please create an issue on the [main repository](https://github.com/Project-MONAI/MONAI/issues). diff --git a/generation/maisi/scripts/compute_fid_2-5d_ct.py b/generation/maisi/scripts/compute_fid_2-5d_ct.py new file mode 100644 index 000000000..123af7c10 --- /dev/null +++ b/generation/maisi/scripts/compute_fid_2-5d_ct.py @@ -0,0 +1,747 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, +# either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. + +""" +Compute 2.5D FID using distributed GPU processing. + +SHELL Usage Example: +------------------- + #!/bin/bash + + export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 + NUM_GPUS=7 + + torchrun --nproc_per_node=${NUM_GPUS} compute_fid_2-5d_ct.py \ + --model_name "radimagenet_resnet50" \ + --real_dataset_root "path/to/datasetA" \ + --real_filelist "path/to/filelistA.txt" \ + --real_features_dir "datasetA" \ + --synth_dataset_root "path/to/datasetB" \ + --synth_filelist "path/to/filelistB.txt" \ + --synth_features_dir "datasetB" \ + --enable_center_slices_ratio 0.4 \ + --enable_padding True \ + --enable_center_cropping True \ + --enable_resampling_spacing "1.0x1.0x1.0" \ + --ignore_existing True \ + --num_images 100 \ + --output_root "./features/features-512x512x512" \ + --target_shape "512x512x512" + +This script loads two datasets (real vs. synthetic) in 3D medical format (NIfTI) +and extracts feature maps via a 2.5D approach. It then computes the Frechet +Inception Distance (FID) across three orthogonal planes. Data parallelism +is implemented using torch.distributed with an NCCL backend. + +Function Arguments (main): +-------------------------- + real_dataset_root (str): + Root folder for the real dataset. + + real_filelist (str): + Text file listing 3D images for the real dataset. + + real_features_dir (str): + Subdirectory (under `output_root`) in which to store feature files + extracted from the real dataset. + + synth_dataset_root (str): + Root folder for the synthetic dataset. + + synth_filelist (str): + Text file listing 3D images for the synthetic dataset. + + synth_features_dir (str): + Subdirectory (under `output_root`) in which to store feature files + extracted from the synthetic dataset. + + enable_center_slices_ratio (float or None): + - If not None, only slices around the specified center ratio will be used + (analogous to "enable_center_slices=True" with that ratio). + - If None, no center-slice selection is performed + (analogous to "enable_center_slices=False"). + + enable_padding (bool): + Whether to pad images to `target_shape`. + + enable_center_cropping (bool): + Whether to center-crop images to `target_shape`. + + enable_resampling_spacing (str or None): + - If not None, resample images to the specified voxel spacing (e.g. "1.0x1.0x1.0") + (analogous to "enable_resampling=True" with that spacing). + - If None, resampling is skipped + (analogous to "enable_resampling=False"). + + ignore_existing (bool): + If True, ignore any existing .pt feature files and force re-extraction. + + model_name (str): + Model identifier. Typically "radimagenet_resnet50" or "squeezenet1_1". + + num_images (int): + Max number of images to process from each dataset (truncate if more are present). + + output_root (str): + Folder where extracted .pt feature files, logs, and results are saved. + + target_shape (str): + Target shape as "XxYxZ" for padding, cropping, or resampling operations. +""" + + +from __future__ import annotations + +import os +import sys +import torch +import fire +import monai +import re +import torch.distributed as dist +import torch.nn.functional as F + +from datetime import timedelta +from pathlib import Path +from monai.metrics.fid import FIDMetric +from monai.transforms import Compose + +import logging + +# ------------------------------------------------------------------------------ +# Create logger +# ------------------------------------------------------------------------------ +logger = logging.getLogger("fid_2-5d_ct") +if not logger.handlers: + # Configure logger only if it has no handlers (avoid reconfiguring in multi-rank scenarios) + logging.basicConfig(stream=sys.stdout, level=logging.INFO) +logger.setLevel(logging.INFO) + + +def drop_empty_slice(slices, empty_threshold: float): + """ + Decide which 2D slices to keep by checking if their maximum intensity + is below a certain threshold. + + Args: + slices (tuple or list of Tensors): Each element is (B, C, H, W). + empty_threshold (float): If the slice's maximum value is below this threshold, + it is considered "empty". + + Returns: + list[bool]: A list of booleans indicating for each slice whether to keep it. + """ + outputs = [] + n_drop = 0 + for s in slices: + largest_unique = torch.max(torch.unique(s)) + if largest_unique < empty_threshold: + outputs.append(False) + n_drop += 1 + else: + outputs.append(True) + + logger.info(f"Empty slice drop rate {round((n_drop/len(slices))*100,1)}%") + return outputs + + +def subtract_mean(x: torch.Tensor) -> torch.Tensor: + """ + Subtract per-channel means (ImageNet-like: [0.406, 0.456, 0.485]) + from the input 4D or 5D tensor. Expects channels in the first dimension + after the batch dimension: (B, C, H, W) or (B, C, H, W, D). + """ + mean = [0.406, 0.456, 0.485] + x[:, 0, ...] -= mean[0] + x[:, 1, ...] -= mean[1] + x[:, 2, ...] -= mean[2] + return x + + +def spatial_average(x: torch.Tensor, keepdim: bool = True) -> torch.Tensor: + """ + Average out the spatial dimensions of a tensor, preserving or removing them + according to `keepdim`. This is used to produce a 1D feature vector + out of a feature map. + + Args: + x (torch.Tensor): Input tensor (B, C, H, W, ...) or (B, C, H, W). + keepdim (bool): Whether to keep dimension or not after averaging. + + Returns: + torch.Tensor: Tensor with reduced spatial dimensions. + """ + dim = len(x.shape) + # 2D -> no average + if dim == 2: + return x + # 3D -> average over last dim + if dim == 3: + return x.mean([2], keepdim=keepdim) + # 4D -> average over H,W + if dim == 4: + return x.mean([2, 3], keepdim=keepdim) + # 5D -> average over H,W,D + if dim == 5: + return x.mean([2, 3, 4], keepdim=keepdim) + return x + + +def medicalnet_intensity_normalisation(volume: torch.Tensor) -> torch.Tensor: + """ + Intensity normalization approach from MedicalNet: + (volume - mean) / (std + 1e-5) across spatial dims. + Expects (B, C, H, W) or (B, C, H, W, D). + """ + dim = len(volume.shape) + if dim == 4: + mean = volume.mean([2, 3], keepdim=True) + std = volume.std([2, 3], keepdim=True) + elif dim == 5: + mean = volume.mean([2, 3, 4], keepdim=True) + std = volume.std([2, 3, 4], keepdim=True) + else: + return volume + return (volume - mean) / (std + 1e-5) + + +def radimagenet_intensity_normalisation(volume: torch.Tensor, norm2d: bool = False) -> torch.Tensor: + """ + Intensity normalization for radimagenet_resnet. Optionally normalizes each 2D slice individually. + + Args: + volume (torch.Tensor): Input (B, C, H, W) or (B, C, H, W, D). + norm2d (bool): If True, normalizes each (H,W) slice to [0,1], then subtracts the ImageNet mean. + """ + logger.info(f"norm2d: {norm2d}") + dim = len(volume.shape) + # If norm2d is True, only meaningful for 4D data (B, C, H, W): + if dim == 4 and norm2d: + max2d, _ = torch.max(volume, dim=2, keepdim=True) + max2d, _ = torch.max(max2d, dim=3, keepdim=True) + min2d, _ = torch.min(volume, dim=2, keepdim=True) + min2d, _ = torch.min(min2d, dim=3, keepdim=True) + # Scale each slice to 0..1 + volume = (volume - min2d) / (max2d - min2d + 1e-10) + # Subtract channel mean + return subtract_mean(volume) + elif dim == 4: + # 4D but no per-slice normalization + max3d = torch.max(volume) + min3d = torch.min(volume) + volume = (volume - min3d) / (max3d - min3d + 1e-10) + return subtract_mean(volume) + # Fallback for e.g. 5D data is simply a min-max over entire volume + if dim == 5: + maxval = torch.max(volume) + minval = torch.min(volume) + volume = (volume - minval) / (maxval - minval + 1e-10) + return subtract_mean(volume) + return volume + + +def get_features_2p5d( + image: torch.Tensor, + feature_network: torch.nn.Module, + center_slices: bool = False, + center_slices_ratio: float = 1.0, + sample_every_k: int = 1, + xy_only: bool = True, + drop_empty: bool = False, + empty_threshold: float = -700, +) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: + """ + Extract 2.5D features from a 3D image by slicing it along XY, YZ, ZX planes. + + Args: + image (torch.Tensor): Input 5D tensor in shape (B, C, H, W, D). + feature_network (torch.nn.Module): Model that processes 2D slices (C,H,W). + center_slices (bool): Whether to slice only the center portion of each axis. + center_slices_ratio (float): Ratio of slices to keep in the center if `center_slices` is True. + sample_every_k (int): Downsampling factor along each axis when slicing. + xy_only (bool): If True, return only the XY-plane features. + drop_empty (bool): Drop slices that are deemed "empty" below `empty_threshold`. + empty_threshold (float): Threshold to decide emptiness of slices. + + Returns: + tuple of torch.Tensor or None: (XY_features, YZ_features, ZX_features). + """ + logger.info(f"center_slices: {center_slices}, ratio: {center_slices_ratio}") + + # If there's only 1 channel, replicate to 3 channels + if image.shape[1] == 1: + image = image.repeat(1, 3, 1, 1, 1) + + # Convert from 'RGB'→(R,G,B) to (B,G,R) + image = image[:, [2, 1, 0], ...] + + B, C, H, W, D = image.size() + with torch.no_grad(): + # ---------------------- XY-plane slicing along D ---------------------- + if center_slices: + start_d = int((1.0 - center_slices_ratio) / 2.0 * D) + end_d = int((1.0 + center_slices_ratio) / 2.0 * D) + slices = torch.unbind(image[:, :, :, :, start_d:end_d:sample_every_k], dim=-1) + else: + slices = torch.unbind(image, dim=-1) + + if drop_empty: + mapping_index = drop_empty_slice(slices, empty_threshold) + else: + mapping_index = [True for _ in range(len(slices))] + + images_2d = torch.cat(slices, dim=0) + images_2d = radimagenet_intensity_normalisation(images_2d) + images_2d = images_2d[mapping_index] + + feature_image_xy = feature_network.forward(images_2d) + feature_image_xy = spatial_average(feature_image_xy, keepdim=False) + if xy_only: + return feature_image_xy, None, None + + # ---------------------- YZ-plane slicing along H ---------------------- + if center_slices: + start_h = int((1.0 - center_slices_ratio) / 2.0 * H) + end_h = int((1.0 + center_slices_ratio) / 2.0 * H) + slices = torch.unbind(image[:, :, start_h:end_h:sample_every_k, :, :], dim=2) + else: + slices = torch.unbind(image, dim=2) + + if drop_empty: + mapping_index = drop_empty_slice(slices, empty_threshold) + else: + mapping_index = [True for _ in range(len(slices))] + + images_2d = torch.cat(slices, dim=0) + images_2d = radimagenet_intensity_normalisation(images_2d) + images_2d = images_2d[mapping_index] + + feature_image_yz = feature_network.forward(images_2d) + feature_image_yz = spatial_average(feature_image_yz, keepdim=False) + + # ---------------------- ZX-plane slicing along W ---------------------- + if center_slices: + start_w = int((1.0 - center_slices_ratio) / 2.0 * W) + end_w = int((1.0 + center_slices_ratio) / 2.0 * W) + slices = torch.unbind(image[:, :, :, start_w:end_w:sample_every_k, :], dim=3) + else: + slices = torch.unbind(image, dim=3) + + if drop_empty: + mapping_index = drop_empty_slice(slices, empty_threshold) + else: + mapping_index = [True for _ in range(len(slices))] + + images_2d = torch.cat(slices, dim=0) + images_2d = radimagenet_intensity_normalisation(images_2d) + images_2d = images_2d[mapping_index] + + feature_image_zx = feature_network.forward(images_2d) + feature_image_zx = spatial_average(feature_image_zx, keepdim=False) + + return feature_image_xy, feature_image_yz, feature_image_zx + + +def pad_to_max_size(tensor: torch.Tensor, max_size: int, padding_value: float = 0.0) -> torch.Tensor: + """ + Zero-pad a 2D feature map or other tensor along the first dimension to match a specified size. + + Args: + tensor (torch.Tensor): The feature tensor to pad. + max_size (int): Desired size along the first dimension. + padding_value (float): Value to fill during padding. + + Returns: + torch.Tensor: Padded tensor matching `max_size` along dim=0. + """ + pad_size = [0, 0] * (len(tensor.shape) - 1) + [0, max_size - tensor.shape[0]] + return F.pad(tensor, pad_size, "constant", padding_value) + + +def main( + real_dataset_root: str = "path/to/datasetA", + real_filelist: str = "path/to/filelistA.txt", + real_features_dir: str = "datasetA", + synth_dataset_root: str = "path/to/datasetB", + synth_filelist: str = "path/to/filelistB.txt", + synth_features_dir: str = "datasetB", + enable_center_slices_ratio: float = None, + enable_padding: bool = True, + enable_center_cropping: bool = True, + enable_resampling_spacing: str = None, + ignore_existing: bool = False, + model_name: str = "radimagenet_resnet50", + num_images: int = 100, + output_root: str = "./features/features-512x512x512", + target_shape: str = "512x512x512", +): + """ + Compute 2.5D FID using distributed GPU processing. + + This function loads two datasets (real vs. synthetic) in 3D medical format (NIfTI) + and extracts feature maps via a 2.5D approach, then computes the Frechet Inception + Distance (FID) across three orthogonal planes. Data parallelism is implemented + using torch.distributed with an NCCL backend. + + Args: + real_dataset_root (str): + Root folder for the real dataset. + real_filelist (str): + Path to a text file listing 3D images (e.g., NIfTI files) for the real dataset. + Each line in this file should contain a relative path (or filename) to a NIfTI file. + For example, your "real_filelist.txt" could look like: + case001.nii.gz + case002.nii.gz + case003.nii.gz + ... + These entries will be appended to `real_dataset_root`. + real_features_dir (str): + Name of the directory under `output_root` in which to store + extracted features for the real dataset. + + synth_dataset_root (str): + Root folder for the synthetic dataset. + synth_filelist (str): + Path to a text file listing 3D images (e.g., NIfTI files) for the synthetic dataset. + The format is the same as the real dataset file list, for example: + synth_case001.nii.gz + synth_case002.nii.gz + synth_case003.nii.gz + ... + These entries will be appended to `synth_dataset_root`. + synth_features_dir (str): + Name of the directory under `output_root` in which to store + extracted features for the synthetic dataset. + + enable_center_slices_ratio (float or None): + - If not None, only slices around the specified center ratio are used. + (similar to "enable_center_slices=True" with that ratio in an earlier script). + - If None, no center-slice selection is performed + (similar to "enable_center_slices=False"). + + enable_padding (bool): + Whether to pad images to `target_shape`. + + enable_center_cropping (bool): + Whether to center-crop images to `target_shape`. + + enable_resampling_spacing (str or None): + - If not None, resample images to this voxel spacing (e.g. "1.0x1.0x1.0") + (similar to "enable_resampling=True" with that spacing). + - If None, skip resampling (similar to "enable_resampling=False"). + + ignore_existing (bool): + If True, ignore any existing .pt feature files and force re-computation. + + model_name (str): + Model identifier. Typically "radimagenet_resnet50" or "squeezenet1_1". + + num_images (int): + Maximum number of images to load from each dataset (truncate if more are present). + + output_root (str): + Parent folder where extracted .pt files and logs will be saved. + + target_shape (str): + Target shape, e.g. "512x512x512", for padding, cropping, or resampling operations. + + Returns: + None + """ + # ------------------------------------------------------------------------- + # Initialize Process Group (Distributed) + # ------------------------------------------------------------------------- + dist.init_process_group(backend="nccl", init_method="env://", timeout=timedelta(seconds=7200)) + + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(dist.get_world_size()) + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + logger.info(f"[INFO] Running process on {device} of total {world_size} ranks.") + + # Convert potential string bools to actual bools (if using Fire or similar) + if not isinstance(enable_padding, bool): + enable_padding = enable_padding.lower() == "true" + if not isinstance(enable_center_cropping, bool): + enable_center_cropping = enable_center_cropping.lower() == "true" + if not isinstance(ignore_existing, bool): + ignore_existing = ignore_existing.lower() == "true" + + # Merge logic for center slices + enable_center_slices = enable_center_slices_ratio is not None + + # Merge logic for resampling + enable_resampling = enable_resampling_spacing is not None + + # Print out some flags on rank 0 + if local_rank == 0: + logger.info(f"Real dataset root: {real_dataset_root}") + logger.info(f"Synth dataset root: {synth_dataset_root}") + logger.info(f"enable_center_slices_ratio: {enable_center_slices_ratio}") + logger.info(f"enable_center_slices: {enable_center_slices}") + logger.info(f"enable_padding: {enable_padding}") + logger.info(f"enable_center_cropping: {enable_center_cropping}") + logger.info(f"enable_resampling_spacing: {enable_resampling_spacing}") + logger.info(f"enable_resampling: {enable_resampling}") + logger.info(f"ignore_existing: {ignore_existing}") + + # ------------------------------------------------------------------------- + # Load feature extraction model + # ------------------------------------------------------------------------- + if model_name == "radimagenet_resnet50": + feature_network = torch.hub.load( + "Warvito/radimagenet-models", model="radimagenet_resnet50", verbose=True, trust_repo=True + ) + suffix = "radimagenet_resnet50" + else: + import torchvision + + feature_network = torchvision.models.squeezenet1_1(pretrained=True) + suffix = "squeezenet1_1" + + feature_network.to(device) + feature_network.eval() + + # ------------------------------------------------------------------------- + # Parse shape/spacings + # ------------------------------------------------------------------------- + t_shape = [int(x) for x in target_shape.split("x")] + target_shape_tuple = tuple(t_shape) + + # If not None, parse the resampling spacing + if enable_resampling: + rs_spacing = [float(x) for x in enable_resampling_spacing.split("x")] + rs_spacing_tuple = tuple(rs_spacing) + if local_rank == 0: + logger.info(f"Resampling spacing: {rs_spacing_tuple}") + else: + rs_spacing_tuple = (1.0, 1.0, 1.0) + + # Use the ratio if provided, otherwise 1.0 + center_slices_ratio_final = enable_center_slices_ratio if enable_center_slices else 1.0 + if local_rank == 0: + logger.info(f"center_slices_ratio: {center_slices_ratio_final}") + + # ------------------------------------------------------------------------- + # Prepare Real Dataset + # ------------------------------------------------------------------------- + output_root_real = os.path.join(output_root, real_features_dir) + with open(real_filelist, "r") as rf: + real_lines = [l.strip() for l in rf.readlines()] + real_lines.sort() + real_lines = real_lines[:num_images] + + real_filenames = [{"image": os.path.join(real_dataset_root, f)} for f in real_lines] + real_filenames = monai.data.partition_dataset( + data=real_filenames, shuffle=False, num_partitions=world_size, even_divisible=False + )[local_rank] + + # ------------------------------------------------------------------------- + # Prepare Synthetic Dataset + # ------------------------------------------------------------------------- + output_root_synth = os.path.join(output_root, synth_features_dir) + with open(synth_filelist, "r") as sf: + synth_lines = [l.strip() for l in sf.readlines()] + synth_lines.sort() + synth_lines = synth_lines[:num_images] + + synth_filenames = [{"image": os.path.join(synth_dataset_root, f)} for f in synth_lines] + synth_filenames = monai.data.partition_dataset( + data=synth_filenames, shuffle=False, num_partitions=world_size, even_divisible=False + )[local_rank] + + # ------------------------------------------------------------------------- + # Build MONAI transforms + # ------------------------------------------------------------------------- + transform_list = [ + monai.transforms.LoadImaged(keys=["image"]), + monai.transforms.EnsureChannelFirstd(keys=["image"]), + monai.transforms.Orientationd(keys=["image"], axcodes="RAS"), + ] + + if enable_resampling: + transform_list.append(monai.transforms.Spacingd(keys=["image"], pixdim=rs_spacing_tuple, mode=["bilinear"])) + + if enable_padding: + transform_list.append( + monai.transforms.SpatialPadd(keys=["image"], spatial_size=target_shape_tuple, mode="constant", value=-1000) + ) + + if enable_center_cropping: + transform_list.append(monai.transforms.CenterSpatialCropd(keys=["image"], roi_size=target_shape_tuple)) + + transform_list.append( + monai.transforms.ScaleIntensityRanged( + keys=["image"], a_min=-1000, a_max=1000, b_min=-1000, b_max=1000, clip=True + ) + ) + transforms = Compose(transform_list) + + # ------------------------------------------------------------------------- + # Create DataLoaders + # ------------------------------------------------------------------------- + real_ds = monai.data.Dataset(data=real_filenames, transform=transforms) + real_loader = monai.data.DataLoader(real_ds, num_workers=6, batch_size=1, shuffle=False) + + synth_ds = monai.data.Dataset(data=synth_filenames, transform=transforms) + synth_loader = monai.data.DataLoader(synth_ds, num_workers=6, batch_size=1, shuffle=False) + + # ------------------------------------------------------------------------- + # Extract features for Real Dataset + # ------------------------------------------------------------------------- + real_features_xy, real_features_yz, real_features_zx = [], [], [] + for idx, batch_data in enumerate(real_loader, start=1): + img = batch_data["image"].to(device) + fn = img.meta["filename_or_obj"][0] + logger.info(f"[Rank {local_rank}] Real data {idx}/{len(real_filenames)}: {fn}") + + out_fp = fn.replace(real_dataset_root, output_root_real).replace(".nii.gz", ".pt") + out_fp = Path(out_fp) + out_fp.parent.mkdir(parents=True, exist_ok=True) + + if (not ignore_existing) and os.path.isfile(out_fp): + feats = torch.load(out_fp) + else: + img_t = img.as_tensor() + logger.info(f"image shape: {tuple(img_t.shape)}") + + feats = get_features_2p5d( + img_t, + feature_network, + center_slices=enable_center_slices, + center_slices_ratio=center_slices_ratio_final, + xy_only=False, + ) + logger.info(f"feats shapes: {feats[0].shape}, {feats[1].shape}, {feats[2].shape}") + torch.save(feats, out_fp) + + real_features_xy.append(feats[0]) + real_features_yz.append(feats[1]) + real_features_zx.append(feats[2]) + + real_features_xy = torch.vstack(real_features_xy) + real_features_yz = torch.vstack(real_features_yz) + real_features_zx = torch.vstack(real_features_zx) + logger.info( + f"Real feature shapes: {real_features_xy.shape}, " f"{real_features_yz.shape}, {real_features_zx.shape}" + ) + + # ------------------------------------------------------------------------- + # Extract features for Synthetic Dataset + # ------------------------------------------------------------------------- + synth_features_xy, synth_features_yz, synth_features_zx = [], [], [] + for idx, batch_data in enumerate(synth_loader, start=1): + img = batch_data["image"].to(device) + fn = img.meta["filename_or_obj"][0] + logger.info(f"[Rank {local_rank}] Synth data {idx}/{len(synth_filenames)}: {fn}") + + out_fp = fn.replace(synth_dataset_root, output_root_synth).replace(".nii.gz", ".pt") + out_fp = Path(out_fp) + out_fp.parent.mkdir(parents=True, exist_ok=True) + + if (not ignore_existing) and os.path.isfile(out_fp): + feats = torch.load(out_fp) + else: + img_t = img.as_tensor() + logger.info(f"image shape: {tuple(img_t.shape)}") + + feats = get_features_2p5d( + img_t, + feature_network, + center_slices=enable_center_slices, + center_slices_ratio=center_slices_ratio_final, + xy_only=False, + ) + logger.info(f"feats shapes: {feats[0].shape}, {feats[1].shape}, {feats[2].shape}") + torch.save(feats, out_fp) + + synth_features_xy.append(feats[0]) + synth_features_yz.append(feats[1]) + synth_features_zx.append(feats[2]) + + synth_features_xy = torch.vstack(synth_features_xy) + synth_features_yz = torch.vstack(synth_features_yz) + synth_features_zx = torch.vstack(synth_features_zx) + logger.info( + f"Synth feature shapes: {synth_features_xy.shape}, " f"{synth_features_yz.shape}, {synth_features_zx.shape}" + ) + + # ------------------------------------------------------------------------- + # All-reduce / gather features across ranks + # ------------------------------------------------------------------------- + features = [ + real_features_xy, + real_features_yz, + real_features_zx, + synth_features_xy, + synth_features_yz, + synth_features_zx, + ] + + # 1) Gather local feature sizes across ranks + local_sizes = [] + for ft_idx in range(len(features)): + local_size = torch.tensor([features[ft_idx].shape[0]], dtype=torch.int64, device=device) + local_sizes.append(local_size) + + all_sizes = [] + for ft_idx in range(len(features)): + rank_sizes = [torch.tensor([0], dtype=torch.int64, device=device) for _ in range(world_size)] + dist.all_gather(rank_sizes, local_sizes[ft_idx]) + all_sizes.append(rank_sizes) + + # 2) Pad and gather all features + all_tensors_list = [] + for ft_idx, ft in enumerate(features): + max_size = max(all_sizes[ft_idx]).item() + ft_padded = pad_to_max_size(ft, max_size) + + gather_list = [torch.empty_like(ft_padded) for _ in range(world_size)] + dist.all_gather(gather_list, ft_padded) + + # Trim each gather back to the real size + for rk in range(world_size): + gather_list[rk] = gather_list[rk][: all_sizes[ft_idx][rk], :] + + all_tensors_list.append(gather_list) + + # On rank 0, compute FID + if local_rank == 0: + real_xy = torch.vstack(all_tensors_list[0]) + real_yz = torch.vstack(all_tensors_list[1]) + real_zx = torch.vstack(all_tensors_list[2]) + + synth_xy = torch.vstack(all_tensors_list[3]) + synth_yz = torch.vstack(all_tensors_list[4]) + synth_zx = torch.vstack(all_tensors_list[5]) + + logger.info(f"Final Real shapes: {real_xy.shape}, {real_yz.shape}, {real_zx.shape}") + logger.info(f"Final Synth shapes: {synth_xy.shape}, {synth_yz.shape}, {synth_zx.shape}") + + fid = FIDMetric() + logger.info(f"Computing FID for: {output_root_real} | {output_root_synth}") + fid_res_xy = fid(synth_xy, real_xy) + fid_res_yz = fid(synth_yz, real_yz) + fid_res_zx = fid(synth_zx, real_zx) + + logger.info(f"FID XY: {fid_res_xy}") + logger.info(f"FID YZ: {fid_res_yz}") + logger.info(f"FID ZX: {fid_res_zx}") + fid_avg = (fid_res_xy + fid_res_yz + fid_res_zx) / 3.0 + logger.info(f"FID Avg: {fid_avg}") + + dist.destroy_process_group() + + +if __name__ == "__main__": + fire.Fire(main)