|
1 | 1 | #!/usr/bin/env python |
2 | | -# coding: utf-8 |
3 | 2 |
|
4 | 3 | # %% [markdown] |
5 | 4 | # In this notebook, we'll go over the basics of lightning Flash by finetunig a TextClassifier on [IMDB Dataset](https://www.imdb.com/interfaces/). |
6 | | -# |
| 5 | +# |
7 | 6 | # # Finetuning |
8 | | -# |
| 7 | +# |
9 | 8 | # Finetuning consists of four steps: |
10 | | -# |
| 9 | +# |
11 | 10 | # - 1. Train a source neural network model on a source dataset. For text classication, it is traditionally a transformer model such as BERT [Bidirectional Encoder Representations from Transformers](https://arxiv.org/abs/1810.04805) trained on wikipedia. |
12 | 11 | # As those model are costly to train, [Transformers](https://github.com/huggingface/transformers) or [FairSeq](https://github.com/pytorch/fairseq) libraries provides popular pre-trained model architectures for NLP. In this notebook, we will be using [tiny-bert](https://huggingface.co/prajjwal1/bert-tiny). |
13 | | -# |
14 | | -# |
| 12 | +# |
| 13 | +# |
15 | 14 | # - 2. Create a new neural network the target model. Its architecture replicates all model designs and their parameters on the source model, expect the latest layer which is removed. This model without its latest layers is traditionally called a backbone |
16 | | -# |
17 | | -# |
| 15 | +# |
| 16 | +# |
18 | 17 | # - 3. Add new layers after the backbone where the latest output size is the number of target dataset categories. Those new layers, traditionally called head, will be randomly initialized while backbone will conserve its pre-trained weights from ImageNet. |
19 | | -# |
20 | | -# |
| 18 | +# |
| 19 | +# |
21 | 20 | # - 4. Train the target model on a target dataset, such as Hymenoptera Dataset with ants and bees. However, freezing some layers at training start such as the backbone tends to be more stable. In Flash, it can easily be done with `trainer.finetune(..., strategy="freeze")`. It is also common to `freeze/unfreeze` the backbone. In `Flash`, it can be done with `trainer.finetune(..., strategy="freeze_unfreeze")`. If a one wants more control on the unfreeze flow, Flash supports `trainer.finetune(..., strategy=MyFinetuningStrategy())` where `MyFinetuningStrategy` is subclassing `pytorch_lightning.callbacks.BaseFinetuning`. |
22 | 21 | # %% |
23 | 22 |
|
24 | 23 | import flash |
25 | 24 | from flash.core.data.utils import download_data |
26 | 25 | from flash.text import TextClassificationData, TextClassifier |
27 | 26 |
|
28 | | - |
29 | 27 | # %% [markdown] |
30 | 28 | # ### 1. Download the data |
31 | 29 | # The data are downloaded from a URL, and save in a 'data' directory. |
32 | 30 | # %% |
33 | | -download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/') |
| 31 | +download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "data/") |
34 | 32 |
|
35 | 33 |
|
36 | 34 | # %% [markdown] |
37 | 35 | # ## 2. Load the data</h2> |
38 | | -# |
| 36 | +# |
39 | 37 | # Flash Tasks have built-in DataModules that you can use to organize your data. Pass in a train, validation and test folders and Flash will take care of the rest. |
40 | 38 | # Creates a TextClassificationData object from csv file. |
41 | 39 | # %% |
|
51 | 49 |
|
52 | 50 | # %% [markdown] |
53 | 51 | # ## 3. Build the model |
54 | | -# |
| 52 | +# |
55 | 53 | # Create the TextClassifier task. By default, the TextClassifier task uses a [tiny-bert](https://huggingface.co/prajjwal1/bert-tiny) backbone to train or finetune your model demo. You could use any models from [transformers - Text Classification](https://huggingface.co/models?filter=text-classification,pytorch) |
56 | | -# |
| 54 | +# |
57 | 55 | # Backbone can easily be changed with such as `TextClassifier(backbone='bert-tiny-mnli')` |
58 | 56 | # %% |
59 | 57 | model = TextClassifier(num_classes=datamodule.num_classes, backbone="prajjwal1/bert-tiny") |
|
67 | 65 |
|
68 | 66 | # %% [markdown] |
69 | 67 | # ## 5. Fine-tune the model |
70 | | -# |
71 | | -# The backbone won't be freezed and the entire model will be finetuned on the imdb dataset |
| 68 | +# |
| 69 | +# The backbone won't be freezed and the entire model will be finetuned on the imdb dataset |
72 | 70 | # %% |
73 | 71 | trainer.finetune(model, datamodule=datamodule, strategy="freeze") |
74 | 72 |
|
|
0 commit comments