Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mypy type checking #81

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/build.yml
Expand Up @@ -38,6 +38,9 @@ jobs:
poetry run flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings.
poetry run flake8 . --count --exit-zero --max-complexity=10 --statistics
- name: Type check with mypy
run: |
poetry run mypy .
- name: Test with pytest
run: |
poetry run pytest tests --cov ./geneeval --cov-report=xml --cov-config=./.coveragerc
Expand Down
3 changes: 1 addition & 2 deletions geneeval/classifiers/__init__.py
@@ -1,2 +1 @@
from geneeval.classifiers.supervised_classifiers import (MLPClassifier,
SupervisedClassifier)
from geneeval.classifiers.supervised_classifiers import MLPClassifier, SupervisedClassifier
9 changes: 2 additions & 7 deletions geneeval/classifiers/auto_classifier.py
@@ -1,17 +1,12 @@
from typing import Tuple

from geneeval.classifiers.supervised_classifiers import (MLPClassifier,
SupervisedClassifier)
from geneeval.classifiers.supervised_classifiers import MLPClassifier, SupervisedClassifier
from geneeval.common.utils import CLASSIFICATION, TASKS
from geneeval.data import PreprocessedData


class AutoClassifier:
"""A factory function, which returns the correct classifiers for a given `task`."""

def __new__(
self, task: str, data: PreprocessedData
) -> Tuple[SupervisedClassifier, SupervisedClassifier]:
def __new__(self, task: str, data: PreprocessedData) -> SupervisedClassifier:

if task not in TASKS:
raise ValueError(f"task must be one of: {', '.join(TASKS)}. Got: {task}")
Expand Down
2 changes: 1 addition & 1 deletion geneeval/classifiers/supervised_classifiers.py
Expand Up @@ -39,7 +39,7 @@ def fit(self) -> None:
"""Wrapper around `self.estimator.fit`."""
self.estimator.fit(self.data.X_train, self.data.y_train)

def score(self) -> Dict[str, float]:
def score(self) -> Dict[str, Dict[str, float]]:
"""Wrapper around `self.estimator.score`."""
X_valid = self.data.X_train[self.data.splits.test_fold == 0]
y_valid = self.data.y_train[self.data.splits.test_fold == 0]
Expand Down
2 changes: 1 addition & 1 deletion geneeval/common/utils.py
Expand Up @@ -7,7 +7,7 @@
CLASSIFICATION = {
"subcellular_localization",
}
REGRESSION = set()
REGRESSION: Set[str] = set()
TASKS = CLASSIFICATION | REGRESSION

TRAIN_SIZE = 0.7
Expand Down
4 changes: 2 additions & 2 deletions geneeval/data.py
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Dict, Union
from typing import Union

import numpy as np
import pandas as pd
Expand All @@ -24,7 +24,7 @@ class DatasetReader:
"""Given a dataframe of gene features, returns a `PreprocessedData` containing everything we
need to train and evaluate with Sklearn."""

def __new__(self, features: pd.DataFrame, task: str) -> Dict[str, PreprocessedData]:
def __new__(self, features: pd.DataFrame, task: str) -> PreprocessedData:

if task not in TASKS:
raise ValueError(f"task must be one of: {', '.join(TASKS)}. Got: {task}")
Expand Down
11 changes: 6 additions & 5 deletions geneeval/engine.py
@@ -1,10 +1,11 @@
from typing import List, Optional, Union
from typing import Dict, List, Optional, Union

import pandas as pd

from geneeval.classifiers import SupervisedClassifier
from geneeval.classifiers.auto_classifier import AutoClassifier
from geneeval.common.utils import resolve_tasks
from geneeval.data import DatasetReader
from geneeval.data import DatasetReader, PreprocessedData


class Engine:
Expand All @@ -20,13 +21,13 @@ def __init__(
) -> None:
self._features = features
self._tasks = resolve_tasks(include_tasks, exclude_tasks)
self.results = {}
self.results: Dict[str, Dict[str, Dict[str, float]]] = {}

def run(self) -> None:
for task in self._tasks:
data = DatasetReader(self._features, task)
data: PreprocessedData = DatasetReader(self._features, task)

classifier = AutoClassifier(task, data)
classifier: SupervisedClassifier = AutoClassifier(task, data)
estimator = classifier
estimator.fit()
results = estimator.score()
Expand Down
4 changes: 2 additions & 2 deletions geneeval/fetcher/auto_fetcher.py
@@ -1,4 +1,4 @@
from typing import List
from typing import Iterable

from geneeval.fetcher.fetchers import Fetcher, LocalizationFetcher, SequenceFetcher, UniprotFetcher

Expand All @@ -14,7 +14,7 @@ class AutoFetcher:
present. This file can be created by running the `get_protein_ids.py` file in `scripts`.
"""

def __new__(cls, tasks: List[str]) -> Fetcher:
def __new__(cls, tasks: Iterable[str]) -> Fetcher:

fetcher = Fetcher()

Expand Down
12 changes: 6 additions & 6 deletions geneeval/fetcher/fetchers.py
Expand Up @@ -4,7 +4,7 @@
from abc import ABCMeta, abstractmethod
from collections import Counter
from tempfile import TemporaryFile
from typing import Any, Callable, Dict, List, Tuple
from typing import IO, Any, Callable, Dict, List, Tuple

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -85,7 +85,7 @@ def register(self, fetcher_class):
def _build_columns(self) -> str:
return ",".join(["id"] + [callback() for callback in self.fetch_callbacks])

def _create_file(self) -> TemporaryFile:
def _create_file(self) -> IO[bytes]:
benchmark = json.load(BENCHMARK_FILEPATH.open())
protein_ids = benchmark["inputs"].keys()
protein_id_file = TemporaryFile()
Expand Down Expand Up @@ -123,12 +123,12 @@ def fetch_callback() -> Any:

@staticmethod
@abstractmethod
def parse_callback() -> Any:
def parse_callback(df: pd.DataFrame) -> Any:
pass

@staticmethod
@abstractmethod
def process_callback() -> Any:
def process_callback(parsed_dct: Dict[str, Any]) -> Any:
pass


Expand Down Expand Up @@ -180,13 +180,13 @@ def parse_callback(df: pd.DataFrame) -> Dict[str, Dict[str, List[str]]]:
return {"subcellular_localization": parsed}

@staticmethod
def process_callback(parsed_dct: Dict[str, Any]) -> Dict[str, Dict[str, List[str]]]:
def process_callback(parsed_dct: Dict[str, Any]) -> Dict[str, Dict[str, Dict[str, List[str]]]]:

LOCALIZATION_COUNT_UPPER_LIMIT = 1000
LOCALIZATION_COUNT_LOWER_LIMIT = 50

localization_dct = parsed_dct["subcellular_localization"]
counter = Counter()
counter: Counter = Counter()
for localization_list in localization_dct.values():
counter.update(localization_list)

Expand Down
22 changes: 11 additions & 11 deletions geneeval/main.py
@@ -1,10 +1,11 @@
from collections import defaultdict
from pathlib import Path
from typing import Dict, List
from typing import Dict, List, Optional

import numpy as np
import orjson
import typer
from sklearn import metrics
from sklearn.preprocessing import LabelBinarizer, MultiLabelBinarizer

from geneeval.common.data_utils import load_benchmark, load_features
Expand All @@ -24,7 +25,7 @@ def build_benchmark() -> None:
`exclude_tasks` respectively.
"""

fetcher = AutoFetcher(list(TASKS.keys()))
fetcher = AutoFetcher(TASKS)
benchmark = fetcher.fetch()
benchmark = {**benchmark, **orjson.loads(BENCHMARK_FILEPATH.read_bytes())}
benchmark = orjson.dumps(benchmark, option=orjson.OPT_INDENT_2)
Expand All @@ -33,7 +34,7 @@ def build_benchmark() -> None:

@app.command()
def prepare(
filepath: Path = typer.Argument(
filepath: str = typer.Argument(
..., writable=True, help="Filepath to save prepared benchmark file."
),
include_tasks: List[str] = typer.Option(
Expand All @@ -42,7 +43,7 @@ def prepare(
exclude_tasks: List[str] = typer.Option(
None, help="A task name (or list of task names) to exclude in the prepared data."
),
):
) -> None:
tasks = resolve_tasks(include_tasks, exclude_tasks)
benchmark = load_benchmark()
prepared_benchmark = {task: benchmark[task] for task in tasks}
Expand All @@ -57,19 +58,18 @@ def prepare(
)
) # Find the subset of genes specific to the `tasks` specified
prepared_benchmark["inputs"] = {gene: benchmark["inputs"][gene] for gene in task_specific_genes}
prepared_benchmark = orjson.dumps(prepared_benchmark, option=orjson.OPT_INDENT_2)
filepath.write_bytes(prepared_benchmark)
Path(filepath).write_bytes(orjson.dumps(prepared_benchmark, option=orjson.OPT_INDENT_2))


@evaluate_app.command("features")
def evaluate_features(
filepath: Path = typer.Argument(
filepath: str = typer.Argument(
..., exists=True, dir_okay=False, help="Filepath to the gene features."
),
include_tasks: List[str] = typer.Option(
include_tasks: Optional[List[str]] = typer.Option(
None, help="A task name (or list of task names) to include in the evaluation."
),
exclude_tasks: List[str] = typer.Option(
exclude_tasks: Optional[List[str]] = typer.Option(
None, help="A task name (or list of task names) to exclude in the evaluation."
),
) -> Dict:
Expand All @@ -92,7 +92,7 @@ def evaluate_features(

@evaluate_app.command("predictions")
def evaluate_predictions(
filepath: Path = typer.Argument(
filepath: str = typer.Argument(
..., exists=True, dir_okay=False, help="Filepath to the gene label predictions."
),
) -> Dict:
Expand All @@ -110,7 +110,7 @@ def recursive_defaultdict():
results = recursive_defaultdict()

for task, partitions in predictions.items():
metric = AutoMetric(task)
metric: metrics = AutoMetric(task)
# Fit the label binarizer on the train set of the benchmark.
multilabel = (
isinstance(list(benchmark[task]["train"].values())[0], list)
Expand Down
5 changes: 3 additions & 2 deletions geneeval/metrics/__init__.py
@@ -1,5 +1,6 @@
from functools import partial
from sklearn import metrics

f1_micro_score = partial(metrics.f1_score, average="micro")
from sklearn.metrics import f1_score

f1_micro_score: f1_score = partial(f1_score, average="micro")
f1_micro_score.__name__ = "f1_micro_score"
6 changes: 6 additions & 0 deletions mypy.ini
@@ -0,0 +1,6 @@
[mypy]
ignore_missing_imports = true
no_site_packages = true

[mypy-tests.*]
strict_optional = false