diff --git a/CHANGELOG.md b/CHANGELOG.md index 81da09f44f..a9ebd7defb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed `Wav2Vec2Processor` to `AutoProcessor` and seperate it from backbone [optional] ([#1075](https://github.com/PyTorchLightning/lightning-flash/pull/1075)) +- Renamed `ClassificationInput` to `ClassificationInputMixin` ([#1116](https://github.com/PyTorchLightning/lightning-flash/pull/1116)) + ### Deprecated ### Fixed diff --git a/docs/source/api/data.rst b/docs/source/api/data.rst index 91983d63b1..b02db82e3f 100644 --- a/docs/source/api/data.rst +++ b/docs/source/api/data.rst @@ -82,6 +82,17 @@ ___________________________ ~flash.core.data.io.input.InputFormat ~flash.core.data.io.input.ImageLabelsMap +flash.core.data.io.classification +_________________________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~flash.core.data.io.classification_input.ClassificationState + ~flash.core.data.io.classification_input.ClassificationInputMixin + flash.core.data.process _______________________ diff --git a/docs/source/template/data.rst b/docs/source/template/data.rst index a8dd3e6d2b..b39bc5786c 100644 --- a/docs/source/template/data.rst +++ b/docs/source/template/data.rst @@ -27,9 +27,9 @@ Each :class:`~flash.core.data.io.input.Input` has 2 methods: By default these methods just return their input, so you don't need both a :meth:`~flash.core.data.io.input.Input.load_data` and a :meth:`~flash.core.data.io.input.Input.load_sample` to create a :class:`~flash.core.data.io.input.Input`. Where possible, you should override one of our existing :class:`~flash.core.data.io.input.Input` classes. -Let's start by implementing a ``TemplateNumpyClassificationInput``, which overrides :class:`~flash.core.data.io.classification_input.ClassificationInput`. +Let's start by implementing a ``TemplateNumpyClassificationInput``, which overrides :class:`~flash.core.data.io.classification_input.ClassificationInputMixin`. The main :class:`~flash.core.data.io.input.Input` method that we have to implement is :meth:`~flash.core.data.io.input.Input.load_data`. -:class:`~flash.core.data.io.classification_input.ClassificationInput` provides utilities for handling targets within flash which need to be called from the :meth:`~flash.core.data.io.input.Input.load_data` and :meth:`~flash.core.data.io.input.Input.load_sample`. +:class:`~flash.core.data.io.classification_input.ClassificationInputMixin` provides utilities for handling targets within flash which need to be called from the :meth:`~flash.core.data.io.input.Input.load_data` and :meth:`~flash.core.data.io.input.Input.load_sample`. In this ``Input``, we'll also set the ``num_features`` attribute so that we can access it later. Here's the code for our ``TemplateNumpyClassificationInput.load_data`` method: diff --git a/flash/audio/classification/input.py b/flash/audio/classification/input.py index 865d439a65..094a38b175 100644 --- a/flash/audio/classification/input.py +++ b/flash/audio/classification/input.py @@ -17,8 +17,8 @@ import numpy as np import pandas as pd -from flash.core.data.io.classification_input import ClassificationInput, ClassificationState -from flash.core.data.io.input import DataKeys +from flash.core.data.io.classification_input import ClassificationInputMixin, ClassificationState +from flash.core.data.io.input import DataKeys, Input from flash.core.data.utilities.classification import TargetMode from flash.core.data.utilities.data_frame import read_csv, resolve_files, resolve_targets from flash.core.data.utilities.paths import filter_valid_files, has_file_allowed_extension, make_dataset, PATH_TYPE @@ -37,7 +37,7 @@ def spectrogram_loader(filepath: str): return data -class AudioClassificationInput(ClassificationInput): +class AudioClassificationInput(Input, ClassificationInputMixin): @requires("audio") def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: h, w = sample[DataKeys.INPUT].shape[-2:] # H x W diff --git a/flash/core/data/io/classification_input.py b/flash/core/data/io/classification_input.py index cda34d9827..7e5c9efd6f 100644 --- a/flash/core/data/io/classification_input.py +++ b/flash/core/data/io/classification_input.py @@ -15,8 +15,7 @@ from functools import lru_cache from typing import Any, List, Optional, Sequence -from flash.core.data.io.input import Input -from flash.core.data.properties import ProcessState +from flash.core.data.properties import ProcessState, Properties from flash.core.data.utilities.classification import ( get_target_details, get_target_formatter, @@ -34,9 +33,9 @@ class ClassificationState(ProcessState): num_classes: Optional[int] = None -class ClassificationInput(Input): - """The ``ClassificationInput`` class provides utility methods for handling classification targets. - :class:`~flash.core.data.io.input.Input` objects that extend ``ClassificationInput`` should do the following: +class ClassificationInputMixin(Properties): + """The ``ClassificationInputMixin`` class provides utility methods for handling classification targets. + :class:`~flash.core.data.io.input.Input` objects that extend ``ClassificationInputMixin`` should do the following: * In the ``load_data`` method, include a call to ``load_target_metadata``. This will determine the format of the targets and store metadata like ``labels`` and ``num_classes``. @@ -47,7 +46,7 @@ class ClassificationInput(Input): @property @lru_cache(maxsize=None) def target_formatter(self) -> TargetFormatter: - """Get the :class:`~flash.core.data.utiltiies.classification.TargetFormatter` to use when formatting + """Get the :class:`~flash.core.data.utilities.classification.TargetFormatter` to use when formatting targets. This property uses ``functools.lru_cache`` so that we only instantiate the formatter once. diff --git a/flash/graph/classification/input.py b/flash/graph/classification/input.py index 680418d210..ea326e1878 100644 --- a/flash/graph/classification/input.py +++ b/flash/graph/classification/input.py @@ -16,7 +16,7 @@ from torch.utils.data import Dataset from flash.core.data.data_module import DatasetInput -from flash.core.data.io.classification_input import ClassificationInput, ClassificationState +from flash.core.data.io.classification_input import ClassificationInputMixin, ClassificationState from flash.core.data.io.input import DataKeys from flash.core.utilities.imports import _GRAPH_AVAILABLE, requires @@ -26,7 +26,7 @@ from torch_geometric.data import InMemoryDataset -class GraphClassificationDatasetInput(DatasetInput, ClassificationInput): +class GraphClassificationDatasetInput(DatasetInput, ClassificationInputMixin): @requires("graph") def load_data(self, dataset: Dataset) -> Dataset: if not self.predicting: diff --git a/flash/image/classification/input.py b/flash/image/classification/input.py index 82dd7dc7d2..2b830cf926 100644 --- a/flash/image/classification/input.py +++ b/flash/image/classification/input.py @@ -16,7 +16,7 @@ import pandas as pd -from flash.core.data.io.classification_input import ClassificationInput, ClassificationState +from flash.core.data.io.classification_input import ClassificationInputMixin, ClassificationState from flash.core.data.io.input import DataKeys from flash.core.data.utilities.classification import TargetMode from flash.core.data.utilities.data_frame import read_csv, resolve_files, resolve_targets @@ -34,7 +34,7 @@ SampleCollection = None -class ImageClassificationFilesInput(ClassificationInput, ImageFilesInput): +class ImageClassificationFilesInput(ClassificationInputMixin, ImageFilesInput): def load_data(self, files: List[PATH_TYPE], targets: Optional[List[Any]] = None) -> List[Dict[str, Any]]: if targets is None: return super().load_data(files) @@ -74,7 +74,7 @@ def predict_load_data(data: SampleCollection) -> List[Dict[str, Any]]: return super().load_data(data.values("filepath")) -class ImageClassificationTensorInput(ClassificationInput, ImageTensorInput): +class ImageClassificationTensorInput(ClassificationInputMixin, ImageTensorInput): def load_data(self, tensor: Any, targets: Optional[List[Any]] = None) -> List[Dict[str, Any]]: if targets is not None: self.load_target_metadata(targets) @@ -87,7 +87,7 @@ def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: return sample -class ImageClassificationNumpyInput(ClassificationInput, ImageNumpyInput): +class ImageClassificationNumpyInput(ClassificationInputMixin, ImageNumpyInput): def load_data(self, array: Any, targets: Optional[List[Any]] = None) -> List[Dict[str, Any]]: if targets is not None: self.load_target_metadata(targets) diff --git a/flash/tabular/classification/input.py b/flash/tabular/classification/input.py index 5a39297779..3484b45920 100644 --- a/flash/tabular/classification/input.py +++ b/flash/tabular/classification/input.py @@ -13,8 +13,8 @@ # limitations under the License. from typing import Any, Dict, List, Optional, Union -from flash import DataKeys -from flash.core.data.io.classification_input import ClassificationInput +from flash.core.data.io.classification_input import ClassificationInputMixin +from flash.core.data.io.input import DataKeys from flash.core.data.utilities.data_frame import read_csv, resolve_targets from flash.core.utilities.imports import _PANDAS_AVAILABLE from flash.tabular.input import TabularDataFrameInput @@ -25,7 +25,7 @@ DataFrame = object -class TabularClassificationDataFrameInput(TabularDataFrameInput, ClassificationInput): +class TabularClassificationDataFrameInput(TabularDataFrameInput, ClassificationInputMixin): def load_data( self, data_frame: DataFrame, diff --git a/flash/tabular/regression/input.py b/flash/tabular/regression/input.py index 44f608ae22..1ff15ca4c1 100644 --- a/flash/tabular/regression/input.py +++ b/flash/tabular/regression/input.py @@ -15,7 +15,7 @@ import numpy as np -from flash import DataKeys +from flash.core.data.io.input import DataKeys from flash.core.data.utilities.data_frame import read_csv from flash.core.utilities.imports import _PANDAS_AVAILABLE from flash.tabular.input import TabularDataFrameInput diff --git a/flash/template/classification/data.py b/flash/template/classification/data.py index 8d1a9a19e7..88c37fa047 100644 --- a/flash/template/classification/data.py +++ b/flash/template/classification/data.py @@ -20,7 +20,7 @@ from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule from flash.core.data.data_pipeline import DataPipelineState -from flash.core.data.io.classification_input import ClassificationInput +from flash.core.data.io.classification_input import ClassificationInputMixin from flash.core.data.io.input import DataKeys, Input from flash.core.data.io.input_transform import InputTransform from flash.core.data.utilities.samples import to_samples @@ -34,7 +34,7 @@ Bunch = object -class TemplateNumpyClassificationInput(ClassificationInput): +class TemplateNumpyClassificationInput(Input, ClassificationInputMixin): """An example data source that records ``num_features`` on the dataset.""" def load_data( diff --git a/flash/text/classification/input.py b/flash/text/classification/input.py index 2c7ad75b46..aeb4fb62c0 100644 --- a/flash/text/classification/input.py +++ b/flash/text/classification/input.py @@ -16,8 +16,8 @@ import pandas as pd -from flash.core.data.io.classification_input import ClassificationInput, ClassificationState -from flash.core.data.io.input import DataKeys +from flash.core.data.io.classification_input import ClassificationInputMixin, ClassificationState +from flash.core.data.io.input import DataKeys, Input from flash.core.data.utilities.classification import TargetMode from flash.core.data.utilities.paths import PATH_TYPE from flash.core.integrations.transformers.states import TransformersBackboneState @@ -29,7 +29,7 @@ Dataset = object -class TextClassificationInput(ClassificationInput): +class TextClassificationInput(Input, ClassificationInputMixin): @staticmethod def _resolve_target(target_keys: Union[str, List[str]], element: Dict[str, Any]) -> Dict[str, Any]: if not isinstance(target_keys, List):