In [None]:
import os
import logging
import json
from datetime import datetime
import gc

import tensorflow as tf
from csbdeep.models import CARE
from csbdeep.data import RawData, create_patches, no_background_patches
from csbdeep.utils import plot_some
from natsort import natsorted
import tifffile as tiff
import numpy as np
from matplotlib import pyplot as plt
from tqdm.notebook import tqdm

from flame.utils import min_max_norm
from flame import FLAMEImage

2025-05-27 14:14:54.063810: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-05-27 14:14:54.128284: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1748380494.150812 2680273 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1748380494.156464 2680273 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1748380494.197602 2680273 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

In [2]:
assert len(tf.config.list_physical_devices("GPU")) > 0

### WARNING: This Notebook may not work if processed dataset images are not YXC format.

In [None]:
DATASET_NAME = "20250527_112I_denoising_5to40F" # processed dataset directory (created by 'create_care_dataset.ipynb')
DATASET_DIREC = "/mnt/d/code/Balu_CARE/datasets"
DATASET_JSON_PATH = os.path.join(DATASET_DIREC, f"{DATASET_NAME}.json")
INPUT_DATA_DIREC = os.path.join("/mnt/d/data/processed/", DATASET_NAME)
assert os.path.isdir(INPUT_DATA_DIREC), f"Directory not found: {INPUT_DATA_DIREC}"
TRAIN_DATA_DIREC = os.path.join(INPUT_DATA_DIREC, "train")
assert os.path.isdir(TRAIN_DATA_DIREC), f"Directory not found: {TRAIN_DATA_DIREC}"
FRAMES_LOW = 5
FRAMES_GT = 40
PATCH_SIZE = 128
PATCH_MULTIPLE = 4

In [4]:
logger = logging.getLogger("main")
logging.basicConfig(
    filename=f"{datetime.now().strftime('%Y%m%d-%H%M%S')}_logger.log",
    encoding="utf-8",
    level=logging.DEBUG
)

In [5]:
logger.info(f"Using {TRAIN_DATA_DIREC} to construct the CARE training data .npz")

### Getting paths of Low Frame and Ground Truth Frame Accumulations

In [6]:
low_paths = []
GT_paths = []
for root, dirs, files in os.walk(TRAIN_DATA_DIREC):
    for f in files:
        if f"frames{FRAMES_LOW}" in f:
            low_paths.append(os.path.join(root, f))
        elif f"frames{FRAMES_GT}" in f:
            GT_paths.append(os.path.join(root, f))

low_paths = natsorted(low_paths)
GT_paths = natsorted(GT_paths)

In [7]:
print(f"Found {len(low_paths)} low- and high-frame accumulation images for this dataset")
logger.info(f"Found {len(low_paths)} low- and high-frame accumulation images for this dataset")

Found 102 low- and high-frame accumulation images for this dataset


In [8]:
for f1, f2 in zip(low_paths, GT_paths):
    id1 = os.path.basename(f1).split('_')[0]
    id2 = os.path.basename(f2).split('_')[0]
    assert id1 == id2


In [9]:
ds_config = json.load(open(DATASET_JSON_PATH, 'r'))
input_config = ds_config['FLAME_Dataset']['input']
output_config = ds_config['FLAME_Dataset']['output']

In [10]:
input_1pct, input_99pct = np.array(input_config['pixel_1pct']), np.array(input_config['pixel_99pct'])
output_1pct, output_99pct = np.array(output_config['pixel_1pct']), np.array(output_config['pixel_99pct'])

### Min-max norm of images

In [None]:
# get largest shape
largest_x = 0
largest_y = 0
for shape in ds_config['FLAME_Dataset']['image_shapes']:
    this_y = shape[0]
    this_x = shape[1]
    if this_x > largest_x: largest_x = this_x
    if this_y > largest_y: largest_y = this_y

In [None]:
low = []
GT = []
# Assumes that all image shapes are XYC format
for low_path, GT_path in tqdm(
        zip(low_paths, GT_paths),
        ascii=True,
        unit="image",
        total=len(low_paths)
    ):

    # default tiff.imread behavior is to read CYX (even if the tif is not written that way), so transpose to YXC
    # also convert to float 32 to allow for full range during min-max normalization
    this_low = tiff.imread(low_path).transpose(1,2,0).astype(np.float32)
    this_GT = tiff.imread(GT_path).transpose(1,2,0).astype(np.float32)

    imshape = this_low.shape
    this_y, this_x, this_c = imshape[0], imshape[1], imshape[2]

    # if the shape of the current input image is not correct, frame it insize the largest x-y possible
    if largest_x != this_x or largest_y != this_y:
        low_zeroes = np.zeros(shape=(largest_y, largest_x, this_c))
        GT_zeroes = low_zeroes.copy()
        low_zeroes[:this_y, :this_x, :] = this_low
        GT_zeroes[:this_y, :this_x, :] = this_GT
        # delete what is currently in this_low and this_GT, since it will be overwritten
        del this_low
        del this_GT
        this_low = low_zeroes
        this_GT = GT_zeroes
        
    low.append(np.clip(min_max_norm(this_low, input_1pct, input_99pct), 0, 1))
    GT.append(np.clip(min_max_norm(this_GT, output_1pct, output_99pct), 0, 1))
    # delete lagging pointers and collect garbage for memory management
    del this_low
    del this_GT
    gc.collect()

### Removing channel dimension

In [13]:
low = np.stack(low, axis=-1).astype(np.float32).transpose(2, 3, 0, 1)
GT = np.stack(GT, axis=-1).astype(np.float32).transpose(2, 3, 0, 1)

ValueError: all input arrays must have the same shape

In [None]:
print(f"Frame{FRAMES_LOW}: {low.shape}, {low.dtype}")
print(f"Frame{FRAMES_GT}: {GT.shape}, {GT.dtype}")

In [None]:
C_x_S = np.cumprod(np.array(low.shape[0:2]))[1]
Y = low.shape[2]
X = low.shape[3]
print(C_x_S)
low_test = np.reshape(low, (C_x_S,Y,X))
GT_test = np.reshape(GT, (C_x_S,Y,X))

In [None]:
print(f"Frame{FRAMES_LOW}: {low_test.shape}, {low_test.dtype}")
print(f"Frame{FRAMES_GT}: {GT_test.shape}, {GT_test.dtype}")

In [None]:
raw_data = RawData.from_arrays(
    X=low,
    Y=GT,
    axes="SCYX"
)

In [None]:
n_patch_per_im = low.shape[-1] // PATCH_SIZE * 2

X, Y, XY_axes = create_patches(
    raw_data=raw_data,
    patch_size=(low.shape[1], PATCH_SIZE, PATCH_SIZE),
    patch_axes="CYX",
    patch_filter=no_background_patches(0),
    n_patches_per_image=low.shape[-1] // PATCH_SIZE * PATCH_MULTIPLE,
    normalization=None,
    save_file=os.path.join(INPUT_DATA_DIREC, f"{DATASET_NAME}_patch{PATCH_SIZE}_{n_patch_per_im}PpI.npz")
)

In [None]:
# from: https://nbviewer.org/url/csbdeep.bioimagecomputing.com/examples/denoising2D/1_datagen.ipynb
for i in range(2):
    plt.figure(figsize=(16,4))
    sl = slice(8*i, 8*(i+1)), 0
    plot_some(
        X[sl],
        Y[sl],
        title_list=[np.arange(sl[0].start,sl[0].stop)],
    )
    plt.show()
None;