In [None]:
import os
from datetime import datetime
import json
import logging

from csbdeep.io import load_training_data
from csbdeep.models import Config, CARE
from csbdeep.utils import axes_dict, plot_history, plot_some

from matplotlib import pyplot as plt
import tensorflow as tf
import tf2onnx
import onnx
import numpy as np

import mlflow
from mlflow.models import infer_signature
from mlflow.pyfunc import log_model as pyfunc_log_model
from mlflow.onnx import log_model as onnx_log_model

from flame.error import CAREInferenceError
from flame.utils import _compress_dict_fields, set_up_tracking_server, get_windows_user_path

In [None]:
tf.config.run_functions_eagerly(False)

In [None]:
WINDOWS_USER_PATH = get_windows_user_path()

In [None]:
logger = logging.getLogger("MAIN")
logging.basicConfig(
    filename=os.path.join(WINDOWS_USER_PATH, "logs", f"{datetime.now().strftime('%Y%m%d-%H%M%S')}_training_CARE.log"),
    encoding="utf-8",
    level=logging.DEBUG
)

In [None]:
mlflow_logger = logging.getLogger("MLFLOW")

In [None]:
DATA_DIREC = "/mnt/d/data/processed/20250618_224I_denoising_5to40F"
PATCH_CONFIG_JSON = os.path.join(DATA_DIREC, "1Chan_patch_config.json")
SAVE_DIREC = "/mnt/d/models/"
UNET_KERN_SIZE = 3
TRAIN_BATCH_SIZE = 16
INFER_BATCH_SIZE = 1
RANDOM_STATE = 8888

MLFLOW_SERVER_IP = "127.0.0.1"
MLFLOW_SERVER_PORT = "5050"
MLFLOW_ARTIFACT_PATH = "/mnt/c/SynologyDrive/mlruns"

In [None]:
# ensure that data directory and patch config json paths are valid
assert os.path.isdir(DATA_DIREC)
assert os.path.isfile(PATCH_CONFIG_JSON)

### Creating training config by building on patch_config

In [None]:
try:
    config_json = json.load(open(PATCH_CONFIG_JSON, 'r'))
    logger.info(f"Successfully loaded patch config from {PATCH_CONFIG_JSON}")
except Exception as e:
    logger.error(f"Could not load patch config json from {PATCH_CONFIG_JSON}.\nERROR: {e}")
    raise CAREInferenceError(f"Could not load patch config json from {PATCH_CONFIG_JSON}.\nERROR: {e}")

In [None]:
server_process = set_up_tracking_server(
    ip=MLFLOW_SERVER_IP,
    port=MLFLOW_SERVER_PORT,
    direc=MLFLOW_ARTIFACT_PATH,
    log_path=os.path.join(WINDOWS_USER_PATH, "logs", f"{datetime.now().strftime('%Y%m%d-%H%M%S')}_training_mlflow_server.log")
)

In [None]:
server_process.errors

In [None]:
mlflow.set_tracking_uri(f"http://{MLFLOW_SERVER_IP}:{MLFLOW_SERVER_PORT}")
mlflow.set_registry_uri(f"http://{MLFLOW_SERVER_IP}:{MLFLOW_SERVER_PORT}")
# EXPERIMENT_NAME = f"CARE Denoising {config_json["Patch_Config"]["patch_shape"][-1]} Channel"

In [None]:
EXPERIMENT_NAME = "test_experiment 4"
mlflow_logger.info(f"Using experiment {EXPERIMENT_NAME}")
mlflow.set_experiment(EXPERIMENT_NAME)

In [None]:
mlflow.start_run()
run = mlflow.active_run()
MLFLOW_RUN_ID = run.info.run_id
MLFLOW_RUN_NAME = run.info.run_name
mlflow_logger.info(f"Run {MLFLOW_RUN_NAME} (id {MLFLOW_RUN_ID}) started at {datetime.now().strftime('%Y%m%d-%H%M%S')}")
print(f"Run {MLFLOW_RUN_NAME} (id {MLFLOW_RUN_ID}) started at {datetime.now().strftime('%Y%m%d-%H%M%S')}")


In [None]:
try:
    MODEL_NAME = f"FLAME_CARE_" \
        + f"{config_json['FLAME_Dataset']['input']['n_frames']}" \
        + f"to" \
        + f"{config_json["FLAME_Dataset"]['output']['n_frames']}F"
    logger.info(f"Training a model with NAME: '{MODEL_NAME}'...")
except Exception as e:
    logger.error(f"Failed to dynamically load model name.\nERROR: {e}")
    raise CAREInferenceError(f"Failed to dynamically load model name.\nERROR: {e}")


In [None]:
try:
    MODEL_DIREC = os.path.join(SAVE_DIREC, MODEL_NAME, str(MLFLOW_RUN_ID))
    # exist_ok being True *SHOULD* (?) be fine because RUN_ID will not iterate upwards unless training either started or finished.
    os.makedirs(MODEL_DIREC, exist_ok = True)
    logger.info(f"Training run id is {MLFLOW_RUN_ID}.")
    logger.info(f"Model saving to {MODEL_DIREC}")
except Exception as e:
    logger.error(f"Failed to load run id and/or set up model save directory.\nERROR: {e}")
    raise CAREInferenceError(f"Failed to load run id and/or set up model save directory.\nERROR: {e}")

In [None]:
config_json['Train_Config'] = {
    'npz_path': os.path.join(DATA_DIREC, config_json['Patch_Config']['npz_name']),
    'name': MODEL_NAME,
    'model_direc': MODEL_DIREC,
    'unet_kern_size': UNET_KERN_SIZE,
    'train_batch_size': TRAIN_BATCH_SIZE,
    'random_state': RANDOM_STATE,
}

In [None]:
_compress_dict_fields(config_json)

In [None]:
# verifying npz path...
NPZ_PATH = config_json['Train_Config']['npz_path']
assert os.path.isfile(config_json['Train_Config']['npz_path']), f"NPZ path {NPZ_PATH} is not a file"

### Training and Validation Data

In [None]:
(X, Y), (X_val, Y_val), axes = load_training_data(
    NPZ_PATH,
    validation_split=0.1,
    verbose=True
)

In [None]:
c = axes_dict(axes)['C']
channels_in, channels_out = X.shape[c], Y.shape[c]

### CARE Model

In [None]:
config_json['CARE_Model'] = {
    'name': MODEL_NAME,
    'experiment_name': EXPERIMENT_NAME,
    'run_id': MLFLOW_RUN_ID,
    'run_name': MLFLOW_RUN_NAME,
    'base_dir': SAVE_DIREC,
    'run_dir': MODEL_DIREC
}
config_json['CARE_Model']['CSBDeep_Config'] = {
    'axes': axes,
    'n_channel_in': channels_in,
    'n_channel_out': channels_out,
    'probabilistic': False, # default from CSBDeep
    'allow_new_parameters': False, # default from CSBDeep
    'unet_kern_size': UNET_KERN_SIZE,
    'train_batch_size': TRAIN_BATCH_SIZE,
    'unet_input_shape': tuple(config_json['Patch_Config']['patch_shape']),
    'allow_new_parameters': True
}

In [None]:
config = Config(
    **config_json['CARE_Model']['CSBDeep_Config']
)

config_json['CARE_Model']['Model_Arch'] = vars(config)

In [None]:
JSON_CONFIG_PATH = os.path.join(MODEL_DIREC, "model_config.json")
json.dump(config_json, open(JSON_CONFIG_PATH, 'w+'))

### Training the Model

In [None]:
model = CARE(
    config,
    str(MLFLOW_RUN_ID),
    basedir=os.path.join(SAVE_DIREC, MODEL_NAME)
)

In [None]:
history = model.train(X, Y, validation_data=(X_val, Y_val), epochs=1)

In [None]:
# model.keras_model.save(os.path.join(MODEL_DIREC, 'saved_model.keras'))

### Some quick visualizations

In [None]:
print(sorted(list(history.history.keys())))
plt.figure(figsize=(16,5))
plot_history(history,['loss','val_loss'],['mse','val_mse','mae','val_mae']);
plt.savefig(os.path.join(MODEL_DIREC, "training_history.png"))

### Model Evaluation from Validation Set

In [None]:
_P = model.keras_model.predict(X_val[:5])

In [None]:
plt.figure(figsize=(20,12))
if config.probabilistic:
    _P = _P[...,:(_P.shape[-1]//2)]
plot_some(X_val[:5],Y_val[:5],_P,pmax=99.5)
plt.suptitle('5 example validation patches\n'      
             'top row: input (source),  '          
             'middle row: target (ground truth),  '
             'bottom row: predicted from source');

plt.savefig(os.path.join(MODEL_DIREC, "val_set_predict_sample.png"))

### Logging Model in MLFlow Database

In [None]:
val_loss = np.min(history.history['val_loss'])
idx = history.history['val_loss'].index(val_loss)
val_mae = history.history['val_mae'][idx]
val_mse = history.history['val_mse'][idx]

### Export to ONNX

In [None]:
input_shape = list(X.shape)
batch_dim = axes_dict(axes)['S']
input_shape[batch_dim] = None
print(input_shape)

In [None]:
input_signature = [
    tf.TensorSpec(
        input_shape, 
        tf.float32, 
        name='patch'
    )
]

In [None]:
onnx_model, _ = tf2onnx.convert.from_keras(
    model.keras_model,
    input_signature,
    opset=13
)

In [None]:
ONNX_PATH = os.path.join(MODEL_DIREC, f"{MODEL_NAME}.onnx")
onnx.save(onnx_model, ONNX_PATH)

In [None]:
from flame.engine import CAREInferenceSession

In [None]:
engine = CAREInferenceSession(
    model_path=ONNX_PATH,
    model_config_path=JSON_CONFIG_PATH,
    dataset_config_path=JSON_CONFIG_PATH
)

In [None]:
X.shape

In [None]:
run.info.artifact_uri

In [None]:
# Log the hyperparameters
mlflow.log_params(_compress_dict_fields(config_json))

# Log the validation performance metrics
mlflow.log_metric("val_loss", np.min(val_loss))
mlflow.log_metric("val_mae", np.min(val_mae))
mlflow.log_metric("val_mse", np.min(val_mse))

model_info = onnx_log_model(
    onnx_model=onnx_model,
    artifact_path="model",
    conda_env=os.path.join(os.getcwd(), "environment_wsl.yml"),
    input_example=X[[0],...],

    # If given, create a model version under registered_model_name, 
    # also creating a registered model if one with the given name does not exist.
    # registered_model_name=_,

    metadata=config_json,
    signature=infer_signature(X[[0],...], engine.predict(X[[0],...])),
    onnx_execution_providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
)

In [None]:
mlflow.log_artifacts(
    local_dir=os.path.join(os.getcwd(), 'flame'),
    artifact_path="flame"
)

mlflow.log_artifact(
    local_path=JSON_CONFIG_PATH,
    artifact_path="model_config"
)

In [None]:
mlflow.end_run()