Skip to content

Commit

Permalink
use direct access to groups in match objects
Browse files Browse the repository at this point in the history
Co-authored-by: Emil Melnikov <emilmelnikov@users.noreply.github.com>
  • Loading branch information
k-dominik and emilmelnikov committed Apr 26, 2024
1 parent b8e3761 commit cf13402
Showing 1 changed file with 21 additions and 25 deletions.
46 changes: 21 additions & 25 deletions ilastik/applets/base/appletSerializer/legacyClassifiers.py
Expand Up @@ -121,15 +121,15 @@ def deserialize_classifier_type(ds: h5py.Dataset) -> LazyflowClassifierTypeABC:

# legacy support - ilastik used to pickle the classifier type
if class_string.isascii() and (m := classifier_pickle_string_matcher.match(class_string)):
groupdict = m.groupdict()
m

if groupdict["submodule_name"] not in _lazyflow_classifier_factory_submodule_allow_list:
raise ValueError(f"Could not load classifier: submodule {groupdict['submodule_name']} not allowed.")
if m["submodule_name"] not in _lazyflow_classifier_factory_submodule_allow_list:
raise ValueError(f"Could not load classifier: submodule {m['submodule_name']} not allowed.")

if groupdict["type_name"] not in _lazyflow_classifier_type_allow_list:
raise ValueError(f"Could not load classifier: type {groupdict['type_name']} not allowed.")
if m["type_name"] not in _lazyflow_classifier_type_allow_list:
raise ValueError(f"Could not load classifier: type {m['type_name']} not allowed.")

return ClassifierInfo(**groupdict).classifier_type
return ClassifierInfo(**m.groupdict()).classifier_type

raise ValueError(f"Could not load classifier type {class_string=}")

Expand Down Expand Up @@ -232,9 +232,8 @@ def _deserialize_classifier_factory_type(pickle_string: str) -> ClassifierFactor
)

if pickle_string.isascii() and (m := classifier_factory_pickle_string_matcher.search(pickle_string)):
groupdict = m.groupdict()
submodule = groupdict["factory_submodule"]
typename = groupdict["type_name"]
submodule = m["factory_submodule"]
typename = m["type_name"]

if submodule not in _lazyflow_classifier_factory_submodule_allow_list:
raise ValueError(f"Could not load classifier: submodule {submodule} not allowed. {pickle_string=}")
Expand All @@ -245,7 +244,7 @@ def _deserialize_classifier_factory_type(pickle_string: str) -> ClassifierFactor
raise ValueError(f"Could not load classifier factory type submodule and type not found {pickle_string=}")

if m := classifier_factory_version_pickle_string_matcher.search(pickle_string):
version = int(m.groupdict()["factory_version"])
version = int(m["factory_version"])
else:
raise ValueError(f"Could not load classifier type, no version found {pickle_string=}")

Expand Down Expand Up @@ -292,7 +291,7 @@ def _deserialize_VigraRfClassifierFactory(pickle_string: str) -> VigraRfLazyflow
)

if m := classifier_factory_args_pickle_string_matcher.search(pickle_string):
arg = int(m.groupdict()["arg"])
arg = int(m["arg"])
else:
raise ValueError(
f"Could not load VigraRfLazyflowClassifierFactory, no argument found not found in {pickle_string=}"
Expand Down Expand Up @@ -335,7 +334,7 @@ def _deserialize_ParallelVigraRfLazyflowClassifierFactory(
)

if m := classifier_factory_num_trees_pickle_string_matcher.search(pickle_string):
num_trees = int(m.groupdict()["num_trees"])
num_trees = int(m["num_trees"])
else:
raise ValueError(
f"Could not load ParallelVigraRfLazyflowClassifierFactory, _num_trees not found in {pickle_string=}"
Expand All @@ -352,7 +351,7 @@ def _deserialize_ParallelVigraRfLazyflowClassifierFactory(
)

if m := classifier_factory_label_proportion_pickle_string_matcher.search(pickle_string):
label_prop_string = m.groupdict()["label_proportion"]
label_prop_string = m["label_proportion"]
label_proportion = None if label_prop_string == "N" else float(label_prop_string)
else:
raise ValueError(
Expand All @@ -370,7 +369,7 @@ def _deserialize_ParallelVigraRfLazyflowClassifierFactory(
)

if m := classifier_factory_variable_importance_path_pickle_string_matcher.search(pickle_string):
variable_importance_pth_string = m.groupdict()["variable_importance_path"]
variable_importance_pth_string = m["variable_importance_path"]
variable_importance_path = None if variable_importance_pth_string == "N" else variable_importance_pth_string
else:
raise ValueError(
Expand All @@ -388,7 +387,7 @@ def _deserialize_ParallelVigraRfLazyflowClassifierFactory(
)

if m := classifier_factory_variable_importance_enabled_pickle_string_matcher.search(pickle_string):
variable_importance_enabled = bool(int(m.groupdict()["variable_importance_enabled"]))
variable_importance_enabled = bool(int(m["variable_importance_enabled"]))
else:
raise ValueError(
f"Could not load ParallelVigraRfLazyflowClassifierFactory, _variable_importance_enabled not found in {pickle_string=}"
Expand All @@ -404,7 +403,7 @@ def _deserialize_ParallelVigraRfLazyflowClassifierFactory(
)

if m := classifier_factory_num_forests_pickle_string_matcher.search(pickle_string):
num_forests = int(m.groupdict()["num_forests"])
num_forests = int(m["num_forests"])
else:
raise ValueError(
f"Could not load ParallelVigraRfLazyflowClassifierFactory, _num_forests not found in {pickle_string=}"
Expand Down Expand Up @@ -499,9 +498,8 @@ def _deserialize_SklearnLazyflowClassifierFactory(pickle_string) -> SklearnClass
]

if m := classifier_factory_sklearn_type_pickle_string_matcher.search(pickle_string):
groupdict = m.groupdict()
submodules = groupdict["submodules"]
typename = groupdict["typename"]
submodules = m["submodules"]
typename = m["typename"]

if submodules not in sklearn_submodule_allow_list or typename not in sklearn_classifier_allow_list:
raise ValueError(f"Classifier of type sklearn.{submodules}.{typename} not permitted.")
Expand Down Expand Up @@ -554,9 +552,7 @@ def _deserialize_sklearn_RandomForest_details(pickle_str: str) -> SklearnClassif
)

if m := classifier_factory_args_pickle_string_matcher.search(pickle_str):
return SklearnClassifierFactoryInfo(
classifier_type=RandomForestClassifier, args=[int(m.groupdict()["arg"])], kwargs={}
)
return SklearnClassifierFactoryInfo(classifier_type=RandomForestClassifier, args=[int(m["arg"])], kwargs={})
else:
raise ValueError("Could not deserialize sklearn RandomForest classifier.")

Expand All @@ -572,7 +568,7 @@ def _deserialize_sklearn_AdaBoostClassifier_details(pickle_str: str) -> SklearnC
)
if m := classifier_factory_n_estimators_pickle_string_matcher.search(pickle_str):
return SklearnClassifierFactoryInfo(
classifier_type=AdaBoostClassifier, args=[], kwargs={"n_estimators": int(m.groupdict()["n_estimators"])}
classifier_type=AdaBoostClassifier, args=[], kwargs={"n_estimators": int(m["n_estimators"])}
)
else:
raise ValueError("Could not deserialize sklearn AdaBoostClassifier.")
Expand All @@ -589,7 +585,7 @@ def _deserialize_sklearn_DecisionTreeClassifier_details(pickle_str: str) -> Skle
)
if m := classifier_factory_max_depth_pickle_string_matcher.search(pickle_str):
return SklearnClassifierFactoryInfo(
classifier_type=DecisionTreeClassifier, args=[], kwargs={"max_depth": int(m.groupdict()["max_depth"])}
classifier_type=DecisionTreeClassifier, args=[], kwargs={"max_depth": int(m["max_depth"])}
)
else:
raise ValueError("Could not deserialize sklearn DecisionTreeClassifier")
Expand All @@ -608,7 +604,7 @@ def _deserialize_sklearn_SVC_details(
)
if m := classifier_factory_probability_pickle_string_matcher.search(pickle_str):
return SklearnClassifierFactoryInfo(
classifier_type=classifier_type, args=[], kwargs={"probability": int(m.groupdict()["probability"]) != 0}
classifier_type=classifier_type, args=[], kwargs={"probability": int(m["probability"]) != 0}
)
else:
raise ValueError("Could not deserialize sklearn SVC/NuSVC classifier.")

0 comments on commit cf13402

Please sign in to comment.