In [None]:
import os

from conabio_ml_vision.datasets.datasets import ImageDataset, ImagePredictionDataset
from conabio_ml_vision.trainer.model import run_megadetector_inference
from conabio_ml.utils.dataset_utils import read_labelmap_file

BASE_PATH = '/shared_volume/ecoinf_tests/kale_aws/'

# Results
results_path = os.path.join(BASE_PATH, "results", "pipeline_1_TF1")
dataset_csv = os.path.join(results_path, "dataset.csv")
dets_md_csv = os.path.join(results_path, "detections_megadet.csv")
# Data
snmb_images_dir = os.path.join(BASE_PATH, 'data', "snmb")
snmb_crops_dir = os.path.join(BASE_PATH, 'data', "snmb_crops_megadetector")
# Files
snmb_json = os.path.join(BASE_PATH, "files", "snmb_2021_detection-bboxes.json")
mappings_csv = os.path.join(BASE_PATH, "files", "snmb_to_wcs_compet.csv")
compet_labelmap_file = os.path.join(BASE_PATH, "files", "compet_labels.txt")
detector_model_path = os.path.join(BASE_PATH, "files", "megadetector_v4.pb")

os.makedirs(results_path, exist_ok=True)

min_score_threshold = 0.3

In [None]:
# Dataset creation
if not os.path.isfile(dataset_csv):
    compet_labelmap = read_labelmap_file(compet_labelmap_file)
    dataset = ImageDataset.from_json(source_path=snmb_json,
                                     images_dir=snmb_images_dir,
                                     categories=list(compet_labelmap.values()),
                                     exclude_categories=['empty'],
                                     mapping_classes=mappings_csv,
                                     not_exist_ok=True)
    dataset.to_csv(dataset_csv)

In [None]:
# Megadetector inference
dataset = ImageDataset.from_csv(dataset_csv, images_dir=snmb_images_dir)
run_megadetector_inference(dataset=dataset,
                           out_predictions_csv=dets_md_csv,
                           images_dir=snmb_images_dir,
                           model_path=detector_model_path,
                           min_score_threshold=min_score_threshold,
                           include_id=True,
                           keep_image_id=True,
                           dataset_partition=None,
                           num_gpus_per_node=1)
