diff --git a/connectomics/data/io/io.py b/connectomics/data/io/io.py index f42286fa..f2a9e51c 100644 --- a/connectomics/data/io/io.py +++ b/connectomics/data/io/io.py @@ -281,7 +281,31 @@ def read_volume( if image_suffix in ["h5", "hdf5"]: data = read_hdf5(filename, dataset) elif "tif" in image_suffix: - data = imageio.volread(filename).squeeze() + # Check if filename contains glob patterns + if "*" in filename or "?" in filename: + # Expand glob pattern to get matching files + file_list = sorted(glob.glob(filename)) + if len(file_list) == 0: + raise FileNotFoundError(f"No TIFF files found matching pattern: {filename}") + + # Read each file and stack along depth dimension + volumes = [] + for filepath in file_list: + vol = imageio.volread(filepath).squeeze() + # imageio.volread can return multi-page TIFF as (D, H, W) or single page as (H, W) + # Ensure all volumes have at least 3D (D, H, W) + if vol.ndim == 2: + vol = vol[np.newaxis, ...] # Add depth dimension: (H, W) -> (1, H, W) + # vol.ndim == 3 means (D, H, W), which is what we want + volumes.append(vol) + + # Stack all volumes along depth dimension + # Each volume is (D_i, H, W), result will be (sum(D_i), H, W) + data = np.concatenate(volumes, axis=0) # Stack along depth (first dimension) + else: + # Single file or multi-page TIFF + data = imageio.volread(filename).squeeze() + if data.ndim == 4: # Convert (D, C, H, W) to (C, D, H, W) order data = data.transpose(1, 0, 2, 3) diff --git a/connectomics/lightning/lit_model.py b/connectomics/lightning/lit_model.py index 269fee8c..6fab85f4 100644 --- a/connectomics/lightning/lit_model.py +++ b/connectomics/lightning/lit_model.py @@ -521,38 +521,66 @@ def _apply_postprocessing(self, data: np.ndarray) -> np.ndarray: from connectomics.decoding.postprocess import apply_binary_postprocessing # Process each sample in batch - batch_size = data.shape[0] if data.ndim >= 4 else 1 - - # Handle different input shapes - if data.ndim == 2: # (H, W) -> (1, 1, H, W) - data = data[np.newaxis, np.newaxis, ...] - elif data.ndim == 3: # (D, H, W) or (C, H, W) -> assume (D, H, W) and add batch dim - data = data[np.newaxis, ...] # (1, D, H, W) + # Handle both 4D (B, C, H, W) for 2D data and 5D (B, C, D, H, W) for 3D data + print(f" DEBUG: _apply_postprocessing - input data shape: {data.shape}, ndim: {data.ndim}") + if data.ndim == 4: + # 2D data: (B, C, H, W) + batch_size = data.shape[0] + print(f" DEBUG: _apply_postprocessing - detected 2D data, batch_size: {batch_size}") + elif data.ndim == 5: + # 3D data: (B, C, D, H, W) + batch_size = data.shape[0] + print(f" DEBUG: _apply_postprocessing - detected 3D data, batch_size: {batch_size}") + elif data.ndim == 3: + # Single 3D volume: (C, D, H, W) or (D, H, W) - add batch dimension + batch_size = 1 + data = data[np.newaxis, ...] # (1, C, D, H, W) or (1, D, H, W) + print(f" DEBUG: _apply_postprocessing - single 3D sample, added batch dimension") + elif data.ndim == 2: + # Single 2D image: (H, W) - add batch and channel dimensions + batch_size = 1 + data = data[np.newaxis, np.newaxis, ...] # (1, 1, H, W) + print(f" DEBUG: _apply_postprocessing - single 2D sample, added batch and channel dimensions") + else: + batch_size = 1 - # Ensure we have at least 4D: (B, ...) where ... can be (D, H, W) or (C, D, H, W) + # Ensure we have at least 4D: (B, ...) where ... can be (C, H, W) for 2D or (C, D, H, W) for 3D results = [] for batch_idx in range(batch_size): - sample = data[batch_idx] # (C, D, H, W) or (D, H, W) - - # Extract foreground probability (handle both 3D and 4D) - if sample.ndim == 4: # (C, D, H, W) - foreground_prob = sample[0] # Use first channel - else: # (D, H, W) - already single channel + sample = data[batch_idx] # (C, H, W) for 2D or (C, D, H, W) for 3D + print(f" DEBUG: _apply_postprocessing - processing batch_idx {batch_idx}, sample shape: {sample.shape}") + + # Extract foreground probability (always use first channel if channel dimension exists) + if sample.ndim == 4: # (C, D, H, W) - 3D with channel + foreground_prob = sample[0] # Use first channel -> (D, H, W) + elif sample.ndim == 3: + # Could be (C, H, W) for 2D or (D, H, W) for 3D without channel + # If first dim is small (<=4), assume it's channel (2D), otherwise depth (3D) + if sample.shape[0] <= 4: + foreground_prob = sample[0] # (C, H, W) -> use first channel -> (H, W) + else: + foreground_prob = sample # (D, H, W) - already single channel + elif sample.ndim == 2: # (H, W) - 2D single channel + foreground_prob = sample + else: foreground_prob = sample # Apply binary postprocessing processed = apply_binary_postprocessing(foreground_prob, binary_config) - # Expand dims to maintain shape consistency - if sample.ndim == 4: + # Expand dims to maintain shape consistency with original sample structure + if sample.ndim == 4: # (C, D, H, W) -> processed is (D, H, W) processed = processed[np.newaxis, ...] # (1, D, H, W) - else: - processed = processed # Keep (D, H, W) + elif sample.ndim == 3 and sample.shape[0] <= 4: # (C, H, W) -> processed is (H, W) + processed = processed[np.newaxis, ...] # (1, H, W) + # else: processed is already correct shape (D, H, W) or (H, W) results.append(processed) # Stack results back into batch + print(f" DEBUG: _apply_postprocessing - stacking {len(results)} results, shapes: {[r.shape for r in results]}") data = np.stack(results, axis=0) + print(f" DEBUG: _apply_postprocessing - after stacking, data shape: {data.shape}") # Step 2: Apply scaling if configured (support both new and legacy names) intensity_scale = getattr(postprocessing, 'intensity_scale', None) @@ -651,13 +679,29 @@ def _apply_decode_mode(self, data: np.ndarray) -> np.ndarray: } # Process each sample in batch - batch_size = data.shape[0] if data.ndim == 5 else 1 + # Handle both 4D (B, C, H, W) for 2D data and 5D (B, C, D, H, W) for 3D data + print(f" DEBUG: _apply_decode_mode - input data shape: {data.shape}, ndim: {data.ndim}") if data.ndim == 4: - data = data[np.newaxis, ...] # Add batch dimension + # 2D data: (B, C, H, W) + batch_size = data.shape[0] + print(f" DEBUG: _apply_decode_mode - detected 2D data, batch_size: {batch_size}") + elif data.ndim == 5: + # 3D data: (B, C, D, H, W) + batch_size = data.shape[0] + print(f" DEBUG: _apply_decode_mode - detected 3D data, batch_size: {batch_size}") + else: + # Single sample: add batch dimension + batch_size = 1 + print(f" DEBUG: _apply_decode_mode - single sample, adding batch dimension") + if data.ndim == 3: + data = data[np.newaxis, ...] # (C, H, W) -> (1, C, H, W) + elif data.ndim == 2: + data = data[np.newaxis, np.newaxis, ...] # (H, W) -> (1, 1, H, W) results = [] for batch_idx in range(batch_size): - sample = data[batch_idx] # (C, D, H, W) + sample = data[batch_idx] # (C, H, W) for 2D or (C, D, H, W) for 3D + print(f" DEBUG: _apply_decode_mode - processing batch_idx {batch_idx}, sample shape: {sample.shape}") # Apply each decode mode sequentially for decode_cfg in decode_modes: @@ -718,8 +762,10 @@ def _apply_decode_mode(self, data: np.ndarray) -> np.ndarray: results.append(sample) # Stack results back into batch - decoded = np.stack(results, axis=0) if len(results) > 1 else results[0] - + # Always preserve batch dimension, even for batch_size=1 + print(f" DEBUG: _apply_decode_mode - stacking {len(results)} results, shapes: {[r.shape for r in results]}") + decoded = np.stack(results, axis=0) + print(f" DEBUG: _apply_decode_mode - final decoded shape: {decoded.shape}") return decoded def _resolve_output_filenames(self, batch: Dict[str, Any]) -> List[str]: @@ -742,26 +788,59 @@ def _resolve_output_filenames(self, batch: Dict[str, Any]) -> List[str]: meta = batch.get('image_meta_dict') filenames: List[Optional[str]] = [] + + print(f" DEBUG: _resolve_output_filenames - meta type: {type(meta)}, batch_size: {batch_size}") - if isinstance(meta, dict): + # Handle different metadata structures + if isinstance(meta, list): + # Multiple metadata dicts (one per sample in batch) + print(f" DEBUG: _resolve_output_filenames - meta is list with {len(meta)} items") + for idx, meta_item in enumerate(meta): + if isinstance(meta_item, dict): + filename = meta_item.get('filename_or_obj') + if filename is not None: + filenames.append(filename) + else: + print(f" DEBUG: _resolve_output_filenames - meta_item[{idx}] has no filename_or_obj") + else: + print(f" DEBUG: _resolve_output_filenames - meta_item[{idx}] is not a dict: {type(meta_item)}") + # Update batch_size from metadata if we have a list + batch_size = max(batch_size, len(filenames)) + print(f" DEBUG: _resolve_output_filenames - extracted {len(filenames)} filenames from list") + elif isinstance(meta, dict): + # Single metadata dict + print(f" DEBUG: _resolve_output_filenames - meta is dict") meta_filenames = meta.get('filename_or_obj') if isinstance(meta_filenames, (list, tuple)): - filenames = list(meta_filenames) + filenames = [f for f in meta_filenames if f is not None] elif meta_filenames is not None: filenames = [meta_filenames] - elif isinstance(meta, list): - for meta_item in meta: - if isinstance(meta_item, dict): - filenames.append(meta_item.get('filename_or_obj')) - # Update batch_size from metadata if we have a list - batch_size = max(batch_size, len(filenames)) + # Update batch_size from metadata + if len(filenames) > 0: + batch_size = max(batch_size, len(filenames)) + print(f" DEBUG: _resolve_output_filenames - extracted {len(filenames)} filenames from dict") + else: + # Handle case where meta might be None or other types + # This can happen if metadata wasn't preserved through transforms + # We'll use fallback filenames based on batch_size + print(f" DEBUG: _resolve_output_filenames - meta is None or unexpected type: {type(meta)}") + pass resolved_names: List[str] = [] for idx in range(batch_size): if idx < len(filenames) and filenames[idx]: resolved_names.append(Path(str(filenames[idx])).stem) else: + # Generate fallback filename - this shouldn't happen if metadata is preserved correctly resolved_names.append(f"volume_{self.global_step}_{idx}") + + print(f" DEBUG: _resolve_output_filenames - returning {len(resolved_names)} resolved names: {resolved_names[:3]}...") + + # Always return exactly batch_size filenames + if len(resolved_names) < batch_size: + print(f" WARNING: _resolve_output_filenames - Only {len(resolved_names)} filenames but batch_size is {batch_size}, padding with fallback names") + while len(resolved_names) < batch_size: + resolved_names.append(f"volume_{self.global_step}_{len(resolved_names)}") return resolved_names @@ -799,8 +878,42 @@ def _write_outputs( if hasattr(self.cfg.inference, 'postprocessing'): output_transpose = getattr(self.cfg.inference.postprocessing, 'output_transpose', []) + # Determine actual batch size from predictions + # Handle both batched (B, ...) and unbatched (...) predictions + print(f" DEBUG: _write_outputs - predictions shape: {predictions.shape}, ndim: {predictions.ndim}, filenames count: {len(filenames)}") + + if predictions.ndim >= 4: + # Has batch dimension: (B, C, D, H, W) or (B, C, H, W) or (B, D, H, W) + actual_batch_size = predictions.shape[0] + elif predictions.ndim == 3: + # Could be batched 2D data (B, H, W) or single 3D volume (D, H, W) + # Check if first dimension matches number of filenames -> it's batched 2D data + if len(filenames) > 0 and predictions.shape[0] == len(filenames): + # Batched 2D data: (B, H, W) where B matches number of filenames + actual_batch_size = predictions.shape[0] + print(f" DEBUG: _write_outputs - detected batched 2D data (B, H, W) with batch_size={actual_batch_size}") + else: + # Single 3D volume: (D, H, W) - treat as batch_size=1 + actual_batch_size = 1 + predictions = predictions[np.newaxis, ...] # Add batch dimension + print(f" DEBUG: _write_outputs - detected single 3D volume, added batch dimension") + elif predictions.ndim == 2: + # Single 2D image: (H, W) - treat as batch_size=1 + actual_batch_size = 1 + predictions = predictions[np.newaxis, ...] # Add batch dimension + else: + # Unexpected shape, default to batch_size=1 + actual_batch_size = 1 + if predictions.ndim < 2: + predictions = predictions[np.newaxis, ...] # Add batch dimension + + # Ensure we don't exceed the actual batch size + batch_size = min(actual_batch_size, len(filenames)) + print(f" DEBUG: _write_outputs - actual_batch_size: {actual_batch_size}, batch_size: {batch_size}, will save {batch_size} predictions") + # Save predictions - for idx, name in enumerate(filenames): + for idx in range(batch_size): + name = filenames[idx] prediction = predictions[idx] # Apply output transpose if specified @@ -1303,6 +1416,10 @@ def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> STEP_OUTP labels = batch.get('label') mask = batch.get('mask') # Get test mask if available + # Get batch size from images + actual_batch_size = images.shape[0] + print(f" DEBUG: test_step - images shape: {images.shape}, batch_size: {actual_batch_size}") + # Always use TTA (handles no-transform case) + sliding window # TTA preprocessing (activation, masking) is applied regardless of flip augmentation # Note: TTA always returns a simple tensor, not a dict (deep supervision not supported in test mode) @@ -1310,9 +1427,17 @@ def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> STEP_OUTP # Convert predictions to numpy for saving/decoding predictions_np = predictions.detach().cpu().float().numpy() + print(f" DEBUG: test_step - predictions_np shape: {predictions_np.shape}") # Resolve filenames once for all saving operations filenames = self._resolve_output_filenames(batch) + print(f" DEBUG: test_step - filenames count: {len(filenames)}, filenames: {filenames[:5]}...") + + # Ensure filenames list matches actual batch size + # If we don't have enough filenames, generate default ones + while len(filenames) < actual_batch_size: + filenames.append(f"volume_{self.global_step}_{len(filenames)}") + print(f" DEBUG: test_step - after padding, filenames count: {len(filenames)}") # Check if we should save intermediate predictions (before decoding) save_intermediate = False @@ -1324,10 +1449,13 @@ def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> STEP_OUTP self._write_outputs(predictions_np, filenames, suffix="tta_prediction") # Apply decode mode (instance segmentation decoding) + print(f" DEBUG: test_step - before decode, predictions_np shape: {predictions_np.shape}") decoded_predictions = self._apply_decode_mode(predictions_np) + print(f" DEBUG: test_step - after decode, decoded_predictions shape: {decoded_predictions.shape}") # Apply postprocessing (scaling and dtype conversion) if configured postprocessed_predictions = self._apply_postprocessing(decoded_predictions) + print(f" DEBUG: test_step - after postprocess, postprocessed_predictions shape: {postprocessed_predictions.shape}") # Save final decoded and postprocessed predictions self._write_outputs(postprocessed_predictions, filenames, suffix="prediction") diff --git a/scripts/visualize_neuroglancer.py b/scripts/visualize_neuroglancer.py index c0977eb7..d854ef92 100755 --- a/scripts/visualize_neuroglancer.py +++ b/scripts/visualize_neuroglancer.py @@ -218,7 +218,7 @@ def parse_args(): def load_volumes_from_config( - config_path: str, mode: str = "train" + config_path: str, mode: str = "train", prediction_base_name: Optional[str] = None ) -> Dict[str, Tuple[np.ndarray, str, Optional[Tuple], None]]: """ Load volumes from a config file. @@ -343,8 +343,42 @@ def load_volumes_from_config( and hasattr(cfg.inference.data, "test_image") and cfg.inference.data.test_image ): - print(f"Loading test image: {cfg.inference.data.test_image}") - data = read_volume(cfg.inference.data.test_image) + test_image_path = cfg.inference.data.test_image + print(f"Loading test image: {test_image_path}") + + # If prediction_base_name is provided and test_image_path contains glob pattern, + # find the specific matching file + if prediction_base_name and ("*" in test_image_path or "?" in test_image_path): + print(f" 🔍 Auto-matching specific test_image for prediction base name: {prediction_base_name}") + import glob + test_path_obj = Path(test_image_path) + test_dir = test_path_obj.parent + print(f" Search directory: {test_dir}") + + # Search for files with matching base name (any extension) + extensions_to_try = ['.tif', '.tiff', '.h5', '.hdf5', '.png', '.jpg', '.jpeg'] + matched_file = None + for ext in extensions_to_try: + potential_file = test_dir / f"{prediction_base_name}{ext}" + if potential_file.exists(): + matched_file = str(potential_file) + print(f" ✓ Found matching test_image: {matched_file}") + break + + # If not found, search for any file with matching base name + if not matched_file: + matching_files = sorted(test_dir.glob(f"{prediction_base_name}.*")) + if matching_files: + matched_file = str(matching_files[0]) + print(f" ✓ Found matching test_image: {matched_file}") + + if matched_file: + test_image_path = matched_file + else: + print(f" ⚠ No matching test_image found for base name: {prediction_base_name}") + print(f" Falling back to loading all files from glob pattern") + + data = read_volume(test_image_path) # Convert 2D to 3D if needed if data.ndim == 2: data = data[None, :, :] # (H, W) -> (1, H, W) @@ -629,10 +663,32 @@ def main(): volumes = {} cfg = None + # Extract prediction base name from --volumes if provided (for auto-matching test_image) + prediction_base_name = None + if args.volumes: + for spec in args.volumes: + parts = spec.split(":") + # Find the path (could be in different positions depending on format) + if len(parts) >= 3: + path = parts[2] # Format: name:type:path + elif len(parts) == 2: + path = parts[1] # Format: name:path + else: + path = parts[0] # Format: path + + # Check if this is a prediction file + path_obj = Path(path) + if "_prediction" in path_obj.stem: + # Extract base name by removing "_prediction" and extension + prediction_base_name = path_obj.stem.replace("_prediction", "") + print(f"📋 Detected prediction file: {path_obj.name}") + print(f" Extracted base name for auto-matching: {prediction_base_name}") + break + # Load from config first (if provided) if args.config: cfg = load_config(args.config) # Store config for interactive access - volumes.update(load_volumes_from_config(args.config, args.mode)) + volumes.update(load_volumes_from_config(args.config, args.mode, prediction_base_name=prediction_base_name)) # Add image/label (if provided and not empty strings) if args.image and args.image.strip(): diff --git a/tutorials/monai2d_worm.yaml b/tutorials/monai2d_worm.yaml index 1420c4b5..55c6eb2a 100644 --- a/tutorials/monai2d_worm.yaml +++ b/tutorials/monai2d_worm.yaml @@ -178,7 +178,7 @@ inference: data: do_2d: true # Enable 2D data processing for inference test_image: /projects/weilab/shenb/PyTC/datasets/Dataset001_worm_image96/imagesTs/*.tif - test_label: /projects/weilab/shenb/PyTC/datasets/Dataset001_worm_image96/imagesTs/*.tif + test_label: test_resolution: [5, 5] output_path: outputs/monai2d_worm/results/