In [10]:
import os, json, logging
from datetime import datetime

import mlflow
import mlflow.artifacts as artifacts
import numpy as np
from tqdm import tqdm
import tifffile as tiff

from flame import CAREInferenceSession
from flame.utils import get_input_and_GT_paths
import flame.eval as eval
from flame.error import FLAMEEvalError

In [9]:
DATASET_NAME = "20250618_224I_denoising_5to40F"
DATASET_DIREC = os.path.join("/mnt/d/data/processed", DATASET_NAME)
TEST_DIREC = os.path.join(DATASET_DIREC, "test")
METRICS = [
    "mse",
    "mae",
    "ssim"
]
TRACKING_URI = "http://127.0.0.1:5050"
MLFLOW_RUN_ID = "bf9a43f3ec154c9ba2deb6de2fb0db33"

In [None]:
mlflow.set_tracking_uri(uri=TRACKING_URI)

In [4]:
logger = logging.getLogger("main")
logging.basicConfig(
    filename=os.path.join(os.getcwd(), "logs", f"{datetime.now().strftime('%Y%m%d-%H%M%S')}_logger.log"),
    encoding="utf-8",
    level=logging.DEBUG
)

In [None]:
assert os.path.isdir(TEST_DIREC), f"Could not find test set directory at path {TEST_DIREC}"
for metric in METRICS:
    try:
        getattr(eval, metric)
    except AttributeError as e:
        logger.error(f"Could not find {metric} among available evaluation metrics.")
        raise FLAMEEvalError(f"Could not find {metric} among available evaluation metrics.")

### Getting MLFlow Model

In [8]:
engine = CAREInferenceSession.from_mlflow_uri(
    tracking_uri=TRACKING_URI,
    run_id=MLFLOW_RUN_ID,
)

Downloading artifacts:   0%|          | 0/8 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

### Starting Inference

In [None]:
FRAMES_LOW = config['FLAME_Dataset']['input']['n_frames']
FRAMES_GT = config['FLAME_Dataset']['output']['n_frames']
low_paths, GT_paths = get_input_and_GT_paths(
    input_direc=TEST_DIREC,
    input_frames=FRAMES_LOW,
    gt_frames=FRAMES_GT,
    logger=logger
)

In [None]:
for low_path, gt_path in tqdm(
        iterable=zip(low_paths, GT_paths),
        total=len(low_paths),
        ascii=True
    ):
    try:
        low=tiff.imread(low_path).transpose(0,2,3,1).astype(np.float32)
        gt=tiff.imread(gt_path).transpose(0,2,3,1).astype(np.float32)
    except Exception as e:
        logger.error(f"Could not load input and/or GT images from {os.path.basename(low_path)} & {os.path.basename(gt_path)}")
        continue
    
    assert low.shape == gt.shape, f"Input and GT image shapes do not match (found {low.shape} and {gt.shape})"

    pred = engine.predict(low)

    break
