In [None]:
import os
import logging
import json
from datetime import datetime

import tensorflow as tf
from csbdeep.models import CARE
from csbdeep.data import RawData, create_patches, no_background_patches
from csbdeep.utils import plot_some
from natsort import natsorted
import tifffile as tiff
import numpy as np
from matplotlib import pyplot as plt

from flame.utils import min_max_norm

In [None]:
assert len(tf.config.list_physical_devices("GPU")) > 0

In [None]:
DATASET_NAME = "20250513_40I_denoising_7to40F"
DATASET_DIREC = "/mnt/d/code/Balu_CARE/datasets"
DATASET_JSON_PATH = os.path.join(DATASET_DIREC, f"{DATASET_NAME}.json")
INPUT_DATA_DIREC = os.path.join("/mnt/d/data/processed/", DATASET_NAME)
TRAIN_DATA_DIREC = os.path.join(INPUT_DATA_DIREC, "train")
FRAMES_LOW = 7
FRAMES_GT = 40
PATCH_SIZE = 128
PATCH_MULTIPLE = 4

In [None]:
low_paths = []
GT_paths = []
for root, dirs, files in os.walk(TRAIN_DATA_DIREC):
    for f in files:
        if f"frames{FRAMES_LOW}" in f:
            low_paths.append(os.path.join(root, f))
        elif f"frames{FRAMES_GT}" in f:
            GT_paths.append(os.path.join(root, f))

low_paths = natsorted(low_paths)
GT_paths = natsorted(GT_paths)

In [None]:
for f1, f2 in zip(low_paths, GT_paths):
    id1 = os.path.basename(f1).split('_')[0]
    id2 = os.path.basename(f2).split('_')[0]
    assert id1 == id2


In [None]:
ds_config = json.load(open(DATASET_JSON_PATH, 'r'))
input_config = ds_config['FLAME_Dataset']['input']
output_config = ds_config['FLAME_Dataset']['output']

In [None]:
input_1pct, input_99pct = np.array(input_config['pixel_1pct']), np.array(input_config['pixel_99pct'])
output_1pct, output_99pct = np.array(output_config['pixel_1pct']), np.array(output_config['pixel_99pct'])

In [None]:
low = [min_max_norm(
    tiff.imread(path).transpose(1, 2, 0).astype(np.float64), 
    input_1pct, input_99pct
) for path in low_paths]

GT = [min_max_norm(
    tiff.imread(path).transpose(1, 2, 0).astype(np.float64), 
    output_1pct, output_99pct
) for path in GT_paths]

In [None]:
low = np.stack(low, axis=-1).astype(np.float32).transpose(3, 2, 0, 1)
GT = np.stack(GT, axis=-1).astype(np.float32).transpose(3, 2, 0, 1)

In [None]:
print(f"Frame7: {low.shape}, {low.dtype}")
print(f"Frame40: {GT.shape}, {GT.dtype}")

In [None]:
raw_data = RawData.from_arrays(
    X=low,
    Y=GT,
    axes="SCYX"
)

In [None]:
n_patch_per_im = low.shape[-1] // PATCH_SIZE * 2

X, Y, XY_axes = create_patches(
    raw_data=raw_data,
    patch_size=(low.shape[1], PATCH_SIZE, PATCH_SIZE),
    patch_axes="CYX",
    patch_filter=no_background_patches(0),
    n_patches_per_image=low.shape[-1] // PATCH_SIZE * PATCH_MULTIPLE,
    normalization=None,
    save_file=os.path.join(INPUT_DATA_DIREC, f"{DATASET_NAME}_patch{PATCH_SIZE}_{n_patch_per_im}PpI.npz")
)

In [None]:
# from: https://nbviewer.org/url/csbdeep.bioimagecomputing.com/examples/denoising2D/1_datagen.ipynb
for i in range(2):
    plt.figure(figsize=(16,4))
    sl = slice(8*i, 8*(i+1)), 0
    plot_some(
        X[sl],
        Y[sl],
        title_list=[np.arange(sl[0].start,sl[0].stop)],
    )
    plt.show()
None;