In [16]:
import os, json, logging
from time import time
from datetime import datetime

import mlflow
import numpy as np
from tqdm import tqdm
import tifffile as tiff
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt

from flame import CAREInferenceSession, FLAMEImage
from flame.utils import _compress_dict_fields
from flame.io import find_dataset_config, flame_paths_from_ids
import flame.eval as eval
from flame.error import FLAMEEvalError, CAREInferenceError

In [None]:
FLAMEImage_ROOT_DIR = "/mnt/d/data/raw"
DATASET_JSON_DIREC = os.path.join(os.getcwd(), "datasets")
FLAMEImage_INDEX_PATH = os.path.join(DATASET_JSON_DIREC, "raw_image_index.csv")
DATASET_ID = "0x0003"
METRICS = [
    "mse",
    "mae",
    "ssim"
]
FRAMES_LOW = 5
FRAMES_GT = 40
TRACKING_URI = "http://127.0.0.1:5050"
MLFLOW_RUN_IDS = [
    "f6f35ad93a6a4c2b9a1a99ac7dea4094",
    "bf9a43f3ec154c9ba2deb6de2fb0db33"
]

In [9]:
logger = logging.getLogger("MAIN")
logging.basicConfig(
    filename=os.path.join(os.getcwd(), "logs", f"{datetime.now().strftime('%Y%m%d-%H%M%S')}_logger.log"),
    encoding="utf-8",
    level=logging.DEBUG
)

In [None]:
assert os.path.isdir(FLAMEImage_ROOT_DIR), f"Could not find FLAMEImage root directory at {FLAMEImage_ROOT_DIR}"
assert os.path.isdir(DATASET_JSON_DIREC), f"Could not find the dataset directory at {DATASET_JSON_DIREC}"
assert os.path.isfile(FLAMEImage_INDEX_PATH), f"Could not find FLAMEImage index at {FLAMEImage_INDEX_PATH}"
for metric in METRICS:
    try:
        getattr(eval, metric)
    except AttributeError as e:
        logger.error(f"Could not find {metric} among available evaluation metrics.")
        raise FLAMEEvalError(f"Could not find {metric} among available evaluation metrics.")

In [4]:
config_path, config = find_dataset_config(
    input_direc=DATASET_JSON_DIREC,
    this_id=DATASET_ID,
)
test_ids = config['FLAME_Dataset']['test_ids']

In [5]:
paths = flame_paths_from_ids(
    root_dir=FLAMEImage_ROOT_DIR,
    index_path=FLAMEImage_INDEX_PATH,
    id_list=test_ids
)

In [10]:
logger.info(f"Found {len(paths)} FLAME Images from {DATASET_ID} test set in {FLAMEImage_ROOT_DIR}")

### Loading FLAMEImages into memory

In [None]:
images = []
logger.info(f"Loading FLAMEImages into memory...")
for p in tqdm(paths, total=len(paths), ascii=True):
    im = FLAMEImage(
        impath=p,
        jsonext='tileData.txt'
    )
    images.append(im)

 86%|########6 | 19/22 [00:52<00:04,  1.66s/it]

### Getting MLFlow Model

In [14]:
mlflow.set_tracking_uri(uri=TRACKING_URI)

In [None]:
inference_engines = []
for rdx, RUN_ID in enumerate(MLFLOW_RUN_IDS):
    logger.info(f"Evaluating run {rdx+1} / {len(MLFLOW_RUN_IDS)}...")
    print(f"Evaluating run {rdx+1} / {len(MLFLOW_RUN_IDS)}...")

    try:
        engine = CAREInferenceSession.from_mlflow_uri(
            tracking_uri=TRACKING_URI,
            run_id=RUN_ID,
        )
        inference_engines.append(engine)
    except Exception as e:
        logger.exception(f"Could not initialize CAREInferenceSession from MLFlow run id {RUN_ID}.\n{e.__class__.__name__}: {e}")
        raise CAREInferenceError(f"Could not initialize CAREInferenceSession from MLFlow run id {RUN_ID}.\n{e.__class__.__name__}: {e}")
    


Evaluating run 1 / 2...


Downloading artifacts:   0%|          | 0/8 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Evaluating run 2 / 2...


Downloading artifacts:   0%|          | 0/8 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

### Starting Inference

In [None]:
input_metrics = {x: [] for x in METRICS}
eval_metrics = {x: [] for x in METRICS}

for low_path, gt_path in tqdm(
        iterable=zip(low_paths, GT_paths),
        total=len(low_paths),
        ascii=True
    ):
    try:
        t1 = time()
        low=tiff.imread(low_path).transpose(0,2,3,1).astype(np.float32)
        gt=tiff.imread(gt_path).transpose(0,2,3,1).astype(np.float32)
        t2 = time()
        logger.info(f"Loaded 2 images, taking {t2 - t1:.2f}s.")
    except Exception as e:
        logger.error(f"Could not load input and/or GT images from {os.path.basename(low_path)} & {os.path.basename(gt_path)}")
        continue
    
    assert low.shape == gt.shape, f"Input and GT image shapes do not match (found {low.shape} and {gt.shape})"

    pred = engine.predict(low).astype(np.float32)

    for metric in METRICS:
        input_metrics[metric].append(getattr(eval, metric)(low[0,...], gt[0,...]))
        eval_metrics[metric].append(getattr(eval, metric)(pred[0,...], gt[0,...]))


In [None]:
df = pd.DataFrame(data=input_metrics)
df["source"] = ["input vs. gt"] * len(df)

In [None]:
eval_df = pd.DataFrame(data=eval_metrics)
eval_df["source"] = ["pred vs. gt"] * len(eval_df)

In [None]:
all_df = pd.concat([df, eval_df])

In [None]:
with mlflow.start_run(run_id=MLFLOW_RUN_ID):
    mlflow.log_params(_compress_dict_fields(config))
    for metric in METRICS:
        sns.catplot(data=all_df, x="source", y=metric)
        mlflow.log_metric(f"ds{DATASET_ID}_test_{metric}", np.mean(eval_df[metric]))
    