In [None]:
import os, json, logging
from time import time
from datetime import datetime

import mlflow
import numpy as np
from tqdm import tqdm
import tifffile as tiff
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt

from flame import CAREInferenceSession, FLAMEImage
from flame.utils import _compress_dict_fields
from flame.io import find_dataset_config, flame_paths_from_ids
import flame.eval as eval
from flame.error import FLAMEEvalError, CAREInferenceError

In [None]:
FLAMEImage_ROOT_DIR = "/mnt/d/data/raw"
DATASET_JSON_DIREC = os.path.join(os.getcwd(), "datasets")
FLAMEImage_INDEX_PATH = os.path.join(DATASET_JSON_DIREC, "raw_image_index.csv")
DATASET_ID = "0x0003"
METRICS = [
    "mse",
    "mae",
    "ssim"
]
FRAMES_LOW = 5
FRAMES_GT = 40
TRACKING_URI = "http://127.0.0.1:5050"
MLFLOW_RUN_IDS = [
    "f6f35ad93a6a4c2b9a1a99ac7dea4094",
    "bf9a43f3ec154c9ba2deb6de2fb0db33"
]

In [None]:
logger = logging.getLogger("MAIN")
logging.basicConfig(
    filename=os.path.join(os.getcwd(), "logs", f"{datetime.now().strftime('%Y%m%d-%H%M%S')}_logger.log"),
    encoding="utf-8",
    level=logging.DEBUG
)

In [None]:
assert os.path.isdir(FLAMEImage_ROOT_DIR), f"Could not find FLAMEImage root directory at {FLAMEImage_ROOT_DIR}"
assert os.path.isdir(DATASET_JSON_DIREC), f"Could not find the dataset directory at {DATASET_JSON_DIREC}"
assert os.path.isfile(FLAMEImage_INDEX_PATH), f"Could not find FLAMEImage index at {FLAMEImage_INDEX_PATH}"
for metric in METRICS:
    try:
        getattr(eval, metric)
    except AttributeError as e:
        logger.error(f"Could not find {metric} among available evaluation metrics.")
        raise FLAMEEvalError(f"Could not find {metric} among available evaluation metrics.")

In [None]:
config_path, config = find_dataset_config(
    input_direc=DATASET_JSON_DIREC,
    this_id=DATASET_ID,
)
test_ids = config['FLAME_Dataset']['test_ids']

In [None]:
paths = flame_paths_from_ids(
    root_dir=FLAMEImage_ROOT_DIR,
    index_path=FLAMEImage_INDEX_PATH,
    id_list=test_ids
)

In [None]:
logger.info(f"Found {len(paths)} FLAME Images from {DATASET_ID} test set in {FLAMEImage_ROOT_DIR}")

### Loading FLAMEImages into memory

In [None]:
flame_image_objects = []
logger.info(f"Loading FLAMEImages into memory...")
for p in tqdm(paths, total=len(paths), ascii=True):
    im = FLAMEImage(
        impath=p,
        jsonext='tileData.txt'
    )
    flame_image_objects.append(im)

### Getting MLFlow Model

In [None]:
mlflow.set_tracking_uri(uri=TRACKING_URI)

In [None]:
inference_engines = []
for rdx, RUN_ID in enumerate(MLFLOW_RUN_IDS):
    logger.info(f"Evaluating run {rdx+1} / {len(MLFLOW_RUN_IDS)}...")
    print(f"Evaluating run {rdx+1} / {len(MLFLOW_RUN_IDS)}...")

    try:
        engine = CAREInferenceSession.from_mlflow_uri(
            tracking_uri=TRACKING_URI,
            run_id=RUN_ID,
        )
        inference_engines.append(engine)
    except Exception as e:
        logger.exception(f"Could not initialize CAREInferenceSession from MLFlow run id {RUN_ID}.\n{e.__class__.__name__}: {e}")
        raise CAREInferenceError(f"Could not initialize CAREInferenceSession from MLFlow run id {RUN_ID}.\n{e.__class__.__name__}: {e}")
    


### Inference and Metrics

In [None]:
#TODO: Input frames should probably be dynamically sourced from the engine's config (as the engine was trained with a certain number of frames)

df_dict = {
    "image": [],
    "metric": [],
    "value": [],
    "run-name": []
}

for flame_im in tqdm(
        iterable=flame_image_objects,
        total=len(flame_image_objects),
        ascii=True
    ):

    flame_im.openImage()
        
    for engine in inference_engines:
        try:
            this_pred = engine.predict_FLAME(
                image=flame_im,
                input_frames=5
            )
        except Exception as e:
            logger.error(f"Could not infer on {flame_im} with Inference Session {engine.mlflow_run_name if engine.from_mlflow else hex(id(engine))}")
            raise CAREInferenceError(f"Could not infer on {flame_im} with Inference Session {engine.mlflow_run_name if engine.from_mlflow else hex(id(engine))}")
        
        this_GT = flame_im.get_frames((0, FRAMES_GT)).astype(this_pred.dtype)

        for metric in METRICS:
            df_dict["image"].append(flame_im.impath)
            df_dict["metric"].append(metric)
            if metric == "ssim":
                channel_index = flame_im.axes_shape.index("C")
                value = getattr(eval, metric)(this_pred[0,...], this_GT[0,...], channel_axis=channel_index-1)
            else:
                value = getattr(eval, metric)(this_pred[0,...], this_GT[0,...])
            df_dict["value"].append(value)
            df_dict["run-name"].append(engine.mlflow_run_name if engine.from_mlflow else hex(id(engine)))

    flame_im.closeImage()


In [None]:
df = pd.DataFrame(data=df_dict)
df.head()

In [None]:
for metric in METRICS:
    if metric == "ssim": continue
    values = df.loc[df["metric"] == metric, "value"].to_list()
    df.loc[df["metric"] == metric, "value"] = (values - np.min(values)) / (np.max(values) - np.min(values))

In [None]:
df.head()

In [None]:
plt.style.use("ggplot")
axes = sns.boxplot(data=df, x="metric", y="value", hue="run-name")
plt.legend(bbox_to_anchor=(1.01, 1))
plt.ylabel("Relative Score (0-1 norm)")
plt.xlabel("Metric")
plt.title(f"Prediction vs. Ground Truth Performance Comparison in\nTest Set from Dataset id{DATASET_ID} ({len(flame_image_objects)} images)")