# Lightning Fast Deep Learning with Flash
---

Flash is a framework of tasks for fast prototyping, baselining, finetuning and solving business and scientific problems with deep learning.
It is focused on:

1. Predictions
1. Finetuning
1. Task-based training


# Some of the Deep Learning Tasks supported by Flash -

#### Image
- Classification
- Segmentation
- Object Detection
- Style Transfer

#### Text
- Text Classification
- Question Answering

#### Audio
- Classification
- Speech Recognition

#### Tabular
- Classification
- Regression

#### Video
- Classification


In [None]:
import torch
import flash
from flash.image import SemanticSegmentationData
from flash.image import SemanticSegmentation
from flash.core.data.utils import download_data
import matplotlib.pyplot as plt

## Image Classification Training

It is a multistep process-

1. Load Dataset
2. Build Model
3. Create Loss function and Optimizer
4. Train Model

In [None]:
import torch

import flash
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier

In [None]:
# download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "/Users/aniket/data/")

In [None]:
datamodule = ImageClassificationData.from_folders(
    train_folder="~/data/hymenoptera_data/train/",
    val_folder="~/data/hymenoptera_data/val/",
    batch_size=8,
    transform_kwargs={"image_size": (196, 196), "mean": (0.485, 0.456, 0.406), "std": (0.229, 0.224, 0.225)},
)

In [None]:
datamodule.show_train_batch()

In [None]:
# https://github.com/kentaroy47/timm_speed_benchmark

model = ImageClassifier(backbone="efficientnet_b0", num_classes=datamodule.num_classes)

In [None]:
trainer = flash.Trainer(max_epochs=5, accelerator="auto")
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
trainer.save_checkpoint("demo_model.pt")

In [None]:
data_dir = "/Users/aniket/data"


# 4. Predict what's on a few images! ants or bees?
datamodule = ImageClassificationData.from_files(
    predict_files=[
        f"{data_dir}/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
        f"{data_dir}/hymenoptera_data/val/bees/590318879_68cf112861.jpg",
        f"{data_dir}/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg",
    ],
    batch_size=3
)
predictions = trainer.predict(model, datamodule=datamodule, output="labels")
print(predictions)

In [None]:
from PIL import Image

In [None]:
Image.open(f"{data_dir}/hymenoptera_data/val/bees/65038344_52a45d090d.jpg")

## Training Semantic Segmentation

In [None]:
download_data(
    "https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip",
    data_dir,
)

In [None]:
dm = SemanticSegmentationData.from_folders(
    train_folder=f"{data_dir}/CameraRGB",
    train_target_folder=f"{data_dir}/CameraSeg",
    val_split=0.1,
    transform_kwargs=dict(image_size=(256, 256)),
    num_classes=21,
    batch_size=4,
)

In [None]:
model = SemanticSegmentation.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/0.7.0/semantic_segmentation_model.pt"
)

In [None]:
# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule)

In [None]:
datamodule = SemanticSegmentationData.from_files(
    predict_files=[
        f"{data_dir}/CameraRGB/F61-1.png",
        f"{data_dir}/CameraRGB/F62-1.png",
        f"{data_dir}/CameraRGB/F63-1.png",
    ],
    batch_size=3,
)
predictions = trainer.predict(model, datamodule=datamodule)

In [None]:
x = predictions[0][0]["preds"]
# x = predictions[0][0]['input']

In [None]:
x = x.permute((1,2,0)).numpy()

In [None]:
plt.imshow(x[..., 2])

In [None]:
s