In [None]:
import os
import logging
from datetime import datetime
import json

import numpy as np
from csbdeep.models import Config
import onnxruntime as ort
import onnx

from flame import FLAMEImage
from flame.utils import min_max_norm
from flame.error import FLAMEImageError

In [None]:
assert 'CUDAExecutionProvider' in ort.get_available_providers()

In [None]:
INFERENCE_DATA_DIR = "/mnt/d/data/raw/0013_250514_HS6307_CAREtest_NA"
DATASET_DIRECTORY = "/mnt/d/code/Balu_CARE/datasets"
DATASET_NAME = "20250513_40I_denoising_7to40F"
DATASET_JSON = os.path.join(DATASET_DIRECTORY, f"{DATASET_NAME}.json")
MODEL_DIRECTORY = "/mnt/d/models/CARE/test_model"
MODEL_NAME = os.path.basename(MODEL_DIRECTORY)
ONNX_PATH = os.path.join(MODEL_DIRECTORY, f"{MODEL_NAME}.onnx")
JSON_PATH = os.path.join(MODEL_DIRECTORY, f"config.json")

for f in [DATASET_JSON, ONNX_PATH, JSON_PATH]:
    assert os.path.isfile(f)

### Loading Images

In [None]:
images = []
for root, dirs, files in os.walk(INFERENCE_DATA_DIR):
    for f in files:
        if "tif" in f or "tiff" in f:
            try:
                this_image = FLAMEImage(
                    impath = os.path.join(root, f),
                    jsonext = "tileData.txt",
                    overrideNFrames = 1,
                    checkFrames = False,
                    checkZs = True
                )
            except FLAMEImageError as e:
                continue
            images.append(this_image)

In [None]:
images[0].axes_shape

In [None]:
test_output_dir = "/mnt/c/Users/BaluLab/Desktop/test"
os.makedirs(test_output_dir, exist_ok=True)

In [None]:
import tifffile as tiff

In [None]:
images[0].openImage()

In [None]:
for idx in range(images[0].imShape[0]):
    tiff.imwrite(os.path.join(test_output_dir, f"{idx}.tiff"), images[0].raw()[idx,...].transpose(1, 2, 0).astype(np.uint8))

In [None]:
images[0].closeImage()

In [None]:
model_config_dict = json.load(open(JSON_PATH, 'r'))
dataset_config_dict = json.load(open(DATASET_JSON, 'r'))

In [None]:
dataset_config_dict['FLAME_Dataset'][]

### Loading ONNX

In [None]:
onnx_model = onnx.load(ONNX_PATH)
onnx.checker.check_model(onnx_model)

In [None]:
ort_session = ort.InferenceSession(
    ONNX_PATH,
    providers=['CUDAExecutionProvider']
)

In [None]:
input_tensor = ort_session.get_inputs()[0]
input_name, input_shape, input_type = input_tensor.name, input_tensor.shape, input_tensor.type
print(f"Input Tensor\nName: {input_name}\nShape: {input_shape}\nType: {input_type}")