<a href="https://colab.research.google.com/github/PyTorchLightning/lightning-flash/blob/master/flash_notebooks/audio_classification.ipynb" target="_parent">
    <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

In this notebook, we'll go over the basics of lightning Flash by finetuning/prediction with an ImageClassifier on [Urban Sound 8k Images Dataset](https://www.kaggle.com/gokulrejith/urban-sound-8k-images) containing mel spectrograms of urban sounds from 10 classes:  *airconditioner, carhorn, childrenplaying, dogbark, drilling, engingeidling, gunshot, jackhammer, siren, and street_music*.

# Finetuning

Finetuning consists of four steps:
 
 - 1. Training a source neural network model on source dataset. In this notebook we can rely on [Torchvision](https://pytorch.org/docs/stable/torchvision/index.html) models, pretrained on the [ImageNet dataset](http://www.image-net.org) and finetune them to fit our dataset of Mel spectrograms. The specific architecture that will be used is the [resnet-18](https://pytorch.org/hub/pytorch_vision_resnet/).
 
 - 2. Create a new neural network  called the target model. Its architecture replicates the source model and parameters, expect the latest layer which is removed. This model without its latest layer is traditionally called a backbone
 
 - 3. Add new layers after the backbone where the latest output size is the number of target dataset categories. Those new layers, traditionally called head will be randomly initialized while backbone will conserve its pre-trained weights from ImageNet.
 
 - 4. Train the target model on a target dataset, such as Urban Sound 8k Images. However, freezing some layers at training start such as the backbone tends to be more stable. In Flash, it can easily be done with `trainer.finetune(..., strategy="freeze")`. It is also common to `freeze/unfreeze` the backbone. In `Flash`, it can be done with `trainer.finetune(..., strategy="freeze_unfreeze")`. If one wants more control on the unfreeze flow, Flash supports `trainer.finetune(..., strategy=MyFinetuningStrategy())` where `MyFinetuningStrategy` is subclassing `pytorch_lightning.callbacks.BaseFinetuning`. The strategy that will be used in this notebook is the `freeze/unfreeze` strategy. Since our dataset deviates so much from the ImageNet dataset, we first train the head only for a couple of epochs, then later unfreeze the whole model, even the backbone, so we can better fit our dataset. The reason for freezing the head for a couple of epochs is to ensure that we don't propagate, random information to the backbone as training starts, due to random weight initialization of the head, and we can actually leverage features already learned by the backbone.
 
 

 

---
  - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)
  - Check out [Flash documentation](https://lightning-flash.readthedocs.io/en/latest/)
  - Check out [Lightning documentation](https://pytorch-lightning.readthedocs.io/en/latest/)
  - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-pw5v393p-qRaDgEk24~EjiZNBpSQFgQ)

In [None]:
%%capture
! pip install git+https://github.com/PyTorchLightning/lightning-flash.git

### The notebook runtime has to be re-started once Flash is installed.

In [None]:
# https://github.com/streamlit/demo-self-driving/issues/17
if 'google.colab' in str(get_ipython()):
    import os
    os.kill(os.getpid(), 9)

In [None]:
import flash
from flash.core.data.utils import download_data
from flash.audio import AudioClassificationData
from flash.image import ImageClassifier
from flash.core.finetuning import FreezeUnfreeze

## 1. Download data
The data are downloaded from a URL, and save in a 'data' directory.

In [None]:
download_data("https://pl-flash-data.s3.amazonaws.com/urban8k_images.zip", "./data")

## 2. Load the data

Flash Tasks have built-in DataModules that you can use to organize your data. Pass in a train, validation and test folders and Flash will take care of the rest.
Creates a AudioClassificationData object from folders of images arranged in this way:</h4>


   train/dog/xxx.png
   train/dog/xxy.png
   train/dog/xxz.png
   train/cat/123.png
   train/cat/nsdf3.png
   train/cat/asd932.png


Note: Each sub-folder content will be considered as a new class.

In [None]:
datamodule = AudioClassificationData.from_folders(
    train_folder="data/urban8k_images/train",
    val_folder="data/urban8k_images/val",
    test_folder="data/urban8k_images/test",
)

### 3. Build the model

Create the ImageClassifier task. By default, the ImageClassifier task uses a [resnet-18](https://pytorch.org/hub/pytorch_vision_resnet/) backbone to train or finetune your model.
For [Urban Sound 8k Images Dataset](https://www.kaggle.com/gokulrejith/urban-sound-8k-images) ``datamodule.num_classes`` will be 10.
Backbone can easily be changed with `ImageClassifier(backbone="resnet50")` or you could provide your own `ImageClassifier(backbone=my_backbone)`

In [None]:
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)

###  4. Create the trainer. Run once on data

The trainer object can be used for training or fine-tuning tasks on new sets of data. 

You can pass in parameters to control the training routine- limit the number of epochs, run on GPUs or TPUs, etc.

For more details, read the  [Trainer Documentation](https://pytorch-lightning.readthedocs.io/en/latest/trainer.html).

In this demo, we will limit the fine-tuning to run just 3 epoch using max_epochs=2.

In [None]:
trainer = flash.Trainer(max_epochs=3)

### 5. Finetune the model 

`FreezeUnfreeze` strategy unfreezes the backbone after 1 epoch.

In [None]:
trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1))

### 6. Test the model

In [None]:
trainer.test()

### 7. Save it!

In [None]:
trainer.save_checkpoint("audio_classification_model.pt")

# Predicting

### 1. Load the model from a checkpoint

In [None]:
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/audio_classification_model.pt")

### 2a. Predict what's on a few images!

In [None]:
predictions = model.predict([
    "data/urban8k_images/test/air_conditioner/13230-0-0-5.wav.jpg",
    "data/urban8k_images/test/children_playing/9223-2-0-15.wav.jpg",
    "data/urban8k_images/test/jackhammer/22883-7-10-0.wav.jpg",
    "data/urban8k_images/test/street_music/7390-9-0-6.wav.jpg",
    "data/urban8k_images/test/car_horn/7389-1-0-6.wav.jpg",
    "data/urban8k_images/test/dog_bark/344-3-4-0.wav.jpg",
    "data/urban8k_images/test/drilling/22962-4-0-0.wav.jpg",
    "data/urban8k_images/test/engine_idling/6988-5-0-2.wav.jpg",
    "data/urban8k_images/test/gun_shot/7063-6-0-0.wav.jpg",
    "data/urban8k_images/test/siren/22601-8-0-9.wav.jpg",
])
print(predictions)

### 2b. Or generate prediction with a whole folder!

In [None]:
datamodule = ImageClassificationData.from_folders(predict_folder="data/urban8k_images/test")
predictions = flash.Trainer().predict(model, datamodule=datamodule)
print(predictions)

<code style="color:#792ee5;">
    <h1> <strong> Congratulations - Time to Join the Community! </strong>  </h1>
</code>

Congratulations on completing this notebook tutorial! If you enjoyed it and would like to join the Lightning movement, you can do so in the following ways!

### Help us build Flash by adding support for new data-types and new tasks.
Flash aims at becoming the first task hub, so anyone can get started to great amazing application using deep learning. 
If you are interested, please open a PR with your contributions !!! 


### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub
The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.

* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)

### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-pw5v393p-qRaDgEk24~EjiZNBpSQFgQ)!
The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel

### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/lightning-bolts)
Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.

* Please, star [Bolt](https://github.com/PyTorchLightning/lightning-bolts)

### Contributions !
The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/lightning-bolts) GitHub Issues page and filter for "good first issue". 

* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)
* [Bolt good first issue](https://github.com/PyTorchLightning/lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)
* You can also contribute your own notebooks with useful examples !

### Great thanks from the entire Pytorch Lightning Team for your interest !

<img src="https://raw.githubusercontent.com/PyTorchLightning/lightning-flash/18c591747e40a0ad862d4f82943d209b8cc25358/docs/source/_static/images/logo.svg" width="800" height="200" />