Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

DataPipeline PoC #141

Merged
merged 191 commits into from
Mar 29, 2021
Merged

DataPipeline PoC #141

merged 191 commits into from
Mar 29, 2021

Conversation

tchaton
Copy link
Contributor

@tchaton tchaton commented Feb 22, 2021

What does this PR do?

This PR introduces the new API for DataPipeline.

Objective:
Provide a flexible API which organise user processing code toward higher readability, debugging and performance.

DataPipeline are composed from 2 parts: Preprocess and Postprocess.

Preprocess implements the following hooks:

  • load_data
  • load_sample
  • per_sample_pre_tensor_transform
  • per_sample_to_tensor_transform
  • per_sample_post_tensor_transform
  • per_batch_transform
  • collate
  • per_sample_transform_on_device
  • per_batch_transform_on_device

Postprocess implements the following hooks:

  • per_batch_transform
  • per_sample_transform
  • uncollate
  • export_data
  • export_sample

The DataPipeline are aware of the Trainer RunningStage, meaning they know if they are running training, validation, testing, predicting,

The users can customise each hooks for a specific RunningStage by adding train, validation, test, predict as prefix before every hooks: Example. train_load_data function would be used for Training stage only or use boolean self.training, self.validating, self.testing and self.predicting

@mock.patch("torch.save")  # need to mock torch.save or we get pickle error
def test_dummy_example(tmpdir):

    class ImageClassificationPeprocess(Preprocess):

        def __init__(self, to_tensor_transform, train_per_sample_transform_on_device):
            super().__init__()
            self._to_tensor = to_tensor_transform# T.ToTensor()
            self._train_per_sample_transform_on_device = train_per_sample_transform_on_device# T.RandomHorizontalFlip()

        def load_data(self, folder: str):
            # from folder -> return files paths
            return ["a.jpg", "b.jpg"]

        def load_sample(self, path: str) -> Image.Image:
            # from a file path, load the associated image
            img8Bit = np.uint8(np.random.uniform(0, 1, (64, 64, 3)) * 255.0)
            return Image.fromarray(img8Bit)

        def per_sample_to_tensor_transform(self, pil_image: Image.Image) -> torch.Tensor:
            # convert pil image into a tensor
            return self._to_tensor(pil_image)

        def train_per_sample_transform_on_device(self, sample: Any) -> Any:
            # apply an augmentation per sample on gpu for train only
            return self._train_per_sample_transform_on_device(sample)

    class CustomModel(Task):

        def __init__(self):
			# This would be a CNN and the loss cross entropy :)
            super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss())

        def training_step(self, batch, batch_idx):
            assert batch.shape == torch.Size([2, 3, 64, 64])

        def validation_step(self, batch, batch_idx):
            assert batch.shape == torch.Size([2, 3, 64, 64])

        def test_step(self, batch, batch_idx):
            assert batch.shape == torch.Size([2, 3, 64, 64])

    class CustomDataModule(DataModule):

        preprocess_cls = ImageClassificationPeprocess

        @property
        def preprocess(self):
            return self.preprocess_cls(
                self.to_tensor_transform,
                self.train_per_sample_transform_on_device)

        @classmethod
        def from_folders(
            cls, 
            train_folder: Optional[str], 
            val_folder: Optional[str], 
            test_folder: Optional[str], 
            predict_folder: Optional[str], 
            to_tensor_transform: torch.nn.Module, 
            train_per_sample_transform_on_device: torch.nn.Module, 
            batch_size: int):

            # attach the arguments for the preprocess onto the cls
            cls.to_tensor_transform = to_tensor_transform
            cls.train_per_sample_transform_on_device = train_per_sample_transform_on_device
            
            # call ``from_load_data_inputs``
            return cls.from_load_data_inputs(
                train_load_data_input=train_folder, 
                valid_load_data_input=val_folder, 
                test_load_data_input=test_folder, 
                predict_load_data_input=predict_folder, 
                batch_size=batch_size)
            
    datamodule = CustomDataModule.from_folders(
        "train_folder", "val_folder", "test_folder", None, T.ToTensor(), T.RandomHorizontalFlip(), batch_size=2)

    assert isinstance(datamodule.train_dataloader().dataset[0], Image.Image)
    batch = next(iter(datamodule.train_dataloader()))
    assert batch[0].shape == torch.Size([3, 64, 64])

    model = CustomModel()
    trainer = Trainer(
        max_epochs=1,
        limit_train_batches=2,
        limit_val_batches=1,
        limit_test_batches=2,
        limit_predict_batches=2,
        num_sanity_val_steps=1
    )
    trainer.fit(model, datamodule=datamodule)
    trainer.test(model

TODOs:

  • Add support for per_sample_pre_tensor_transform, per_sample_to_tensor_transform, per_sample_post_tensor_transform
  • Add tests for the above hooks
  • Add tests for reload_dataloaders_every_n_epochs
  • Add check that per_sample_post_tensor_transform receives tensors + tests
  • Convert to new API for ImageClassifier.
  • Convert to new API for ObjectDetector (WIP) @kaushikb11 Mind having a look ?
  • Convert to new API for SummarizationTask
  • Convert to new API for TabularClassifier.
  • Convert to new API for TextClassifier
  • Convert to new API for TranslationTask
  • Resolve CI failing test + update outdated tests

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests? [not needed for typos/docs]
  • Did you verify new and existing tests pass locally with your changes?
  • If you made a notable change (that affects users), did you update the CHANGELOG?

PR review

  • Is this pull request ready for review? (if not, please submit in draft mode)

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@pep8speaks
Copy link

pep8speaks commented Feb 22, 2021

Hello @tchaton! Thanks for updating this PR.

There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻

Comment last updated at 2021-03-29 18:33:32 UTC

@codecov
Copy link

codecov bot commented Feb 23, 2021

Codecov Report

Merging #141 (e2f24dc) into master (3b4c6b6) will increase coverage by 3.37%.
The diff coverage is 78.15%.

❗ Current head e2f24dc differs from pull request most recent head de3327b. Consider uploading reports for the commit de3327b to get more accurate results
Impacted file tree graph

@@            Coverage Diff             @@
##           master     #141      +/-   ##
==========================================
+ Coverage   76.52%   79.89%   +3.37%     
==========================================
  Files          56       55       -1     
  Lines        2334     2447     +113     
==========================================
+ Hits         1786     1955     +169     
+ Misses        548      492      -56     
Flag Coverage Δ
unittests 79.89% <78.15%> (+3.37%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
flash/core/data/utils.py 29.26% <ø> (-58.83%) ⬇️
flash/vision/detection/model.py 72.46% <ø> (-0.14%) ⬇️
flash/text/seq2seq/core/model.py 61.53% <25.00%> (-2.53%) ⬇️
flash/text/seq2seq/core/data.py 42.62% <32.55%> (-42.31%) ⬇️
flash/text/classification/data.py 40.90% <36.55%> (-44.81%) ⬇️
flash/text/classification/model.py 63.63% <40.00%> (-33.34%) ⬇️
flash/text/seq2seq/summarization/data.py 57.14% <47.61%> (-27.48%) ⬇️
flash/text/seq2seq/summarization/model.py 82.35% <66.66%> (+5.88%) ⬆️
flash/vision/utils.py 69.23% <69.23%> (ø)
flash/tabular/classification/data/dataset.py 75.75% <75.00%> (ø)
... and 32 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 3b4c6b6...de3327b. Read the comment docs.

@kaushikb11 kaushikb11 mentioned this pull request Feb 24, 2021
8 tasks
@Borda Borda added the Priority label Feb 24, 2021
@Borda
Copy link
Member

Borda commented Feb 24, 2021

@tchaton @justusschock can we get this done asap as it blocks transition to proper PL version

@justusschock
Copy link
Member

@Borda why does it block? This should be independent.

@Borda
Copy link
Member

Borda commented Feb 24, 2021

@Borda why does it block? This should be independent.

pls, check #133 (comment)

README.md Show resolved Hide resolved
docs/source/general/data.rst Outdated Show resolved Hide resolved
flash/core/imports.py Outdated Show resolved Hide resolved
flash/core/model.py Outdated Show resolved Hide resolved
flash/vision/classification/data.py Outdated Show resolved Hide resolved
flash/vision/detection/data.py Outdated Show resolved Hide resolved
flash/vision/detection/data.py Show resolved Hide resolved
flash_notebooks/generic_task.ipynb Show resolved Hide resolved
requirements.txt Show resolved Hide resolved
flash_examples/generic_task.py Show resolved Hide resolved
requirements.txt Outdated Show resolved Hide resolved
requirements.txt Outdated Show resolved Hide resolved
docs/source/general/data.rst Show resolved Hide resolved
docs/source/custom_task.rst Show resolved Hide resolved
flash/data/data_module.py Show resolved Hide resolved
flash/data/data_pipeline.py Show resolved Hide resolved
@tchaton tchaton merged commit ba34bf4 into master Mar 29, 2021
@tchaton tchaton deleted the datapipeline_poc_1 branch March 29, 2021 19:10
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants