diff --git a/src/neuron_proofreader/config.py b/src/neuron_proofreader/config.py index 09b974b..4cf9ffb 100644 --- a/src/neuron_proofreader/config.py +++ b/src/neuron_proofreader/config.py @@ -109,8 +109,6 @@ class MLConfig: Name of model used to perform inference. Default is None. patch_shape : Tuple[int] Shape of image patch expected by vision model. Default is (96, 96, 96). - threshold : float - A general threshold value used in classification. Default is 0.8. transform : bool Indication of whether to apply data augmentation to image patches. Default is False. @@ -121,7 +119,6 @@ class MLConfig: device: str = "cuda" model_name: str = None patch_shape: tuple = (96, 96, 96) - threshold: float = 0.8 transform: bool = False def to_dict(self): diff --git a/src/neuron_proofreader/machine_learning/subgraph_sampler.py b/src/neuron_proofreader/machine_learning/subgraph_sampler.py index ea69a13..c9839bb 100644 --- a/src/neuron_proofreader/machine_learning/subgraph_sampler.py +++ b/src/neuron_proofreader/machine_learning/subgraph_sampler.py @@ -114,8 +114,8 @@ def __iter__(self): # Yield batch yield subgraph - def populate_via_bfs(self, subgraph, root): - i, j = tuple(root) + def populate_via_bfs(self, subgraph, root_proposal): + i, j = root_proposal queue = deque([(i, 0), (j, 0)]) visited = {i, j} while queue: diff --git a/src/neuron_proofreader/split_proofreading/split_inference.py b/src/neuron_proofreader/split_proofreading/split_inference.py index 574f4f9..2edf4ba 100644 --- a/src/neuron_proofreader/split_proofreading/split_inference.py +++ b/src/neuron_proofreader/split_proofreading/split_inference.py @@ -146,7 +146,9 @@ def _load_data(self, fragments_path, img_path, segmentation_path): self.log(f"Module Runtime: {elapsed:.2f} {unit}\n") # --- Pipelines --- - def __call__(self, search_radius): + def __call__( + self, search_radius, dt=0.1, min_threshold=0.75, removal_threshold=0.3 + ): """ Executes the full inference pipeline. @@ -154,44 +156,39 @@ def __call__(self, search_radius): ---------- search_radius : float Search radius (in microns) used to generate proposals. + dt : float, optional + Increment that acceptance threshold is lowered by. Default is 0.1. + min_threshold : float, optional + Minimum threshold for accepting proposals. Default is 0.75. + removal_threshold : float, optional + Proposals with model predictions less than this value are removed. + Default is 0.3. """ - # Generate proposal - t0 = time() + # Generate proposals self.generate_proposals(search_radius) - preds = self.predict_proposals() - - # Update graph - self.merge_with_threshold_schedule(preds, self.config.ml.threshold) - - # Report results - t, unit = util.time_writer(time() - t0) - self.log(self.dataset.summary(prefix="\nFinal")) - self.log(f"Total Runtime: {t:.2f} {unit}\n") - self.save_results() + total_proposals = self.dataset.n_proposals() - def multistep(self, search_radius, low_threshold=0.3, dt=0.1): + # Run inference + cnt = 0 t0 = time() - self.generate_proposals(search_radius) - total_proposals = self.dataset.n_proposals() for only_leaf2leaf in [True, False]: - cnt = 0 name = "_leaf2leaf" if only_leaf2leaf else "" new_threshold = 0.99 while self.dataset.proposals: # Generate predictons cnt += 1 - print(f"Threshold={new_threshold} w/ only_leaf2leaf={only_leaf2leaf}") - preds = self.predict_proposals(suffix=f"{name}_round={cnt}") + self.log(f"Threshold={new_threshold} w/ only_leaf2leaf={only_leaf2leaf}") + preds = self.predict_proposals(suffix=f"{name}_round={cnt}_threshold={new_threshold}") # Merge accetped proposals cur_threshold = new_threshold self.merge_with_threshold_schedule( preds, cur_threshold, only_leaf2leaf=only_leaf2leaf ) - self.filter_proposals(preds, low_threshold) + self.filter_proposals(preds, removal_threshold) # Update threshold - new_threshold = max(cur_threshold - dt, self.config.ml.threshold) + new_threshold = max(cur_threshold - dt, min_threshold) if cur_threshold == new_threshold: break @@ -205,6 +202,7 @@ def multistep(self, search_radius, low_threshold=0.3, dt=0.1): # --- Core Routines --- def filter_proposals(self, preds, threshold): + # Remove based on model predictions and mergeability cnt = 0 for proposal, pred in preds.items(): is_valid = self.dataset.is_mergeable(*proposal) @@ -212,6 +210,13 @@ def filter_proposals(self, preds, threshold): self.dataset.remove_proposal(proposal) cnt += 1 + # Sanity check + for proposal in self.dataset.list_proposals(): + i, j = proposal + if self.dataset.degree[i] > 2 or self.dataset.degree[j] > 2: + self.dataset.remove_proposal(proposal) + cnt += 1 + self.log("\nFilter Proposals") self.log(f"# Proposals Removed: {cnt}") self.log(f"# Proposals Remaining: {self.dataset.n_proposals()}\n")