In [1]:
import ast
from datetime import datetime

import pandas as pd
import numpy as np
from PIL import Image
from tqdm.auto import tqdm

from utils.cropping import crop_normalized_bbox, crop_normalized_bbox_square
from utils.predict import predict_batch

In [2]:
from fine_tuning.speciesnet_head.speciesnet_polish_model import get_model
checkpoint_path = 'fine_tuning/speciesnet_head/speciesnet_polish_lr4_checkpoint.pt'
crop_function = crop_normalized_bbox_square

Speciesnet head

In [4]:
MODEL_NAME = 'speciesnet_head'
CROP_SIZE = 480 # change and check?
BATCH_SIZE = 50
classifier, class_names = get_model(
    checkpoint_path=checkpoint_path)
classifier.to('cuda')
pd.set_option('display.max_columns', None)
pd.set_option('display.max_colwidth', None)

images = pd.read_csv('megadetector_results.csv', index_col=0)
images['bbox'] = images["bbox"].apply(
    lambda b: ast.literal_eval(b) if isinstance(b, str) else None)

batch = []
paths = []
results = pd.DataFrame({'image': [], 'detected_animal': [], 'confidence': []})

for _, row in tqdm(images.iterrows(), total=len(images)):
    image_path = row['image_path']

    # only animals
    category = row['category']
    if category != 1:
        results.loc[len(results)] = [image_path, 'empty', 0]
        continue

    # image
    try:
        image = Image.open(image_path).convert("RGB")
        cropped_image = crop_function(image, row['bbox'])
    except Exception as e:
        # print(f'Error in image {image_path}: {e}')
        continue

    paths.append(image_path)
    batch.append(cropped_image)

    # run classifier every N images (e.g. 32)
    if len(batch) == BATCH_SIZE:
        preds = predict_batch(classifier, batch, class_names)
        # if confidence (prediction[0][1]) is less than 0.1, classify as other
        detections = [
            prediction[0][0] if prediction[0][1] > 0.1 else 'other' for prediction in preds]
        confs = [prediction[0][1] for prediction in preds]

        batch_results = pd.DataFrame(
            {'image': paths, 'detected_animal': detections, 'confidence': confs})
        results = pd.concat([results, batch_results], ignore_index=True)
        # if confidence less than threshold: other
        batch = []
        paths = []

if len(batch) > 0:
    preds = predict_batch(classifier, batch, class_names)
    detections = [
        prediction[0][0] if prediction[0][1] > 0.1 else 'other' for prediction in preds]
    confs = [prediction[0][1] for prediction in preds]

    batch_results = pd.DataFrame(
        {'image': paths, 'detected_animal': detections, 'confidence': confs})
    results = pd.concat([results, batch_results], ignore_index=True)

now = datetime.now().strftime('%Y_%m_%d_%H_%M')
results.to_csv(f'results/{MODEL_NAME}/results_{MODEL_NAME}_{now}.csv')

  0%|          | 0/18107 [00:00<?, ?it/s]