In [None]:
import os, json, logging
from time import time
from datetime import datetime

import mlflow
import numpy as np
from tqdm import tqdm
import tifffile as tiff
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt

from flame import CAREInferenceSession
from flame.utils import get_input_and_GT_paths, _compress_dict_fields
import flame.eval as eval
from flame.error import FLAMEEvalError

In [None]:
DATASET_NAME = "20250618_224I_denoising_5to40F"
DATASET_DIREC = os.path.join("/mnt/d/data/processed", DATASET_NAME)
DATASET_ID = "0x0003"
TEST_DIREC = os.path.join(DATASET_DIREC, "test")
METRICS = [
    "mse",
    "mae",
    "ssim"
]
TRACKING_URI = "http://127.0.0.1:5050"
MLFLOW_RUN_ID = "bf9a43f3ec154c9ba2deb6de2fb0db33"

In [None]:
mlflow.set_tracking_uri(uri=TRACKING_URI)

In [None]:
logger = logging.getLogger("main")
logging.basicConfig(
    filename=os.path.join(os.getcwd(), "logs", f"{datetime.now().strftime('%Y%m%d-%H%M%S')}_logger.log"),
    encoding="utf-8",
    level=logging.DEBUG
)

In [None]:
assert os.path.isdir(TEST_DIREC), f"Could not find test set directory at path {TEST_DIREC}"
assert os.path.isdir(DATASET_DIREC), f"Could not find dataset directory at path {DATASET_DIREC}"
for metric in METRICS:
    try:
        getattr(eval, metric)
    except AttributeError as e:
        logger.error(f"Could not find {metric} among available evaluation metrics.")
        raise FLAMEEvalError(f"Could not find {metric} among available evaluation metrics.")

### Getting MLFlow Model

In [None]:
engine = CAREInferenceSession.from_mlflow_uri(
    tracking_uri=TRACKING_URI,
    run_id=MLFLOW_RUN_ID,
)

### Starting Inference

In [None]:
config = engine.model_config
FRAMES_LOW = config['FLAME_Dataset']['input']['n_frames']
FRAMES_GT = config['FLAME_Dataset']['output']['n_frames']
low_paths, GT_paths = get_input_and_GT_paths(
    input_direc=TEST_DIREC,
    input_frames=FRAMES_LOW,
    gt_frames=FRAMES_GT,
    logger=logger
)
# config['FLAME_Dataset']['id'] = DATASET_ID

In [None]:
input_metrics = {x: [] for x in METRICS}
eval_metrics = {x: [] for x in METRICS}

for low_path, gt_path in tqdm(
        iterable=zip(low_paths, GT_paths),
        total=len(low_paths),
        ascii=True
    ):
    try:
        t1 = time()
        low=tiff.imread(low_path).transpose(0,2,3,1).astype(np.float32)
        gt=tiff.imread(gt_path).transpose(0,2,3,1).astype(np.float32)
        t2 = time()
        logger.info(f"Loaded 2 images, taking {t2 - t1:.2f}s.")
    except Exception as e:
        logger.error(f"Could not load input and/or GT images from {os.path.basename(low_path)} & {os.path.basename(gt_path)}")
        continue
    
    assert low.shape == gt.shape, f"Input and GT image shapes do not match (found {low.shape} and {gt.shape})"

    pred = engine.predict(low).astype(np.float32)

    for metric in METRICS:
        input_metrics[metric].append(getattr(eval, metric)(low[0,...], gt[0,...]))
        eval_metrics[metric].append(getattr(eval, metric)(pred[0,...], gt[0,...]))


In [None]:
df = pd.DataFrame(data=input_metrics)
df["source"] = ["input vs. gt"] * len(df)

In [None]:
eval_df = pd.DataFrame(data=eval_metrics)
eval_df["source"] = ["pred vs. gt"] * len(eval_df)

In [None]:
all_df = pd.concat([df, eval_df])

In [None]:
with mlflow.start_run(run_id=MLFLOW_RUN_ID):
    mlflow.log_params(_compress_dict_fields(config))
    for metric in METRICS:
        sns.catplot(data=all_df, x="source", y=metric)
        mlflow.log_metric(f"ds{DATASET_ID}_test_{metric}", np.mean(eval_df[metric]))
    