# Explore the Inference Pipeline

In [2]:
import torch
import json
import os
import warnings
from datetime import datetime
import matplotlib.pyplot as plt
import json
import numpy as np
import pandas as pd
import torch
import torchvision.transforms as transforms
from PIL import Image, ImageDraw, ImageFont

from utils.custom_models import load_models
from utils.inference_scripts import classify_species, classify_order, flatbug, classify_box

In [None]:
regional_category_map_path='./models/03_uk_data_category_map.json'

# load the json
with open(regional_category_map_path) as f:
    regional_category_map = json.load(f)

regional_category_map

In [None]:
# check if they go in order
list(regional_category_map.values()) == list(range(len(regional_category_map)))

In [None]:
# test the labeling is being indexed correctly
index_to_label = {index: label for label, index in regional_category_map.items()}

index_to_label

In [13]:
# now check these match
assert(list(index_to_label.keys()) == list(range(len(regional_category_map))))
assert(list(index_to_label.values()) == list(regional_category_map.keys()))
assert(list(index_to_label.keys()) == list(regional_category_map.values()))

In [7]:
def classify_species2(image_tensor, regional_model, regional_category_map, top_n=5):
    """
    Classify the species of the moth using the regional model.
    """

    # print('Inference for species...')
    output = regional_model(image_tensor)
    predictions = torch.nn.functional.softmax(output, dim=1).cpu().detach().numpy()[0]

    # Sort predictions to get the indices of the top 5 scores
    top_n_indices = predictions.argsort()[-top_n:][::-1]

    # Map indices to labels and fetch their confidence scores
    index_to_label = {index: label for label, index in regional_category_map.items()}
    top_n_labels = [index_to_label[idx] for idx in top_n_indices]
    top_n_scores = [predictions[idx] for idx in top_n_indices]

    return top_n_labels, top_n_scores

In [8]:
image_path1 = '/home/users/katriona/amber-inferences/data/uk/downloaded_images/20240624232150-snapshot.jpg'

# image with blur
image_path2 = '/home/users/katriona/amber-inferences/data/uk/downloaded_images/20240626230030-snapshot.jpg'

# nice obvious moth
image_path3 = '/home/users/katriona/amber-inferences/data/uk/downloaded_images/20240803003019-snapshot.jpg'

image_path = image_path2

In [9]:
bucket_name = 'gbr'
flatbug_model_path='./models/flat_bug_M.pt'
binary_model_path='./models/moth-nonmoth-effv2b3_20220506_061527_30.pth'
order_model_path='./models/dhc_best_128.pth'
order_labels_path='./models/thresholdsTestTrain.csv'
regional_model_path='./models/turing-uk_v03_resnet50_2024-05-13-10-03_state.pt'
regional_category_map_path='./models/03_uk_data_category_map.json'
localisation_model_path = '/home/users/katriona/amber-inferences/models/v1_localizmodel_2021-08-17-12-06.pt'
order_data_thresholds=image_path,
proc_device=torch.device("cuda:0")
csv_file ='./examples/gb_examples.csv'
save_crops=True
box_threshold=0.995
top_n=5

In [None]:
models = load_models(
        proc_device,
        binary_model_path,
        order_model_path,
        order_labels_path,
        regional_model_path,
        regional_category_map_path,
        flatbug_model_path,
        localisation_model_path,
    )

In [11]:
transform_species = transforms.Compose(
        [
            transforms.Resize((300, 300)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ]
    )

In [None]:
image_dt = os.path.basename(image_path)

if image_dt.startswith('20'):
    image_dt = image_dt.split("-")[0]
else:
    image_dt = image_dt.split("-")[1]
image_dt = datetime.strptime(image_dt, "%Y%m%d%H%M%S%f")
image_dt = datetime.strftime(image_dt, "%Y-%m-%d %H:%M:%S")

current_dt = datetime.now()
current_dt = datetime.strftime(current_dt, "%Y-%m-%d %H:%M:%S")

print(image_dt)

In [None]:
try:
    image = Image.open(image_path).convert("RGB")

    # print the image
    plt.imshow(image)
except Exception as e:
    print(f"Error opening image {image_path}: {e}")



In [14]:
original_image = image.copy()
original_width, original_height = image.size

Plot the flatbug crops and the predictions of the model.

# Example with Flatbug

In [15]:
flatbug_outputs = flatbug(image_path, models['flatbug_model'])

In [16]:
# for each detection
image2 = image.copy()

for i in range(len(flatbug_outputs["boxes"])):
    crop_status = "crop " + str(i)

    x_min, y_min, x_max, y_max = flatbug_outputs["boxes"][i]

    box_score = flatbug_outputs["scores"][i]
    box_label = flatbug_outputs["labels"][i]

    x_min = x_min #* original_width / 300)
    y_min = y_min #* original_height / 300)
    x_max = x_max #* original_width / 300)
    y_max = y_max #* original_height / 300)

    # Crop the detected region and perform classification
    cropped_image = original_image.crop((x_min, y_min, x_max, y_max))
    cropped_tensor = transform_species(cropped_image).unsqueeze(0).to(proc_device)

    class_name, class_confidence = classify_box(cropped_tensor, models['classification_model'])
    order_name, order_confidence = classify_order(
        cropped_tensor, models['order_model'], models['order_model_labels'], models['order_model_thresholds']
    )



    if (class_name == "moth") and ('Lepidoptera' not in order_name):
        col = 'orange'
    elif (class_name != "moth") and ('Lepidoptera' in order_name):
        col = 'purple'
    elif (class_name == "moth") and ('Lepidoptera' in order_name):
            col = 'green'
    else:
        col = 'red'


    # annotate the image with bounding boxes
    draw = ImageDraw.Draw(image2)
    draw.rectangle([x_min, y_min, x_max, y_max], outline=col, width=4)

        # Annotate image with bounding box and class
    if (class_name == "moth") or ("Lepidoptera" in order_name):
        species_names, species_confidences = classify_species(
            cropped_tensor, models['species_model'],
            models['species_model_labels'], top_n
        )
        # annotate the name on the image
        draw.text((x_min, y_min-50),
                f"{species_names[0]}: {species_confidences[0]:.2f}",
                fill=col, font=ImageFont.truetype("DejaVuSans.ttf", 50),
                verticalalignment="bottom")

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
plt.imshow(image2)
plt.axis('off')

# add legend
adj=100
shift=-20
plt.text(0, shift, 'red: non-moth, non-Lepidoptera', color='red', size=8)
plt.text(0, -1*adj+shift, 'orange: moth but not Lepidoptera', color='orange', size=8)
plt.text(0, -2*adj+shift, 'purple: Lepidoptera but not moth', color='purple', size=8)
plt.text(0, -3*adj+shift, 'green: moth and Lepidoptera', color='green', size=8)
plt.show()

# Example with Localisation Model

In [19]:
transform_loc = transforms.Compose(
        [
            transforms.Resize((300, 300)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )

In [20]:
input_tensor = transform_loc(image).unsqueeze(0).to(proc_device)

with torch.no_grad():
    localisation_outputs = models['localisation_model'](input_tensor)

In [None]:
len(localisation_outputs[0]["boxes"])

In [None]:
print(len(localisation_outputs[0]["boxes"]))

In [36]:
box_threshold = 0.8

In [None]:
# for each detection
image3 = image.copy()

for i in range(len(localisation_outputs[0]["boxes"])):
    crop_status = "crop " + str(i)

    x_min, y_min, x_max, y_max = localisation_outputs[0]["boxes"][i]
    box_score = localisation_outputs[0]["scores"].tolist()[i]
    box_label = localisation_outputs[0]["labels"].tolist()[i]

    x_min = int(int(x_min) * original_width / 300)
    y_min = int(int(y_min) * original_height / 300)
    x_max = int(int(x_max) * original_width / 300)
    y_max = int(int(y_max) * original_height / 300)

    if box_score < box_threshold:
                continue

    # Crop the detected region and perform classification
    cropped_image = original_image.crop((x_min, y_min, x_max, y_max))
    cropped_tensor = transform_species(cropped_image).unsqueeze(0).to(proc_device)

    class_name, class_confidence = classify_box(cropped_tensor, models['classification_model'])
    order_name, order_confidence = classify_order(
        cropped_tensor, models['order_model'], models['order_model_labels'], models['order_model_thresholds']
    )

    if (class_name == "moth") and ('Lepidoptera' not in order_name):
        col = 'orange'
    elif (class_name != "moth") and ('Lepidoptera' in order_name):
        col = 'purple'
    elif (class_name == "moth") and ('Lepidoptera' in order_name):
            col = 'green'
    else:
        col = 'red'


    # annotate the image with bounding boxes
    draw = ImageDraw.Draw(image3)
    draw.rectangle([x_min, y_min, x_max, y_max], outline=col, width=4)

        # Annotate image with bounding box and class
    if (class_name == "moth") or ("Lepidoptera" in order_name):
        species_names, species_confidences = classify_species(
            cropped_tensor, models['species_model'],
            models['species_model_labels'], top_n
        )
        # annotate the name on the image
        draw.text((x_min, y_min-50),
                f"{species_names[0]}: {species_confidences[0]:.2f}",
                fill=col, font=ImageFont.truetype("DejaVuSans.ttf", 50),
                verticalalignment="bottom")

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
plt.imshow(image3)
plt.axis('off')

# add legend
adj=100
shift=-20
plt.text(0, shift, 'red: non-moth, non-Lepidoptera', color='red', size=8)
plt.text(0, -1*adj+shift, 'orange: moth but not Lepidoptera', color='orange', size=8)
plt.text(0, -2*adj+shift, 'purple: Lepidoptera but not moth', color='purple', size=8)
plt.text(0, -3*adj+shift, 'green: moth and Lepidoptera', color='green', size=8)
plt.show()

# Lets look side by side

In [None]:
fig, ax = plt.subplots(2, 1, figsize=(10, 15))
ax[1].imshow(image2)
ax[1].axis('off')
ax[0].imshow(image3)
ax[0].axis('off')

# add legend
# adj=100
# shift=-20
# ax[0].text(0, shift, 'red: non-moth, non-Lepidoptera', color='red', size=8)
# ax[0].text(0, -1*adj+shift, 'orange: moth but not Lepidoptera', color='orange', size=8)
# ax[0].text(0, -2*adj+shift, 'purple: Lepidoptera but not moth', color='purple', size=8)
# ax[0].text(0, -3*adj+shift, 'green: moth and Lepidoptera', color='green', size=8)
plt.show()