Skip to content

Commit

Permalink
Add type hints to skll.data module & improve types in skll.config (
Browse files Browse the repository at this point in the history
…#730)

* chore(pre-commit): ignore missing imports in my check

* chore(pre-commit): add conventional commit check

* chore: add new custom types in `skll.types`

* chore: improve type hints in `config/__init__.py`
- Use new `ClassMap` custom type.
- Make sure `pos_label` is typed correctly.
- Fix incorrect default value for sampler parameters.

* chore: add type hints to `skll.data` module
- Refactor code to ensure correct types.
- Add a lot of checks for labels, features, vectorizer being ``None``
  that was previously handled.
- Update docstrings to use the type hints.

* chore: Update test_featureset.py
- Add new tests for merging featuresets.
- Tweak tests that were using `StringIO` to use actual files.
  • Loading branch information
desilinguist committed May 18, 2023
1 parent d7f0eb2 commit 9f26caf
Show file tree
Hide file tree
Showing 7 changed files with 751 additions and 496 deletions.
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ repos:
- id: check-ast
- id: check-json
- id: debug-statements
- repo: https://github.com/compilerla/conventional-pre-commit
rev: 'v2.2.0'
hooks:
- id: conventional-pre-commit
stages: [commit-msg]
- repo: https://github.com/ikamensh/flynt/
rev: '0.78'
hooks:
Expand All @@ -31,3 +36,4 @@ repos:
rev: 'v1.2.0'
hooks:
- id: mypy
args: [--ignore-missing-imports]
20 changes: 9 additions & 11 deletions skll/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import ruamel.yaml as yaml

from skll.data.readers import safe_float
from skll.types import FoldMapping, PathOrStr
from skll.types import ClassMap, FoldMapping, LabelType, PathOrStr
from skll.utils.constants import (
PROBABILISTIC_METRICS,
VALID_FEATURE_SCALING_OPTIONS,
Expand Down Expand Up @@ -86,7 +86,7 @@ def __init__(self) -> None:
"random_folds": "False",
"results": "",
"sampler": "",
"sampler_parameters": "[]",
"sampler_parameters": "{}",
"save_cv_folds": "True",
"save_cv_models": "False",
"shuffle": "False",
Expand Down Expand Up @@ -296,7 +296,7 @@ def parse_config_file(
bool,
bool,
str,
str,
Optional[LabelType],
str,
int,
str,
Expand All @@ -317,7 +317,7 @@ def parse_config_file(
str,
str,
bool,
Optional[Dict[str, List[str]]],
Optional[ClassMap],
str,
str,
List[int],
Expand Down Expand Up @@ -407,9 +407,8 @@ def parse_config_file(
results_path : str
Path to store result files in.
pos_label : str
The string label for the positive class in the binary
classification setting.
pos_label : Optional[LabelType]
The label for the positive class in the binary classification setting.
feature_scaling : str
How to scale features (e.g. 'with_mean').
Expand Down Expand Up @@ -484,7 +483,7 @@ def parse_config_file(
ids_to_floats : bool
Whether to convert IDs to floats.
class_map : Optional[Dict[str, List[str]]]
class_map : Optional[ClassMap]
A class map collapsing several labels into one. The keys
are the collapsed labels and each key's value is the list of
labels to be collapsed into said label.
Expand Down Expand Up @@ -704,9 +703,8 @@ def parse_config_file(
param_grid_list = yaml.safe_load(fix_json(config.get("Tuning", "param_grids")))

# read and normalize the value of `pos_label`
pos_label = safe_float(config.get("Tuning", "pos_label"))
if pos_label == "":
pos_label = None
pos_label_string = safe_float(config.get("Tuning", "pos_label"))
pos_label: Optional[LabelType] = pos_label_string if pos_label_string else None

# ensure that feature_scaling is specified only as one of the
# four available choices
Expand Down

0 comments on commit 9f26caf

Please sign in to comment.