In [None]:
import os, json, logging
from datetime import datetime

import mlflow
import mlflow.artifacts as artifacts
import numpy as np
from tqdm import tqdm
import tifffile as tiff

from flame import CAREInferenceSession
from flame.utils import get_input_and_GT_paths
import flame.eval as eval
from flame.error import (
    FLAMEImageError, 
    CAREInferenceError, 
    CAREDatasetError, 
    FLAMEEvalError
)

In [None]:
mlflow.set_tracking_uri(uri="http://127.0.0.1:5050")

In [None]:
DATASET_NAME = "20250618_224I_denoising_5to40F"
DATASET_DIREC = os.path.join("/mnt/d/data/processed", DATASET_NAME)
TEST_DIREC = os.path.join(DATASET_DIREC, "test")
METRICS = [
    "mse",
    "mae",
    "ssim"
]
CONFIG_JSON_PATH = os.path.join(DATASET_DIREC, "patch_config.json")
MLFLOW_RUN_ID = "bf9a43f3ec154c9ba2deb6de2fb0db33"
TEMP_DIREC = os.path.join(os.getcwd(), "temp")
ONNX_RELATIVE_PATH = os.path.join("model")
JSON_RELATIVE_PATH = os.path.join("model_config", "model_config.json")

In [None]:
logger = logging.getLogger("main")
logging.basicConfig(
    filename=os.path.join(os.getcwd(), "logs", f"{datetime.now().strftime('%Y%m%d-%H%M%S')}_logger.log"),
    encoding="utf-8",
    level=logging.DEBUG
)

In [None]:
os.makedirs(TEMP_DIREC, exist_ok=True)

In [None]:
assert os.path.isdir(TEST_DIREC), f"Could not find test set directory at path {TEST_DIREC}"
assert os.path.isdir(TEMP_DIREC), f"Could not find temp directory. Look at path {TEMP_DIREC}"
assert os.path.isfile(CONFIG_JSON_PATH), f"Could not find config json at path {CONFIG_JSON_PATH}"
for metric in METRICS:
    try:
        getattr(eval, metric)
    except AttributeError as e:
        logger.error(f"Could not find {metric} among available evaluation metrics.")
        raise FLAMEEvalError(f"Could not find {metric} among available evaluation metrics.")

In [None]:
config = json.load(open(CONFIG_JSON_PATH, "r"))

### Getting MLFlow Model

In [None]:
try: 
    artifacts.download_artifacts(
        run_id=MLFLOW_RUN_ID,
        artifact_path=ONNX_RELATIVE_PATH,
        dst_path=TEMP_DIREC
    )
except Exception as e:
    logger.error(f"Could not load model.onnx from mlflow run of id {MLFLOW_RUN_ID}.\nEXCEPTION: {e}")
    raise CAREInferenceError(f"Could not load model.onnx from mlflow run of id {MLFLOW_RUN_ID}.\nEXCEPTION: {e}")

In [None]:
try: 
    artifacts.download_artifacts(
        run_id=MLFLOW_RUN_ID,
        artifact_path=JSON_RELATIVE_PATH,
        dst_path=TEMP_DIREC
    )
except Exception as e:
    logger.error(f"Could not load model_config.json from mlflow run of id {MLFLOW_RUN_ID}.\nEXCEPTION: {e}")
    raise CAREInferenceError(f"Could not load model_config.json from mlflow run of id {MLFLOW_RUN_ID}.\nEXCEPTION: {e}")

In [None]:
ONNX_PATH = os.path.join(TEMP_DIREC, "model", "model.onnx")
MODEL_CONFIG_PATH = os.path.join(TEMP_DIREC, "model_config", "model_config.json")
assert os.path.isfile(ONNX_PATH), f"Could not find model ONNX at {ONNX_PATH}"
assert os.path.isfile(MODEL_CONFIG_PATH), f"Could not find model config JSON at {MODEL_CONFIG_PATH}"

In [None]:
engine = CAREInferenceSession(
    model_path=ONNX_PATH,
    model_config_path=MODEL_CONFIG_PATH,
    dataset_config_path=MODEL_CONFIG_PATH
)

### Starting Inference

In [None]:
FRAMES_LOW = config['FLAME_Dataset']['input']['n_frames']
FRAMES_GT = config['FLAME_Dataset']['output']['n_frames']
low_paths, GT_paths = get_input_and_GT_paths(
    input_direc=TEST_DIREC,
    input_frames=FRAMES_LOW,
    gt_frames=FRAMES_GT,
    logger=logger
)

In [None]:
for low_path, gt_path in tqdm(
        iterable=zip(low_paths, GT_paths),
        total=len(low_paths),
        ascii=True
    ):
    try:
        low=tiff.imread(low_path).transpose(0,2,3,1).astype(np.float32)
        gt=tiff.imread(gt_path).transpose(0,2,3,1).astype(np.float32)
    except Exception as e:
        logger.error(f"Could not load input and/or GT images from {os.path.basename(low_path)} & {os.path.basename(gt_path)}")
        continue
    
    assert low.shape == gt.shape, f"Input and GT image shapes do not match (found {low.shape} and {gt.shape})"

    pred = engine.predict(low)

    break
