In [None]:
import timm
import torch
import torchvision
import torch.utils.data
import torchvision.transforms
import datetime, time
import json
from typing import Union
import pathlib
import urllib
import cv2
from PIL import Image
from pathlib import Path
import aiohttp
import asyncio
import numpy as np
from tqdm import tqdm
import pandas as pd

FilePath = Union[pathlib.Path, str]

def slugify(s):
    # Quick method to make an acceptable attribute name or url part from a title
    # install python-slugify for handling unicode chars, numbers at the beginning, etc.
    separator = "_"
    acceptable_chars = list(string.ascii_letters) + list(string.digits) + [separator]
    return (
        "".join(
            [
                chr
                for chr in s.replace(" ", separator).lower()
                if chr in acceptable_chars
            ]
        )
        .strip(separator)
        .replace(separator * 2, separator)
    )

def get_or_download_file(
    path, destination_dir=None, prefix=None, suffix=None
) -> pathlib.Path:
    """
    >>> filename, headers = get_weights("https://drive.google.com/file/d/1KdQc56WtnMWX9PUapy6cS0CdjC8VSdVe/view?usp=sharing")

    """
    if not path:
        raise Exception("Specify a URL or path to fetch file from.")

    # If path is a local path instead of a URL then urlretrieve will just return that path
    destination_dir = destination_dir or os.environ.get("LOCAL_WEIGHTS_PATH")
    fname = path.rsplit("/", 1)[-1]
    if destination_dir:
        destination_dir = pathlib.Path(destination_dir)
        if prefix:
            destination_dir = destination_dir / prefix
        if not destination_dir.exists():
            logger.info(f"Creating local directory {str(destination_dir)}")
            destination_dir.mkdir(parents=True, exist_ok=True)
        local_filepath = pathlib.Path(destination_dir) / fname
        if suffix:
            local_filepath = local_filepath.with_suffix(suffix)
    else:
        raise Exception(
            "No destination directory specified by LOCAL_WEIGHTS_PATH or app settings."
        )

    if local_filepath and local_filepath.exists():
        logger.info(f"Using existing {local_filepath}")
        return local_filepath

    else:
        logger.info(f"Downloading {path} to {local_filepath}")
        resulting_filepath, headers = urllib.request.urlretrieve(
            url=path, filename=local_filepath
        )
        resulting_filepath = pathlib.Path(resulting_filepath)
        logger.info(f"Downloaded to {resulting_filepath}")
        return resulting_filepath

def get_device(device_str=None) -> torch.device:
    """
    Select CUDA if available.

    @TODO add macOS Metal?
    @TODO check Kivy settings to see if user forced use of CPU
    """
    if not device_str:
        device_str = "cuda" if torch.cuda.is_available() else "cpu"
    device = torch.device(device_str)
    logger.info(f"Using device '{device}' for inference")
    return device

def synchronize_clocks():
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    else:
        pass

class StopWatch:
    """
    Measure inference time with GPU support.

    >>> with stopwatch() as t:
    >>>     sleep(5)
    >>> int(t.duration)
    >>> 5
    """

    def __enter__(self):
        synchronize_clocks()
        # self.start = time.perf_counter()
        self.start = time.time()
        return self

    def __exit__(self, type, value, traceback):
        synchronize_clocks()
        # self.end = time.perf_counter()
        self.end = time.time()
        self.duration = self.end - self.start

    def __repr__(self):
        start = datetime.datetime.fromtimestamp(self.start).strftime("%H:%M:%S")
        end = datetime.datetime.fromtimestamp(self.end).strftime("%H:%M:%S")
        seconds = int(round(self.duration, 1))
        return f"Started: {start}, Ended: {end}, Duration: {seconds} seconds"


#KBE??? 
class Logger():
    
    def info(self, text):
        print("Info:", text)
    
    def debug(self, text):
        return
        # print("Debug:", text)
        
logger = Logger()
#KBE??? 

class BatchEmptyException(Exception):
    pass


def zero_okay_collate(batch):
    """
    If the queue is cleared or shortened before the original batch count is complete
    then the dataloader will crash. This catches the empty batch more gracefully.

    @TODO switch to streaming IterableDataset type.
    """
    if any(not item for item in batch):
        logger.debug(f"There's a None in the batch of len {len(batch)}")
        return None
    else:
        return torch.utils.data.default_collate(batch)


imagenet_normalization = torchvision.transforms.Normalize(
    # "torch preprocessing"
    mean=[0.485, 0.456, 0.406],  # RGB
    std=[0.229, 0.224, 0.225],  # RGB
)

tensorflow_normalization = torchvision.transforms.Normalize(
    # -1 to 1
    mean=[0.5, 0.5, 0.5],  # RGB
    std=[0.5, 0.5, 0.5],  # RGB
)

generic_normalization = torchvision.transforms.Normalize(
    # 0 to 1
    mean=[0.5, 0.5, 0.5],  # RGB
    std=[0.5, 0.5, 0.5],  # RGB
)


class InferenceBaseClass:
    """
    Base class for all batch-inference models.

    This outlines a common interface for all classifiers and object detectors.
    Generic methods like `get_weights_from_url` are defined here, but
    methods that return NotImplementedError must be overridden in a subclass
    that is specific to each inference model.

    See examples in `classification.py` and `localization.py`
    """

    #KBE??? db_path: Union[str, sqlalchemy.engine.URL]
    image_base_path: FilePath
    name = "Unknown Inference Model"
    description = str()
    model_type = None
    device = None
    weights_path = None
    weights = None
    labels_path = None
    category_map = {}
    num_classes: Union[int, None] = None  # Will use len(category_map) if None
    lookup_gbif_names: bool = False
    model: torch.nn.Module
    normalization = tensorflow_normalization
    transforms: torchvision.transforms.Compose
    batch_size = 4
    num_workers = 1
    user_data_path = None
    type = "unknown"
    stage = 0
    single = True
    #KBE??? queue: QueueManager
    dataset: torch.utils.data.Dataset
    dataloader: torch.utils.data.DataLoader

    def __init__(
        self,
        #KBE??? db_path: Union[str, sqlalchemy.engine.URL],
        user_data_path: FilePath,
        image_base_path: FilePath,
        **kwargs,
    ):
        #KBE??? self.db_path = db_path
        self.user_data_path = user_data_path
        self.image_base_path = image_base_path

        for k, v in kwargs.items():
            setattr(self, k, v)

        logger.info(f"Initializing inference class {self.name}")

        self.device = self.device or get_device()
        self.category_map = self.get_labels(self.labels_path)
        self.num_classes = self.num_classes or len(self.category_map)
        self.weights = self.get_weights(self.weights_path)
        self.transforms = self.get_transforms()
        #KBE??? self.queue = self.get_queue()
        #KBE??? self.dataset = self.get_dataset()
        self.dataset = None
        self.dataloader = self.get_dataloader()
        logger.info(
            f"Loading {self.type} model (stage: {self.stage}) for {self.name} with {len(self.category_map or [])} categories"
        )
        self.model = self.get_model()

    @classmethod
    def get_key(cls):
        if hasattr(cls, "key") and cls.key:  # type: ignore
            return cls.key  # type: ignore
        else:
            return slugify(cls.name)
    
    def get_weights(self, weights_path):
        if weights_path:
            return get_or_download_file(
                weights_path, self.user_data_path, prefix="models"
            )
        else:
            logger.warn(f"No weights specified for model {self.name}")

    def get_labels(self, labels_path):
        if labels_path:
            local_path = get_or_download_file(
                labels_path, self.user_data_path, prefix="models"
            )

            with open(local_path) as f:
                labels = json.load(f)

            if self.lookup_gbif_names:
                """
                Use this if you want to store name strings instead of taxon IDs.
                Taxon IDs are helpful for looking up additional information about the species
                such as the genus and family.
                """
                #KBE??? from trapdata.ml.utils import replace_gbif_id_with_name
                from ml.utils import replace_gbif_id_with_name

                string_labels = {}
                for label, index in labels.items():
                    string_label = replace_gbif_id_with_name(label)
                    string_labels[string_label] = index

                logger.info(f"Replacing GBIF IDs with names in {local_path}")
                # Backup the original file
                local_path.rename(local_path.with_suffix(".bak"))
                with open(local_path, "w") as f:
                    json.dump(string_labels, f)

            # @TODO would this be faster as a list? especially when getting the labels of multiple
            # indexes in one prediction
            index_to_label = {index: label for label, index in labels.items()}

            return index_to_label
        else:
            return {}

    def get_model(self) -> torch.nn.Module:
        """
        This method must be implemented by a subclass.

        Example:

        model = torch.nn.Module()
        checkpoint = torch.load(self.weights, map_location=self.device)
        model.load_state_dict(checkpoint["model_state_dict"])
        model = model.to(self.device)
        model.eval()
        return model
        """
        raise NotImplementedError

    def get_transforms(self) -> torchvision.transforms.Compose:
        """
        This method must be implemented by a subclass.

        Example:

        transforms = torchvision.transforms.Compose(
            [
                torchvision.transforms.ToTensor(),
            ]
        )
        return transforms
        """
        raise NotImplementedError

    #KBE??? def get_queue(self) -> QueueManager:
        """
        This method must be implemented by a subclass.
        Example:

        from trapdata.db.models.queue import DetectedObjectQueue
        def get_queue(self):
            return DetectedObjectQueue(self.db_path, self.image_base_path)
        """
        #KBE??? raise NotImplementedError

    def get_dataset(self) -> torch.utils.data.Dataset:
        """
        This method must be implemented by a subclass.

        Example:

        dataset = torch.utils.data.Dataset()
        return dataset
        """
        raise NotImplementedError

    def get_dataloader(self):
        """
        Prepare dataloader for streaming/iterable datasets from database
        """
        if self.single:
            logger.info(
                f"Preparing dataloader with batch size of {self.batch_size} in single worker mode."
            )
        else:
            logger.info(
                f"Preparing dataloader with batch size of {self.batch_size} and {self.num_workers} workers."
            )
        self.dataloader = torch.utils.data.DataLoader(
            self.dataset,
            num_workers=0 if self.single else self.num_workers,
            persistent_workers=False if self.single else True,
            shuffle=False,
            pin_memory=False if self.single else True,  # @TODO review this
            batch_size=None,  # Recommended setting for streaming datasets
            batch_sampler=None,  # Recommended setting for streaming datasets
        )
        return self.dataloader

    def predict_batch(self, batch):
        batch_input = batch.to(
            self.device,
            non_blocking=True,  # Block while in development, are we already in a background process?
        )
        batch_output = self.model(batch_input)
        return batch_output

    def post_process_single(self, item):
        return item

    def post_process_batch(self, batch_output):
        return [self.post_process_single(item) for item in batch_output]
        # Had problems with this generator and multiprocessing
        # for item in batch_output:
        #     yield self.post_process_single(item)

    def save_results(self, item_ids, batch_output):
        logger.warn("No save method configured for model. Doing nothing with results")
        return None

    @torch.no_grad()
    def run(self):
        torch.cuda.empty_cache()

        for i, batch in enumerate(self.dataloader):
            if not batch:
                # @TODO review this once we switch to streaming IterableDataset
                logger.info(f"Batch {i+1} is empty, skipping")
                continue

            item_ids, batch_input = batch

            logger.info(
                f"Processing batch {i+1}, about {len(self.dataloader)} remaining"
            )

            # @TODO the StopWatch doesn't seem to work when there are multiple workers,
            # it always returns 0 seconds.
            with StopWatch() as batch_time:
                #KBE??? with start_transaction(op="inference_batch", name=self.name):
                batch_output = self.predict_batch(batch_input)

            seconds_per_item = batch_time.duration / len(batch_output)
            logger.info(
                f"Inference time for batch: {batch_time}, "
                f"Seconds per item: {round(seconds_per_item, 2)}"
            )

            batch_output = list(self.post_process_batch(batch_output))
            item_ids = item_ids.tolist()
            logger.info(f"Saving {len(item_ids)} results")
            self.save_results(item_ids, batch_output)
            logger.info(f"{self.name} Batch -- Done")

        logger.info(f"{self.name} -- Done")

class Resnet50(torch.nn.Module):
    def __init__(self, num_classes):
        """
        Args:
            config: provides parameters for model generation
        """
        super(Resnet50, self).__init__()
        self.num_classes = num_classes
        self.backbone = torchvision.models.resnet50(weights="DEFAULT")
        out_dim = self.backbone.fc.in_features

        self.backbone = torch.nn.Sequential(*list(self.backbone.children())[:-2])
        self.avgpool = torch.nn.AdaptiveAvgPool2d(output_size=(1, 1))
        self.classifier = torch.nn.Linear(out_dim, self.num_classes, bias=False)

    def forward(self, x):
        x = self.backbone(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)

        return x
    
class Resnet50Classifier(InferenceBaseClass):
    input_size = 300

    def get_model(self):
        num_classes = len(self.category_map)
        model = Resnet50(num_classes=num_classes)
        model = model.to(self.device)
        # state_dict = torch.hub.load_state_dict_from_url(weights_url)
        checkpoint = torch.load(self.weights, map_location=self.device)
        # The model state dict is nested in some checkpoints, and not in others
        state_dict = checkpoint.get("model_state_dict") or checkpoint
        model.load_state_dict(state_dict)
        model.eval()
        return model

    def get_transforms(self):
        mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
        return torchvision.transforms.Compose(
            [
                torchvision.transforms.Resize((self.input_size, self.input_size)),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(mean, std),
            ]
        )

    def post_process_batch(self, output):
        predictions = torch.nn.functional.softmax(output, dim=1)
        predictions = predictions.cpu().numpy()

        categories = predictions.argmax(axis=1)
        labels = [self.category_map[cat] for cat in categories]
        scores = predictions.max(axis=1).astype(float)

        result = list(zip(labels, scores, categories))
        logger.debug(f"Post-processing result batch: {result}")
        return result


class Resnet50ClassifierLowRes(Resnet50Classifier):
    input_size = 128

    def get_model(self):
        #KBE??? model = torchvision.models.resnet50(weights=None) 
        model = torchvision.models.resnet50(pretrained=False) # Older version of torchvision
        num_ftrs = model.fc.in_features
        assert (
            self.num_classes
        ), f"Number of classes could not be determined for for {self.name}"
        model.fc = torch.nn.Linear(num_ftrs, self.num_classes)
        model = model.to(self.device)
        assert self.weights, f"No weights path configured for {self.name}"
        checkpoint = torch.load(self.weights, map_location=self.device)
        state_dict = checkpoint.get("model_state_dict") or checkpoint
        model.load_state_dict(state_dict)
        model.eval()
        return model

class SpeciesClassifier(InferenceBaseClass):
    stage = 4
    type = "fine_grained_classifier"

    #KBE??? def get_queue(self) -> UnclassifiedObjectQueue:
    #KBE???     return UnclassifiedObjectQueue(self.db_path, self.image_base_path)

    def get_dataset(self):
        dataset = ClassificationIterableDatabaseDataset(
            queue=self.queue,
            image_transforms=self.get_transforms(),
            batch_size=self.batch_size,
        )
        return dataset

    def save_results(self, object_ids, batch_output):
        # Here we are saving the specific taxon labels
        classified_objects_data = [
            {
                "specific_label": label,
                "specific_label_score": score,
                "model_name": self.name,
                "in_queue": True,  # Put back in queue for the feature extractor & tracking
            }
            for label, score in batch_output
        ]
        #KBE??? save_classified_objects(self.db_path, object_ids, classified_objects_data)


class UKDenmarkMothSpeciesClassifierMixedResolution(
    SpeciesClassifier, Resnet50ClassifierLowRes
):
    """
    Training log and weights can be found here:
    https://wandb.ai/moth-ai/uk-denmark/artifacts/model/model/v0/overview

    Species checklist used for training:
    https://github.com/adityajain07/mothAI/blob/main/species_lists/UK-Denmark-Moth-List_11July2022.csv
    """

    name = "UK & Denmark Species Classifier"
    description = "Trained on April 3, 2023 using mix of low & med resolution images."
    weights_path = (
        "https://object-arbutus.cloud.computecanada.ca/ami-models/moths/classification/"
        "uk-denmark-moths-mixedres-20230403_140131_30.pth"
    )
    labels_path = (
        "https://object-arbutus.cloud.computecanada.ca/ami-models/moths/classification/"
        "01-moths-ukdenmark_v2_category_map_species_names.json"
    )

def classifySpeciesBatch(classifier, batch):
    predictions = classifier.predict_batch(batch)
    predictions = predictions.detach()
    predLabelsScores = classifier.post_process_batch(predictions)
    return predLabelsScores

    lines = []
    for pred in predLabelsScores:
        predicted_label_text = pred[0]
        confidence_value = round(pred[1]*10000)/100
        predicted_label = pred[2]
        line = f"{predicted_label_text},{predicted_label},{confidence_value}"
        lines.append(line)
    
    return lines, predictions

In [None]:
classifier=UKDenmarkMothSpeciesClassifierMixedResolution("/home/george/tmp/models","")

In [None]:
# Get one batch of image
img_dir = Path("/home/george/codes/lepinet/data/flemming_ucloud/images")

In [None]:
img_filenames = list(img_dir.glob('*/*.jpg'))
img_filenames[:10]

In [None]:
# make a prediction on all images
batch_size = 64
img_size = 128

# preds = []

for i in tqdm(range(0,len(img_filenames),batch_size)):
    bs = min(batch_size, len(img_filenames)-i)
    batch = torch.FloatTensor(bs, 3, img_size, img_size) 
    
    for j in range(bs):
        image = cv2.imread(img_filenames[i+j])
        image = cv2.resize(image, (img_size, img_size))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(image)
        image = torchvision.transforms.ToTensor()(image) #.unsqueeze_(0)
        batch[j] = image #/255.0
    preds += classifySpeciesBatch(classifier, batch)

In [None]:
len(preds), len(img_filenames), len(img_filenames)-len(preds)

In [None]:
preds[:10], img_filenames[:10]

In [None]:
async def get_key(session, scientificName=None, usageKey=None, rank='SPECIES', order='Lepidoptera'):
    url = "https://api.gbif.org/v1/species/match?"
    assert usageKey is not None or scientificName is not None, "One of scientificName or usageKey must be defined."

    if usageKey is not None:
        url += f"usageKey={usageKey}&"
    if scientificName is not None:
        if scientificName=='Tethea or': return 5142971 # bug fix
        url += f"scientificName={scientificName}&"
    if rank is not None:
        url += f"rank={rank}&"
    if order is not None:
        url += f"order={order}"

    async with session.get(url) as response:
        r = await response.json()
        # return r if not 'canonicalName' in r.keys() else r['canonicalName']
        return r if not 'speciesKey' in r.keys() else r['speciesKey']

async def get_all_keys(vocab):
    async with aiohttp.ClientSession() as session:
        tasks = [get_key(session, scientificName=k, rank=None) for k in vocab]
        return await asyncio.gather(*tasks)

In [None]:
unique_names = np.unique([e[0] for e in preds])

In [None]:
import nest_asyncio
nest_asyncio.apply()

In [None]:
gbif_keys=asyncio.run(get_all_keys(unique_names))

In [None]:
name2key = {str(k):v for k,v in zip(unique_names, gbif_keys)}

In [None]:
instance_id=[]
filename=[]
level=[]
label=[]
prediction=[]
confidence=[]
threshold=[] 

In [None]:
for i, f in enumerate(img_filenames):
    instance_id += [i]
    filename += [f]
    level += [0]
    label += [f.parent.name]
    prediction += [name2key[preds[i][0]]]
    confidence += [float(preds[i][1])]
    threshold += [0.0]

In [None]:
df = pd.DataFrame({
    'instance_id':instance_id,
    'filename':filename,
    'level':level,
    'label':label,
    'prediction':prediction,
    'confidence':confidence,
    'threshold':threshold})

In [None]:
df.head(6)

In [None]:
df.to_csv("/home/george/codes/lepinet/data/flemming_ucloud/ukdenmark_model/ukdk.csv", index=False)