Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add more ways to load image data for classification and detection #1372

Merged
merged 3 commits into from Jul 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Expand Up @@ -18,6 +18,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for Flash serve to the `ObjectDetector` ([#1370](https://github.com/PyTorchLightning/lightning-flash/pull/1370))

- Added support for loading `ImageClassificationData` from PIL images with `from_images` ([#1372](https://github.com/PyTorchLightning/lightning-flash/pull/1372))

- Added support for loading `ObjectDetectionData` with `from_numpy`, `from_images`, and `from_tensors` ([#1372](https://github.com/PyTorchLightning/lightning-flash/pull/1372))

### Changed

- Changed the `ImageEmbedder` dependency on VISSL to optional ([#1276](https://github.com/PyTorchLightning/lightning-flash/pull/1276))
Expand Down
93 changes: 93 additions & 0 deletions flash/image/classification/data.py
Expand Up @@ -41,6 +41,7 @@
ImageClassificationFiftyOneInput,
ImageClassificationFilesInput,
ImageClassificationFolderInput,
ImageClassificationImageInput,
ImageClassificationNumpyInput,
ImageClassificationTensorInput,
)
Expand All @@ -64,6 +65,7 @@
"ImageClassificationData.from_files",
"ImageClassificationData.from_folders",
"ImageClassificationData.from_numpy",
"ImageClassificationData.from_images",
"ImageClassificationData.from_tensors",
"ImageClassificationData.from_data_frame",
"ImageClassificationData.from_csv",
Expand Down Expand Up @@ -385,6 +387,97 @@ def from_numpy(
**data_module_kwargs,
)

@classmethod
def from_images(
cls,
train_images: Optional[List[Image.Image]] = None,
train_targets: Optional[Sequence[Any]] = None,
val_images: Optional[List[Image.Image]] = None,
val_targets: Optional[Sequence[Any]] = None,
test_images: Optional[List[Image.Image]] = None,
test_targets: Optional[Sequence[Any]] = None,
predict_images: Optional[List[Image.Image]] = None,
target_formatter: Optional[TargetFormatter] = None,
input_cls: Type[Input] = ImageClassificationImageInput,
transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform,
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs: Any,
) -> "ImageClassificationData":
"""Load the :class:`~flash.image.classification.data.ImageClassificationData` from lists of PIL images and
corresponding lists of targets.

The targets can be in any of our
:ref:`supported classification target formats <formatting_classification_targets>`.
To learn how to customize the transforms applied for each stage, read our
:ref:`customizing transforms guide <customizing_transforms>`.

Args:
train_images: The list of PIL images to use when training.
train_targets: The list of targets to use when training.
val_images: The list of PIL images to use when validating.
val_targets: The list of targets to use when validating.
test_images: The list of PIL images to use when testing.
test_targets: The list of targets to use when testing.
predict_images: The list of PIL images to use when predicting.
target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to
control how targets are handled. See :ref:`formatting_classification_targets` for more details.
input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data.
transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use.
transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms.
data_module_kwargs: Additional keyword arguments to provide to the
:class:`~flash.core.data.data_module.DataModule` constructor.

Returns:
The constructed :class:`~flash.image.classification.data.ImageClassificationData`.

Examples
________

.. doctest::

>>> from PIL import Image
>>> import numpy as np
>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> datamodule = ImageClassificationData.from_images(
... train_images=[
... Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")),
... Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")),
... Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")),
... ],
... train_targets=["cat", "dog", "cat"],
... predict_images=[Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8"))],
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
2
>>> datamodule.labels
['cat', 'dog']
>>> model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Training...
>>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Predicting...
"""
ds_kw = dict(
target_formatter=target_formatter,
)

train_input = input_cls(RunningStage.TRAINING, train_images, train_targets, **ds_kw)
ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None)

return cls(
train_input,
input_cls(RunningStage.VALIDATING, val_images, val_targets, **ds_kw),
input_cls(RunningStage.TESTING, test_images, test_targets, **ds_kw),
input_cls(RunningStage.PREDICTING, predict_images, **ds_kw),
transform=transform,
transform_kwargs=transform_kwargs,
**data_module_kwargs,
)

@classmethod
def from_tensors(
cls,
Expand Down
24 changes: 23 additions & 1 deletion flash/image/classification/input.py
Expand Up @@ -24,7 +24,14 @@
from flash.core.data.utilities.samples import to_samples
from flash.core.integrations.fiftyone.utils import FiftyOneLabelUtilities
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import, requires
from flash.image.data import ImageFilesInput, ImageNumpyInput, ImageTensorInput, IMG_EXTENSIONS, NP_EXTENSIONS
from flash.image.data import (
ImageFilesInput,
ImageInput,
ImageNumpyInput,
ImageTensorInput,
IMG_EXTENSIONS,
NP_EXTENSIONS,
)

if _FIFTYONE_AVAILABLE:
fol = lazy_import("fiftyone.core.labels")
Expand Down Expand Up @@ -115,6 +122,21 @@ def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
return sample


class ImageClassificationImageInput(ClassificationInputMixin, ImageInput):
def load_data(
self, images: Any, targets: Optional[List[Any]] = None, target_formatter: Optional[TargetFormatter] = None
) -> List[Dict[str, Any]]:
if targets is not None:
self.load_target_metadata(targets, target_formatter=target_formatter)
return to_samples(images, targets)

def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
sample = super().load_sample(sample)
if DataKeys.TARGET in sample:
sample[DataKeys.TARGET] = self.format_target(sample[DataKeys.TARGET])
return sample


class ImageClassificationDataFrameInput(ImageClassificationFilesInput):
labels: list

Expand Down