In [None]:
import config as CONFIG
import sys
sys.path.append(CONFIG.GLOBAL_MODEL_PATH)
from models.mwunet import mwunet

from csbdeep.io import load_training_data
from csbdeep.utils import plot_history, plot_some
from matplotlib import pyplot as plt

# Load Data

In [None]:
(X,Y), (X_val,Y_val), axes = load_training_data(CONFIG.TRAIN_DATASET_PATH, validation_split=0.1, axes='SCYX', verbose=True)

# Define Model

In [None]:
model = mwunet(input_shape=Y.shape[1:], output_shape=X.shape[1:],
               conv_kernel_size=3, n_filters_per_scale=[8, 16, 32])
model.model.summary()

# Train

In [None]:
history = model.train(X=Y, Y=X, validation_data=(Y_val, X_val),
                      train_epochs=200, train_batch_size=32,
                      train_learning_rate=0.001,
                      base_dir=CONFIG.CHECKPOINT_PATH,
                      name='DX',
                      train_checkpoint='DX.hdf5'
                      )

In [None]:
print(sorted(list(history.history.keys())))
plt.figure(figsize=(24,5))
plot_history(history, ['loss', 'val_loss'], ['mse', 'val_mse'], ['lr'])

# Validation

In [None]:
plt.figure(figsize=(12,7))
predict_res = model.predict(Y_val[:5])
plot_some(Y_val[:5], X_val[:5], predict_res, pmax=99.5)
plt.suptitle('5 example validation patches\n'
             'top row: input(source), '
             'middle row: target(ground truth), '
             'bottom row: predicted from source.')
plt.show()