In [None]:
import os
from datetime import datetime
import json

from csbdeep.io import load_training_data
from csbdeep.models import Config, CARE
from csbdeep.utils import axes_dict, plot_history, plot_some

from matplotlib import pyplot as plt
import tensorflow as tf
import tf2onnx
import onnx
import mlflow

In [None]:
mlflow.set_tracking_uri(uri="http://127.0.0.1:8080")

In [None]:
NPZ_PATH = "/mnt/d/data/processed/20250513_40I_denoising_7to40F/20250513_40I_denoising_7to40F_patch128_18PpI.npz"
MODEL_NAME = "test_model"
MODEL_SAVEDIR = "/mnt/d/models/CARE"
UNET_KERN_SIZE = 3
TRAIN_BATCH_SIZE = 16
RANDOM_STATE = 8888

In [None]:
model_params = {
    'npz_path': NPZ_PATH,
    'name': MODEL_NAME,
    'save_direc': MODEL_SAVEDIR,
    'unet_kern_size': UNET_KERN_SIZE,
    'train_batch_size': TRAIN_BATCH_SIZE,
    'random_state': RANDOM_STATE
}

### Training and Validation Data

In [None]:
(X, Y), (X_val, Y_val), axes = load_training_data(
    NPZ_PATH,
    validation_split=0.1,
    verbose=True
)

In [None]:
c = axes_dict(axes)['C']
channels_in, channels_out = X.shape[c], Y.shape[c]

### CARE Model

In [None]:
config = Config(
    axes,
    channels_in,
    channels_out,
    unet_kern_size=UNET_KERN_SIZE,
    train_batch_size=TRAIN_BATCH_SIZE,
)

vars(config)

In [None]:
model = CARE(
    config,
    MODEL_NAME,
    basedir=MODEL_SAVEDIR
)

In [None]:
JSON_CONFIG_PATH = os.path.join(MODEL_SAVEDIR, MODEL_NAME, "config.json")
json.dump(vars(config), open(JSON_CONFIG_PATH, 'w+'))

### Training the Model

In [None]:
history = model.train(X, Y, validation_data=(X_val, Y_val))

In [None]:
print(sorted(list(history.history.keys())))
plt.figure(figsize=(16,5))
plot_history(history,['loss','val_loss'],['mse','val_mse','mae','val_mae']);

### Model Evaluation

In [None]:
_P = model.keras_model.predict(X_val[:5])

In [None]:
plt.figure(figsize=(20,12))
if config.probabilistic:
    _P = _P[...,:(_P.shape[-1]//2)]
plot_some(X_val[:5],Y_val[:5],_P,pmax=99.5)
plt.suptitle('5 example validation patches\n'      
             'top row: input (source),  '          
             'middle row: target (ground truth),  '
             'bottom row: predicted from source');

plt.savefig(os.path.join(MODEL_SAVEDIR, MODEL_NAME, "eval_sample.png"))

### Export to ONNX

In [None]:
input_shape = list(X.shape)
batch_dim = axes_dict(axes)['S']
input_shape[batch_dim] = None
print(input_shape)

In [None]:
input_signature = [
    tf.TensorSpec(
        input_shape, 
        tf.float32, 
        name='patch'
    )
]

In [None]:
onnx_model, _ = tf2onnx.convert.from_keras(
    model.keras_model,
    input_signature,
    opset=13
)

In [None]:
onnx.save(onnx_model, os.path.join(MODEL_SAVEDIR, MODEL_NAME, f"{MODEL_NAME}.onnx"))