In [None]:
#
# Notebook: 2_inference_examples.ipynb
#
from PIL import ImageDraw, ImageFont

import ibbi

# --- 1. Load the Test Dataset ---
print("Loading the test dataset...")
test_dataset = ibbi.get_dataset()
print("Dataset loaded.")

# Select an image from the dataset to work with.
image_to_explain = test_dataset[1010]["image"]
print("\nAn image from the test dataset will be used for inference.")


# --- 2. Feature Extraction ---
print("\n--- Feature Extraction ---")
feature_extractor = ibbi.create_model("feature_extractor", pretrained=True)
features = feature_extractor.extract_features(image_to_explain)
print("Feature tensor shape:", features.shape)


# --- 3. Single-Class Object Detection ---
print("\n--- Single-Class Detection (with Visualization) ---")
beetle_detector = ibbi.create_model("beetle_detector", pretrained=True)
results = beetle_detector.predict(image_to_explain)

# Manually draw the bounding boxes
image_with_boxes_single = image_to_explain.copy()
draw_single = ImageDraw.Draw(image_with_boxes_single)
try:
    font = ImageFont.truetype("arial.ttf", 80)
except OSError:
    font = ImageFont.load_default()

for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
    text = f"{label}: {score:.2f}"
    bbox = font.getbbox(text)

    draw_single.rectangle(box, outline="darkgreen", width=20)
    draw_single.rectangle((box[0], box[1] - bbox[3] - 5, box[0] + bbox[2], box[1]), fill="darkgreen")
    draw_single.text((box[0], box[1] - bbox[3] - 5), text, fill="white", font=font)

image_with_boxes_single.show()
print("Detection results:", results)


# --- 4. Multi-Class Species Classification ---
print("\n--- Multi-Class Species Classification (with Visualization) ---")
species_classifier = ibbi.create_model("species_classifier", pretrained=True)
species_results = species_classifier.predict(image_to_explain)

# Manually draw the bounding boxes
image_with_boxes_multi = image_to_explain.copy()
draw_multi = ImageDraw.Draw(image_with_boxes_multi)
for score, label, box in zip(species_results["scores"], species_results["labels"], species_results["boxes"]):
    text = f"{label}: {score:.2f}"
    bbox = font.getbbox(text)

    draw_multi.rectangle(box, outline="darkblue", width=20)
    draw_multi.rectangle((box[0], box[1] - bbox[3] - 5, box[0] + bbox[2], box[1]), fill="darkblue")
    draw_multi.text((box[0], box[1] - bbox[3] - 5), text, fill="white", font=font)

image_with_boxes_multi.show()
print("Classification results:", species_results)


# --- 5. Zero-Shot Object Detection ---
print("\n--- Zero-Shot Detection (with Visualization) ---")
zero_shot_detector = ibbi.create_model("zero_shot_detector", pretrained=True)

prompt = "insect . circle"
zero_shot_results = zero_shot_detector.predict(
    image_to_explain,
    text_prompt=prompt,
)

# Manually draw the bounding boxes
image_with_boxes_zero_shot = image_to_explain.copy()
draw_zero_shot = ImageDraw.Draw(image_with_boxes_zero_shot)
for score, label, box in zip(zero_shot_results["scores"], zero_shot_results["labels"], zero_shot_results["boxes"]):
    text = f"{label}: {score:.2f}"
    bbox = font.getbbox(text)

    draw_zero_shot.rectangle(box, outline="darkred", width=20)
    draw_zero_shot.rectangle((box[0], box[1] - bbox[3] - 5, box[0] + bbox[2], box[1]), fill="darkred")
    draw_zero_shot.text((box[0], box[1] - bbox[3] - 5), text, fill="white", font=font)

image_with_boxes_zero_shot.show()
print("Zero-shot results:", zero_shot_results)