Skip to content
This repository was archived by the owner on Aug 28, 2025. It is now read-only.

Commit 94be2c2

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 3172828 commit 94be2c2

File tree

3 files changed

+40
-44
lines changed

3 files changed

+40
-44
lines changed

flash_tutorials/image_classification/image_classification.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,37 @@
11
#!/usr/bin/env python
2-
# coding: utf-8
32

43
# %% [markdown]
54
# In this tutorial, we'll go over the basics of lightning Flash by finetuning/predictin with an ImageClassifier on [Hymenoptera Dataset](https://www.kaggle.com/ajayrana/hymenoptera-data) containing ants and bees images.
6-
#
5+
#
76
# # Finetuning
8-
#
7+
#
98
# Finetuning consists of four steps:
10-
#
9+
#
1110
# - 1. Train a source neural network model on a source dataset. For computer vision, it is traditionally the [ImageNet dataset](http://www.image-net.org/search?q=cat). As training is costly, library such as [Torchvion](https://pytorch.org/docs/stable/torchvision/index.html) library supports popular pre-trainer model architectures . In this notebook, we will be using their [resnet-18](https://pytorch.org/hub/pytorch_vision_resnet/).
12-
#
11+
#
1312
# - 2. Create a new neural network called the target model. Its architecture replicates the source model and parameters, expect the latest layer which is removed. This model without its latest layer is traditionally called a backbone
14-
#
13+
#
1514
# - 3. Add new layers after the backbone where the latest output size is the number of target dataset categories. Those new layers, traditionally called head will be randomly initialized while backbone will conserve its pre-trained weights from ImageNet.
16-
#
15+
#
1716
# - 4. Train the target model on a target dataset, such as Hymenoptera Dataset with ants and bees. However, freezing some layers at training start such as the backbone tends to be more stable. In Flash, it can easily be done with `trainer.finetune(..., strategy="freeze")`. It is also common to `freeze/unfreeze` the backbone. In `Flash`, it can be done with `trainer.finetune(..., strategy="freeze_unfreeze")`. If one wants more control on the unfreeze flow, Flash supports `trainer.finetune(..., strategy=MyFinetuningStrategy())` where `MyFinetuningStrategy` is subclassing `pytorch_lightning.callbacks.BaseFinetuning`.
1817

1918
import flash
2019
from flash.core.data.utils import download_data
2120
from flash.image import ImageClassificationData, ImageClassifier
2221

23-
2422
# %% [markdown]
2523
# ## 1. Download data
2624
# The data are downloaded from a URL, and save in a 'data' directory.
2725
# %%
28-
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')
26+
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")
2927

3028

3129
# %% [markdown]
3230
# ## 2. Load the data
33-
#
31+
#
3432
# Flash Tasks have built-in DataModules that you can use to organize your data. Pass in a train, validation and test folders and Flash will take care of the rest.
3533
# Creates a ImageClassificationData object from folders of images arranged in this way:</h4>
36-
#
34+
#
3735
# train/dog/xxx.png
3836
# train/dog/xxy.png
3937
# train/dog/xxz.png
@@ -60,7 +58,7 @@
6058

6159
# %% [markdown]
6260
# ## 4. Create the trainer. Run once on data
63-
# The trainer object can be used for training or fine-tuning tasks on new sets of data.
61+
# The trainer object can be used for training or fine-tuning tasks on new sets of data.
6462
# You can pass in parameters to control the training routine- limit the number of epochs, run on GPUs or TPUs, etc.
6563
# For more details, read the [Trainer Documentation](https://pytorch-lightning.readthedocs.io/en/latest/trainer.html).
6664
# In this demo, we will limit the fine-tuning to run just one epoch using max_epochs=2.
@@ -89,7 +87,9 @@
8987
# ## 8. Predicting
9088
# ### 1. Load the model from a checkpoint
9189
# %%
92-
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.7.0/image_classification_model.pt")
90+
model = ImageClassifier.load_from_checkpoint(
91+
"https://flash-weights.s3.amazonaws.com/0.7.0/image_classification_model.pt"
92+
)
9393

9494
# %% [markdown]
9595
# ### 2. Predict what's on a few images! ants or bees?

flash_tutorials/tabular_classification/tabular_classification.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,28 @@
11
#!/usr/bin/env python
2-
# coding: utf-8
32

43
# %% [markdown]
54
# In this notebook, we'll go over the basics of lightning Flash by training a TabularClassifier on [Titanic Dataset](https://www.kaggle.com/c/titanic).
65

76
# # Training
87
# %%
98

10-
from torchmetrics.classification import Accuracy, Precision, Recall
11-
129
import flash
1310
from flash.core.data.utils import download_data
14-
from flash.tabular import TabularClassifier, TabularClassificationData
15-
11+
from flash.tabular import TabularClassificationData, TabularClassifier
12+
from torchmetrics.classification import Accuracy, Precision, Recall
1613

1714
# %% [markdown]
1815
# ### 1. Download the data
1916
# The data are downloaded from a URL, and save in a 'data' directory.
2017
# %%
21-
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/')
18+
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "data/")
2219

2320

2421
# %% [markdown]
2522
# ### 2. Load the data
2623
# Flash Tasks have built-in DataModules that you can use to organize your data. Pass in a train, validation and test folders and Flash will take care of the rest.
27-
#
28-
# Creates a TabularData relies on [Pandas DataFrame](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html).
24+
#
25+
# Creates a TabularData relies on [Pandas DataFrame](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html).
2926
# %%
3027
datamodule = TabularClassificationData.from_csv(
3128
["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"],
@@ -40,8 +37,8 @@
4037

4138
# %% [markdown]
4239
# ### 3. Build the model
43-
#
44-
# Note: Categorical columns will be mapped to the embedding space. Embedding space is set of tensors to be trained associated to each categorical column.
40+
#
41+
# Note: Categorical columns will be mapped to the embedding space. Embedding space is set of tensors to be trained associated to each categorical column.
4542
# %%
4643
model = TabularClassifier.from_data(datamodule)
4744

@@ -73,16 +70,17 @@
7370

7471
# %% [markdown]
7572
# ### 8. Load the model from a checkpoint
76-
#
77-
# `TabularClassifier.load_from_checkpoint` supports both url or local_path to a checkpoint. If provided with an url, the checkpoint will first be downloaded and laoded to re-create the model.
73+
#
74+
# `TabularClassifier.load_from_checkpoint` supports both url or local_path to a checkpoint. If provided with an url, the checkpoint will first be downloaded and laoded to re-create the model.
7875
# %%
7976
model = TabularClassifier.load_from_checkpoint(
80-
"https://flash-weights.s3.amazonaws.com/0.7.0/tabular_classification_model.pt")
77+
"https://flash-weights.s3.amazonaws.com/0.7.0/tabular_classification_model.pt"
78+
)
8179

8280

8381
# %% [markdown]
8482
# ### 9. Generate predictions from a sheet file! Who would survive?
85-
#
83+
#
8684
# `TabularClassifier.predict` support both DataFrame and path to `.csv` file.
8785
# %%
8886
datamodule = TabularClassificationData.from_csv(

flash_tutorials/text_classification/text_classification.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,39 @@
11
#!/usr/bin/env python
2-
# coding: utf-8
32

43
# %% [markdown]
54
# In this notebook, we'll go over the basics of lightning Flash by finetunig a TextClassifier on [IMDB Dataset](https://www.imdb.com/interfaces/).
6-
#
5+
#
76
# # Finetuning
8-
#
7+
#
98
# Finetuning consists of four steps:
10-
#
9+
#
1110
# - 1. Train a source neural network model on a source dataset. For text classication, it is traditionally a transformer model such as BERT [Bidirectional Encoder Representations from Transformers](https://arxiv.org/abs/1810.04805) trained on wikipedia.
1211
# As those model are costly to train, [Transformers](https://github.com/huggingface/transformers) or [FairSeq](https://github.com/pytorch/fairseq) libraries provides popular pre-trained model architectures for NLP. In this notebook, we will be using [tiny-bert](https://huggingface.co/prajjwal1/bert-tiny).
13-
#
14-
#
12+
#
13+
#
1514
# - 2. Create a new neural network the target model. Its architecture replicates all model designs and their parameters on the source model, expect the latest layer which is removed. This model without its latest layers is traditionally called a backbone
16-
#
17-
#
15+
#
16+
#
1817
# - 3. Add new layers after the backbone where the latest output size is the number of target dataset categories. Those new layers, traditionally called head, will be randomly initialized while backbone will conserve its pre-trained weights from ImageNet.
19-
#
20-
#
18+
#
19+
#
2120
# - 4. Train the target model on a target dataset, such as Hymenoptera Dataset with ants and bees. However, freezing some layers at training start such as the backbone tends to be more stable. In Flash, it can easily be done with `trainer.finetune(..., strategy="freeze")`. It is also common to `freeze/unfreeze` the backbone. In `Flash`, it can be done with `trainer.finetune(..., strategy="freeze_unfreeze")`. If a one wants more control on the unfreeze flow, Flash supports `trainer.finetune(..., strategy=MyFinetuningStrategy())` where `MyFinetuningStrategy` is subclassing `pytorch_lightning.callbacks.BaseFinetuning`.
2221
# %%
2322

2423
import flash
2524
from flash.core.data.utils import download_data
2625
from flash.text import TextClassificationData, TextClassifier
2726

28-
2927
# %% [markdown]
3028
# ### 1. Download the data
3129
# The data are downloaded from a URL, and save in a 'data' directory.
3230
# %%
33-
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/')
31+
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "data/")
3432

3533

3634
# %% [markdown]
3735
# ## 2. Load the data</h2>
38-
#
36+
#
3937
# Flash Tasks have built-in DataModules that you can use to organize your data. Pass in a train, validation and test folders and Flash will take care of the rest.
4038
# Creates a TextClassificationData object from csv file.
4139
# %%
@@ -51,9 +49,9 @@
5149

5250
# %% [markdown]
5351
# ## 3. Build the model
54-
#
52+
#
5553
# Create the TextClassifier task. By default, the TextClassifier task uses a [tiny-bert](https://huggingface.co/prajjwal1/bert-tiny) backbone to train or finetune your model demo. You could use any models from [transformers - Text Classification](https://huggingface.co/models?filter=text-classification,pytorch)
56-
#
54+
#
5755
# Backbone can easily be changed with such as `TextClassifier(backbone='bert-tiny-mnli')`
5856
# %%
5957
model = TextClassifier(num_classes=datamodule.num_classes, backbone="prajjwal1/bert-tiny")
@@ -67,8 +65,8 @@
6765

6866
# %% [markdown]
6967
# ## 5. Fine-tune the model
70-
#
71-
# The backbone won't be freezed and the entire model will be finetuned on the imdb dataset
68+
#
69+
# The backbone won't be freezed and the entire model will be finetuned on the imdb dataset
7270
# %%
7371
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
7472

0 commit comments

Comments
 (0)