In [None]:
import os, logging, json
from datetime import datetime

import mlflow
import numpy as np
import tensorflow as tf
from cloudpickle import Pickler
import onnx
from mlflow.onnx import log_model as onnx_log_model
from mlflow.models import infer_signature
from csbdeep.models import Config, CARE

from flame.utils import _compress_dict_fields
from flame import CAREInferenceSession

In [None]:
MODEL_DIR = "/mnt/d/models/CARE/test_model"
ONNX_MODEL_PATH = os.path.join(MODEL_DIR, "test_model.onnx")
JSON_CONFIG_PATH = os.path.join(MODEL_DIR, "model_config.json")

In [None]:
config_json = json.load(open(JSON_CONFIG_PATH, 'r'))

In [None]:
engine = CAREInferenceSession(
    model_path = ONNX_MODEL_PATH,
    model_config_path= JSON_CONFIG_PATH,
    dataset_config_path= JSON_CONFIG_PATH
)

In [None]:
onnx_model = onnx.load(ONNX_MODEL_PATH)

In [None]:
import matplotlib.pyplot as plt

In [None]:
X = np.random.uniform(low=0, high=1, size=(1*1200*1200*3)).reshape((1,1200,1200,3))

In [None]:
mlflow.set_tracking_uri(uri="http://127.0.0.1:5050")
EXPERIMENT_NAME = "CARE Denoising 3 Channel"
mlflow.set_experiment(EXPERIMENT_NAME)

In [None]:
# Log the hyperparameters
mlflow.log_params(_compress_dict_fields(config_json))

# Log the validation performance metrics
# mlflow.log_metric("val_loss", np.min(val_loss))
# mlflow.log_metric("val_mae", np.min(val_mae))
# mlflow.log_metric("val_mse", np.min(val_mse))

model_info = onnx_log_model(
    onnx_model=onnx_model,
    artifact_path="model",
    conda_env=os.path.join(os.getcwd(), "environment.yml"),
    input_example=X[[0],...],

    # If given, create a model version under registered_model_name
    # also creating a registered model if one with the given name does not exist.
    # registered_model_name=_,

    metadata=config_json,
    signature=infer_signature(X[[0],...], engine.predict(X[[0],...])),
    onnx_execution_providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
)

mlflow.log_artifacts(
    local_dir=os.path.join(os.getcwd(), 'flame'),
    artifact_path="flame"
)

mlflow.log_artifact(
    local_path=JSON_CONFIG_PATH,
    artifact_path="model_config"
)

mlflow.end_run()

In [None]:
engine