# Inference for 🌿Herbarium with Lightning⚡Flash


**This is just inference version fo the original work: https://www.kaggle.com/jirkaborovec/herbarium-eda-baseline-flash-efficientnet**

See our story: [Best Practices to Rank on Kaggle Competition with PyTorch Lightning and Grid.ai Spot Instances](https://devblog.pytorchlightning.ai/best-practices-to-rank-on-kaggle-competition-with-pytorch-lightning-and-grid-ai-spot-instances-54aa5248aa8e)


In [None]:
! ls -l /kaggle/input/

image_size = (512, 512)
normalize = True

## Browse test images 

In [None]:
import os
import json
import pandas as pd
import matplotlib.pyplot as plt

PATH_DATASET = "/kaggle/input/herbarium-2022-fgvc9"

In [None]:

with open(os.path.join(PATH_DATASET, "test_metadata.json")) as fp:
    test_data = json.load(fp)

print(len(test_data))
df_test = pd.DataFrame(test_data).set_index("image_id")
display(df_test.head())

## Inference with Lightning⚡Flash


In [None]:
!pip install -q 'lightning-flash[image]' --find-links /kaggle/input/herbarium-eda-baseline-flash-efficientnet/frozen_packages/ --no-index
!pip install -q timm -U --find-links /kaggle/input/herbarium-submissions/packages/ --no-index
!pip uninstall -y wandb

In [None]:
import torch
import flash
from flash.image import ImageClassificationData, ImageClassifier

### 1. Load the task ⚙️

In [None]:
ls /kaggle/input/herbariumflash/cont-v4.ckpt

In [None]:
model = ImageClassifier.load_from_checkpoint(
    "/kaggle/input/herbariumflash/cont-v4.ckpt"
).eval()

del model.train_metrics, model.val_metrics, model.test_metrics
model = model.cuda()

In [None]:
# Trainer Args
GPUS = int(torch.cuda.is_available())  # Set to 1 if GPU is enabled for notebook
trainer = flash.Trainer(gpus=GPUS)

### 2. Run predictions 🎉

In [None]:
from dataclasses import dataclass
from torchvision import transforms as T
from typing import Tuple, Callable
from flash.core.data.io.input_transform import InputTransform


@dataclass
class ImageClassificationInputTransform(InputTransform):

    image_size: Tuple[int, int] = image_size
    image_color_mean: Tuple[float, float] = (0.781, 0.759, 0.710)
    image_color_std: Tuple[float, float] = (0.241, 0.245, 0.249)

    def input_per_sample_transform(self):
        tfsm = [
            T.Resize(self.image_size),
            T.ToTensor(),
        ]

        if normalize:
            tfsm.append(T.Normalize(self.image_color_mean, self.image_color_std))

        return T.Compose(tfsm)


    def target_per_sample_transform(self) -> Callable:
        return torch.as_tensor

In [None]:
df_test.head()

In [None]:
print(len(df_test))

datamodule = ImageClassificationData.from_data_frame(
    input_field="file_name",
    predict_data_frame=df_test,
    predict_images_root=os.path.join(PATH_DATASET, "test_images"),
    predict_transform=ImageClassificationInputTransform,
    batch_size=64,
    transform_kwargs={"image_size": image_size},
    num_workers=os.cpu_count(),
)

In [None]:
with torch.inference_mode():
    predictions = []
    for lbs in trainer.predict(model, datamodule=datamodule, output="labels"):
        predictions += lbs

In [None]:
submission = pd.DataFrame({"id": df_test.index, "Predicted": predictions}).set_index("id")
submission.to_csv("submission.csv")

! head submission.csv