Conversation
Co-authored-by: Adam Kukučka <adam.kukucka4@gmail.com>
Add bitmask graph operations, MCTS/MCGS node structures, segmentation utilities, model predictor, and supporting ImageNet class names.
Add MCTS, MCGS, lookahead, and potential search algorithms for hyperpixel construction.
Improve error handling and remove legacy parameters in algorithm functions.
📝 WalkthroughWalkthroughAdds the CIAO explainability package: image loading/preprocessing, segmentation, model prediction utilities, bitmask graph and node structures, four hyperpixel construction algorithms (MCTS, MCGS, potential, lookahead), a CIAOExplainer orchestrator, and ImageNet class labels. Changes
Sequence Diagram(s)sequenceDiagram
actor User
participant Explainer as CIAOExplainer
participant Preproc as Preprocessor
participant Seg as Segmentation
participant Predictor as ModelPredictor
participant Alg as Algorithm
participant Compiler as ResultCompiler
User->>Explainer: explain(image_path, predictor, method...)
Explainer->>Preproc: load_and_preprocess_image(image_path)
Preproc-->>Explainer: input_batch, original_image, input_tensor
Explainer->>Predictor: get_predicted_class(input_batch)
Predictor-->>Explainer: target_class_idx
Explainer->>Seg: create_segmentation(input_tensor,...)
Seg-->>Explainer: segments, adjacency
Explainer->>Predictor: create_surrogate_dataset(...)
Predictor-->>Explainer: per-segment scores
Explainer->>Alg: build_all_hyperpixels_*(predictor,input_batch,segments,adj,...)
Alg->>Alg: iterative search (select → expand → evaluate → backup)
Alg->>Predictor: evaluate_masks(batch_masks)
Predictor-->>Alg: rewards
Alg-->>Explainer: hyperpixels list
Explainer->>Compiler: select_top_hyperpixels(hyperpixels)
Compiler-->>Explainer: ranked hyperpixels
Explainer-->>User: explanation dict
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @dhalmazna, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request establishes the foundational pipeline for the CIAO (Contextual Importance Assessment via Obfuscation) explainable AI method. It introduces the main Highlights
Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces the core pipeline for the CIAO explainer, including data loading, preprocessing, and the main CIAOExplainer class that connects various search algorithms. The code is well-structured with clear separation of concerns. The implementations of the search algorithms (MCTS, MCGS, Lookahead, Potential) are sophisticated and optimized for performance using techniques like bitmasks and batching. My review focuses on a performance issue related to redundant computations in the explainer pipeline, removal of dead code, and suggestions to improve type safety across the new algorithm modules by using TypedDict for complex dictionary return types.
| segments, graph = create_segmentation( | ||
| input_tensor, | ||
| segmentation_type=segmentation_type, | ||
| segment_size=segment_size, | ||
| neighborhood=neighborhood, | ||
| ) | ||
| print( | ||
| f"Built {segmentation_type} spatial graph with {graph.number_of_nodes()} " | ||
| f"segments and {graph.number_of_edges()} edges" | ||
| ) | ||
|
|
||
| # Calculate scores from surrogate dataset | ||
| X, y = create_surrogate_dataset( | ||
| predictor, | ||
| input_batch, | ||
| segments, | ||
| graph, | ||
| target_class_idx, | ||
| batch_size=batch_size, | ||
| ) | ||
| scores = calculate_scores_from_surrogate(X, y) | ||
|
|
||
| # Create adjacency structures (needed by all methods) | ||
| segments_list, adj_list = create_hexagonal_grid_with_list( | ||
| input_tensor, segment_size | ||
| ) | ||
| adj_masks = build_adjacency_bitmasks(adj_list) |
There was a problem hiding this comment.
There is a redundant and potentially expensive computation happening here. The pixel-to-hexagonal-segment mapping is performed twice: once in create_segmentation (line 141) to get the networkx graph, and again in create_hexagonal_grid_with_list (line 164) to get the fast adjacency structures.
This should be refactored to perform the segmentation only once. A single, efficient segmentation function could generate the segmentation map, the networkx graph (if still needed), and the bitmask-based adjacency structures simultaneously.
| def _find_first_step(base_mask: int, target_mask: int) -> int: | ||
| """Find the first segment added from base_mask to reach target_mask.""" | ||
| diff = target_mask & ~base_mask | ||
| # Return the first bit in the difference | ||
| for seg_id in iter_bits(diff): | ||
| return seg_id | ||
| raise ValueError("Could not find first step between base and target mask.") |
| used_mask = result["mask"] | used_mask # type: ignore[operator] | ||
|
|
||
| # Format for compatibility | ||
| hyperpixel = { | ||
| "segments": result["segments"], | ||
| "sign": result["sign"], | ||
| "size": result["size"], | ||
| "hyperpixel_score": result["score"], | ||
| "stats": result.get("stats", {}), # Include lookahead statistics | ||
| } | ||
| hyperpixels.append(hyperpixel) | ||
|
|
||
| logger.info( | ||
| f"Built hyperpixel with {len(result['segments'])} segments, score={result['score']:.4f}" # type: ignore[arg-type] | ||
| ) | ||
|
|
||
| # Sort by absolute score | ||
| hyperpixels.sort(key=lambda x: abs(x["hyperpixel_score"]), reverse=True) # type: ignore[arg-type] |
There was a problem hiding this comment.
The use of dict[str, object] as a return type for build_hyperpixel_greedy_lookahead leads to type: ignore comments here and reduces type safety. This can be improved by using a TypedDict to define the structure of the result dictionary.
This would provide better static analysis, autocompletion, and make the code easier to maintain.
Here is an example of how you could define the types:
from typing import TypedDict
class LookaheadStats(TypedDict):
method: str
lookahead_distance: int
num_steps: int
total_evaluations: int
prefix_evaluations: int
class HyperpixelLookaheadResult(TypedDict):
mask: int
segments: list[int]
sign: int
score: float
size: int
stats: LookaheadStatsThen, build_hyperpixel_greedy_lookahead can be annotated to return HyperpixelLookaheadResult, which would eliminate the need for these type: ignore comments.
|
|
||
| # Add RAVE-specific data | ||
| if mode == "rave": | ||
| result["stats"]["rave_k"] = rave_k # type: ignore[index] |
There was a problem hiding this comment.
The return type of build_hyperpixel_mcgs is dict[str, Any], which forces the use of type: ignore here. To improve type safety and code clarity, consider using a TypedDict to define the structure of the returned dictionary.
This is especially useful because the stats dictionary contains a conditional key (rave_k), which can be nicely handled with typing.NotRequired.
Example:
from typing import TypedDict, NotRequired
class MCGSStats(TypedDict):
# ... other stats keys
rave_k: NotRequired[float]
class MCGSHyperpixelResult(TypedDict):
mask: int
score: float
used_mask: int
root: MCGSNode
stats: MCGSStats|
|
||
| # Add RAVE-specific data | ||
| if mode == "rave": | ||
| result["stats"]["rave_k"] = rave_k # type: ignore[index] |
There was a problem hiding this comment.
The return type dict[str, Any] from build_hyperpixel_mcts necessitates a type: ignore here. Using a TypedDict would make the return structure explicit, improving type safety and maintainability.
Given that the stats dictionary has a conditional key (rave_k), TypedDict with NotRequired is a good fit.
Example:
from typing import TypedDict, NotRequired
class MCTSStats(TypedDict):
# ... other stats keys
rave_k: NotRequired[int]
class MCTSHyperpixelResult(TypedDict):
mask: int
score: float
used_mask: int
root: MCTSNode
stats: MCTSStats| "size": len(hyperpixel_segments), # type: ignore[arg-type] | ||
| "hyperpixel_score": result["score"], | ||
| "stats": result.get( | ||
| "stats", {} | ||
| ), # Include potential method statistics | ||
| } | ||
| ) | ||
| processed_segments.update(hyperpixel_segments) # type: ignore[arg-type] |
There was a problem hiding this comment.
The return type dict[str, object] from build_hyperpixel_using_potential leads to type: ignore comments here. This reduces type safety and code clarity.
Consider defining a TypedDict for the return value to make the code more robust, self-documenting, and to allow the type checker to catch potential errors.
There was a problem hiding this comment.
Actionable comments posted: 8
🧹 Nitpick comments (11)
ciao/algorithm/__init__.py (1)
3-3: Consider adding package-level exports for consistency with the codebase pattern.The
__all__list is empty, requiring users to import directly from submodules (e.g.,from ciao.algorithm.mcts import GlobalStats). However, other packages in the codebase (ciao, data, explainer, utils) consistently export their public APIs at the package level. The algorithm submodules expose public symbols that could be conveniently exported:
GlobalStats(from mcts)- Build functions:
build_hyperpixel_greedy_lookahead,build_all_hyperpixels_greedy_lookahead,build_hyperpixel_mcgs,build_all_hyperpixels_mcgs,build_hyperpixel_mcts,build_all_hyperpixels_mcts,build_hyperpixel_using_potential,build_all_hyperpixels_potentialAdding corresponding imports and populating
__all__would align with the established codebase pattern and improve API ergonomics.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@ciao/algorithm/__init__.py` at line 3, The package __all__ is empty, forcing consumers to import submodules directly; update ciao.algorithm.__init__ to re-export the public symbols by importing GlobalStats from mcts and the build functions (build_hyperpixel_greedy_lookahead, build_all_hyperpixels_greedy_lookahead, build_hyperpixel_mcgs, build_all_hyperpixels_mcgs, build_hyperpixel_mcts, build_all_hyperpixels_mcts, build_hyperpixel_using_potential, build_all_hyperpixels_potential) from their respective submodules and populate __all__ with those symbol names so users can import them from ciao.algorithm directly. Ensure imports use local (relative) imports and only expose the listed public names in __all__..gitignore (1)
173-174: Pre-commit config should typically be version-controlled.Ignoring
.pre-commit-config.yamlmeans contributors won't share the same pre-commit hooks, defeating the purpose of automated quality checks. This file is usually committed so the entire team uses consistent linting/formatting rules.If this is intentional (e.g., individual developer preference), consider documenting the rationale. Otherwise, remove this entry and commit the config file.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In @.gitignore around lines 173 - 174, The .gitignore currently excludes .pre-commit-config.yaml which prevents sharing pre-commit hooks; remove the ".pre-commit-config.yaml" entry from .gitignore and commit the .pre-commit-config.yaml file so all contributors use the same hooks (or if exclusion is intentional, add a short justification comment in the repo README explaining the reason). Ensure you update .gitignore to delete that line and add/commit the .pre-commit-config.yaml to version control.ciao/structures/nodes.py (2)
37-41: Consider using TypedDict for edge statistics structure.Using
dict[str, float]with string keys ("N","W","Q","max_reward") is error-prone and lacks IDE autocompletion. ATypedDictordataclasswould provide better type safety:♻️ Suggested improvement
from typing import TypedDict class EdgeStats(TypedDict): N: float W: float Q: float max_reward: floatThen use
edge_stats: dict[int, EdgeStats].🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@ciao/structures/nodes.py` around lines 37 - 41, Replace the unstructured dict[str, float] used for self.edge_stats with a TypedDict to enforce keys and improve autocompletion: add from typing import TypedDict and define an EdgeStats TypedDict (keys N, W, Q, max_reward) then change the annotation to self.edge_stats: dict[int, EdgeStats] and update any initialization/assignments that create entries for self.edge_stats (and similarly consider a TypedDict for self.rave_stats if it uses the same keys) so all created edge-stat entries conform to EdgeStats shape.
25-26: Consider UCB implications of returning 0.0 for unvisited nodes.Returning
0.0for unvisited nodes may cause issues in UCB selection, where unvisited nodes should typically be prioritized (often via optimistic initialization likefloat('inf')or usingprior_score). If this is intentional and handled elsewhere in the selection logic, this is fine.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@ciao/structures/nodes.py` around lines 25 - 26, The mean_value() method currently returns 0.0 for unvisited nodes which can demote them in UCB selection; change mean_value (referenced by mean_value, value_sum, visits) to return an optimistic value for unvisited nodes instead of 0.0 — e.g., return self.prior_score if the node has a prior_score attribute set, otherwise return float('inf') (or another chosen optimistic initialization) so that selection logic prioritizes unvisited nodes.pyproject.toml (1)
9-34: Consider upper bounds for major dependencies to prevent breaking changes.Using only lower bounds (
>=) provides flexibility but risks compatibility issues when dependencies release breaking changes. For a library, consider adding upper bounds for critical dependencies:- "torch>=2.0.0", - "torchvision>=0.15.0", + "torch>=2.0.0,<3.0.0", + "torchvision>=0.15.0,<1.0.0",This is especially relevant for ML frameworks where APIs can change significantly between major versions.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@pyproject.toml` around lines 9 - 34, The dependencies list in pyproject.toml uses only lower bounds which risks breakage on major releases; update the dependencies array to include conservative upper bounds for the critical packages (e.g., torch, torchvision, hydra-core, mlflow, omegaconf, numpy, scikit-image, pillow, matplotlib, plotly) by changing entries in the dependencies list to ranges like >=X.Y.Z,<NextMajor.0.0 (or a chosen compatible upper bound) so each package string in the dependencies array explicitly caps the major version to prevent breaking changes while keeping patch/minor upgrades allowed.ciao/utils/calculations.py (1)
16-20: Handle edge case where model has no parameters.
next(model.parameters())will raiseStopIterationif the model has no parameters. While uncommon, this could occur with certain model wrappers.🛡️ Suggested defensive fix
def __init__(self, model: torch.nn.Module, class_names: list[str]) -> None: self.model = model self.class_names = class_names - self.device = next(model.parameters()).device + try: + self.device = next(model.parameters()).device + except StopIteration: + self.device = torch.device("cpu") self.replacement_image = None🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@ciao/utils/calculations.py` around lines 16 - 20, The __init__ currently sets self.device = next(model.parameters()).device which raises StopIteration for parameterless models; change it to retrieve the first parameter safely (e.g., p = next(model.parameters(), None)) and if p is None set self.device to a safe default like torch.device('cpu') (or use getattr(model, 'device', torch.device('cpu'))), leaving self.model, self.class_names and self.replacement_image initialization unchanged; update the __init__ function where self.device is assigned to use this defensive check.ciao/utils/segmentation.py (1)
166-191: Code duplication withcreate_hexagonal_grid.
create_hexagonal_grid_with_list(lines 166-191) andcreate_hexagonal_grid(lines 234-272) share nearly identical pixel-to-hex mapping logic. Consider extracting the common hex assignment loop into a helper function to reduce duplication.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@ciao/utils/segmentation.py` around lines 166 - 191, Extract the duplicated pixel-to-hex mapping loop from create_hexagonal_grid_with_list and create_hexagonal_grid into a single helper (e.g., map_pixels_to_hex or assign_pixels_to_hex) that accepts input dimensions (height, width), hex_radius and a pixel_to_hex callable and returns the segments array plus the hex_to_id dict and next_id (or just segments and hex_to_id/next_id). Replace the loops in both functions to call this helper and then call build_fast_adjacency_list(hex_to_id, next_id) as before; ensure return types of create_hexagonal_grid_with_list and create_hexagonal_grid remain unchanged.ciao/explainer/ciao_explainer.py (1)
264-271: Make the placeholder explicit.
visualizeis part of the public class API but currently silently no-ops. RaisingNotImplementedErroris clearer for callers.Proposed fix
def visualize( self, image: torch.Tensor, explanation: dict[str, Any], save_path: str | Path | None = None, interactive: bool = True, ) -> Any: - pass + raise NotImplementedError("CIAOExplainer.visualize is not implemented yet.")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@ciao/explainer/ciao_explainer.py` around lines 264 - 271, The visualize method currently silently returns None; update the public method visualize(self, image: torch.Tensor, explanation: dict[str, Any], save_path: str | Path | None = None, interactive: bool = True) to explicitly raise NotImplementedError with a short message (e.g., "visualize must be implemented by subclasses") so callers see a clear error when it's not implemented; modify the method body in ciao_explainer.py (the visualize definition) to raise that exception instead of pass.ciao/algorithm/lookahead_bitset.py (3)
217-223: Unused function_find_first_step.This function is defined but never called within the module. The BFS approach in
_generate_lookahead_candidatesdirectly tracks thefirst_stepduring traversal, making this function redundant.Consider removing this dead code or, if kept for future use, document why it's retained.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@ciao/algorithm/lookahead_bitset.py` around lines 217 - 223, The helper function _find_first_step is dead code—it's defined but never used (the BFS in _generate_lookahead_candidates already computes first_step); remove the entire _find_first_step function to clean up the module, or if you intend to keep it, add a clear comment above _find_first_step explaining its intended future usage and why it remains despite not being referenced (but prefer deleting it to avoid confusion).
29-29: Unused parameterscores.The
scoresparameter is declared but never referenced in the function body. Consider removing it or documenting why it's needed (e.g., for API consistency with other builders).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@ciao/algorithm/lookahead_bitset.py` at line 29, The parameter scores in the function signature (in lookahead_bitset.py) is declared but never used; either remove the scores parameter from the function/method signature to avoid dead API surface, or retain it for API compatibility and reference it (for example by storing it as an attribute or validating its shape) inside the same function (lookahead builder function/method that declares scores) and add a brief comment explaining it's kept for interface consistency; update any callers accordingly to match the changed signature if you remove it.
17-20: Module docstring placement is non-standard.The module docstring is placed after imports as a string literal rather than at the top of the file. This won't be recognized as the module's
__doc__attribute.📝 Suggested fix
Move the docstring to the top of the file:
+"""Greedy lookahead hyperpixel building with bitmask operations. + +Rolling horizon strategy: Look ahead multiple steps but only commit one step at a time. +""" + import logging import numpy as np import torch ... - -"""Greedy lookahead hyperpixel building with bitmask operations. - -Rolling horizon strategy: Look ahead multiple steps but only commit one step at a time. -"""🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@ciao/algorithm/lookahead_bitset.py` around lines 17 - 20, The triple-quoted string currently appearing after the imports should be moved to the very top of the file so it becomes the module docstring (so __doc__ is populated); locate the string literal in lookahead_bitset.py and cut/paste it above all import statements (before any code or comments) so the module-level docstring sits at file start.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@ciao/data/loader.py`:
- Around line 24-31: The code uses Path(config.data.batch_path) without
validating that it exists and is a directory, which can silently yield zero
images; update the batch path handling in loader.py so that after constructing
directory = Path(config.data.batch_path) you check directory.exists() and
directory.is_dir(), and if either check fails raise a clear error (or log and
raise) indicating the provided config.data.batch_path is invalid; keep using
directory.glob("**/*{ext}") for iteration only after the validation passes.
In `@ciao/data/preprocessing.py`:
- Around line 34-37: Open images using a context manager to ensure the file
handle is closed: replace direct Image.open(image_path).convert("RGB") with a
with Image.open(image_path) as img: block, call img.convert("RGB") inside the
block, assign the result to a variable, then call original_image = image.copy()
(or copy the converted image) before exiting the with so that input_tensor =
preprocess(image).to(device) and subsequent input_batch creation use an image
whose data is fully loaded and the file handle has been released.
In `@ciao/explainer/ciao_explainer.py`:
- Around line 141-167: The segmentation used to compute surrogate scores
(create_segmentation -> segments, graph -> create_surrogate_dataset ->
calculate_scores_from_surrogate) is not guaranteed to match the adjacency
structures because the code always calls create_hexagonal_grid_with_list and
build_adjacency_bitmasks; change this so the adjacency list/masks are derived
from the same segmentation_type and segments computed by create_segmentation (or
add a factory/function that builds adj_list/masks for the given
segmentation_type), e.g., replace the unconditional
create_hexagonal_grid_with_list call with logic that: given segmentation_type
and the existing segments/graph, constructs segments_list and adj_list
consistent with create_segmentation, then pass that adj_list into
build_adjacency_bitmasks so search uses the same segmentation as scoring (update
references in this block to use create_segmentation,
create_hexagonal_grid_with_list only when segmentation_type == "hexagonal" or
use a unified builder function).
- Around line 246-258: The code prints and accesses
class_names[target_class_idx] without validating the index which can raise
IndexError; update the print and the "class_name" entry in the result dict in
ciao_explainer.py to check that target_class_idx is an int within 0 <=
target_class_idx < len(class_names) before indexing (or use a safe lookup that
falls back to f"Class {target_class_idx}" for out-of-range/invalid values), and
ensure negative indices are treated as invalid if that is desired; apply this
check wherever class_names[target_class_idx] is used (the initial print and the
result construction) so you never directly index with an unvalidated
target_class_idx.
In `@ciao/structures/bitmask_graph.py`:
- Around line 18-34: The iter_bits function currently assumes a non-negative
integer; add input validation at the start of iter_bits to ensure mask is an int
and mask >= 0 and raise a ValueError (or TypeError for non-int) if not, so
negative masks cannot enter the while loop; keep the rest of the implementation
(temp, low_bit, node_id, temp ^= low_bit) untouched after the validation.
In `@ciao/utils/calculations.py`:
- Around line 159-166: The plot_image_mean_color function is showing
ImageNet-normalized values directly (which can be negative), so denormalize the
tensor returned by calculate_image_mean_color before plotting: take the
normalized_mean, apply inverse ImageNet normalization (multiply by std
[0.229,0.224,0.225] and add mean [0.485,0.456,0.406] per channel), clamp the
result to [0,1], convert to CPU numpy and permute to HxWxC, then call plt.imshow
and plt.show; reference the functions plot_image_mean_color and
calculate_image_mean_color to locate and update the visualization path.
In `@ciao/utils/segmentation.py`:
- Around line 275-308: Add validation for the neighborhood parameter inside
create_segmentation: when segmentation_type == "square" check that neighborhood
is one of the allowed values (4 or 8) and raise a ValueError with a clear
message if not; perform this validation before calling create_square_grid
(referencing create_segmentation and create_square_grid) so invalid neighborhood
values (e.g., 3, 6) no longer silently produce incorrect adjacency graphs.
In `@README.md`:
- Line 83: The fenced code block opened with ``` in README.md should include a
language specifier to satisfy markdownlint MD040; update the opening fence to
include a language token (e.g., change the opening ``` to ```text or another
appropriate language) so the tree excerpt block is fenced as ```text and the
rest of the block remains unchanged.
---
Nitpick comments:
In @.gitignore:
- Around line 173-174: The .gitignore currently excludes .pre-commit-config.yaml
which prevents sharing pre-commit hooks; remove the ".pre-commit-config.yaml"
entry from .gitignore and commit the .pre-commit-config.yaml file so all
contributors use the same hooks (or if exclusion is intentional, add a short
justification comment in the repo README explaining the reason). Ensure you
update .gitignore to delete that line and add/commit the .pre-commit-config.yaml
to version control.
In `@ciao/algorithm/__init__.py`:
- Line 3: The package __all__ is empty, forcing consumers to import submodules
directly; update ciao.algorithm.__init__ to re-export the public symbols by
importing GlobalStats from mcts and the build functions
(build_hyperpixel_greedy_lookahead, build_all_hyperpixels_greedy_lookahead,
build_hyperpixel_mcgs, build_all_hyperpixels_mcgs, build_hyperpixel_mcts,
build_all_hyperpixels_mcts, build_hyperpixel_using_potential,
build_all_hyperpixels_potential) from their respective submodules and populate
__all__ with those symbol names so users can import them from ciao.algorithm
directly. Ensure imports use local (relative) imports and only expose the listed
public names in __all__.
In `@ciao/algorithm/lookahead_bitset.py`:
- Around line 217-223: The helper function _find_first_step is dead code—it's
defined but never used (the BFS in _generate_lookahead_candidates already
computes first_step); remove the entire _find_first_step function to clean up
the module, or if you intend to keep it, add a clear comment above
_find_first_step explaining its intended future usage and why it remains despite
not being referenced (but prefer deleting it to avoid confusion).
- Line 29: The parameter scores in the function signature (in
lookahead_bitset.py) is declared but never used; either remove the scores
parameter from the function/method signature to avoid dead API surface, or
retain it for API compatibility and reference it (for example by storing it as
an attribute or validating its shape) inside the same function (lookahead
builder function/method that declares scores) and add a brief comment explaining
it's kept for interface consistency; update any callers accordingly to match the
changed signature if you remove it.
- Around line 17-20: The triple-quoted string currently appearing after the
imports should be moved to the very top of the file so it becomes the module
docstring (so __doc__ is populated); locate the string literal in
lookahead_bitset.py and cut/paste it above all import statements (before any
code or comments) so the module-level docstring sits at file start.
In `@ciao/explainer/ciao_explainer.py`:
- Around line 264-271: The visualize method currently silently returns None;
update the public method visualize(self, image: torch.Tensor, explanation:
dict[str, Any], save_path: str | Path | None = None, interactive: bool = True)
to explicitly raise NotImplementedError with a short message (e.g., "visualize
must be implemented by subclasses") so callers see a clear error when it's not
implemented; modify the method body in ciao_explainer.py (the visualize
definition) to raise that exception instead of pass.
In `@ciao/structures/nodes.py`:
- Around line 37-41: Replace the unstructured dict[str, float] used for
self.edge_stats with a TypedDict to enforce keys and improve autocompletion: add
from typing import TypedDict and define an EdgeStats TypedDict (keys N, W, Q,
max_reward) then change the annotation to self.edge_stats: dict[int, EdgeStats]
and update any initialization/assignments that create entries for
self.edge_stats (and similarly consider a TypedDict for self.rave_stats if it
uses the same keys) so all created edge-stat entries conform to EdgeStats shape.
- Around line 25-26: The mean_value() method currently returns 0.0 for unvisited
nodes which can demote them in UCB selection; change mean_value (referenced by
mean_value, value_sum, visits) to return an optimistic value for unvisited nodes
instead of 0.0 — e.g., return self.prior_score if the node has a prior_score
attribute set, otherwise return float('inf') (or another chosen optimistic
initialization) so that selection logic prioritizes unvisited nodes.
In `@ciao/utils/calculations.py`:
- Around line 16-20: The __init__ currently sets self.device =
next(model.parameters()).device which raises StopIteration for parameterless
models; change it to retrieve the first parameter safely (e.g., p =
next(model.parameters(), None)) and if p is None set self.device to a safe
default like torch.device('cpu') (or use getattr(model, 'device',
torch.device('cpu'))), leaving self.model, self.class_names and
self.replacement_image initialization unchanged; update the __init__ function
where self.device is assigned to use this defensive check.
In `@ciao/utils/segmentation.py`:
- Around line 166-191: Extract the duplicated pixel-to-hex mapping loop from
create_hexagonal_grid_with_list and create_hexagonal_grid into a single helper
(e.g., map_pixels_to_hex or assign_pixels_to_hex) that accepts input dimensions
(height, width), hex_radius and a pixel_to_hex callable and returns the segments
array plus the hex_to_id dict and next_id (or just segments and
hex_to_id/next_id). Replace the loops in both functions to call this helper and
then call build_fast_adjacency_list(hex_to_id, next_id) as before; ensure return
types of create_hexagonal_grid_with_list and create_hexagonal_grid remain
unchanged.
In `@pyproject.toml`:
- Around line 9-34: The dependencies list in pyproject.toml uses only lower
bounds which risks breakage on major releases; update the dependencies array to
include conservative upper bounds for the critical packages (e.g., torch,
torchvision, hydra-core, mlflow, omegaconf, numpy, scikit-image, pillow,
matplotlib, plotly) by changing entries in the dependencies list to ranges like
>=X.Y.Z,<NextMajor.0.0 (or a chosen compatible upper bound) so each package
string in the dependencies array explicitly caps the major version to prevent
breaking changes while keeping patch/minor upgrades allowed.
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (1)
uv.lockis excluded by!**/*.lock
📒 Files selected for processing (22)
.gitignoreREADME.mdciao/__init__.pyciao/algorithm/__init__.pyciao/algorithm/lookahead_bitset.pyciao/algorithm/mcgs.pyciao/algorithm/mcts.pyciao/algorithm/potential.pyciao/data/__init__.pyciao/data/loader.pyciao/data/preprocessing.pyciao/explainer/__init__.pyciao/explainer/ciao_explainer.pyciao/imagenet_classes.txtciao/structures/__init__.pyciao/structures/bitmask_graph.pyciao/structures/nodes.pyciao/utils/__init__.pyciao/utils/calculations.pyciao/utils/search_utils.pyciao/utils/segmentation.pypyproject.toml
| image = Image.open(image_path).convert("RGB") | ||
| original_image = image.copy() | ||
| input_tensor = preprocess(image).to(device) # (3, 224, 224) - on correct device | ||
| input_batch = input_tensor.unsqueeze(0) # (1, 3, 224, 224) - already on device |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
cat -n ciao/data/preprocessing.py | head -50Repository: RationAI/ciao
Length of output: 1529
🏁 Script executed:
rg -nP --type=py 'Image\.open\(' -B2 -A2Repository: RationAI/ciao
Length of output: 454
🌐 Web query:
PIL Pillow Image.open context manager file handle leak
💡 Result:
PIL.Image.open() is lazy: it usually keeps the underlying file open until the image is actually loaded (e.g., via im.load() or an operation that forces decoding). (pillow.readthedocs.io)
Does with Image.open(...) leak file handles?
- If you pass a filename/path:
with Image.open(path) as im:will deterministically close resources when thewithblock exits, and (for single-frame images) the file may also be closed earlier whenload()happens. (pillow.readthedocs.io) - If you pass an already-open file object: Pillow’s docs explicitly say you are responsible for closing that file object (Pillow can’t safely manage a handle it didn’t open). So you should wrap the file handle in its own context manager. (pillow.readthedocs.io)
Common gotchas that look like “leaks”
- Multi-frame formats (animated GIF, multipage TIFF): Pillow may intentionally keep the file open to allow
seek()to other frames, so you must keep it open (or close the image when done). (pillow.readthedocs.io) - Relying on garbage collection to close files is discouraged; Pillow deprecated implicit closing in
Image.__del__—use a context manager orim.close()for deterministic behavior. (pillow.readthedocs.io)
Patterns that don’t leak
Best practice (filename/path):
from PIL import Image
with Image.open("img.png") as im:
im.load() # optional, but forces decode while file is open
# use im freely hereIf you already have a file object:
from PIL import Image
with open("img.png", "rb") as f:
with Image.open(f) as im:
im.load()(outer with closes f; inner with closes Pillow’s image object) (pillow.readthedocs.io)
Citations:
- 1: https://pillow.readthedocs.io/en/stable/reference/open_files.html?utm_source=openai
- 2: https://pillow.readthedocs.io/en/stable/reference/open_files.html?utm_source=openai
- 3: https://pillow.readthedocs.io/en/stable/reference/open_files.html?utm_source=openai
- 4: https://pillow.readthedocs.io/en/stable/reference/open_files.html?utm_source=openai
- 5: https://pillow.readthedocs.io/en/stable/releasenotes/6.1.0.html?utm_source=openai
- 6: https://pillow.readthedocs.io/en/stable/reference/open_files.html?utm_source=openai
- 7: https://pillow.readthedocs.io/en/stable/reference/open_files.html?utm_source=openai
Use a context manager for image loading to prevent file handle exhaustion in batch processing.
Line 34 opens an image without explicit resource management. Although convert("RGB") forces image data loading, the file handle is not deterministically closed until garbage collection. In batch or repeated processing, this can exhaust available file descriptors. Use a context manager to guarantee cleanup:
with Image.open(image_path) as img:
image = img.convert("RGB")
original_image = image.copy()The convert() operation completes within the with block while the file is open, and the resulting image object is independent with data fully loaded, allowing safe handle closure.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@ciao/data/preprocessing.py` around lines 34 - 37, Open images using a context
manager to ensure the file handle is closed: replace direct
Image.open(image_path).convert("RGB") with a with Image.open(image_path) as img:
block, call img.convert("RGB") inside the block, assign the result to a
variable, then call original_image = image.copy() (or copy the converted image)
before exiting the with so that input_tensor = preprocess(image).to(device) and
subsequent input_batch creation use an image whose data is fully loaded and the
file handle has been released.
ciao/explainer/ciao_explainer.py
Outdated
| print(f"Class name: {class_names[target_class_idx]}") | ||
|
|
||
| # Return results | ||
| result = { | ||
| "input_batch": input_batch, | ||
| "target_class_idx": target_class_idx, | ||
| "segments": segments, | ||
| "scores": scores, | ||
| "hyperpixels": hyperpixels, | ||
| "top_hyperpixels": top_hyperpixels, | ||
| "class_name": class_names[target_class_idx] | ||
| if target_class_idx < len(class_names) | ||
| else f"Class {target_class_idx}", |
There was a problem hiding this comment.
Validate class index before direct access.
Line 246 indexes class_names[target_class_idx] directly. If target_class_idx is out of range (or negative), this can crash before fallback logic is applied.
Proposed fix
# 2. Get target class
if target_class_idx is None:
target_class_idx = get_predicted_class(predictor, input_batch)
print(f"Auto-selected target class: {target_class_idx}")
+ if target_class_idx < 0 or target_class_idx >= len(class_names):
+ raise ValueError(
+ f"target_class_idx out of range: {target_class_idx} "
+ f"(valid: 0..{len(class_names)-1})"
+ )
...
- print(f"Class name: {class_names[target_class_idx]}")
+ print(f"Class name: {class_names[target_class_idx]}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@ciao/explainer/ciao_explainer.py` around lines 246 - 258, The code prints and
accesses class_names[target_class_idx] without validating the index which can
raise IndexError; update the print and the "class_name" entry in the result dict
in ciao_explainer.py to check that target_class_idx is an int within 0 <=
target_class_idx < len(class_names) before indexing (or use a safe lookup that
falls back to f"Class {target_class_idx}" for out-of-range/invalid values), and
ensure negative indices are treated as invalid if that is desired; apply this
check wherever class_names[target_class_idx] is used (the initial print and the
result construction) so you never directly index with an unvalidated
target_class_idx.
| def iter_bits(mask: int) -> Iterator[int]: | ||
| """Iterate over set bits in a mask using low-bit isolation. | ||
|
|
||
| Yields node IDs in arbitrary order (depends on bit positions). | ||
| Performance: O(k) where k is the number of set bits. | ||
|
|
||
| Example: | ||
| mask = 0b10110 # bits 1, 2, 4 are set | ||
| list(iter_bits(mask)) # [1, 2, 4] | ||
| """ | ||
| temp = mask | ||
| while temp: | ||
| low_bit = temp & -temp | ||
| node_id = low_bit.bit_length() - 1 | ||
| yield node_id | ||
| temp ^= low_bit | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Test iter_bits with negative inputs to verify the claim
python3 << 'EOF'
def iter_bits(mask: int):
"""Test version to trace behavior"""
temp = mask
iteration = 0
max_iterations = 100 # safety limit
while temp and iteration < max_iterations:
iteration += 1
low_bit = temp & -temp
node_id = low_bit.bit_length() - 1
print(f"Iteration {iteration}: temp={temp}, low_bit={low_bit}, node_id={node_id}")
yield node_id
temp ^= low_bit
# Test with positive
print("Testing with positive mask (5):")
list(iter_bits(5))
# Test with negative
print("\nTesting with negative mask (-1):")
result = list(iter_bits(-1))
print(f"Result: {result}")
print(f"Completed successfully (no infinite loop)")
print("\nTesting with negative mask (-5):")
result = list(iter_bits(-5))
print(f"Result: {result}")
EOFRepository: RationAI/ciao
Length of output: 15868
🏁 Script executed:
#!/bin/bash
# Check the complete file and all usages
cat -n ciao/structures/bitmask_graph.py | head -50
rg -nP '\biter_bits\(' --type=pyRepository: RationAI/ciao
Length of output: 3248
🏁 Script executed:
#!/bin/bash
# Check how masks are created and if they could be negative
rg -B5 -A2 'iter_bits\(' ciao/utils/search_utils.py ciao/structures/bitmask_graph.py ciao/algorithm/mcts.py | head -60Repository: RationAI/ciao
Length of output: 3426
🏁 Script executed:
#!/bin/bash
# Check mask_to_ids alternative and how masks are generated
rg -nP 'mask\s*=' ciao/structures/bitmask_graph.py ciao/algorithm/ | head -20Repository: RationAI/ciao
Length of output: 1656
Add input validation to prevent infinite loops with negative masks.
The iter_bits function will loop indefinitely if passed a negative mask, as Python's two's complement representation causes temp & -temp and temp ^= low_bit to perpetually extract bits. While the current codebase doesn't produce negative masks, this function should validate its inputs to prevent misuse.
Proposed fix
def iter_bits(mask: int) -> Iterator[int]:
"""Iterate over set bits in a mask using low-bit isolation.
@@
"""
+ if mask < 0:
+ raise ValueError(f"mask must be non-negative, got {mask}")
temp = mask
while temp:
low_bit = temp & -temp
node_id = low_bit.bit_length() - 1
yield node_id
temp ^= low_bit🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@ciao/structures/bitmask_graph.py` around lines 18 - 34, The iter_bits
function currently assumes a non-negative integer; add input validation at the
start of iter_bits to ensure mask is an int and mask >= 0 and raise a ValueError
(or TypeError for non-int) if not, so negative masks cannot enter the while
loop; keep the rest of the implementation (temp, low_bit, node_id, temp ^=
low_bit) untouched after the validation.
| def plot_image_mean_color(self, input_tensor: torch.Tensor) -> None: | ||
| """Display the mean color of the image. | ||
|
|
||
| Note: The visualization shows the normalized tensor (ImageNet normalization). | ||
| """ | ||
| normalized_mean = self.calculate_image_mean_color(input_tensor).unsqueeze(0) | ||
| plt.imshow(normalized_mean[0].permute(1, 2, 0)) | ||
| plt.show() |
There was a problem hiding this comment.
Plot displays ImageNet-normalized values incorrectly.
plt.imshow expects values in [0,1] or [0,255] range, but ImageNet-normalized tensors can have negative values. This will produce incorrect visualizations.
🐛 Suggested fix: denormalize before display
def plot_image_mean_color(self, input_tensor: torch.Tensor) -> None:
"""Display the mean color of the image.
Note: The visualization shows the normalized tensor (ImageNet normalization).
"""
normalized_mean = self.calculate_image_mean_color(input_tensor).unsqueeze(0)
- plt.imshow(normalized_mean[0].permute(1, 2, 0))
+ # Denormalize for correct display
+ denormalized = (normalized_mean * self.imagenet_std) + self.imagenet_mean
+ # Clamp to valid range and convert to displayable format
+ display_img = denormalized[0].permute(1, 2, 0).cpu().clamp(0, 1)
+ plt.imshow(display_img)
plt.show()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@ciao/utils/calculations.py` around lines 159 - 166, The plot_image_mean_color
function is showing ImageNet-normalized values directly (which can be negative),
so denormalize the tensor returned by calculate_image_mean_color before
plotting: take the normalized_mean, apply inverse ImageNet normalization
(multiply by std [0.229,0.224,0.225] and add mean [0.485,0.456,0.406] per
channel), clamp the result to [0,1], convert to CPU numpy and permute to HxWxC,
then call plt.imshow and plt.show; reference the functions plot_image_mean_color
and calculate_image_mean_color to locate and update the visualization path.
| def create_segmentation( | ||
| input_tensor: torch.Tensor, | ||
| segmentation_type: str = "hexagonal", | ||
| segment_size: int = 14, | ||
| neighborhood: int = 8, | ||
| ) -> tuple[np.ndarray, nx.Graph]: | ||
| """Create image segmentation with specified type. | ||
|
|
||
| Args: | ||
| input_tensor: Input image tensor [C, H, W] | ||
| segmentation_type: "square" or "hexagonal" | ||
| segment_size: Size parameter (square_size or hex_radius) | ||
| neighborhood: Neighborhood connectivity for squares (4, or 8) | ||
|
|
||
| Returns: | ||
| segments: 2D array mapping pixels to segment IDs | ||
| adjacency_graph: NetworkX graph of segment relationships | ||
| """ | ||
| if segment_size <= 0: | ||
| raise ValueError( | ||
| f"segment_size must be positive, got {segment_size}. " | ||
| "Non-positive values cause division by zero or invalid range operations." | ||
| ) | ||
|
|
||
| if segmentation_type == "square": | ||
| return create_square_grid( | ||
| input_tensor, square_size=segment_size, neighborhood=neighborhood | ||
| ) | ||
| elif segmentation_type == "hexagonal": | ||
| return create_hexagonal_grid(input_tensor, hex_radius=segment_size) | ||
| else: | ||
| raise ValueError( | ||
| f"Unknown segmentation_type: {segmentation_type}. Use 'square' or 'hexagonal'." | ||
| ) |
There was a problem hiding this comment.
Add validation for neighborhood parameter.
The function validates segment_size but not neighborhood. Invalid values (e.g., 3, 6) would silently produce unexpected adjacency graphs.
🛡️ Suggested fix
if segment_size <= 0:
raise ValueError(
f"segment_size must be positive, got {segment_size}. "
"Non-positive values cause division by zero or invalid range operations."
)
+
+ if segmentation_type == "square" and neighborhood not in (4, 8):
+ raise ValueError(
+ f"neighborhood must be 4 or 8 for square segmentation, got {neighborhood}."
+ )
if segmentation_type == "square":🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@ciao/utils/segmentation.py` around lines 275 - 308, Add validation for the
neighborhood parameter inside create_segmentation: when segmentation_type ==
"square" check that neighborhood is one of the allowed values (4 or 8) and raise
a ValueError with a clear message if not; perform this validation before calling
create_square_grid (referencing create_segmentation and create_square_grid) so
invalid neighborhood values (e.g., 3, 6) no longer silently produce incorrect
adjacency graphs.
|
|
||
| ## Project Structure | ||
|
|
||
| ``` |
There was a problem hiding this comment.
Add a language to the fenced code block.
Line 83 opens a fenced block without a language, which triggers markdownlint MD040.
Proposed fix
-```
+```text
ciao/
├── ciao/ # Main package
...
└── pyproject.toml # Project metadata and dependencies</details>
<!-- suggestion_start -->
<details>
<summary>📝 Committable suggestion</summary>
> ‼️ **IMPORTANT**
> Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
```suggestion
🧰 Tools
🪛 markdownlint-cli2 (0.21.0)
[warning] 83-83: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@README.md` at line 83, The fenced code block opened with ``` in README.md
should include a language specifier to satisfy markdownlint MD040; update the
opening fence to include a language token (e.g., change the opening ``` to
```text or another appropriate language) so the tree excerpt block is fenced as
```text and the rest of the block remains unchanged.
There was a problem hiding this comment.
Pull request overview
This pull request introduces the core explainer pipeline for the CIAO (Contextual Importance Assessment via Obfuscation) package, building on top of the foundational data structures and search algorithms from PR #3. The PR adds data loading and preprocessing capabilities, the main CIAOExplainer class that serves as the primary user-facing API, and integrates all search algorithms (MCTS, MCGS, Lookahead, Potential) into a cohesive pipeline.
Changes:
- Added
data/module for image loading and preprocessing with ImageNet normalization - Implemented
explainer/module with the mainCIAOExplainerclass that orchestrates the explanation pipeline - Connected search algorithms and data structures through utility functions in
utils/module
Reviewed changes
Copilot reviewed 21 out of 23 changed files in this pull request and generated 19 comments.
Show a summary per file
| File | Description |
|---|---|
| pyproject.toml | Updated dependencies (torch, torchvision, matplotlib, etc.) and project metadata |
| ciao/data/loader.py | Image path loading utilities with support for single images and batch directories |
| ciao/data/preprocessing.py | Image preprocessing pipeline with ImageNet normalization |
| ciao/explainer/ciao_explainer.py | Main CIAOExplainer class providing the primary API for generating explanations |
| ciao/utils/calculations.py | ModelPredictor class and scoring functions for segment importance |
| ciao/utils/segmentation.py | Hexagonal and square grid segmentation with adjacency graph construction |
| ciao/utils/search_utils.py | Shared utilities for MCTS and MCGS algorithms |
| ciao/structures/bitmask_graph.py | Bitmask operations for efficient segment manipulation |
| ciao/structures/nodes.py | Node classes for MCTS and MCGS tree/graph structures |
| ciao/algorithm/*.py | Search algorithm implementations (MCTS, MCGS, Lookahead, Potential) |
| ciao/init.py | Package initialization exporting CIAOExplainer |
| ciao/imagenet_classes.txt | ImageNet class labels for model predictions |
| README.md | Comprehensive documentation of CIAO methodology and usage |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| # ruff: noqa: RUF002 | ||
| def sampling_phase( | ||
| S_mask: int, # noqa: N803 |
There was a problem hiding this comment.
The variable name S_mask uses non-PEP8 naming convention (uppercase S). While there's a noqa comment for RUF002, consider renaming to current_mask or structure_mask to follow Python naming conventions consistently throughout the codebase.
| def explain( | ||
| self, | ||
| image_path: str | Path, | ||
| predictor: ModelPredictor, | ||
| method: str = "lookahead", | ||
| target_class_idx: int | None = None, | ||
| segment_size: int = 4, | ||
| segmentation_type: str = "hexagonal", | ||
| max_hyperpixels: int = 10, | ||
| desired_length: int = 30, | ||
| batch_size: int = 64, | ||
| neighborhood: int = 8, | ||
| replacement: str = "mean_color", | ||
| replacement_kwargs: dict[str, Any] | None = None, | ||
| method_params: dict[str, Any] | None = None, | ||
| ) -> dict[str, Any]: | ||
| """Generate CIAO explanation for an image. | ||
|
|
||
| Args: | ||
| image_path: Path to image or PIL Image object | ||
| predictor: ModelPredictor instance | ||
| method: Hyperpixel construction method. Options: | ||
| - "potential": Potential field guided search | ||
| - "mcts": Monte Carlo Tree Search | ||
| - "mc_rave": MC-RAVE (MCTS with RAVE heuristic) | ||
| - "lookahead": Optimized greedy lookahead with bitsets (default) | ||
| - "mcgs": Monte Carlo Graph Search | ||
| - "mcgs_rave": MCGS with RAVE | ||
| target_class_idx: Target class to explain (None = auto-select) | ||
| segment_size: Size of segments in pixels | ||
| segmentation_type: Type of segmentation ("hexagonal") | ||
| max_hyperpixels: Maximum number of hyperpixels to build | ||
| desired_length: Target number of segments per hyperpixel (default=30) | ||
| batch_size: Batch size for model evaluation | ||
| neighborhood: Adjacency neighborhood (6 or 8 for hexagonal) | ||
| replacement: Masking strategy for model evaluation | ||
| replacement_kwargs: Additional kwargs for replacement method | ||
| method_params: Dictionary of method-specific parameters: | ||
|
|
||
| For "potential": | ||
| - num_simulations: int (default=50) - Number of simulations | ||
|
|
||
| For "mcts": | ||
| - num_iterations: int (default=100) - MCTS iterations | ||
| - exploration_c: float (default=1.4) - UCT exploration constant | ||
| - mcts_batch_size: int (default=64) - Batch size for MCTS | ||
|
|
||
| For "mc_rave": | ||
| - num_iterations: int (default=100) | ||
| - exploration_c: float (default=1.4) | ||
| - mcts_batch_size: int (default=64) | ||
| - rave_k: float (default=1000) | ||
|
|
||
| For "lookahead": | ||
| - lookahead_distance: int (default=2) | ||
|
|
||
| For "mcgs": | ||
| - num_iterations: int (default=100) | ||
| - mcts_batch_size: int (default=64) | ||
| - exploration_c: float (default=1.4) | ||
|
|
||
| For "mcgs_rave": | ||
| - num_iterations: int (default=100) | ||
| - mcts_batch_size: int (default=64) | ||
| - exploration_c: float (default=1.4) | ||
| - rave_k: float (default=1000) | ||
|
|
||
| Returns: | ||
| Dictionary containing: | ||
| - input_batch: Preprocessed input tensor | ||
| - target_class_idx: Class being explained | ||
| - segments: Segmentation map | ||
| - scores: Individual segment scores | ||
| - hyperpixels: List of all hyperpixels found | ||
| - top_hyperpixels: Top-k hyperpixels by score | ||
| - class_name: Human-readable class name | ||
| - performance_mode: Method identifier | ||
| """ |
There was a problem hiding this comment.
The explain method lacks input validation for several parameters. For example: max_hyperpixels, desired_length, batch_size should be positive integers; segment_size is validated in create_segmentation but could be checked earlier; target_class_idx could be validated against the number of classes; method is validated later but an early check would provide better error messages. Consider adding parameter validation at the start of the method.
| def get_image_loader(config: Any) -> Iterator[Path]: | ||
| """Create image loader based on configuration. | ||
|
|
||
| Args: | ||
| config: Hydra config object | ||
|
|
||
| Returns: | ||
| Iterator of Path objects | ||
|
|
||
| Raises: | ||
| ValueError: If neither image_path nor batch_path is specified | ||
| """ | ||
| if config.data.get("image_path"): | ||
| # Single image mode | ||
| yield Path(config.data.image_path) | ||
|
|
||
| elif config.data.get("batch_path"): | ||
| # Directory mode | ||
| directory = Path(config.data.batch_path) | ||
| extensions = config.data.get( | ||
| "image_extensions", [".jpg", ".jpeg", ".png", ".bmp", ".webp"] | ||
| ) | ||
| for ext in extensions: | ||
| yield from directory.glob(f"**/*{ext}") | ||
|
|
||
| else: | ||
| raise ValueError("Must specify either image_path or batch_path in config") |
There was a problem hiding this comment.
The get_image_loader function accepts a config parameter typed as Any, which defeats the purpose of type checking. Consider defining a proper configuration protocol or TypedDict to specify the expected structure of the config object (e.g., with data.image_path, data.batch_path, data.image_extensions attributes).
ciao/__init__.py
Outdated
| from ciao.explainer.ciao_explainer import CIAOExplainer | ||
|
|
||
|
|
||
| __all__ = ["CIAOExplainer"] |
There was a problem hiding this comment.
The __init__.py file only exports CIAOExplainer, but based on the README and typical usage patterns, users may also need access to ModelPredictor for initialization. Consider exporting ModelPredictor from the main package for easier imports (e.g., from ciao import CIAOExplainer, ModelPredictor).
| from ciao.explainer.ciao_explainer import CIAOExplainer | |
| __all__ = ["CIAOExplainer"] | |
| from ciao.explainer.ciao_explainer import CIAOExplainer, ModelPredictor | |
| __all__ = ["CIAOExplainer", "ModelPredictor"] |
|
|
||
|
|
||
| def redistribute_history( | ||
| H_winner: list[tuple[int, float]], # noqa: N803 |
There was a problem hiding this comment.
The variable name H_winner uses non-PEP8 naming convention (uppercase H). While there's a noqa comment for RUF002, the actual issue is the naming convention. Consider renaming to winner_history to follow Python naming conventions consistently.
| def load_and_preprocess_image( | ||
| image_path: str | Path, device: torch.device | None = None | ||
| ) -> tuple[torch.Tensor, Image.Image, torch.Tensor]: | ||
| """Load and preprocess an image for the model. | ||
|
|
||
| Args: | ||
| image_path: Path to image file | ||
| device: Device to place tensor on (defaults to cuda if available, else cpu) | ||
|
|
||
| Returns: | ||
| Tuple of (input_batch, original_image, input_tensor) | ||
| """ | ||
| if device is None: | ||
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
|
||
| image = Image.open(image_path).convert("RGB") | ||
| original_image = image.copy() | ||
| input_tensor = preprocess(image).to(device) # (3, 224, 224) - on correct device | ||
| input_batch = input_tensor.unsqueeze(0) # (1, 3, 224, 224) - already on device | ||
|
|
||
| return input_batch, original_image, input_tensor |
There was a problem hiding this comment.
The load_and_preprocess_image function opens image files using Image.open() without validation or error handling. This could be vulnerable to malicious image files that could cause denial-of-service or other issues. Consider adding try-except blocks to handle corrupted or malicious image files, and potentially validate file types before opening.
ciao/explainer/ciao_explainer.py
Outdated
| save_path: str | Path | None = None, | ||
| interactive: bool = True, | ||
| ) -> Any: | ||
| pass |
There was a problem hiding this comment.
The visualize method is a placeholder that only contains pass. This creates an incomplete public API - the method signature suggests it should return Any, but it actually returns None. Either implement the method or clearly mark it as not implemented by raising NotImplementedError with a descriptive message.
| pass | |
| raise NotImplementedError( | |
| "CIAOExplainer.visualize is not implemented yet. " | |
| "Use custom visualization based on the 'explanation' output, " | |
| "or extend this method to generate and optionally save visualizations." | |
| ) |
| def _find_first_step(base_mask: int, target_mask: int) -> int: | ||
| """Find the first segment added from base_mask to reach target_mask.""" | ||
| diff = target_mask & ~base_mask | ||
| # Return the first bit in the difference | ||
| for seg_id in iter_bits(diff): | ||
| return seg_id | ||
| raise ValueError("Could not find first step between base and target mask.") | ||
|
|
||
|
|
There was a problem hiding this comment.
The function _find_first_step (lines 217-223) is defined but never used in the codebase. This appears to be dead code that should either be removed or integrated if it serves a purpose.
| def _find_first_step(base_mask: int, target_mask: int) -> int: | |
| """Find the first segment added from base_mask to reach target_mask.""" | |
| diff = target_mask & ~base_mask | |
| # Return the first bit in the difference | |
| for seg_id in iter_bits(diff): | |
| return seg_id | |
| raise ValueError("Could not find first step between base and target mask.") |
ciao/explainer/ciao_explainer.py
Outdated
| print(f"Auto-selected target class: {target_class_idx}") | ||
|
|
||
| # 3. Create segmentation | ||
| segments, graph = create_segmentation( | ||
| input_tensor, | ||
| segmentation_type=segmentation_type, | ||
| segment_size=segment_size, | ||
| neighborhood=neighborhood, | ||
| ) | ||
| print( | ||
| f"Built {segmentation_type} spatial graph with {graph.number_of_nodes()} " | ||
| f"segments and {graph.number_of_edges()} edges" | ||
| ) | ||
|
|
||
| # Calculate scores from surrogate dataset | ||
| X, y = create_surrogate_dataset( | ||
| predictor, | ||
| input_batch, | ||
| segments, | ||
| graph, | ||
| target_class_idx, | ||
| batch_size=batch_size, | ||
| ) | ||
| scores = calculate_scores_from_surrogate(X, y) | ||
|
|
||
| # Create adjacency structures (needed by all methods) | ||
| segments_list, adj_list = create_hexagonal_grid_with_list( | ||
| input_tensor, segment_size | ||
| ) | ||
| adj_masks = build_adjacency_bitmasks(adj_list) | ||
|
|
||
| # Build hyperpixels based on method | ||
| if method == "potential": | ||
| hyperpixels = build_all_hyperpixels_potential( | ||
| predictor=predictor, | ||
| input_batch=input_batch, | ||
| segments=segments_list, | ||
| adj_masks=adj_masks, | ||
| target_class_idx=target_class_idx, | ||
| scores=scores, | ||
| max_hyperpixels=max_hyperpixels, | ||
| desired_length=desired_length, | ||
| num_simulations=method_params.get("num_simulations", 50), | ||
| batch_size=batch_size, | ||
| ) | ||
|
|
||
| elif method in ["mcts", "mc_rave"]: | ||
| mode_str = "rave" if method == "mc_rave" else "standard" | ||
|
|
||
| hyperpixels = build_all_hyperpixels_mcts( | ||
| predictor=predictor, | ||
| input_batch=input_batch, | ||
| segments=segments_list, | ||
| adj_masks=adj_masks, | ||
| target_class_idx=target_class_idx, | ||
| scores=scores, | ||
| max_hyperpixels=max_hyperpixels, | ||
| desired_length=desired_length, | ||
| num_iterations=method_params.get("num_iterations", 100), | ||
| mode=mode_str, | ||
| batch_size=method_params.get("mcts_batch_size", 64), | ||
| exploration_c=method_params.get("exploration_c", 1.4), | ||
| rave_k=method_params.get("rave_k", 1000), | ||
| ) | ||
|
|
||
| elif method == "lookahead": | ||
| hyperpixels = build_all_hyperpixels_greedy_lookahead( | ||
| predictor=predictor, | ||
| input_batch=input_batch, | ||
| segments=segments_list, | ||
| adj_masks=adj_masks, | ||
| target_class_idx=target_class_idx, | ||
| scores=scores, | ||
| max_hyperpixels=max_hyperpixels, | ||
| desired_length=desired_length, | ||
| lookahead_distance=method_params.get("lookahead_distance", 2), | ||
| batch_size=batch_size, | ||
| ) | ||
|
|
||
| elif method in ["mcgs", "mcgs_rave"]: | ||
| # Determine mode based on method | ||
| mode_str = "rave" if method == "mcgs_rave" else "standard" | ||
|
|
||
| hyperpixels = build_all_hyperpixels_mcgs( | ||
| predictor=predictor, | ||
| input_batch=input_batch, | ||
| segments=segments_list, | ||
| adj_masks=adj_masks, | ||
| target_class_idx=target_class_idx, | ||
| scores=scores, | ||
| max_hyperpixels=max_hyperpixels, | ||
| desired_length=desired_length, | ||
| num_iterations=method_params.get("num_iterations", 100), | ||
| mode=mode_str, | ||
| batch_size=method_params.get("mcts_batch_size", 64), | ||
| exploration_c=method_params.get("exploration_c", 1.4), | ||
| rave_k=method_params.get("rave_k", 1000.0), | ||
| ) | ||
|
|
||
| else: | ||
| raise ValueError( | ||
| f"Unknown method: {method}. Valid options: potential, mcts, " | ||
| f"mc_rave, lookahead, mcgs, mcgs_rave" | ||
| ) | ||
|
|
||
| # Select top hyperpixels | ||
| top_hyperpixels = select_top_hyperpixels(hyperpixels, max_hyperpixels) | ||
|
|
||
| print(f"Class name: {class_names[target_class_idx]}") |
There was a problem hiding this comment.
The method uses print() statements for user feedback (lines 138, 147-150, 246) rather than using a logger. This is inconsistent with the rest of the codebase which uses the logging module. Consider using logger.info() instead for consistency and better control over output levels.
| """Test if a node is present in the mask.""" | ||
| return (mask >> node) & 1 == 1 | ||
|
|
||
|
|
||
| def add_node(mask: int, node: int) -> int: | ||
| """Add a node to the mask.""" | ||
| return mask | (1 << node) | ||
|
|
||
|
|
||
| def remove_node(mask: int, node: int) -> int: | ||
| """Remove a node from the mask.""" |
There was a problem hiding this comment.
The bitmask operations (e.g., add_node, has_node) use bit shifting with node IDs directly without bounds checking. If a node ID is negative or exceeds the maximum representable bit position (typically around 63 for standard integers, though Python supports arbitrary precision), this could lead to unexpected behavior or excessive memory usage. Consider adding validation or documentation about the maximum supported node count.
| """Test if a node is present in the mask.""" | |
| return (mask >> node) & 1 == 1 | |
| def add_node(mask: int, node: int) -> int: | |
| """Add a node to the mask.""" | |
| return mask | (1 << node) | |
| def remove_node(mask: int, node: int) -> int: | |
| """Remove a node from the mask.""" | |
| """Test if a node is present in the mask.""" | |
| if node < 0: | |
| # Negative indices are never considered valid node IDs in the bitmask. | |
| return False | |
| return (mask >> node) & 1 == 1 | |
| def add_node(mask: int, node: int) -> int: | |
| """Add a node to the mask.""" | |
| if node < 0: | |
| raise ValueError(f"node index must be non-negative, got {node}") | |
| return mask | (1 << node) | |
| def remove_node(mask: int, node: int) -> int: | |
| """Remove a node from the mask.""" | |
| if node < 0: | |
| raise ValueError(f"node index must be non-negative, got {node}") |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@ciao/utils/calculations.py`:
- Around line 225-227: segment IDs from graph.nodes() may be
sparse/non-contiguous so indexing arrays by raw segment_id is unsafe; build a
dense mapping like id_to_idx = {seg_id: idx for idx, seg_id in
enumerate(segment_ids)} (using the existing segment_ids and num_segments) and
replace any direct uses of segment_id as an array index in the block around
where segment_ids/num_segments are defined (the code that writes features using
segment_id between ~lines 257–264) with id_to_idx[segment_id]; also ensure any
arrays are allocated with length=num_segments so indexing via the mapped dense
index is safe.
- Around line 335-337: Validate that input_batch contains exactly one image
before the code that repeats images and computes deltas: add a check that
input_batch.shape[0] == 1 (or raise ValueError) at the start of the section
handling hyperpixel_segment_ids_list, and abort early if not. Ensure this check
is placed before the batching/delta logic that uses input_batch to build deltas
and align with candidates (the block that repeats images and computes "deltas"
from input_batch and compares to "candidates"), so you don't duplicate images or
compute extra deltas when batch > 1.
- Around line 97-106: The code builds even indices from height but then uses
them for both column and row slicing, which breaks for non-square images; fix by
creating distinct index arrays: compute even_col_indices = torch.arange(0,
width, 2) and use it in replacement_image[:, :, even_col_indices] (the vertical
flip with dims=[1]), and compute even_row_indices = torch.arange(0, height, 2)
and use it in replacement_image[:, even_row_indices, :] (the horizontal flip
with dims=[2]); update any variable names (even_indices) accordingly so each
slice uses the correct axis-derived indices.
| segment_ids = list(graph.nodes()) | ||
| num_segments = len(segment_ids) | ||
|
|
There was a problem hiding this comment.
Map segment IDs to surrogate columns instead of indexing by raw ID.
Line 263 assumes segment_id is a dense zero-based column index. If graph node IDs are sparse/non-contiguous, this can write incorrect features or hit out-of-bounds.
🐛 Proposed fix
segment_ids = list(graph.nodes())
num_segments = len(segment_ids)
+ segment_to_col = {segment_id: col_idx for col_idx, segment_id in enumerate(segment_ids)}
@@
for i, masked_segments in enumerate(local_groups):
for segment_id in masked_segments:
- X[i, segment_id] = 1.0
+ X[i, segment_to_col[segment_id]] = 1.0Also applies to: 257-264
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@ciao/utils/calculations.py` around lines 225 - 227, segment IDs from
graph.nodes() may be sparse/non-contiguous so indexing arrays by raw segment_id
is unsafe; build a dense mapping like id_to_idx = {seg_id: idx for idx, seg_id
in enumerate(segment_ids)} (using the existing segment_ids and num_segments) and
replace any direct uses of segment_id as an array index in the block around
where segment_ids/num_segments are defined (the code that writes features using
segment_id between ~lines 257–264) with id_to_idx[segment_id]; also ensure any
arrays are allocated with length=num_segments so indexing via the mapped dense
index is safe.
| if not hyperpixel_segment_ids_list: | ||
| return [] | ||
|
|
There was a problem hiding this comment.
Enforce single-image input shape before batching deltas.
The function contract says input_batch is [1, 3, H, W], but it is not validated. With batch > 1, Line 367 repeats all images and Line 388 computes too many deltas, breaking alignment with candidates.
🐛 Proposed fix
def calculate_hyperpixel_deltas(
@@
) -> list[float]:
@@
if not hyperpixel_segment_ids_list:
return []
+ if input_batch.ndim != 4 or input_batch.shape[0] != 1:
+ raise ValueError(
+ f"Expected input_batch shape [1, 3, H, W], got {tuple(input_batch.shape)}"
+ )
+ if batch_size <= 0:
+ raise ValueError(f"batch_size must be > 0, got {batch_size}")Also applies to: 363-389
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@ciao/utils/calculations.py` around lines 335 - 337, Validate that input_batch
contains exactly one image before the code that repeats images and computes
deltas: add a check that input_batch.shape[0] == 1 (or raise ValueError) at the
start of the section handling hyperpixel_segment_ids_list, and abort early if
not. Ensure this check is placed before the batching/delta logic that uses
input_batch to build deltas and align with candidates (the block that repeats
images and computes "deltas" from input_batch and compares to "candidates"), so
you don't duplicate images or compute extra deltas when batch > 1.
…odelPredictor to __all__; improve image loader with directory checks; add graph to adjacency list conversion
…arameter for solid_color mode
57b34c8 to
f919ec9
Compare
There was a problem hiding this comment.
Actionable comments posted: 4
♻️ Duplicate comments (3)
ciao/data/preprocessing.py (1)
34-37:⚠️ Potential issue | 🟠 MajorUse a context manager for
Image.opento guarantee file-handle cleanup.Line 34 opens the image without deterministic close semantics. This is still unresolved.
Proposed fix
- image = Image.open(image_path).convert("RGB") - original_image = image.copy() + with Image.open(image_path) as img: + image = img.convert("RGB") + original_image = image.copy()#!/bin/bash # Verify all PIL opens and inspect whether they are context-managed rg -nP --type=py 'Image\.open\(' -C2🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@ciao/data/preprocessing.py` around lines 34 - 37, Replace the direct Image.open(image_path) call with a context manager to ensure the file handle is closed: use "with Image.open(image_path) as img:" then call img.convert("RGB"), assign a copy to original_image before the with-block exits, and continue using that copy for preprocess/convert and creating input_tensor/input_batch; specifically modify the code around Image.open/image.copy so Image.open, convert, original_image, preprocess, input_tensor and input_batch still work the same but the file handle is deterministically closed.ciao/utils/segmentation.py (1)
320-323:⚠️ Potential issue | 🟡 MinorValidate
neighborhoodfor square segmentation.Invalid
neighborhoodvalues currently fall through silently. This should fail fast forsquaremode.Proposed fix
if segmentation_type == "square": + if neighborhood not in (4, 8): + raise ValueError( + f"neighborhood must be 4 or 8 for square segmentation, got {neighborhood}." + ) return create_square_grid( input_tensor, square_size=segment_size, neighborhood=neighborhood )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@ciao/utils/segmentation.py` around lines 320 - 323, For the "square" branch, add explicit validation of the neighborhood parameter before calling create_square_grid: check that neighborhood is an integer and within the allowed range (e.g., positive integer and if your algorithm requires an odd size, enforce oddness) and raise a ValueError with a clear message if invalid; perform this check in the block handling segmentation_type == "square" (or inside create_square_grid if you prefer central validation) so invalid neighborhood values fail fast rather than silently falling through.ciao/explainer/ciao_explainer.py (1)
168-180:⚠️ Potential issue | 🟠 MajorAvoid direct
class_names[...]indexing before safe resolution.Line 288 can throw before Lines 298-300 fallback logic is used when the class index is invalid for
class_names.Proposed fix
if target_class_idx is None: target_class_idx = get_predicted_class(predictor, input_batch) logger.info(f"Auto-selected target class: {target_class_idx}") else: # Validate target_class_idx if provided num_classes = len(class_names) if class_names else None if num_classes and ( target_class_idx >= num_classes or target_class_idx < 0 ): raise ValueError( f"target_class_idx {target_class_idx} is out of range. " f"Model has {num_classes} classes (indices 0-{num_classes - 1})" ) + class_name = ( + class_names[target_class_idx] + if 0 <= target_class_idx < len(class_names) + else f"Class {target_class_idx}" + ) ... - logger.info(f"Class name: {class_names[target_class_idx]}") + logger.info("Class name: %s", class_name) ... - "class_name": class_names[target_class_idx] - if target_class_idx < len(class_names) - else f"Class {target_class_idx}", + "class_name": class_name,Also applies to: 288-300
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@ciao/explainer/ciao_explainer.py` around lines 168 - 180, The code may index class_names[target_class_idx] before validating target_class_idx; modify the flow in the target-class selection block (around get_predicted_class, target_class_idx, class_names, logger.info) so you first compute num_classes = len(class_names) if class_names else None and validate that target_class_idx is within 0..num_classes-1 before any use of class_names[...] (or defer any class_names indexing until after fallback to get_predicted_class when target_class_idx is None or invalid); if invalid, either raise the ValueError as currently intended or auto-select via get_predicted_class and log it with logger.info, ensuring no direct class_names indexing happens before the bounds check.
🧹 Nitpick comments (4)
ciao/utils/calculations.py (1)
221-224: Avoid unconditional extra model inference in debug logging.
predictor.get_predictions(...)is executed even when debug logs are disabled because it’s inside an f-string expression.Proposed fix
- logger.debug( - f"Probability of class {target_class_idx}: " - f"{predictor.get_predictions(input_batch)[0, target_class_idx].item()}" - ) + if logger.isEnabledFor(logging.DEBUG): + prob = predictor.get_predictions(input_batch)[0, target_class_idx].item() + logger.debug("Probability of class %s: %s", target_class_idx, prob)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@ciao/utils/calculations.py` around lines 221 - 224, The debug log is forcing a model inference because predictor.get_predictions(input_batch) is evaluated inside the f-string; to fix, avoid calling the predictor unless debug is enabled—wrap the call in a logger.isEnabledFor(logging.DEBUG) check (or compute preds once into a local variable used by subsequent code) and then call logger.debug with the precomputed value; reference the predictor.get_predictions, logger.debug, input_batch, and target_class_idx symbols to locate and update the code.ciao/structures/bitmask_graph.py (1)
13-16: Align input validation across all bitmask helpers.
iter_bitsguards negative masks, butmask_to_idsandpick_random_set_bitcurrently accept them silently. Consistent validation would make invalid states fail fast.Proposed fix
def mask_to_ids(mask: int) -> list[int]: """Convert integer bitmask to list of segment indices.""" + if mask < 0: + raise ValueError(f"mask must be non-negative, got {mask}") return [i for i in range(mask.bit_length()) if (mask >> i) & 1]def pick_random_set_bit(mask: int) -> int: """Select a random set bit from the mask in O(N) where N is the index of the bit. @@ Without allocating a list. Efficient for sparse masks. """ + if mask < 0: + raise ValueError(f"mask must be non-negative, got {mask}") count = mask.bit_count()Also applies to: 63-79
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@ciao/structures/bitmask_graph.py` around lines 13 - 16, mask_to_ids and pick_random_set_bit currently accept negative masks while iter_bits rejects them; make behavior consistent by validating the mask parameter in both mask_to_ids and pick_random_set_bit (and any other bitmask helpers around lines 63-79) to raise a ValueError for negative inputs. Locate the functions mask_to_ids and pick_random_set_bit and add an early check like "if mask < 0: raise ValueError('mask must be non-negative')" so invalid states fail fast and match iter_bits' validation.ciao/algorithm/mcgs.py (2)
490-495: Verify cached reward sign consistency.The caching logic stores
node.max_valuedirectly when a terminal node has been visited before. Sincemax_valueis updated with signed rewards during backup, this is consistent. However, consider adding a brief comment clarifying that cached values are already in optimization-signed space (matching the comment on line 538).📝 Suggested clarification
if ( is_terminal(node.mask, adj_masks, used_mask, desired_length) and node.visits > 0 ): rollout_mask = node.mask - cached_rewards.append(node.max_value) + # max_value is already in optimization-signed space from prior backups + cached_rewards.append(node.max_value)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@ciao/algorithm/mcgs.py` around lines 490 - 495, The cached reward being stored with cached_rewards.append(node.max_value) should be clearly documented as already being in the optimization-signed space (i.e., includes sign from backups), so add a short inline comment near this block (around the is_terminal check and the cached_rewards append) stating that node.max_value is the signed/optimized value and can be used directly as a rollout reward; reference the variables node.max_value, cached_rewards, is_terminal, and rollout_mask to make locating the change easy.
240-285: Consider consistent error handling for virtual loss underflow.Lines 262-266 raise
RuntimeErrorfor root pending underflow, but lines 279-282 silently clamp edge pending underflow to zero. This asymmetry could mask bugs where edge virtual loss accounting is incorrect.If the clamping is intentional (e.g., for robustness in edge cases), consider adding a brief comment explaining why edge underflows are tolerated while root underflows are not.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@ciao/algorithm/mcgs.py` around lines 240 - 285, backup_paths currently raises RuntimeError on root virtual-loss underflow (MCGSNode.pending) but silently clamps parent.pending_edges[action] to zero for edge underflow, causing asymmetry; either make them consistent by replacing the clamp in backup_paths (the block that handles parent.pending_edges[action]) with the same underflow check and raise a RuntimeError when parent.pending_edges[action] <= 0, or—if edge clamping is intentionally tolerated—add a brief inline comment above the clamp explaining why edge underflows are allowed (e.g., for robustness with concurrent DAG updates) and reference the symbols parent.pending_edges, MCGSNode.pending, and update_edge_stats so reviewers can find the logic to change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@ciao/algorithm/mcts.py`:
- Around line 347-349: In build_hyperpixel_mcts, add upfront validation for
numeric hyperparameters to reject non-positive or non-integer values: check
num_iterations, batch_size, desired_length, and (when mode == "rave") rave_k to
ensure they are integers > 0 (and rave_k > 0 only needed for rave mode); if any
check fails raise ValueError with a clear message naming the offending
parameter(s). Keep the existing mode validation and reference function
build_hyperpixel_mcts, parameter names num_iterations, batch_size,
desired_length, and rave_k when locating where to add these checks.
In `@ciao/data/loader.py`:
- Around line 20-23: The loader currently yields Path(config.data.image_path)
without validating it; update the single-image branch in loader.py (the block
that checks config.data.get("image_path")) to validate that the provided path
exists and is a file before yielding: convert to a Path, check path.exists() and
path.is_file(), and if the checks fail raise a clear ValueError (or log and
raise) mentioning the invalid image_path so callers get an actionable error
instead of a downstream failure.
In `@ciao/explainer/ciao_explainer.py`:
- Around line 238-243: Validate numeric method_params before passing them into
the search dispatcher: check method_params["num_iterations"],
["mcts_batch_size"], ["lookahead_distance"], ["rave_k"] (and any other numeric
entries like ["exploration_c"]) for presence, type and positive bounds and
replace or raise on invalid values before constructing the call in
ciao_explainer (the block that builds num_iterations=..., batch_size=...,
exploration_c=..., rave_k=...); enforce min values (e.g. >=1 for
iteration/batch/lookahead and sensible >0 for exploration_c), coerce ints where
appropriate, and centralize this validation into a small helper or inline
pre-check so the dispatched call only uses validated numeric arguments.
In `@ciao/utils/calculations.py`:
- Around line 351-354: The RuntimeError raised in calculate_hyperpixel_deltas
refers to a non-existent create_replacement_image() function; update the message
to instruct callers to initialize the replacement_image via the ModelPredictor
API instead (reference replacement_image and ModelPredictor and suggest calling
the actual initializer method on ModelPredictor, e.g., "initialize
replacement_image on ModelPredictor (call the ModelPredictor method that sets up
replacement_image) before using calculate_hyperpixel_deltas").
---
Duplicate comments:
In `@ciao/data/preprocessing.py`:
- Around line 34-37: Replace the direct Image.open(image_path) call with a
context manager to ensure the file handle is closed: use "with
Image.open(image_path) as img:" then call img.convert("RGB"), assign a copy to
original_image before the with-block exits, and continue using that copy for
preprocess/convert and creating input_tensor/input_batch; specifically modify
the code around Image.open/image.copy so Image.open, convert, original_image,
preprocess, input_tensor and input_batch still work the same but the file handle
is deterministically closed.
In `@ciao/explainer/ciao_explainer.py`:
- Around line 168-180: The code may index class_names[target_class_idx] before
validating target_class_idx; modify the flow in the target-class selection block
(around get_predicted_class, target_class_idx, class_names, logger.info) so you
first compute num_classes = len(class_names) if class_names else None and
validate that target_class_idx is within 0..num_classes-1 before any use of
class_names[...] (or defer any class_names indexing until after fallback to
get_predicted_class when target_class_idx is None or invalid); if invalid,
either raise the ValueError as currently intended or auto-select via
get_predicted_class and log it with logger.info, ensuring no direct class_names
indexing happens before the bounds check.
In `@ciao/utils/segmentation.py`:
- Around line 320-323: For the "square" branch, add explicit validation of the
neighborhood parameter before calling create_square_grid: check that
neighborhood is an integer and within the allowed range (e.g., positive integer
and if your algorithm requires an odd size, enforce oddness) and raise a
ValueError with a clear message if invalid; perform this check in the block
handling segmentation_type == "square" (or inside create_square_grid if you
prefer central validation) so invalid neighborhood values fail fast rather than
silently falling through.
---
Nitpick comments:
In `@ciao/algorithm/mcgs.py`:
- Around line 490-495: The cached reward being stored with
cached_rewards.append(node.max_value) should be clearly documented as already
being in the optimization-signed space (i.e., includes sign from backups), so
add a short inline comment near this block (around the is_terminal check and the
cached_rewards append) stating that node.max_value is the signed/optimized value
and can be used directly as a rollout reward; reference the variables
node.max_value, cached_rewards, is_terminal, and rollout_mask to make locating
the change easy.
- Around line 240-285: backup_paths currently raises RuntimeError on root
virtual-loss underflow (MCGSNode.pending) but silently clamps
parent.pending_edges[action] to zero for edge underflow, causing asymmetry;
either make them consistent by replacing the clamp in backup_paths (the block
that handles parent.pending_edges[action]) with the same underflow check and
raise a RuntimeError when parent.pending_edges[action] <= 0, or—if edge clamping
is intentionally tolerated—add a brief inline comment above the clamp explaining
why edge underflows are allowed (e.g., for robustness with concurrent DAG
updates) and reference the symbols parent.pending_edges, MCGSNode.pending, and
update_edge_stats so reviewers can find the logic to change.
In `@ciao/structures/bitmask_graph.py`:
- Around line 13-16: mask_to_ids and pick_random_set_bit currently accept
negative masks while iter_bits rejects them; make behavior consistent by
validating the mask parameter in both mask_to_ids and pick_random_set_bit (and
any other bitmask helpers around lines 63-79) to raise a ValueError for negative
inputs. Locate the functions mask_to_ids and pick_random_set_bit and add an
early check like "if mask < 0: raise ValueError('mask must be non-negative')" so
invalid states fail fast and match iter_bits' validation.
In `@ciao/utils/calculations.py`:
- Around line 221-224: The debug log is forcing a model inference because
predictor.get_predictions(input_batch) is evaluated inside the f-string; to fix,
avoid calling the predictor unless debug is enabled—wrap the call in a
logger.isEnabledFor(logging.DEBUG) check (or compute preds once into a local
variable used by subsequent code) and then call logger.debug with the
precomputed value; reference the predictor.get_predictions, logger.debug,
input_batch, and target_class_idx symbols to locate and update the code.
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (12)
ciao/__init__.pyciao/algorithm/mcgs.pyciao/algorithm/mcts.pyciao/data/__init__.pyciao/data/loader.pyciao/data/preprocessing.pyciao/explainer/__init__.pyciao/explainer/ciao_explainer.pyciao/structures/bitmask_graph.pyciao/structures/nodes.pyciao/utils/calculations.pyciao/utils/segmentation.py
🚧 Files skipped from review as they are similar to previous changes (3)
- ciao/explainer/init.py
- ciao/init.py
- ciao/data/init.py
| if mode not in ["standard", "rave"]: | ||
| raise ValueError(f"Invalid mode '{mode}'. Must be 'standard' or 'rave'.") | ||
|
|
There was a problem hiding this comment.
Validate core search hyperparameters at function entry.
build_hyperpixel_mcts currently only validates mode. Non-positive num_iterations, batch_size, desired_length, or rave_k can produce degenerate/no-op search behavior.
Proposed fix
if mode not in ["standard", "rave"]:
raise ValueError(f"Invalid mode '{mode}'. Must be 'standard' or 'rave'.")
+ if num_iterations <= 0:
+ raise ValueError(f"num_iterations must be positive, got {num_iterations}")
+ if batch_size <= 0:
+ raise ValueError(f"batch_size must be positive, got {batch_size}")
+ if desired_length <= 0:
+ raise ValueError(f"desired_length must be positive, got {desired_length}")
+ if rave_k <= 0:
+ raise ValueError(f"rave_k must be positive, got {rave_k}")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if mode not in ["standard", "rave"]: | |
| raise ValueError(f"Invalid mode '{mode}'. Must be 'standard' or 'rave'.") | |
| if mode not in ["standard", "rave"]: | |
| raise ValueError(f"Invalid mode '{mode}'. Must be 'standard' or 'rave'.") | |
| if num_iterations <= 0: | |
| raise ValueError(f"num_iterations must be positive, got {num_iterations}") | |
| if batch_size <= 0: | |
| raise ValueError(f"batch_size must be positive, got {batch_size}") | |
| if desired_length <= 0: | |
| raise ValueError(f"desired_length must be positive, got {desired_length}") | |
| if rave_k <= 0: | |
| raise ValueError(f"rave_k must be positive, got {rave_k}") |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@ciao/algorithm/mcts.py` around lines 347 - 349, In build_hyperpixel_mcts, add
upfront validation for numeric hyperparameters to reject non-positive or
non-integer values: check num_iterations, batch_size, desired_length, and (when
mode == "rave") rave_k to ensure they are integers > 0 (and rave_k > 0 only
needed for rave mode); if any check fails raise ValueError with a clear message
naming the offending parameter(s). Keep the existing mode validation and
reference function build_hyperpixel_mcts, parameter names num_iterations,
batch_size, desired_length, and rave_k when locating where to add these checks.
| if config.data.get("image_path"): | ||
| # Single image mode | ||
| yield Path(config.data.image_path) | ||
|
|
There was a problem hiding this comment.
Validate image_path before yielding.
Line 22 yields a path without checking that it exists and is a file. A typoed config currently fails later with a less actionable error.
Proposed fix
if config.data.get("image_path"):
# Single image mode
- yield Path(config.data.image_path)
+ image_path = Path(config.data.image_path)
+ if not image_path.is_file():
+ raise ValueError(f"image_path must be an existing file, got: {image_path}")
+ yield image_path🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@ciao/data/loader.py` around lines 20 - 23, The loader currently yields
Path(config.data.image_path) without validating it; update the single-image
branch in loader.py (the block that checks config.data.get("image_path")) to
validate that the provided path exists and is a file before yielding: convert to
a Path, check path.exists() and path.is_file(), and if the checks fail raise a
clear ValueError (or log and raise) mentioning the invalid image_path so callers
get an actionable error instead of a downstream failure.
| num_iterations=method_params.get("num_iterations", 100), | ||
| mode=mode_str, | ||
| batch_size=method_params.get("mcts_batch_size", 64), | ||
| exploration_c=method_params.get("exploration_c", 1.4), | ||
| rave_k=method_params.get("rave_k", 1000), | ||
| ) |
There was a problem hiding this comment.
Validate method-specific numeric params before dispatch.
method_params values like num_iterations, mcts_batch_size, lookahead_distance, and rave_k are used without bounds checks. Non-positive values can produce no-op searches or invalid stats.
Proposed fix
if method_params is None:
method_params = {}
+
+ def _require_positive_int(name: str, value: int) -> None:
+ if not isinstance(value, int) or value <= 0:
+ raise ValueError(f"{name} must be a positive integer, got {value}") elif method in ["mcts", "mc_rave"]:
mode_str = "rave" if method == "mc_rave" else "standard"
+ _require_positive_int("num_iterations", method_params.get("num_iterations", 100))
+ _require_positive_int("mcts_batch_size", method_params.get("mcts_batch_size", 64)) elif method == "lookahead":
+ _require_positive_int("lookahead_distance", method_params.get("lookahead_distance", 2)) elif method in ["mcgs", "mcgs_rave"]:
# Determine mode based on method
mode_str = "rave" if method == "mcgs_rave" else "standard"
+ _require_positive_int("num_iterations", method_params.get("num_iterations", 100))
+ _require_positive_int("mcts_batch_size", method_params.get("mcts_batch_size", 64))
+ _require_positive_int("rave_k", method_params.get("rave_k", 1000))Also applies to: 255-257, 272-277
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@ciao/explainer/ciao_explainer.py` around lines 238 - 243, Validate numeric
method_params before passing them into the search dispatcher: check
method_params["num_iterations"], ["mcts_batch_size"], ["lookahead_distance"],
["rave_k"] (and any other numeric entries like ["exploration_c"]) for presence,
type and positive bounds and replace or raise on invalid values before
constructing the call in ciao_explainer (the block that builds
num_iterations=..., batch_size=..., exploration_c=..., rave_k=...); enforce min
values (e.g. >=1 for iteration/batch/lookahead and sensible >0 for
exploration_c), coerce ints where appropriate, and centralize this validation
into a small helper or inline pre-check so the dispatched call only uses
validated numeric arguments.
| raise RuntimeError( | ||
| "replacement_image is not initialized. " | ||
| "Call create_replacement_image() before using calculate_hyperpixel_deltas." | ||
| ) |
There was a problem hiding this comment.
Fix misleading recovery hint in the runtime error.
The message says create_replacement_image(), but that method does not exist in ModelPredictor.
Proposed fix
if predictor.replacement_image is None:
raise RuntimeError(
"replacement_image is not initialized. "
- "Call create_replacement_image() before using calculate_hyperpixel_deltas."
+ "Call ModelPredictor.get_replacement_image(...) and assign "
+ "predictor.replacement_image before using calculate_hyperpixel_deltas."
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| raise RuntimeError( | |
| "replacement_image is not initialized. " | |
| "Call create_replacement_image() before using calculate_hyperpixel_deltas." | |
| ) | |
| raise RuntimeError( | |
| "replacement_image is not initialized. " | |
| "Call ModelPredictor.get_replacement_image(...) and assign " | |
| "predictor.replacement_image before using calculate_hyperpixel_deltas." | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@ciao/utils/calculations.py` around lines 351 - 354, The RuntimeError raised
in calculate_hyperpixel_deltas refers to a non-existent
create_replacement_image() function; update the message to instruct callers to
initialize the replacement_image via the ModelPredictor API instead (reference
replacement_image and ModelPredictor and suggest calling the actual initializer
method on ModelPredictor, e.g., "initialize replacement_image on ModelPredictor
(call the ModelPredictor method that sets up replacement_image) before using
calculate_hyperpixel_deltas").
**DEPENDS ON PR #3 **
Please review and merge PR #3 first. This PR is stacked on top of it, so it temporarily shows commits from PR #3. Once PR #3 is merged into master, those commits will automatically disappear from this diff.
What was changed:
data/module for data loading and preprocessing.explainer/module containing the mainCIAOExplainerclass.Why:
To build the primary user-facing API (
CIAOExplainer) and handle data flow. By isolating this from visualizations and configurations, we keep the PR size manageable and focus the review strictly on how the algorithm interacts with incoming data and outputs results.Related Task:
XAI-29
Summary by CodeRabbit
New Features
Documentation
Chores