
# Image Classification on Hymenoptera Dataset

* **Author:** Ethan Harris (ethan@pytorchlightning.ai)
* **License:** CC BY-SA
* **Generated:** 2023-01-05T11:45:49.866948

In this tutorial, we'll go over the basics of lightning Flash by finetuning/predictin with an ImageClassifier on [Hymenoptera Dataset](https://www.kaggle.com/ajayrana/hymenoptera-data) containing ants and bees images.


---
Open in [Open In Colab{height="20px" width="117px"}](https://colab.research.google.com/github/PytorchLightning/lightning-tutorials/blob/publication/.notebooks/flash_tutorials/image_classification.ipynb)

Give us a ⭐ [on Github](https://www.github.com/Lightning-AI/lightning/)
| Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/stable/)
| Join us [on Slack](https://www.pytorchlightning.ai/community)

## Setup
This notebook requires some packages besides pytorch-lightning.

In [1]:
! pip install --quiet "pytorch-lightning>=1.4, <1.9" "pytorch-lightning==1.6.*" "torchmetrics<0.11" "numpy<1.24" "setuptools==65.6.3" "lightning-flash[image]>=0.7.0" "torchmetrics>=0.7, <0.12" "ipython[notebook]>=8.0.0, <8.9.0" "torch>=1.8.1, <1.14.0"

[31mERROR: Cannot install lightning-flash[image]==0.7.0, lightning-flash[image]==0.7.1, lightning-flash[image]==0.7.2, lightning-flash[image]==0.7.3, lightning-flash[image]==0.7.4, lightning-flash[image]==0.7.5, lightning-flash[image]==0.8.0, lightning-flash[image]==0.8.1, lightning-flash[image]==0.8.1.post0 and setuptools==65.6.3 because these package versions have conflicting dependencies.[0m[31m
[0m[31mERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts[0m[31m
[0m

In this tutorial, we'll go over the basics of lightning Flash by finetuning/predictin with an ImageClassifier on [Hymenoptera Dataset](https://www.kaggle.com/ajayrana/hymenoptera-data) containing ants and bees images.

# Finetuning

Finetuning consists of four steps:

 - 1. Train a source neural network model on a source dataset. For computer vision, it is traditionally  the [ImageNet dataset](http://www.image-net.org). As training is costly, library such as [Torchvision](https://pytorch.org/vision/stable/index.html) library supports popular pre-trainer model architectures . In this notebook, we will be using their [resnet-18](https://pytorch.org/hub/pytorch_vision_resnet/).

 - 2. Create a new neural network  called the target model. Its architecture replicates the source model and parameters, expect the latest layer which is removed. This model without its latest layer is traditionally called a backbone

 - 3. Add new layers after the backbone where the latest output size is the number of target dataset categories. Those new layers, traditionally called head will be randomly initialized while backbone will conserve its pre-trained weights from ImageNet.

 - 4. Train the target model on a target dataset, such as Hymenoptera Dataset with ants and bees. However, freezing some layers at training start such as the backbone tends to be more stable. In Flash, it can easily be done with `trainer.finetune(..., strategy="freeze")`. It is also common to `freeze/unfreeze` the backbone. In `Flash`, it can be done with `trainer.finetune(..., strategy="freeze_unfreeze")`. If one wants more control on the unfreeze flow, Flash supports `trainer.finetune(..., strategy=MyFinetuningStrategy())` where `MyFinetuningStrategy` is subclassing `pytorch_lightning.callbacks.BaseFinetuning`.

In [2]:

import flash
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier



## Download data
The data are downloaded from a URL, and save in a 'data' directory.

In [3]:
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")

data/hymenoptera_data.zip:   0%|          | 0/67334 [00:00<?, ?KB/s]

## Load the data

Flash Tasks have built-in DataModules that you can use to organize your data. Pass in a train, validation and test folders and Flash will take care of the rest.
Creates a ImageClassificationData object from folders of images arranged in this way:</h4>

   train/dog/xxx.png
   train/dog/xxy.png
   train/dog/xxz.png
   train/cat/123.png
   train/cat/nsdf3.png
   train/cat/asd932.png

In [4]:
datamodule = ImageClassificationData.from_folders(
    train_folder="data/hymenoptera_data/train/",
    val_folder="data/hymenoptera_data/val/",
    test_folder="data/hymenoptera_data/test/",
    batch_size=1,
)

  exec(code_obj, self.user_global_ns, self.user_ns)


## Build the model
Create the ImageClassifier task. By default, the ImageClassifier task uses a [resnet-18](https://pytorch.org/hub/pytorch_vision_resnet/) backbone to train or finetune your model.
For [Hymenoptera Dataset](https://www.kaggle.com/ajayrana/hymenoptera-data) containing ants and bees images, ``datamodule.num_classes`` will be 2.
Backbone can easily be changed with `ImageClassifier(backbone="resnet50")` or you could provide your own `ImageClassifier(backbone=my_backbone)`

In [5]:
model = ImageClassifier(num_classes=datamodule.num_classes)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

## Create the trainer. Run once on data
The trainer object can be used for training or fine-tuning tasks on new sets of data.
You can pass in parameters to control the training routine- limit the number of epochs, run on GPUs or TPUs, etc.
For more details, read the  [Trainer Documentation](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=Trainer).
In this demo, we will limit the fine-tuning to run just one epoch using max_epochs=2.

In [6]:
trainer = flash.Trainer(max_epochs=1)

GPU available: True, used: False


TPU available: False, using: 0 TPU cores


IPU available: False, using: 0 IPUs


HPU available: False, using: 0 HPUs


  rank_zero_warn(


## Finetune the model

In [7]:
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

Missing logger folder: /__w/8/s/lightning_logs



  | Name          | Type           | Params
-------------------------------------------------
0 | train_metrics | ModuleDict     | 0     
1 | val_metrics   | ModuleDict     | 0     
2 | test_metrics  | ModuleDict     | 0     
3 | adapter       | DefaultAdapter | 11.2 M
-------------------------------------------------
10.6 K    Trainable params
11.2 M    Non-trainable params
11.2 M    Total params
44.710    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

## Test the model

In [8]:
trainer.test(model, datamodule=datamodule)

  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

[{'test_accuracy': 0.6339869499206543,
  'test_cross_entropy': 0.6619874835014343}]

## Save it!

In [9]:
trainer.save_checkpoint("image_classification_model.pt")

## Predicting
**Load the model from a checkpoint**

In [10]:
model = ImageClassifier.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/0.7.0/image_classification_model.pt"
)

Downloading: "https://flash-weights.s3.amazonaws.com/0.7.0/image_classification_model.pt" to /root/.cache/torch/hub/checkpoints/image_classification_model.pt


  0%|          | 0.00/42.8M [00:00<?, ?B/s]

**Predict what's on a few images! ants or bees?**

In [11]:
datamodule = ImageClassificationData.from_files(
    predict_files=[
        "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
        "data/hymenoptera_data/val/bees/590318879_68cf112861.jpg",
        "data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg",
    ],
    batch_size=1,
)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)

  rank_zero_warn(


Predicting: 244it [00:00, ?it/s]

[[{'input': tensor([[[-0.9799, -1.0453, -1.0359,  ..., -0.5708, -0.5411, -0.3982],
         [-0.7713, -0.8059, -0.8286,  ..., -0.6740, -0.6823, -0.5400],
         [-0.6333, -0.6244, -0.6642,  ..., -0.6824, -0.5657, -0.5417],
         ...,
         [-0.5699, -0.5031, -0.4201,  ...,  0.8444,  1.0809,  1.4173],
         [-0.5997, -0.4415, -0.3124,  ...,  1.7086,  1.5849,  0.2182],
         [-0.6386, -0.4954, -0.3901,  ...,  1.4721,  0.3956,  0.6723]],

        [[-0.2051, -0.2759, -0.2459,  ..., -0.0800,  0.0534,  0.1778],
         [ 0.0062, -0.0292, -0.0083,  ..., -0.0966, -0.0084,  0.1084],
         [ 0.1496,  0.1909,  0.2002,  ..., -0.0945,  0.0722,  0.1778],
         ...,
         [ 0.0578,  0.1934,  0.3558,  ...,  0.6774,  0.9248,  1.5000],
         [ 0.0476,  0.2255,  0.3717,  ...,  1.8122,  1.7425,  0.1570],
         [ 0.0344,  0.1822,  0.3240,  ...,  1.5007,  0.2662,  0.7724]],

        [[-0.8747, -0.9413, -0.9207,  ..., -0.5371, -0.4372, -0.3146],
         [-0.6605, -0.6957, -0.70

## Congratulations - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning
movement, you can do so in the following ways!

### Star [Lightning](https://github.com/Lightning-AI/lightning) on GitHub
The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool
tools we're building.

### Join our [Slack](https://www.pytorchlightning.ai/community)!
The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself
and share your interests in `#general` channel


### Contributions !
The best way to contribute to our community is to become a code contributor! At any time you can go to
[Lightning](https://github.com/Lightning-AI/lightning) or [Bolt](https://github.com/Lightning-AI/lightning-bolts)
GitHub Issues page and filter for "good first issue".

* [Lightning good first issue](https://github.com/Lightning-AI/lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)
* [Bolt good first issue](https://github.com/Lightning-AI/lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)
* You can also contribute your own notebooks with useful examples !

### Great thanks from the entire Pytorch Lightning Team for your interest !

[Pytorch Lightning{height="60px" width="240px"}](https://pytorchlightning.ai)