In [None]:
!pip install opencv-python

In [None]:
import logging
import random
from functools import partial
from pathlib import Path
from types import SimpleNamespace
import urllib.request

import cv2
import numpy as np
import torch
from fastai.vision.all import CategoryMap, load_learner

import torch
import torchvision
import torch.utils.data
import torchvision.transforms
import datetime, time
import json
from typing import Union
import pathlib
import urllib
import cv2
from PIL import Image
from pathlib import Path
import aiohttp
import asyncio
import numpy as np
from tqdm import tqdm
import pandas as pd


ERDA_MODELS = "https://anon.erda.au.dk/share_redirect/C1nJdS1jtA/{}"
DEFAULT_MODEL = "00_eulepi.pkl"
MODEL_LOCAL_PATH="{}"

def gen_level_idx(vocab, hierarchy):
    """
    Returns a list of integers of the size of vocab indicating the hierarchical level of the taxa at index i.
    - Species is level 0, Genus 1, Family 2, etc.
    - Missing values are noted with -1.

    Args:
    - vocab (list): List of taxa names to find levels for.
    - hierarchy (dict): Nested dictionary representing taxonomic hierarchy.

    Returns:
    - np.ndarray: Array of level indices for each taxa in vocab.
    """
    level_lookup = {}

    def traverse(node, level=0):
        """Recursively traverse the hierarchy and store levels."""
        for key, subnode in node.items():
            level_lookup[key] = level  # Assign level to the taxon
            if isinstance(subnode, dict):
                traverse(subnode, level + 1)
            elif isinstance(subnode, list):  # Leaf nodes (species level)
                for species in subnode:
                    level_lookup[species] = level + 1

    # Build the level lookup dictionary
    traverse(hierarchy)  # Start from -1 so species end up at level 0

    # Assign levels to vocab, default to -1 if missing
    indices = np.array([level_lookup.get(v, -1) for v in vocab], dtype=int)

    # Invert the indices, so species is 0, genus is 1 etc
    indices = np.where(indices < 0, indices, indices.max()-indices)

    # Warning for missing values
    missing_count = np.sum(indices == -1)
    if missing_count > 0:
        print(f"[Warning] Missing values in taxa dictionary: {missing_count}.")

    return indices

def get_pred_conf(preds:torch.Tensor, vocab:CategoryMap, indices:np.ndarray):
    """Returns predicted labels and confidence for each pred and for each 
    hierarchy level.

    `preds` is a batch of predictions.
    """
    out_preds = []
    out_confs = []
    indices = torch.from_numpy(indices)
    for i in range(indices.max()+1):
        one_level_pred = preds[:,indices==i].cpu().numpy()
        one_level_prd = vocab[indices==i][one_level_pred.argmax(axis=1)]
        one_level_cnf = one_level_pred.max(axis=1)
        out_preds += [one_level_prd]
        out_confs += [one_level_cnf]
    return np.array(out_preds).swapaxes(0,1), np.array(out_confs).swapaxes(0,1)

class FastaiSpeciesClassifier:
    valid_type = ['all', 'species', 'best']
    f"""
    Prediciton class for the species classifier trained with fastai.

    Args:
    - speciesModelPath (str): Path to the species model, if None download it from a ERDA link.
    - device (str): Device to run the computations. Either cpu or cuda.
    - output_type (str): Type of the output: one of {valid_type}. 'all' outputs a list of the best model prediction per hierarchy level, in the following order: species, genus, family. 'species' only outputs the species level. 'best' only outputs the lowest ranked predictions with a confidence above `th`.
    - th (float): Confidence threshold.
    """
    def __init__(self, speciesModelPath:str=None, device='cuda', output_type: str='species', th: float=0.5):
        assert output_type in self.valid_type, f"Error: `output_type` must be one of {self.valid_type} but found {output_type}"

        self.log = logging.getLogger(__name__)

        # Download model from ERDA if not found locally
        if speciesModelPath is None:
            speciesModelPath = MODEL_LOCAL_PATH.format(DEFAULT_MODEL)
        if not Path(speciesModelPath).is_file() or\
            not Path(DEFAULT_MODEL).exists():
            self.log.info(f"Model not found. Downloading from {ERDA_MODELS.format(DEFAULT_MODEL)}")
            url = ERDA_MODELS.format(DEFAULT_MODEL)
            with urllib.request.urlopen(url) as response, open(speciesModelPath, 'wb') as out_file:
                out_file.write(response.read())
        else:
            print(f"Found model in {speciesModelPath}")

        self.log.info("Moth species model path %s", speciesModelPath)
        # Load fastai Learner instead of previous speciesClassifier
        self.speciesLearner = load_learner(speciesModelPath, cpu=(device == 'cpu'))
        self.speciesLearner.model.eval()
        self.id2name = self.speciesLearner.id2name

        indices = gen_level_idx(
            self.speciesLearner.dls.vocab,
            self.speciesLearner.hierarchy)
    
        self.get_pred_conf = partial(
            get_pred_conf, 
            vocab=self.speciesLearner.dls.vocab,
            indices=indices,)
    
        self.output_type = output_type
        self.th = th

    def extractCrop(self, image, bbox):
        if image is None:
            print("Error image cannot be none")
            raise Exception("None image in extract crop")
        x1 = bbox.x1
        x2 = bbox.x2
        y1 = bbox.y1
        y2 = bbox.y2
        image_crop = image[y1:y2, x1:x2]
        return image_crop

    def batchFromDetections(self, image, detections):
        if image is None:
            print("Image must not be None")
            raise Exception("None image not allowed in batchFromDetections")
        print(f"Batching {len(detections)} detections")
        print(f"Batching {detections} detections")
        imagesInBatch = []
        detection_ids = []
        for idx in range(len(detections)):
            print(f"Adding: {detections[idx]}")
            bbox = detections[idx].bbox
            im = self.extractCrop(image, bbox)
            imagesInBatch.append(np.array(im))
            detection_ids.append(detections[idx].id)
        return { "imagesInBatch": imagesInBatch, "detection_ids": detection_ids }

    def classifySpeciesBatch(self, batch):
        detection_ids = batch["detection_ids"]
        images = batch["imagesInBatch"]

        # Create fastai test dataloader
        test_dl = self.speciesLearner.dls.test_dl(images)

        # Inference without progress bar or logging
        # with self.speciesLearner.no_bar(), self.speciesLearner.no_logging():
        preds, _ = self.speciesLearner.get_preds(dl=test_dl)
        
        # Get predictions classes and confidence
        prds, cnfs = self.get_pred_conf(preds)

        results = []
        for idx, (prd, cnf) in enumerate(zip(prds, cnfs)):
            if self.output_type=='species':
                results.append({
                    "id": detection_ids[idx],
                    "label": self.id2name[prd[0]],
                    "labelId": prd[0],
                    "confidence_value": cnf[0]
                })
            elif self.output_type=='best':
                i=0 # Index of when the cnf is above 0.5
                while i < len(cnf) and cnf[i] < self.th: i += 1
                if i == len(cnf): # No prediction, outputs highest level
                    results.append({
                        "id": detection_ids[idx],
                        "label": self.id2name[prd[-1]],
                        "labelId": prd[-1],
                        "confidence_value": cnf[-1]
                    })
                else:
                    results.append({
                        "id": detection_ids[idx],
                        "label": self.id2name[prd[i]],
                        "labelId": prd[i],
                        "confidence_value": cnf[i]
                    })
            elif self.output_type=='all':
                results.append({
                    "id": detection_ids[idx],
                    "label": [self.id2name[p] for p in prd],
                    "labelId": prd,
                    "confidence_value": cnf
                })
            else:
                raise NotImplementedError("Choose a valid output type.")

        return results

In [None]:
img_dir = Path("/home/george/codes/lepinet/data/flemming_ucloud/images")
img_filenames = list(img_dir.glob('*/*.jpg'))
img_filenames[:10]

In [None]:
classifier=FastaiSpeciesClassifier(speciesModelPath="/home/george/codes/lepinet/data/lepi/models/00_eulepi.pkl", th=0.0)

In [None]:
preds = classifier.classifySpeciesBatch({"detection_ids":list(range(len(img_filenames))),"imagesInBatch":img_filenames})

In [None]:
instance_id=[]
filename=[]
level=[]
label=[]
prediction=[]
confidence=[]
threshold=[] 
for i, f in enumerate(img_filenames):
    instance_id += [i]
    filename += [f]
    level += [0]
    label += [f.parent.name]
    prediction += [str(preds[i]['labelId'])]
    confidence += [float(preds[i]['confidence_value'])]
    threshold += [0.0]

In [None]:
df = pd.DataFrame({
    'instance_id':instance_id,
    'filename':filename,
    'level':level,
    'label':label,
    'prediction':prediction,
    'confidence':confidence,
    'threshold':threshold})

In [None]:
df.head()

In [None]:
df.to_csv("/home/george/codes/lepinet/data/flemming_ucloud/old_fastai/fastai.csv", index=False)

In [None]:
len(df)