Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Rename ClassificationInput to ClassificationInputMixin (#1116)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed Jan 14, 2022
1 parent 428cdb8 commit 5d41d1a
Show file tree
Hide file tree
Showing 11 changed files with 38 additions and 26 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Changed `Wav2Vec2Processor` to `AutoProcessor` and seperate it from backbone [optional] ([#1075](https://github.com/PyTorchLightning/lightning-flash/pull/1075))

- Renamed `ClassificationInput` to `ClassificationInputMixin` ([#1116](https://github.com/PyTorchLightning/lightning-flash/pull/1116))

### Deprecated

### Fixed
Expand Down
11 changes: 11 additions & 0 deletions docs/source/api/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,17 @@ ___________________________
~flash.core.data.io.input.InputFormat
~flash.core.data.io.input.ImageLabelsMap

flash.core.data.io.classification
_________________________________

.. autosummary::
:toctree: generated/
:nosignatures:
:template: classtemplate.rst

~flash.core.data.io.classification_input.ClassificationState
~flash.core.data.io.classification_input.ClassificationInputMixin

flash.core.data.process
_______________________

Expand Down
4 changes: 2 additions & 2 deletions docs/source/template/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ Each :class:`~flash.core.data.io.input.Input` has 2 methods:
By default these methods just return their input, so you don't need both a :meth:`~flash.core.data.io.input.Input.load_data` and a :meth:`~flash.core.data.io.input.Input.load_sample` to create a :class:`~flash.core.data.io.input.Input`.
Where possible, you should override one of our existing :class:`~flash.core.data.io.input.Input` classes.

Let's start by implementing a ``TemplateNumpyClassificationInput``, which overrides :class:`~flash.core.data.io.classification_input.ClassificationInput`.
Let's start by implementing a ``TemplateNumpyClassificationInput``, which overrides :class:`~flash.core.data.io.classification_input.ClassificationInputMixin`.
The main :class:`~flash.core.data.io.input.Input` method that we have to implement is :meth:`~flash.core.data.io.input.Input.load_data`.
:class:`~flash.core.data.io.classification_input.ClassificationInput` provides utilities for handling targets within flash which need to be called from the :meth:`~flash.core.data.io.input.Input.load_data` and :meth:`~flash.core.data.io.input.Input.load_sample`.
:class:`~flash.core.data.io.classification_input.ClassificationInputMixin` provides utilities for handling targets within flash which need to be called from the :meth:`~flash.core.data.io.input.Input.load_data` and :meth:`~flash.core.data.io.input.Input.load_sample`.
In this ``Input``, we'll also set the ``num_features`` attribute so that we can access it later.

Here's the code for our ``TemplateNumpyClassificationInput.load_data`` method:
Expand Down
6 changes: 3 additions & 3 deletions flash/audio/classification/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
import numpy as np
import pandas as pd

from flash.core.data.io.classification_input import ClassificationInput, ClassificationState
from flash.core.data.io.input import DataKeys
from flash.core.data.io.classification_input import ClassificationInputMixin, ClassificationState
from flash.core.data.io.input import DataKeys, Input
from flash.core.data.utilities.classification import TargetMode
from flash.core.data.utilities.data_frame import read_csv, resolve_files, resolve_targets
from flash.core.data.utilities.paths import filter_valid_files, has_file_allowed_extension, make_dataset, PATH_TYPE
Expand All @@ -37,7 +37,7 @@ def spectrogram_loader(filepath: str):
return data


class AudioClassificationInput(ClassificationInput):
class AudioClassificationInput(Input, ClassificationInputMixin):
@requires("audio")
def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
h, w = sample[DataKeys.INPUT].shape[-2:] # H x W
Expand Down
11 changes: 5 additions & 6 deletions flash/core/data/io/classification_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
from functools import lru_cache
from typing import Any, List, Optional, Sequence

from flash.core.data.io.input import Input
from flash.core.data.properties import ProcessState
from flash.core.data.properties import ProcessState, Properties
from flash.core.data.utilities.classification import (
get_target_details,
get_target_formatter,
Expand All @@ -34,9 +33,9 @@ class ClassificationState(ProcessState):
num_classes: Optional[int] = None


class ClassificationInput(Input):
"""The ``ClassificationInput`` class provides utility methods for handling classification targets.
:class:`~flash.core.data.io.input.Input` objects that extend ``ClassificationInput`` should do the following:
class ClassificationInputMixin(Properties):
"""The ``ClassificationInputMixin`` class provides utility methods for handling classification targets.
:class:`~flash.core.data.io.input.Input` objects that extend ``ClassificationInputMixin`` should do the following:
* In the ``load_data`` method, include a call to ``load_target_metadata``. This will determine the format of the
targets and store metadata like ``labels`` and ``num_classes``.
Expand All @@ -47,7 +46,7 @@ class ClassificationInput(Input):
@property
@lru_cache(maxsize=None)
def target_formatter(self) -> TargetFormatter:
"""Get the :class:`~flash.core.data.utiltiies.classification.TargetFormatter` to use when formatting
"""Get the :class:`~flash.core.data.utilities.classification.TargetFormatter` to use when formatting
targets.
This property uses ``functools.lru_cache`` so that we only instantiate the formatter once.
Expand Down
4 changes: 2 additions & 2 deletions flash/graph/classification/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch.utils.data import Dataset

from flash.core.data.data_module import DatasetInput
from flash.core.data.io.classification_input import ClassificationInput, ClassificationState
from flash.core.data.io.classification_input import ClassificationInputMixin, ClassificationState
from flash.core.data.io.input import DataKeys
from flash.core.utilities.imports import _GRAPH_AVAILABLE, requires

Expand All @@ -26,7 +26,7 @@
from torch_geometric.data import InMemoryDataset


class GraphClassificationDatasetInput(DatasetInput, ClassificationInput):
class GraphClassificationDatasetInput(DatasetInput, ClassificationInputMixin):
@requires("graph")
def load_data(self, dataset: Dataset) -> Dataset:
if not self.predicting:
Expand Down
8 changes: 4 additions & 4 deletions flash/image/classification/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import pandas as pd

from flash.core.data.io.classification_input import ClassificationInput, ClassificationState
from flash.core.data.io.classification_input import ClassificationInputMixin, ClassificationState
from flash.core.data.io.input import DataKeys
from flash.core.data.utilities.classification import TargetMode
from flash.core.data.utilities.data_frame import read_csv, resolve_files, resolve_targets
Expand All @@ -34,7 +34,7 @@
SampleCollection = None


class ImageClassificationFilesInput(ClassificationInput, ImageFilesInput):
class ImageClassificationFilesInput(ClassificationInputMixin, ImageFilesInput):
def load_data(self, files: List[PATH_TYPE], targets: Optional[List[Any]] = None) -> List[Dict[str, Any]]:
if targets is None:
return super().load_data(files)
Expand Down Expand Up @@ -74,7 +74,7 @@ def predict_load_data(data: SampleCollection) -> List[Dict[str, Any]]:
return super().load_data(data.values("filepath"))


class ImageClassificationTensorInput(ClassificationInput, ImageTensorInput):
class ImageClassificationTensorInput(ClassificationInputMixin, ImageTensorInput):
def load_data(self, tensor: Any, targets: Optional[List[Any]] = None) -> List[Dict[str, Any]]:
if targets is not None:
self.load_target_metadata(targets)
Expand All @@ -87,7 +87,7 @@ def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
return sample


class ImageClassificationNumpyInput(ClassificationInput, ImageNumpyInput):
class ImageClassificationNumpyInput(ClassificationInputMixin, ImageNumpyInput):
def load_data(self, array: Any, targets: Optional[List[Any]] = None) -> List[Dict[str, Any]]:
if targets is not None:
self.load_target_metadata(targets)
Expand Down
6 changes: 3 additions & 3 deletions flash/tabular/classification/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.
from typing import Any, Dict, List, Optional, Union

from flash import DataKeys
from flash.core.data.io.classification_input import ClassificationInput
from flash.core.data.io.classification_input import ClassificationInputMixin
from flash.core.data.io.input import DataKeys
from flash.core.data.utilities.data_frame import read_csv, resolve_targets
from flash.core.utilities.imports import _PANDAS_AVAILABLE
from flash.tabular.input import TabularDataFrameInput
Expand All @@ -25,7 +25,7 @@
DataFrame = object


class TabularClassificationDataFrameInput(TabularDataFrameInput, ClassificationInput):
class TabularClassificationDataFrameInput(TabularDataFrameInput, ClassificationInputMixin):
def load_data(
self,
data_frame: DataFrame,
Expand Down
2 changes: 1 addition & 1 deletion flash/tabular/regression/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import numpy as np

from flash import DataKeys
from flash.core.data.io.input import DataKeys
from flash.core.data.utilities.data_frame import read_csv
from flash.core.utilities.imports import _PANDAS_AVAILABLE
from flash.tabular.input import TabularDataFrameInput
Expand Down
4 changes: 2 additions & 2 deletions flash/template/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from flash.core.data.callback import BaseDataFetcher
from flash.core.data.data_module import DataModule
from flash.core.data.data_pipeline import DataPipelineState
from flash.core.data.io.classification_input import ClassificationInput
from flash.core.data.io.classification_input import ClassificationInputMixin
from flash.core.data.io.input import DataKeys, Input
from flash.core.data.io.input_transform import InputTransform
from flash.core.data.utilities.samples import to_samples
Expand All @@ -34,7 +34,7 @@
Bunch = object


class TemplateNumpyClassificationInput(ClassificationInput):
class TemplateNumpyClassificationInput(Input, ClassificationInputMixin):
"""An example data source that records ``num_features`` on the dataset."""

def load_data(
Expand Down
6 changes: 3 additions & 3 deletions flash/text/classification/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

import pandas as pd

from flash.core.data.io.classification_input import ClassificationInput, ClassificationState
from flash.core.data.io.input import DataKeys
from flash.core.data.io.classification_input import ClassificationInputMixin, ClassificationState
from flash.core.data.io.input import DataKeys, Input
from flash.core.data.utilities.classification import TargetMode
from flash.core.data.utilities.paths import PATH_TYPE
from flash.core.integrations.transformers.states import TransformersBackboneState
Expand All @@ -29,7 +29,7 @@
Dataset = object


class TextClassificationInput(ClassificationInput):
class TextClassificationInput(Input, ClassificationInputMixin):
@staticmethod
def _resolve_target(target_keys: Union[str, List[str]], element: Dict[str, Any]) -> Dict[str, Any]:
if not isinstance(target_keys, List):
Expand Down

0 comments on commit 5d41d1a

Please sign in to comment.