In [None]:
import config as CONFIG

from csbdeep.io import load_training_data
from csbdeep.models import IsotropicCARE, Config
from csbdeep.utils import axes_dict, plot_some, plot_history

from matplotlib import pyplot as plt

# Load Data

In [None]:
(X,Y), (X_val,Y_val), axes = load_training_data(CONFIG.TRAIN_DATASET_PATH, validation_split=0.1, axes='SCYX', verbose=True)

In [None]:
c = axes_dict(axes)['C']
n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

# Define Model

In [None]:
model_config = Config(axes=axes, 
                      n_channel_in=n_channel_in, 
                      n_channel_out=n_channel_out, 
                      train_batch_size=32, 
                      train_epochs=200,
                      train_checkpoint='DX.hdf5')
vars(model_config)

In [None]:
model = IsotropicCARE(model_config, name='DX', basedir=CONFIG.CHECKPOINT_PATH)

# Train

In [None]:
history = model.train(X=Y, Y=X, validation_data=(Y_val,X_val))

In [None]:
print(sorted(list(history.history.keys())))
plt.figure(figsize=(24,5))
plot_history(history, ['loss', 'val_loss'], ['mse', 'val_mse'], ['lr'])

# Validation

In [None]:
plt.figure(figsize=(12,7))
predict_res = model.keras_model.predict(Y_val[:5])
plot_some(Y_val[:5], X_val[:5], predict_res, pmax=99.5)
plt.suptitle('5 example validation patches\n'
             'top row: input(source), '
             'middle row: target(ground truth), '
             'bottom row: predicted from source.')
plt.show()