In [None]:
import os
import json
import numpy as np

from csbdeep.models import Config, CARE
from matplotlib import pyplot as plt
import tifffile as tiff

from flame import FLAMEImage
from flame.utils import min_max_norm

In [None]:
INFERENCE_DATA_DIR = "/mnt/d/data/raw/0013_250514_HS6307_CAREtest_NA"
INFERENCE_OUTPUT = "/mnt/d/data/output/0013_250514_HS6307_CAREtest_NA"
DATASET_DIRECTORY = "/mnt/d/code/Balu_CARE/datasets"
DATASET_NAME = "20250513_40I_denoising_7to40F"
DATASET_JSON = os.path.join(DATASET_DIRECTORY, f"{DATASET_NAME}.json")
MODEL_DIRECTORY = "/mnt/d/models/CARE/test_model"
MODEL_NAME = os.path.basename(MODEL_DIRECTORY)
INFERENCE_OUTPUT_DIRECTORY = os.path.join(INFERENCE_OUTPUT, MODEL_NAME)

ONNX_PATH = os.path.join(MODEL_DIRECTORY, f"{MODEL_NAME}.onnx")
JSON_PATH = os.path.join(MODEL_DIRECTORY, f"config.json")

for f in [DATASET_JSON, ONNX_PATH, JSON_PATH]:
    assert os.path.isfile(f)

In [None]:
os.makedirs(INFERENCE_OUTPUT_DIRECTORY, exist_ok = True)

In [None]:
images = []
for root, dirs, files in os.walk(INFERENCE_DATA_DIR):
    for f in files:
        if "tif" in f or "tiff" in f:
            try:
                this_image = FLAMEImage(
                    impath = os.path.join(root, f),
                    jsonext = "tileData.txt",
                    overrideNFrames = 1,
                    checkFrames = False,
                    checkZs = True
                )
            except FLAMEImageError as e:
                continue
            images.append(this_image)

In [None]:
model = CARE(config=None, name=MODEL_DIRECTORY)

In [None]:
dataset_dict = json.load(open(DATASET_JSON, 'r'))

In [None]:
input_min = np.array(dataset_dict['FLAME_Dataset']['input']['pixel_1pct'])
input_max = np.array(dataset_dict['FLAME_Dataset']['input']['pixel_99pct'])
output_min = np.array(dataset_dict['FLAME_Dataset']['output']['pixel_1pct'])
output_max = np.array(dataset_dict['FLAME_Dataset']['output']['pixel_99pct'])

In [None]:
for image in images:
    image.openImage()
    model_input = image.imageData.copy()
    image.closeImage()
    model_input = model_input.transpose(0, 2, 3, 1)
    model_input = np.clip(model_input, input_min, input_max)
    model_input = min_max_norm(model_input, input_min, input_max, dtype=np.float32)
    outputs = []
    for zdx in range(model_input.shape[0]):
        outputs.append(model.predict(model_input[zdx,...], axes='YXC'))
    outputs = np.stack(outputs, axis = -1).transpose(3, 0, 1, 2)
    outputs = (outputs * (output_max - output_min + 1e-20)) + output_min
    outputs[...,0] *= 4/3
    outputs[...,2] *= 2/3
    imname = "_".join(image.impath.split(os.path.sep)[4:])
    os.makedirs(os.path.join(INFERENCE_OUTPUT_DIRECTORY, os.path.splitext(imname)[0]))
    tiff.imwrite(os.path.join(INFERENCE_OUTPUT_DIRECTORY, imname), outputs)

In [None]:
DIR = "/mnt/c/Users/BaluLab/Desktop/raw_vis"
os.makedirs(DIR)

In [None]:
for idx, image in enumerate(images):
    image.openImage()
    raw = image.imageData.copy()
    image.closeImage()
    raw = raw.transpose(0, 2, 3, 1).astype(np.float32)
    raw[...,0] *= 4/3
    raw[...,2] *= 2/3
    raw = min_max_norm(raw, np.min(raw), np.max(raw), dtype=np.float32)
    raw = (raw * 255).astype(np.uint8)
    for zdx in [0, 4, 9]:
        impath = os.path.join(DIR, f"im_i{idx}_z{zdx}.tif")
        tiff.imwrite(os.path.join(impath), raw[zdx,...])
