In [1]:
import os
import logging
import json
from datetime import datetime

import tensorflow as tf
from csbdeep.models import CARE
from csbdeep.data import RawData, create_patches, no_background_patches
from csbdeep.utils import plot_some
from natsort import natsorted
import tifffile as tiff
import numpy as np
from matplotlib import pyplot as plt

from flame.utils import min_max_norm

2025-05-23 09:53:38.859287: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-05-23 09:53:41.808487: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1748019222.433881 1614179 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1748019222.596107 1614179 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1748019224.982801 1614179 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

In [2]:
assert len(tf.config.list_physical_devices("GPU")) > 0

In [None]:
DATASET_NAME = "20250522_192I_denoising_5to40F"
DATASET_DIREC = "/mnt/d/code/Balu_CARE/datasets"
DATASET_JSON_PATH = os.path.join(DATASET_DIREC, f"{DATASET_NAME}.json")
INPUT_DATA_DIREC = os.path.join("/mnt/d/data/processed/", DATASET_NAME)
assert os.path.isdir(INPUT_DATA_DIREC), f"Directory not found: {INPUT_DATA_DIREC}"
TRAIN_DATA_DIREC = os.path.join(INPUT_DATA_DIREC, "train")
assert os.path.isdir(TRAIN_DATA_DIREC), f"Directory not found: {TRAIN_DATA_DIREC}"
FRAMES_LOW = 5
FRAMES_GT = 40
PATCH_SIZE = 128
PATCH_MULTIPLE = 4

AssertionError: 

In [17]:
logger = logging.getLogger("main")
logging.basicConfig(
    filename=f"{datetime.now().strftime('%Y%m%d-%H%M%S')}_logger.log",
    encoding="utf-8",
    level=logging.DEBUG
)

### Getting paths of Low Frame and Ground Truth Frame Accumulations

In [18]:
low_paths = []
GT_paths = []
for root, dirs, files in os.walk(TRAIN_DATA_DIREC):
    for f in files:
        if f"frames{FRAMES_LOW}" in f:
            low_paths.append(os.path.join(root, f))
        elif f"frames{FRAMES_GT}" in f:
            GT_paths.append(os.path.join(root, f))

low_paths = natsorted(low_paths)
GT_paths = natsorted(GT_paths)

In [20]:
print(f"Found {len(low_paths)} low- and high-frame accumulation images for this dataset")
logger.info(f"Found {len(low_paths)} low- and high-frame accumulation images for this dataset")

Found 0 low- and high-frame accumulation images for this dataset


In [9]:
for f1, f2 in zip(low_paths, GT_paths):
    id1 = os.path.basename(f1).split('_')[0]
    id2 = os.path.basename(f2).split('_')[0]
    assert id1 == id2


In [10]:
ds_config = json.load(open(DATASET_JSON_PATH, 'r'))
input_config = ds_config['FLAME_Dataset']['input']
output_config = ds_config['FLAME_Dataset']['output']

In [11]:
input_1pct, input_99pct = np.array(input_config['pixel_1pct']), np.array(input_config['pixel_99pct'])
output_1pct, output_99pct = np.array(output_config['pixel_1pct']), np.array(output_config['pixel_99pct'])

In [12]:
low = [min_max_norm(
    tiff.imread(path).transpose(1, 2, 0).astype(np.float64), 
    input_1pct, input_99pct
) for path in low_paths]

GT = [min_max_norm(
    tiff.imread(path).transpose(1, 2, 0).astype(np.float64), 
    output_1pct, output_99pct
) for path in GT_paths]

In [None]:
low = np.stack(low, axis=-1).astype(np.float32).transpose(3, 2, 0, 1)
GT = np.stack(GT, axis=-1).astype(np.float32).transpose(3, 2, 0, 1)

In [None]:
print(f"Frame7: {low.shape}, {low.dtype}")
print(f"Frame40: {GT.shape}, {GT.dtype}")

In [None]:
raw_data = RawData.from_arrays(
    X=low,
    Y=GT,
    axes="SCYX"
)

In [None]:
n_patch_per_im = low.shape[-1] // PATCH_SIZE * 2

X, Y, XY_axes = create_patches(
    raw_data=raw_data,
    patch_size=(low.shape[1], PATCH_SIZE, PATCH_SIZE),
    patch_axes="CYX",
    patch_filter=no_background_patches(0),
    n_patches_per_image=low.shape[-1] // PATCH_SIZE * PATCH_MULTIPLE,
    normalization=None,
    save_file=os.path.join(INPUT_DATA_DIREC, f"{DATASET_NAME}_patch{PATCH_SIZE}_{n_patch_per_im}PpI.npz")
)

In [None]:
# from: https://nbviewer.org/url/csbdeep.bioimagecomputing.com/examples/denoising2D/1_datagen.ipynb
for i in range(2):
    plt.figure(figsize=(16,4))
    sl = slice(8*i, 8*(i+1)), 0
    plot_some(
        X[sl],
        Y[sl],
        title_list=[np.arange(sl[0].start,sl[0].stop)],
    )
    plt.show()
None;