In [1]:
import os
from datetime import datetime
import json
import logging

from csbdeep.io import load_training_data
from csbdeep.models import Config, CARE
from csbdeep.utils import axes_dict, plot_history, plot_some

from matplotlib import pyplot as plt
import tensorflow as tf
import tf2onnx
import onnx
import numpy as np

import mlflow
from mlflow.models import infer_signature
from mlflow.pyfunc import log_model as pyfunc_log_model
from mlflow.onnx import log_model as onnx_log_model

from flame.error import CAREInferenceError
from flame.utils import _compress_dict_fields

2025-06-19 03:40:49.841855: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-06-19 03:40:49.854133: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1750329649.862331 2028509 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750329649.864740 2028509 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1750329649.871048 2028509 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

In [2]:
tf.config.run_functions_eagerly(False)

In [3]:
mlflow.set_tracking_uri(uri="http://127.0.0.1:5050")
EXPERIMENT_NAME = "CARE Denoising 1 Channel"
mlflow.set_experiment(EXPERIMENT_NAME)

<Experiment: artifact_location='mlflow-artifacts:/859033271774544446', creation_time=1748986871785, experiment_id='859033271774544446', last_update_time=1748986871785, lifecycle_stage='active', name='CARE Denoising 1 Channel', tags={}>

In [4]:
DATA_DIREC = "/mnt/d/data/processed/20250618_224I_denoising_5to40F"
PATCH_CONFIG_JSON = os.path.join(DATA_DIREC, "patch_config.json")
SAVE_DIREC = "/mnt/d/models/"
UNET_KERN_SIZE = 3
TRAIN_BATCH_SIZE = 16
INFER_BATCH_SIZE = 1
RANDOM_STATE = 8888

In [5]:
logger = logging.getLogger("main")
logging.basicConfig(
    filename=os.path.join(os.getcwd(), "logs", f"{datetime.now().strftime('%Y%m%d-%H%M%S')}_logger.log"),
    encoding="utf-8",
    level=logging.DEBUG
)

In [6]:
# ensure that data directory and patch config json paths are valid
assert os.path.isdir(DATA_DIREC)
assert os.path.isfile(PATCH_CONFIG_JSON)

### Creating training config by building on patch_config

In [7]:
try:
    config_json = json.load(open(PATCH_CONFIG_JSON, 'r'))
    logger.info(f"Successfully loaded patch config from {PATCH_CONFIG_JSON}")
except Exception as e:
    logger.error(f"Could not load patch config json from {PATCH_CONFIG_JSON}.\nERROR: {e}")
    raise CAREInferenceError(f"Could not load patch config json from {PATCH_CONFIG_JSON}.\nERROR: {e}")

In [None]:
try:
    MODEL_NAME = f"FLAME_CARE_" \
        + f"{config_json['FLAME_Dataset']['input']['n_frames']}" \
        + f"to" \
        + f"{config_json["FLAME_Dataset"]['output']['n_frames']}F"
    logger.info(f"Training a model with NAME: '{MODEL_NAME}'...")
except Exception as e:
    logger.error(f"Failed to dynamically load model name.\nERROR: {e}")
    raise CAREInferenceError(f"Failed to dynamically load model name.\nERROR: {e}")


In [9]:
try:
    RUN_ID = mlflow.search_runs().shape[0]
    MODEL_DIREC = os.path.join(SAVE_DIREC, MODEL_NAME, str(RUN_ID))
    # exist_ok being True *SHOULD* (?) be fine because RUN_ID will not iterate upwards unless training either started or finished.
    os.makedirs(MODEL_DIREC, exist_ok = True)
    logger.info(f"Training run id is {RUN_ID}.")
    logger.info(f"Model saving to {MODEL_DIREC}")
except Exception as e:
    logger.error(f"Failed to load run id and/or set up model save directory.\nERROR: {e}")
    raise CAREInferenceError(f"Failed to load run id and/or set up model save directory.\nERROR: {e}")

In [10]:
config_json['Train_Config'] = {
    'npz_path': os.path.join(DATA_DIREC, config_json['Patch_Config']['name']),
    'name': MODEL_NAME,
    'model_direc': MODEL_DIREC,
    'unet_kern_size': UNET_KERN_SIZE,
    'train_batch_size': TRAIN_BATCH_SIZE,
    'random_state': RANDOM_STATE,
}

In [11]:
config_json

{'FLAME_Dataset': {'name': '20250618_224I_denoising_5to40F',
  'type': 'denoising',
  'image_shapes': [[1, 3, 1200, 1200], [1, 3, 1200, 1196]],
  'input': {'n_frames': 5,
   'pixel_mean': 378.72098887921106,
   'pixel_min': 0,
   'pixel_max': 28400,
   'pixel_p1pct': [0.0, 0.0, 0.0],
   'pixel_1pct': [0.0, 0.0, 0.0],
   'pixel_5pct': [0.0, 0.0, 0.0],
   'pixel_95pct': [832.0, 1000.0, 1400.0],
   'pixel_99pct': [1400.0, 1800.0, 2200.0],
   'pixel_99p9pct': [2400.0, 3141.0, 4165.0],
   'pixel_std': 420.2389718127719},
  'output': {'n_frames': 40,
   'pixel_mean': 3022.801186980542,
   'pixel_min': 0,
   'pixel_max': 206600,
   'pixel_p1pct': [0.0, 0.0, 0.0],
   'pixel_1pct': [0.0, 0.0, 200.0],
   'pixel_5pct': [200.0, 375.0, 710.0],
   'pixel_95pct': [5800.0, 7545.0, 9653.0],
   'pixel_99pct': [9000.0, 12600.0, 15400.0],
   'pixel_99p9pct': [17200.0, 22555.0, 31400.0],
   'pixel_std': 2790.1654790982434},
  'image_ids': [11,
   12,
   13,
   14,
   15,
   16,
   17,
   18,
   19,
   20,


In [None]:
# verifying npz path...
NPZ_PATH = config_json['Train_Config']['npz_path']
assert os.path.isfile(config_json['Train_Config']['npz_path']), f"NPZ path {NPZ_PATH} is not a file"

### Training and Validation Data

In [None]:
(X, Y), (X_val, Y_val), axes = load_training_data(
    NPZ_PATH,
    validation_split=0.1,
    verbose=True
)

In [None]:
c = axes_dict(axes)['C']
channels_in, channels_out = X.shape[c], Y.shape[c]

### CARE Model

In [None]:
config_json['CARE_Model'] = {
    'name': MODEL_NAME,
    'experiment_name': EXPERIMENT_NAME,
    'run_id': RUN_ID,
    'base_dir': SAVE_DIREC,
    'run_dir': MODEL_DIREC
}
config_json['CARE_Model']['CSBDeep_Config'] = {
    'axes': axes,
    'n_channel_in': channels_in,
    'n_channel_out': channels_out,
    'probabilistic': False, # default from CSBDeep
    'allow_new_parameters': False, # default from CSBDeep
    'unet_kern_size': UNET_KERN_SIZE,
    'train_batch_size': TRAIN_BATCH_SIZE,
    'unet_input_shape': tuple(config_json['Patch_Config']['patch_shape']),
    'allow_new_parameters': True
}

In [None]:
config = Config(
    **config_json['CARE_Model']['CSBDeep_Config']
)

config_json['CARE_Model']['Model_Arch'] = vars(config)

In [None]:
JSON_CONFIG_PATH = os.path.join(MODEL_DIREC, "model_config.json")
json.dump(config_json, open(JSON_CONFIG_PATH, 'w+'))

### Training the Model

In [None]:
model = CARE(
    config,
    str(RUN_ID),
    basedir=os.path.join(SAVE_DIREC, MODEL_NAME)
)

In [None]:
history = model.train(X, Y, validation_data=(X_val, Y_val))

In [None]:
# model.keras_model.save(os.path.join(MODEL_DIREC, 'saved_model.keras'))

### Some quick visualizations

In [None]:
print(sorted(list(history.history.keys())))
plt.figure(figsize=(16,5))
plot_history(history,['loss','val_loss'],['mse','val_mse','mae','val_mae']);
plt.savefig(os.path.join(MODEL_DIREC, "training_history.png"))

### Model Evaluation from Validation Set

In [None]:
_P = model.keras_model.predict(X_val[:5])

In [None]:
plt.figure(figsize=(20,12))
if config.probabilistic:
    _P = _P[...,:(_P.shape[-1]//2)]
plot_some(X_val[:5],Y_val[:5],_P,pmax=99.5)
plt.suptitle('5 example validation patches\n'      
             'top row: input (source),  '          
             'middle row: target (ground truth),  '
             'bottom row: predicted from source');

plt.savefig(os.path.join(MODEL_DIREC, "val_set_predict_sample.png"))

### Logging Model in MLFlow Database

In [None]:
val_loss = np.min(history.history['val_loss'])
idx = history.history['val_loss'].index(val_loss)
val_mae = history.history['val_mae'][idx]
val_mse = history.history['val_mse'][idx]

In [None]:
# # with mlflow.start_run():

# mlflow.start_run()

# logger.info(f"Run started")
# print("Run started")

# # Log the hyperparameters
# mlflow.log_params(config_json)

# # Log the validation performance metrics
# mlflow.log_metric("val_loss", np.min(val_loss))
# mlflow.log_metric("val_mae", np.min(val_mae))
# mlflow.log_metric("val_mse", np.min(val_mse))

# # infer the model signature
# signature = infer_signature(X, pyfunc_model.predict(X))

# # Log the model
# model_info = pyfunc_log_model(
#     artifact_path='care model',
#     python_model=pyfunc_model,
#     # model_code_path=os.path.join(os.getcwd(), "flame", "model.py"),
#     code_paths=[
#         os.path.join(os.getcwd(), "flame", "__init__.py"),
#         os.path.join(os.getcwd(), "flame", "mlflow_pyfunc.py"),
#         os.path.join(MODEL_DIREC, "model_config.json"),
#         os.path.join(MODEL_DIREC, "config.json")
#     ],
#     conda_env=os.path.join(os.getcwd(), "environment.yml"),
#     signature=signature,
#     input_example=X
# )
    
# logger.info(f"Run finished.")
# print("Run finished.")
# mlflow.end_run()

### Export to ONNX

In [None]:
input_shape = list(X.shape)
batch_dim = axes_dict(axes)['S']
input_shape[batch_dim] = None
print(input_shape)

In [None]:
input_signature = [
    tf.TensorSpec(
        input_shape, 
        tf.float32, 
        name='patch'
    )
]

In [None]:
onnx_model, _ = tf2onnx.convert.from_keras(
    model.keras_model,
    input_signature,
    opset=13
)

In [None]:
ONNX_PATH = os.path.join(MODEL_DIREC, f"{MODEL_NAME}.onnx")
onnx.save(onnx_model, ONNX_PATH)

In [None]:
from flame.engine import CAREInferenceSession

In [None]:
engine = CAREInferenceSession(
    model_path=ONNX_PATH,
    model_config_path=JSON_CONFIG_PATH,
    dataset_config_path=JSON_CONFIG_PATH
)

In [None]:
X.shape

In [None]:
mlflow.start_run()

logger.info(f"Run started")
print("Run started")

# Log the hyperparameters
mlflow.log_params(config_json)

# Log the validation performance metrics
mlflow.log_metric("val_loss", np.min(val_loss))
mlflow.log_metric("val_mae", np.min(val_mae))
mlflow.log_metric("val_mse", np.min(val_mse))

model_info = onnx_log_model(
    onnx_model=onnx_model,
    artifact_path="model",
    conda_env=os.path.join(os.getcwd(), "environment.yml"),
    input_example=X[[0],...],

    # If given, create a model version under registered_model_name, 
    # also creating a registered model if one with the given name does not exist.
    # registered_model_name=_,

    metadata=config_json,
    signature=infer_signature(X[[0],...], engine.predict(X[[0],...])),
    onnx_execution_providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
)

mlflow.log_artifacts(
    local_dir=os.path.join(os.getcwd(), 'flame'),
    artifact_path="flame"
)

mlflow.log_artifact(
    local_path=JSON_CONFIG_PATH,
    artifact_path="model_config"
)

mlflow.end_run()