<a href="https://colab.research.google.com/github/PyTorchLightning/lightning-flash/blob/master/flash_notebooks/text_classification.ipynb" target="_parent">
    <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

In this notebook, we'll go over the basics of lightning Flash by finetunig a TextClassifier on [IMDB Dataset](https://www.imdb.com/interfaces/).

# Finetuning

Finetuning consists of four steps:
 
 - 1. Train a source neural network model on a source dataset. For text classication, it is traditionally  a transformer model such as BERT [Bidirectional Encoder Representations from Transformers](https://arxiv.org/abs/1810.04805) trained on wikipedia.
As those model are costly to train, [Transformers](https://github.com/huggingface/transformers) or [FairSeq](https://github.com/pytorch/fairseq) libraries provides popular pre-trained model architectures for NLP. In this notebook, we will be using [tiny-bert](https://huggingface.co/prajjwal1/bert-tiny).

 
 - 2. Create a new neural network the target model. Its architecture replicates all model designs and their parameters on the source model, expect the latest layer which is removed. This model without its latest layers is traditionally called a backbone
 

- 3. Add new layers after the backbone where the latest output size is the number of target dataset categories. Those new layers, traditionally called head, will be randomly initialized while backbone will conserve its pre-trained weights from ImageNet.
 

- 4. Train the target model on a target dataset, such as Hymenoptera Dataset with ants and bees. However, freezing some layers at training start such as the backbone tends to be more stable. In Flash, it can easily be done with `trainer.finetune(..., strategy="freeze")`. It is also common to `freeze/unfreeze` the backbone. In `Flash`, it can be done with `trainer.finetune(..., strategy="freeze_unfreeze")`. If a one wants more control on the unfreeze flow, Flash supports `trainer.finetune(..., strategy=MyFinetuningStrategy())` where `MyFinetuningStrategy` is subclassing `pytorch_lightning.callbacks.BaseFinetuning`.

---
  - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)
  - Check out [Flash documentation](https://lightning-flash.readthedocs.io/en/latest/)
  - Check out [Lightning documentation](https://pytorch-lightning.readthedocs.io/en/latest/)
  - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-pw5v393p-qRaDgEk24~EjiZNBpSQFgQ)

### Setup  
Lightning Flash is easy to install. Simply ```pip install lightning-flash```

In [None]:
%%capture
! pip install 'git+https://github.com/PyTorchLightning/lightning-flash.git#egg=lightning-flash[text]'

In [33]:
from transformers import AutoModel, AutoConfig
import torch

In [7]:
input_ids = torch.tensor([[1, 2, 3]])
attention_mask = torch.tensor([[0, 0, 0]])

In [18]:
from functools import partial, wraps

In [13]:
def f(a, b, c):
    print(a, b, c)

In [14]:
a, b, c = "a, b, c".split(",")

In [17]:
g = partial(f, a)
g(b, c)

a  b  c


In [22]:
def print_provider_info(name, providers, func):
    message = f"Using '{name}' provided by {', '.join(str(provider) for provider in providers)}."

    @wraps(func)
    def wrapper(*args, **kwargs):
        print(message)
        return func(*args, **kwargs)

    return wrapper

In [24]:
fn = partial(f, a)
fn = print_provider_info(a, ["mimmo"], fn)
fn(b, c)

Using 'a' provided by mimmo.
a  b  c


a  b  c


In [39]:
?AutoModel.from_config

[0;31mSignature:[0m [0mAutoModel[0m[0;34m.[0m[0mfrom_config[0m[0;34m([0m[0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Instantiates one of the base model classes of the library from a configuration.

Note:
    Loading a model from its configuration file does **not** load the model weights. It only affects the
    model's configuration. Use :meth:`~transformers.AutoModel.from_pretrained` to load the model
    weights.

Args:
    config (:class:`~transformers.PretrainedConfig`):
        The model class to instantiate is selected based on the configuration class:

        - :class:`~transformers.AlbertConfig` configuration class: :class:`~transformers.AlbertModel` (ALBERT model)
        - :class:`~transformers.BartConfig` configuration class: :class:`~transformers.BartModel` (BART model)
        - :class:`~transformers.BeitConfig` configuration class: :class:`~transformers.BeitModel` (BeiT model)
        - :class:`~transformers.BertConfig

In [11]:
model = AutoModel.from_pretrained("prajjwal1/bert-tiny")

Some weights of the model checkpoint at prajjwal1/bert-tiny were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [16]:
from flash.core.registry import FlashRegistry


In [17]:
register = FlashRegistry("ciao")

In [30]:
register(fn=lambda x: x ** 2, name="squares")

<function __main__.<lambda>(x)>

In [32]:
register.get("squares")(3)

9

In [38]:
model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]

tensor([[ 1.0980,  1.6983, -2.0286,  1.8005, -0.1955,  1.5695,  0.6801,  0.0172,
          0.0210,  0.2806, -2.4556, -0.5767, -0.0667,  0.1077,  1.6762,  0.1255,
         -0.7593,  1.3331, -0.6025,  0.9293,  0.2805,  0.6155,  0.5316,  0.8547,
         -0.4913,  1.1260, -0.5975,  0.8861,  2.0638, -0.3643, -0.4109,  0.1442,
          0.0918,  0.7337, -0.7533,  0.0958, -0.3635, -0.1696, -1.1720, -1.2479,
          0.4800, -0.0416, -0.9969, -0.4413,  0.1434, -0.2172, -0.4719,  0.3363,
         -1.8119, -2.9664, -0.2746,  0.3044,  0.5981,  0.4752,  0.0835, -0.2949,
         -0.8395, -0.1491,  0.0084,  0.0396, -0.0845, -0.4880,  0.5800,  0.1660,
          1.0523,  0.7101, -1.1045, -0.5234,  2.1414,  0.1768, -1.8038, -2.1305,
         -1.3730,  1.2224, -2.2038,  0.2272, -0.1295,  1.5207,  0.3886,  0.6455,
          0.1510,  0.4063, -1.6927, -0.9324, -0.4468,  0.1298,  1.6526, -2.0272,
         -0.1826, -0.4826, -0.0427, -1.5033,  0.3871, -0.4460,  0.7717,  0.9557,
          0.1360, -1.2174, -

In [15]:
model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]

tensor([[-0.6779,  0.2137, -3.0522, -3.6507,  0.2144, -1.2370, -0.5905, -0.2969,
         -0.9668,  1.1663,  1.0880,  0.4692, -0.6553, -0.3710,  0.0250, -0.2260,
         -0.3706,  0.8032, -1.4529,  0.8649,  0.4158, -0.4853,  1.0500,  1.7995,
          1.2632,  0.1047,  0.2858, -0.0896,  0.2542, -0.4467, -0.8571, -0.6128,
         -1.5168, -0.6560, -0.2723, -0.8820,  0.6165, -0.7374, -1.6829, -1.1235,
          0.8140, -0.6164,  3.7400, -1.7932,  0.7968, -0.4430, -0.7920, -1.2361,
          0.5063, -0.5180, -0.0926,  0.3824,  0.4915, -0.9325,  0.1472, -1.0297,
          0.7198,  0.5642, -0.4940,  1.5990,  0.5750, -0.1140, -0.4705, -0.9576,
         -0.9040,  1.2670,  0.5872, -0.2413,  0.8461, -0.4272,  0.3727, -1.3484,
         -1.0625, -0.5191, -1.8573,  0.3746,  1.2833,  0.9423,  3.2822,  0.6382,
          1.1604,  0.3568, -0.1511, -0.6394,  0.2059,  0.2658, -0.5631, -0.3437,
          2.2515, -1.0959,  0.8620,  0.9133, -0.7004,  0.6072,  1.4474,  1.3105,
         -0.0406, -0.0971, -

In [5]:
import flash
from flash.core.data.utils import download_data
from flash.text import TextClassificationData

IndentationError: unexpected indent (transfomers.py, line 36)

###  1. Download the data
The data are downloaded from a URL, and save in a 'data' directory.

In [None]:
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/')

<h2>2. Load the data</h2>

Flash Tasks have built-in DataModules that you can use to organize your data. Pass in a train, validation and test folders and Flash will take care of the rest.
Creates a TextClassificationData object from csv file.

In [4]:
datamodule = TextClassificationData.from_csv(
    train_file="data/imdb/train.csv",
    val_file="data/imdb/valid.csv",
    test_file="data/imdb/test.csv",
    input_fields="review",
    target_fields="sentiment"
)

NameError: name 'TextClassificationData' is not defined

###  3. Build the model

Create the TextClassifier task. By default, the TextClassifier task uses a [tiny-bert](https://huggingface.co/prajjwal1/bert-tiny) backbone to train or finetune your model demo. You could use any models from [transformers - Text Classification](https://huggingface.co/models?filter=text-classification,pytorch)

Backbone can easily be changed with such as `TextClassifier(backbone='bert-tiny-mnli')`

In [2]:
model = TextClassifier(num_classes=datamodule.num_classes, backbone="bert-tiny")

NameError: name 'TextClassifier' is not defined

###  4. Create the trainer. Run once on data

In [None]:
trainer = flash.Trainer(max_epochs=1)

###  5. Fine-tune the model

The backbone won't be freezed and the entire model will be finetuned on the imdb dataset 

In [None]:
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

###  6. Test model

In [None]:
trainer.test(model, datamodule=datamodule)

###  7. Save it!

In [None]:
trainer.save_checkpoint("text_classification_model.pt")

# Predicting

### 1. Load the model from a checkpoint

In [None]:
model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/text_classification_model.pt")

### 2a. Classify a few sentences! How was the movie?

In [None]:
predictions = model.predict([
    "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
    "The worst movie in the history of cinema.",
    "I come from Bulgaria where it 's almost impossible to have a tornado."
    "Very, very afraid"
    "This guy has done a great job with this movie!",
])
print(predictions)

### 2b. Or generate predictions from a sheet file!

In [None]:
datamodule = TextClassificationData.from_csv(
    predict_file="data/imdb/predict.csv",
    input_fields="review",
)
predictions = flash.Trainer().predict(model, datamodule=datamodule)
print(predictions)

<code style="color:#792ee5;">
    <h1> <strong> Congratulations - Time to Join the Community! </strong>  </h1>
</code>

Congratulations on completing this notebook tutorial! If you enjoyed it and would like to join the Lightning movement, you can do so in the following ways!

### Help us build Flash by adding support for new data-types and new tasks.
Flash aims at becoming the first task hub, so anyone can get started to great amazing application using deep learning. 
If you are interested, please open a PR with your contributions !!! 


### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub
The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.

* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)

### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-pw5v393p-qRaDgEk24~EjiZNBpSQFgQ)!
The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel

### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/lightning-bolts)
Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.

* Please, star [Bolt](https://github.com/PyTorchLightning/lightning-bolts)

### Contributions !
The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/lightning-bolts) GitHub Issues page and filter for "good first issue". 

* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)
* [Bolt good first issue](https://github.com/PyTorchLightning/lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)
* You can also contribute your own notebooks with useful examples !

### Great thanks from the entire Pytorch Lightning Team for your interest !

<img src="https://raw.githubusercontent.com/PyTorchLightning/lightning-flash/18c591747e40a0ad862d4f82943d209b8cc25358/docs/source/_static/images/logo.svg" width="800" height="200" />