In [1]:
import json

import pandas as pd
import numpy as np
import os

from hydra import compose, initialize
from omegaconf import OmegaConf
from sklearn.metrics import accuracy_score, balanced_accuracy_score
from tqdm import tqdm
from PIL import Image
import torch
from pathlib import Path
from PIL import ImageFile

from closedset_model import build_model
from competition_metrics import evaluate
from datasets import get_valid_transform
from paths import METADATA_DIR, VAL_DATA_DIR
from utils import copy_config, get_device

np.set_printoptions(precision=5)
ImageFile.LOAD_TRUNCATED_IMAGES = True

class PytorchWorker:
    """Run inference using PyTorch."""

    def __init__(self, model_path: str, number_of_categories: int = 1784, model_id="efficientnet_b0", device="cpu", transforms=None):

        ########################################
        # must be set before calling _load_model
        self.number_of_categories = number_of_categories
        self.model_id = model_id
        self.device = device
        ########################################

        self.transforms = transforms
        # most other attributes must be set before calling _load_model, so call last
        self.model = self._load_model(model_path)

    def _load_model(self, model_path):
        print("Setting up Pytorch Model")
        # model = models.efficientnet_b0()
        # model.classifier[1] = nn.Linear(in_features=1280, out_features=self.number_of_categories)
        model = build_model(
            model_id=self.model_id,
            pretrained=False,
            fine_tune=False,
            num_classes=self.number_of_categories,
            # this is all that matters. everything else will be overwritten by checkpoint state
            dropout_rate=0.5,
        ).to(self.device)
        model_ckpt = torch.load(model_path, map_location=self.device)
        model.load_state_dict(model_ckpt['model_state_dict'])

        return model.to(self.device).eval()

    # def predict_image(self, image: np.ndarray) -> list():
    #     """Run inference using ONNX runtime.

    #     :param image: Input image as numpy array.
    #     :return: A list with logits and confidences.
    #     """

    #     img = self.transforms(image)
        
    #     if isinstance(img, tuple):
    #         img = torch.cat([instance.unsqueeze(0) for instance in img])
    #         img = torch.unique(img, dim=0)

    #     if img.dim() < 4:
    #         img = img.unsqueeze(0)
        
    #     img = img.to(self.device)
        
    #     logits = self.model(img)

    #     return logits

    def prepare_for_tensor_concat(self, image):
        """
        Transforms a single image to a single image tensor of a tensor of multiple crops. 
        """
        
        img = self.transforms(image)
        
        if isinstance(img, tuple):
            img = torch.cat([instance.unsqueeze(0) for instance in img])
            img = torch.unique(img, dim=0)
    
        if img.dim() < 4:
            img = img.unsqueeze(0)
        
        return img
        

    def predict_image(self, image_batch: np.ndarray) -> list():
        """Run inference using ONNX runtime.
    
        :param image: Input image as numpy array.
        :return: A list with logits and confidences.
        """

        with torch.no_grad():
            logits = self.model(image_batch.to(self.device))
    
        return logits


def get_probas(test_metadata, model_id, model_path, images_root_path, batch_size, device, transforms):
    """Make submission file"""

    model = PytorchWorker(model_path, model_id=model_id, device=device, transforms=transforms)

    probas_total = []
    image_paths = test_metadata["image_path"]

    def order_preserved_unique(arr):
        # print(arr)
        _, idx = np.unique(arr, return_index=True)
        # print(idx)
        return arr[np.sort(idx)]

    with torch.no_grad():
        batch = []
        batch_labels = []
        for i, image_path in enumerate(tqdm(image_paths)):
            image_path = os.path.join(images_root_path, image_path)
            test_image = Image.open(image_path).convert("RGB")
            transformed_image = model.prepare_for_tensor_concat(test_image)
            batch.append(transformed_image)
            batch_labels.extend([i]*transformed_image.shape[0])
            if i + 1 % batch_size == 0 or i == len(image_paths) - 1:
                # only run model for a full batch or if the end of image_paths is reached
                batch_labels = np.array(batch_labels)
                image_batch = torch.cat(batch)
                logits = model.predict_image(image_batch)
                probas = torch.nn.functional.softmax(logits, dim=-1).cpu().numpy()
                for label in order_preserved_unique(batch_labels):
                    label_probas = probas[batch_labels == label]
                    # print(label_probas)
                    if label_probas.shape[0] > 1:
                        label_probas = np.mean(label_probas, axis=0)
                    label_probas = label_probas.squeeze()
                    probas_total.append(label_probas)
                    # TODO: handle case where absolute batch size is exceeded, in which case add it to the next batch
                batch = []
                batch_labels = []
    
    return probas_total


def evaluate_experiment(cfgs, multi_instance=False, device="cpu", multicrop=False, batch_size=1, debug=False):

    submission_file_path = "batched-test-time-augmengations-submission.csv"

    metadata_file_path = METADATA_DIR / "SnakeCLEF2023-ValMetadata.csv"
    test_metadata = pd.read_csv(metadata_file_path)
    if debug:
        test_metadata = test_metadata.head(20)
    if not multi_instance:
        test_metadata.drop_duplicates("observation_id", keep="first", inplace=True)
    
    probas_per_model = []
    for cfg in cfgs:
        experiment_id = cfg["experiment_id"]
        if debug: print(f"getting probas for experiment {experiment_id}")
        model_id = cfg["model_id"]
        image_size = cfg["image_size"]
        transforms = get_valid_transform(image_size=image_size, pretrained=True, fivecrop=multicrop)
        experiment_dir = Path("model_checkpoints") / experiment_id
        predictions_output_csv_path = str(experiment_dir / "submission.csv")
        model_file = "model.pth"
        model_path = str(experiment_dir / model_file)
        probas = get_probas(
            model_id=model_id,
            test_metadata=test_metadata,
            model_path=model_path,
            images_root_path=VAL_DATA_DIR,
            batch_size=batch_size,
            device=device,
            transforms=transforms,
        )
        probas_per_model.append(probas)
    probas_per_model = np.array(probas_per_model)
    if debug: print("probas_per_model.shape", probas_per_model.shape)
    if len(cfgs) > 1:
        averaged_probas = np.mean(probas_per_model, axis=0)
    else:
        averaged_probas = probas_per_model.squeeze()
    if debug: print("averaged_probas.shape", averaged_probas.shape)
    # if debug: print("np.argmax(averaged_probas)", np.argmax(averaged_probas))

    if multi_instance:
        preds = []
        # pandas unique preserves order
        for obs_id in test_metadata["observation_id"].unique():
            indices = list(test_metadata["observation_id"].loc[lambda x: x==obs_id].index)
            if len(indices) > 1:
                if debug: print("indices", indices)
                observation_probas = averaged_probas[indices, :]
                observation_average = np.mean(averaged_probas[indices], axis=0)
                if debug: print("observation_average.shape", observation_average.shape)
                preds.extend([np.argmax(observation_average)] * len(indices))
            else:
                preds.append(np.argmax(averaged_probas[indices], axis=1)[0])
    else:
        preds = np.argmax(averaged_probas, axis=1)

    if debug:
        print("preds", preds)
        if isinstance(preds, list):
            preds = np.array(preds)
        print("preds.shape", preds.shape)
    
    submission_df = test_metadata.copy()
    submission_df["class_id"] = preds
    submission_df = submission_df[["observation_id", "class_id"]]
    submission_df.drop_duplicates("observation_id", keep="first", inplace=True)
    submission_df.to_csv(submission_file_path, index=False)

    competition_metrics_scores = evaluate(
        test_annotation_file=metadata_file_path,
        user_submission_file=submission_file_path,
        phase_codename="prediction-based",
    )["submission_result"]

    return competition_metrics_scores

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
cfgs = [
    {"model_id": "caformer_s18.sail_in22k_ft_in1k_384",
     "experiment_id": "2024-05-05 22:41:41.115323",
    "image_size": 768,},]
device = get_device()
scores_three_ensemble_multi_multicrop = evaluate_experiment(cfgs=cfgs, multi_instance=False, multicrop=False, device=device, 
                                                            batch_size=1, debug=False)

Using device: cuda:0
Setting up Pytorch Model
Not loading pre-trained weights
Freezing hidden layers...


 17%|█▋        | 1334/7816 [00:43<04:03, 26.67it/s]

In [4]:
cfgs = [
    {"model_id": "caformer_s18.sail_in22k_ft_in1k_384",
     "experiment_id": "2024-05-05 22:41:41.115323",
    "image_size": 768,},]
device = get_device()
scores_three_ensemble_multi_multicrop = evaluate_experiment(cfgs=cfgs, multi_instance=False, multicrop=False, device=device, 
                                                            batch_size=2, debug=False)

Using device: cuda:0
Setting up Pytorch Model
Not loading pre-trained weights
Freezing hidden layers...


 27%|██▋       | 2115/7816 [01:08<03:04, 30.82it/s]


KeyboardInterrupt: 

In [None]:
cfgs = [
    {"model_id": "caformer_s18.sail_in22k_ft_in1k_384",
     "experiment_id": "2024-05-05 22:41:41.115323",
    "image_size": 768,},]
device = get_device()
scores_three_ensemble_multi_multicrop = evaluate_experiment(cfgs=cfgs, multi_instance=False, multicrop=False, device=device, 
                                                            batch_size=40, debug=False)

Using device: cuda:0
Setting up Pytorch Model
Not loading pre-trained weights
Freezing hidden layers...


100%|█████████▉| 7814/7816 [04:38<00:00, 33.98it/s]

In [6]:
cfgs = [
    {"model_id": "caformer_s18.sail_in22k_ft_in1k_384",
     "experiment_id": "2024-05-05 22:41:41.115323",
    "image_size": 768,},]
device = get_device()
scores_three_ensemble_multi_multicrop = evaluate_experiment(cfgs=cfgs, multi_instance=False, multicrop=False, device="cpu", 
                                                            batch_size=1, debug=False)

Using device: cuda:0
Setting up Pytorch Model
Not loading pre-trained weights
Freezing hidden layers...


 27%|██▋       | 2096/7816 [01:10<03:11, 29.90it/s]


KeyboardInterrupt: 