Flash is a high-level deep learning framework for fast prototyping, baselining, finetuning and solving deep learning problems. It features a set of tasks for you to use for inference and finetuning out of the box, and an easy to implement API to customize every step of the process for full flexibility.
Flash is built for beginners with a simple API that requires very little deep learning background, and for data scientists, Kagglers, applied ML practitioners and deep learning researchers that want a quick way to get a deep learning baseline with advanced features PyTorch Lightning offers.
If you are just getting started with deep learning, Flash offers common deep learning tasks you can use out-of-the-box in a few lines of code, no math, fancy nn.Modules or research experience required!
Flash is built on top of PyTorch Lightning, a powerful deep learning research framework for training models at scale. With the power of Lightning, you can train your flash tasks on any hardware: CPUs, GPUs or TPUs without any code changes.
If you want to create more complex and customized models, you can refactor any part of flash with PyTorch or PyTorch Lightning components to get all the flexibility you need. Lightning is just organized PyTorch with the unnecessary engineering details abstracted away.
- Flash (high-level)
- Lightning (mid-level)
- PyTorch (low-level)
When you need more flexibility you can build your own tasks or simply use Lightning directly.
PyTorch Lightning is designed to abstract away unnecessary boilerplate, while enabling maximal flexibility. In order to provide full flexibility, solving very common deep learning problems such as classification in Lightning still requires some boilerplate. It can still take quite some time to get a baseline model running on a new dataset or out of domain task. We created Flash to answer our users need for a super quick way to baseline for Lightning using proven backbones for common data patterns. Flash aims to be the easiest starting point for your research- start with a Flash Task to benchmark against, and override any part of flash with Lightning or PyTorch components on your way to SOTA research.
Flash tasks are essentially LightningModules, and the Flash Trainer is a thin wrapper for the Lightning Trainer. You can use your own LightningModule instead of the Flash task, the Lightning Trainer instead of the flash trainer, etc. Flash helps you focus even more only on your research, and less on anything else.
Flash tasks implement the standard best practices for a variety of different models and domains, to save you time digging through different implementations. Flash abstracts even more details than Lightning, allowing deep learning experts to share their tips and tricks for solving scoped deep learning problems.
Flash is comprised of a collection of Tasks. The Flash tasks are laser-focused objects designed to solve a well-defined type of problem, using state-of-the-art methods.
The Flash tasks contain all the relevant information to solve the task at hand- the number of class labels you want to predict, number of columns in your dataset, as well as details on the model architecture used such as loss function, optimizers, etc.
Here are examples of tasks:
from flash.text import TextClassifier from flash.image import ImageClassifier from flash.tabular import TabularClassifier
Note
Tasks are inflexible by definition! To get more flexibility, you can simply use ~pytorch_lightning.core.lightning.LightningModule
directly or modify an existing task in just a few lines.
Inference is the process of generating predictions from trained models. To use a task for inference:
- Init your task with pretrained weights using a checkpoint (a checkpoint is simply a file that capture the exact value of all parameters used by a model). Local file or URL works.
- Load your data into a
~flash.core.data.data_module.DataModule
and pass it toTrainer.predict <flash.core.trainer.Trainer.predict>
.
Here's an example of inference:
# import our libraries from flash import Trainer from flash.text import TextClassifier, TextClassificationData
# 1. Init the finetuned task from URL model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.9.0/text_classification_model.pt")
# 2. Perform inference from list of sequences trainer = Trainer() datamodule = TextClassificationData.from_lists( predict_data=[ "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.", "The worst movie in the history of cinema.", "This guy has done a great job with this movie!", ], batch_size=4, ) predictions = trainer.predict(model, datamodule=datamodule, output="labels") print(predictions)
We get the following output:
...
- assert all(
- [
all([prediction in ["positive", "negative"] for prediction in prediction_batch]) for prediction_batch in predictions
]
)
[["negative", "negative", "positive"]]
Finetuning (or transfer-learning) is the process of tweaking a model trained on a large dataset, to your particular (likely much smaller) dataset. All Flash tasks have pre-trained backbones that are already trained on large datasets such as ImageNet. Finetuning on pretrained models decreases training time significantly.
When you have enough data, you're likely better off training from scratch instead of finetuning.
ImageClassification <reference/image_classification>
ImageEmbedder <reference/image_embedder>
TextClassification <reference/text_classification>
SummarizationTask <reference/summarization>
TranslationTask <reference/translation>
TabularClassification <reference/tabular_classification>
More tasks coming soon!
The lightning + Flash team is hard at work building more tasks for common deep-learning use cases. But we're looking for incredible contributors like you to submit new tasks!
Join our Slack to get help becoming a contributor!