In [1]:
# %matplotlib inline
import datetime
import os
import seaborn as sns

from keras.optimizers import Adam
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras.losses import MeanSquaredError
import tensorflow as tf

import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt

In [2]:
import matplotlib.colors as colors

def Model(m, w):
    H, W, L = m.shape
    m = m.T
    c = ["#D1FEFE", "#D1FEFE", "#00FEF9", "#00FDFE", "#50FB7F", "#D3F821", "#FFDE00", "#FF9D00", "#F03A00", "#E10000"]
    x, y, z = np.indices((L, W, H))
    model = (x < 0) & (y < 0) & (z < 0)
    color = np.empty(m.shape, dtype=object)
    for i in range(L):
        for j in range(W):
            for k in range(H):
                if m[i][j][k] >= w:
                    cube = (x > i-1) & (x <= i)& (y > j-1) & (y <= j) & (z > k-1) & (z <= k)
                    color[cube] = c[int(round(10*m[i][j][k]))-1]
                    model = model | cube

    fig = plt.figure(figsize = (20, 10))
    ax = fig.add_axes(Axes3D(fig))
    ax.voxels(model, facecolors=color, edgecolors='w', linewidth=0.5)

    plt.xticks(np.arange(L+1), ['-400']+[" "]*31+['0']+[" "]*31+['400'])
    ax.set_xlabel('Easting (m)', labelpad=20)
    plt.yticks(np.arange(W+1), ['-400']+[" "]*31+['0']+[" "]*31+['400'])
    ax.set_ylabel('Northing (m)', labelpad=15)
    ax.set_zticks(np.arange(H+1))
    ax.set_zticklabels(['0']+[" "]*15+['200']+[" "]*15+['400'])
    ax.set_zlabel('Depth (m)')
    ax.invert_zaxis()
    plt.show()

def colormap():
    cdict = ["#F2F2F2", "#D1FEFE", "#00FEF9", "#00FDFE", "#50FB7F", "#D3F821", "#FFDE00", "#FF9D00", "#F03A00", "#E10000"]
    return colors.ListedColormap(cdict, 'indexed')

def plot_xoy(model, index, factor=0.1, ylabel=True):
    ax = plt.gca()
    model = np.where(model>factor, model, 0)
    plt.imshow(model[:, :, index].T, cmap=colormap())
    ax.invert_yaxis()
    plt.xticks(np.arange(0, 65, 32), ('-400', '0', '400'))
    plt.xlabel('Easting (km)')
    plt.yticks(np.arange(0, 65, 32), ('-400', '0', '400'))
    if ylabel:
        plt.ylabel('Northing (km)')
    plt.tick_params(bottom=False, top=False, left=False, right=False)

In [None]:
# from utils.A_network_model import create_conv_autoencoder_with_skip_connections
from utils.A_network_model_imporved import create_conv_autoencoder_with_skip_connections
# from utils.C_network_model import create_conv_autoencoder_with_skip_connections

In [None]:
X_test = np.load("data/test_model_data.npy")
y_test = np.load("data/test_model_label.npy")

# X_train = (X_train - X_train.min()) / (X_train.max() - X_train.min())

In [None]:
autoencoder_with_skip = create_conv_autoencoder_with_skip_connections()

In [None]:
model_dir = "./models/20231129-135011/"
autoencoder_with_skip.load_weights(os.path.join(model_dir, 'final_model.h5'))
autoencoder_with_skip.compile(optimizer='adam', loss=MeanSquaredError())

In [None]:
predict = autoencoder_with_skip.predict(X_test)
eval_loss = autoencoder_with_skip.evaluate(X_test, y_test)