# Demo with Lightning⚡Flash: Image 🌹 Classification on TPU

**REF: https://lightning-flash.readthedocs.io/en/stable/reference/image_classification.html**

The task of identifying what is in an image is called image classification. Typically, Image Classification is used to identify images containing a single object. The task predicts which ‘class’ the image most likely belongs to with a degree of certainty. A class is a label that describes what is in an image, such as ‘car’, ‘house’, ‘cat’ etc.

## 0. Installing dependencies

including this notebook you can pull packages for your offline kernels...

See: [Easy Kaggle Offline Submission With Chaining Kernel Notebooks](https://towardsdatascience.com/easy-kaggle-offline-submission-with-chaining-kernels-30bba5ea5c4d)

In [None]:
# !pip download -q "icevision[all]" 'lightning-flash[image]' --dest frozen_packages --prefer-binary
# !pip download -q 'torchmetrics==0.7.*' --dest frozen_packages --prefer-binary
# !pip download -q effdet timm segmentation-models-pytorch --dest frozen_packages --prefer-binary
# !wget https://storage.googleapis.com/tpu-pytor -q -P frozen_packages/
# !rm frozen_packages/torch-*
# !ls -l frozen_packages | grep -e torch -e lightning -e timm

In [None]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version 1.8 --apt-packages libomp5 libopenblas-dev
!pip install -q -U 'lightning-flash[image]' --find-links frozen_packages # --no-index
!pip uninstall -y -q torchtext

In [None]:
import torch

import flash
from flash.image import ImageClassificationData, ImageClassifier

In [None]:
import os, glob
import tensorflow as tf 
from functools import partial
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

AUTO = tf.data.experimental.AUTOTUNE # instructs the API to read from multiple files if available.

def decode_image(image_data, height: int = 512, width: int = 512):
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.reshape(image, [height, width, 3])
    return image

# for tfr in tfr_dataset:
#     ex = tf.train.Example()
#     ex.ParseFromString(tfr.numpy())
#     spl = json.loads(MessageToJson(ex))['features']['feature']
#     print(spl.keys())
#     break

    
def read_tfrecord(example, with_labels: bool = False):
    tfrec_format = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "id": tf.io.FixedLenFeature([], tf.string),  # shape [] means single element
    }
    if with_labels:
        tfrec_format.update({
            "class": tf.io.FixedLenFeature([], tf.int64),  # shape [] means single element
        })
    example = tf.io.parse_single_example(example, tfrec_format)
    example['image'] = decode_image(example['image'])
    if with_labels:
        example['class'] = tf.cast(example['class'], tf.int32)
    return example


def convert_dataset(path_in: str, path_out: str, sfolder: str = "train"):
    fnames = glob.glob(os.path.join(path_in, sfolder, "*.tfrec"))
    print(list(map(os.path.basename, fnames)))
    # automatically interleaves reads from multiple files
    dataset = tf.data.TFRecordDataset(fnames, num_parallel_reads=AUTO)
    map_tfrecord = partial(read_tfrecord, with_labels=sfolder != "test")
    dataset = dataset.map(map_tfrecord, num_parallel_calls=AUTO)

    for spl in tqdm(dataset):
        img_name = spl["id"].numpy().decode("utf-8")  + ".jpg"
        folders = [path_out, sfolder]
        if "class" in spl:
            folders.append(str(spl["class"].numpy()))
        img_path = os.path.join(*folders, img_name)
        os.makedirs(os.path.dirname(img_path), exist_ok=True)
        # print(img_path)
        plt.imsave(img_path, spl["image"].numpy())

In [None]:
! rm -rf /kaggle/working/jpeg-512x512

PATH_TFRECORD = "/kaggle/input/tpu-getting-started/tfrecords-jpeg-512x512"
PATH_DATASET = "/kaggle/working/jpeg-512x512"

convert_dataset(PATH_TFRECORD, PATH_DATASET, "train")
convert_dataset(PATH_TFRECORD, PATH_DATASET, "val")
convert_dataset(PATH_TFRECORD, PATH_DATASET, "test")

## 1. Create the DataModule

In [None]:
datamodule = ImageClassificationData.from_folders(
    train_folder=os.path.join(PATH_DATASET, "train"),
    val_folder=os.path.join(PATH_DATASET, "val"),
    predict_folder=os.path.join(PATH_DATASET, "test"),
    batch_size=12,
    transform_kwargs={"image_size": (380, 380), "mean": (0.485, 0.456, 0.406), "std": (0.229, 0.224, 0.225)},
    num_workers=3,
)

In [None]:
import numpy as np
# datamodule.show_train_batch()

nb_samples = 9
fig, axarr = plt.subplots(ncols=3, nrows=3, figsize=(8, 8))

for batch in datamodule.train_dataloader():
    print(batch.keys())
    for i, (img, lb) in enumerate(list(zip(batch["input"], batch["target"]))[:nb_samples]):
        img = np.rollaxis(img.numpy(), 0, 3)
        axarr[i % 3, i // 3].imshow(img, vmin=-5., vmax=5.)
        axarr[i % 3, i // 3].set_title(lb)
    break

## 2. Build the task

In [None]:
model = ImageClassifier(
    backbone="tf_efficientnet_b4_ns",
    pretrained=True,
    labels=datamodule.labels,
)

## 3. Create the trainer and finetune the model

In [None]:
from pytorch_lightning.loggers import CSVLogger

trainer = flash.Trainer(
    max_epochs=5,
    logger=CSVLogger(save_dir='logs/'),
    precision=16,
    tpu_cores=1,
)

In [None]:
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

## Save the model!
trainer.save_checkpoint("image_classification_model.pt")

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sn

metrics = pd.read_csv(f'{trainer.logger.log_dir}/metrics.csv')
del metrics["step"]
metrics.set_index("epoch", inplace=True)
# display(metrics.dropna(axis=1, how="all").head())
g = sn.relplot(data=metrics, kind="line")
plt.grid()

## 4. Predict what's on a few images!

In [None]:
predictions = trainer.predict(model, datamodule=datamodule, output="labels")
print(predictions)