# Inference with Lightning⚡Flash

See the example: https://lightning-flash.readthedocs.io/en/stable/reference/image_embedder.html

**This is follow-up of https://www.kaggle.com/jirkaborovec/whale-dolphin-embedding-lit-flash-simclr**

In [None]:
!pip install -q vissl fairscale 'lightning-flash[image]' -U --pre --find-links /kaggle/input/whale-dolphin-embedding-lit-flash-simclr/frozen_packages/ --no-index
!pip uninstall -y wandb

In [None]:
import torch

import flash
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageEmbedder

## 1. Load the task ⚙️

In [None]:
embedder = ImageEmbedder.load_from_checkpoint(
#     "/kaggle/input/whale-dolphin-embedding-lit-flash-simclr/image_embedder_model.pt"
    "/kaggle/input/happywhale-submissions/happywhale_embedder_model.pt"
)

print(embedder)

In [None]:
GPUS = int(torch.cuda.is_available())  # Set to 1 if GPU is enabled for notebook

trainer = flash.Trainer(gpus=GPUS)

## 2. Run predictions 🎉

In [None]:
!ls -l /kaggle/input/happy-whale-and-dolphin

PATH_DATASET = "/kaggle/input/happy-whale-and-dolphin"

In [None]:
import os
import pandas as pd
from pprint import pprint

df_train = pd.read_csv(os.path.join(PATH_DATASET, "train.csv"))
display(df_train.head())
print(f"Dataset size: {len(df_train)}")
print(f"Unique ids: {len(df_train['individual_id'].unique())}")

### Train images

In [None]:
datamodule = ImageClassificationData.from_files(
    predict_files=[f"{PATH_DATASET}/train_images/{im}" for im in df_train["image"]],
    batch_size=12,
    num_workers=4,
)

embedder.input_transform = None
train_embeddings = []
for emb in trainer.predict(embedder, datamodule=datamodule):
    train_embeddings += emb

# list of embeddings for images sent to the predict function
print(len(train_embeddings))
pprint(train_embeddings[:5])

### Test images

In [None]:
import glob

imgs = glob.glob(f"{PATH_DATASET}/test_images/*.jpg")
datamodule = ImageClassificationData.from_files(
    predict_files=imgs,
    batch_size=12,
    num_workers=4,
)

embedder.input_transform = None
test_embeddings = []
for emb in trainer.predict(embedder, datamodule=datamodule):
    test_embeddings += emb

# list of embeddings for images sent to the predict function
print(len(test_embeddings))
pprint(test_embeddings[:5])

## 3. Compute distances 🛣️

In [None]:
device = "cuda" if GPUS else "cpu"

dist_embeddings = torch.cdist(
    torch.stack(train_embeddings).to(device).to(torch.float32),
    torch.stack(test_embeddings).to(device).to(torch.float32),
    p=256,
).T.cpu()
print(dist_embeddings.shape)

### Format predictions

In [None]:
from tqdm.auto import tqdm

submission = []
for im, dist in tqdm(zip(imgs, dist_embeddings), total=len(imgs)):
    #print(im)
    sorted_embs = [emb for _, emb in sorted(zip(dist.numpy(), df_train["individual_id"]))]
    for i in range(3, len(sorted_embs)):
        embs = set(sorted_embs[:i])
        if len(embs) == 4:
            break
    #print(embs)
    submission.append({"image": os.path.basename(im), "predictions": " ".join(list(embs) + ["new_individual"])})


df_submission = pd.DataFrame(submission).set_index("image")
display(df_submission.head())

In [None]:
df_submission.to_csv("submission.csv")

!head submission.csv