In [None]:
import os
from datetime import datetime
import json
import logging

from csbdeep.io import load_training_data
from csbdeep.models import Config, CARE
from csbdeep.utils import axes_dict, plot_history, plot_some

from matplotlib import pyplot as plt
import tensorflow as tf
import tf2onnx
import onnx
import mlflow

from flame.error import CAREInferenceError

In [None]:
mlflow.set_tracking_uri(uri="http://127.0.0.1:8080")
EXPERIMENT_NAME = "CARE Denoising 1 Channel"
mlflow.set_experiment(EXPERIMENT_NAME)

In [None]:
DATA_DIREC = "/mnt/d/data/processed/20250527_112I_denoising_5to40F"
PATCH_CONFIG_JSON = os.path.join(DATA_DIREC, "patch_config.json")
SAVE_DIREC = "/mnt/d/models/"
UNET_KERN_SIZE = 3
TRAIN_BATCH_SIZE = 16
RANDOM_STATE = 8888

In [None]:
logger = logging.getLogger("main")
logging.basicConfig(
    filename=f"{datetime.now().strftime('%Y%m%d-%H%M%S')}_logger.log",
    encoding="utf-8",
    level=logging.DEBUG
)

In [None]:
# ensure that data directory and patch config json paths are valid
assert os.path.isdir(DATA_DIREC)
assert os.path.isfile(PATCH_CONFIG_JSON)

### Creating training config by building on patch_config

In [None]:
try:
    config_json = json.load(open(PATCH_CONFIG_JSON, 'r'))
    logger.info(f"Successfully loaded patch config from {PATCH_CONFIG_JSON}")
except Exception as e:
    logger.error(f"Could not load patch config json from {PATCH_CONFIG_JSON}.\nERROR: {e}")
    raise CAREInferenceError(f"Could not load patch config json from {PATCH_CONFIG_JSON}.\nERROR: {e}")

In [None]:
try:
    MODEL_NAME = f"FLAME_CARE_" \
        + f"{config_json['FLAME_Dataset']['input']['n_frames']}F" \
        + f"-" \
        + f"{config_json["FLAME_Dataset"]['output']['n_frames']}F"
    logger.info(f"Training a model with NAME: '{MODEL_NAME}'...")
except Exception as e:
    logger.error(f"Failed to dynamically load model name.\nERROR: {e}")
    raise CAREInferenceError(f"Failed to dynamically load model name.\nERROR: {e}")


In [None]:
try:
    RUN_ID = mlflow.search_runs(MODEL_NAME).shape[0]
    MODEL_DIREC = os.path.join(SAVE_DIREC, MODEL_NAME, str(RUN_ID))
    # exist_ok being True *SHOULD* (?) be fine because RUN_ID will not iterate upwards unless training either started or finished.
    os.makedirs(MODEL_DIREC, exist_ok = True)
    logger.info(f"Training run id is {RUN_ID}.")
    logger.info(f"Model saving to {MODEL_DIREC}")
except Exception as e:
    logger.error(f"Failed to load run id and/or set up model save directory.\nERROR: {e}")
    raise CAREInferenceError(f"Failed to load run id and/or set up model save directory.\nERROR: {e}")

In [None]:
config_json['Train_Config'] = {
    'npz_path': os.path.join(DATA_DIREC, config_json['Patch_Config']['name']),
    'name': MODEL_NAME,
    'model_direc': MODEL_DIREC,
    'unet_kern_size': UNET_KERN_SIZE,
    'train_batch_size': TRAIN_BATCH_SIZE,
    'random_state': RANDOM_STATE,
}

In [None]:
# verifying npz path...
NPZ_PATH = config_json['Train_Config']['npz_path']
assert os.path.isfile(config_json['Train_Config']['npz_path']), f"NPZ path {NPZ_PATH} is not a file"

### Training and Validation Data

In [None]:
(X, Y), (X_val, Y_val), axes = load_training_data(
    NPZ_PATH,
    validation_split=0.1,
    verbose=True
)

In [None]:
c = axes_dict(axes)['C']
channels_in, channels_out = X.shape[c], Y.shape[c]

### CARE Model

In [None]:
config_json['CARE_Model'] = {
    'name': MODEL_NAME,
    'experiment_name': EXPERIMENT_NAME,
    'run_id': RUN_ID,
    'base_dir': SAVE_DIREC,
    'run_dir': MODEL_DIREC
}
config_json['CARE_Model']['CSBDeep_Config'] = {
    'axes': axes,
    'n_channel_in': channels_in,
    'n_channel_out': channels_out,
    'probabilistic': False, # default from CSBDeep
    'allow_new_parameters': False, # default from CSBDeep
    'unet_kern_size': UNET_KERN_SIZE,
    'train_batch_size': TRAIN_BATCH_SIZE,
}

In [None]:
config = Config(
    **config_json['CARE_Model']['CSBDeep_Config']
)

config_json['CARE_Model']['Model_Arch'] = vars(config)

In [None]:
JSON_CONFIG_PATH = os.path.join(MODEL_DIREC, "model_config.json")
json.dump(config_json, open(JSON_CONFIG_PATH, 'w+'))

### Training the Model

In [None]:
model = CARE(
    config,
    str(RUN_ID),
    basedir=os.path.join(SAVE_DIREC, MODEL_NAME)
)

In [None]:
history = model.train(X, Y, validation_data=(X_val, Y_val))

In [None]:
print(sorted(list(history.history.keys())))
plt.figure(figsize=(16,5))
plot_history(history,['loss','val_loss'],['mse','val_mse','mae','val_mae']);
plt.savefig(os.path.join(MODEL_DIREC, "training_history.png"))

### Model Evaluation from Validation Set

In [None]:
_P = model.keras_model.predict(X_val[:5])

In [None]:
plt.figure(figsize=(20,12))
if config.probabilistic:
    _P = _P[...,:(_P.shape[-1]//2)]
plot_some(X_val[:5],Y_val[:5],_P,pmax=99.5)
plt.suptitle('5 example validation patches\n'      
             'top row: input (source),  '          
             'middle row: target (ground truth),  '
             'bottom row: predicted from source');

plt.savefig(os.path.join(MODEL_DIREC, "val_set_predict_sample.png"))

### Quantitative Model Evaluation from Test Set

### Logging Model in MLFlow Database

### Export to ONNX

In [None]:
input_shape = list(X.shape)
batch_dim = axes_dict(axes)['S']
input_shape[batch_dim] = None
print(input_shape)

In [None]:
input_signature = [
    tf.TensorSpec(
        input_shape, 
        tf.float32, 
        name='patch'
    )
]

In [None]:
onnx_model, _ = tf2onnx.convert.from_keras(
    model.keras_model,
    input_signature,
    opset=13
)

In [None]:
onnx.save(onnx_model, os.path.join(MODEL_SAVEDIR, MODEL_NAME, f"{MODEL_NAME}.onnx"))