In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import tensorflow as tf

from src.performance_evaluation import (
    center_of_mass,
    compute_position_error,
    compute_volume_error,
)
from src.util import plot_reconstruction_set, plot_voxel, plot_voxel_c

# import vae
from src.vae import vae_model

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Stage 3

_Build the full reconstruction network architecture_

1. Load the VAE $\Psi : \mathbf{z} \mapsto \hat{\gamma}$
2. Load the mapper $\Xi : \mathbf{u} \mapsto \mathbf{z}$
3. Load the material classificator $\Upsilon : \mathbf{u} \mapsto m$

The final model is described by:

$$\Gamma := \Xi \circ \Psi : \mathbf{u} \mapsto \mathbf{z} \mapsto \hat{\gamma} $$

in parallel with the material classification model:

$$\Upsilon : \mathbf{u} \mapsto m $$

## *i)* Load the VAE

In [None]:
vae = vae_model()
vae.load_weights("models/vaes/vae_21.weights.h5")  # select VAE 21
vae.summary()

Φ = vae.encoder
Ψ = vae.decoder

In [None]:
tmp = np.load("models/testdata_stage3.npz", allow_pickle=True)
X_test, gamma_test, m_test = (
    tmp["X_test"],
    tmp["gamma_test"],
    tmp["m_test"],
)
print(X_test.shape, gamma_test.shape, m_test.shape)

In [None]:
_, _, z_pred = Φ.predict(gamma_test)
γ_pred = Ψ.predict(z_pred)

In [None]:
for rdn in np.random.randint(low=0, high=X_test.shape[0], size=5):
    print("True γ distribution")
    plot_voxel(gamma_test[rdn, :, :, :, 0])
    sgl_pred = np.squeeze(gamma_test, axis=4)[rdn]
    sgl_pred = np.clip(sgl_pred, a_min=0, a_max=1)
    print("Predicted γ distribution")
    plot_voxel(sgl_pred)
    print("----------")

## *ii)* Load the material classification network $\Upsilon$ and the Mapper $\Xi$ 

In [None]:
def Upsilon_model(input_shape=(64, 64, 1), m_dim=1, kernel=3):
    mapper_input = tf.keras.layers.Input(shape=input_shape)
    x = tf.keras.layers.Conv2D(8, kernel, strides=(2, 4), padding="same")(mapper_input)
    x = tf.keras.layers.Conv2D(16, kernel, strides=(2, 4), padding="same")(x)
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(128, activation="relu")(x)
    x = tf.keras.layers.Dropout(0.2)(x)
    x = tf.keras.layers.Dense(16)(x)
    mapper_output = tf.keras.layers.Dense(m_dim, activation="sigmoid")(x)
    return tf.keras.Model(mapper_input, mapper_output)


Y = Upsilon_model()
Y.summary()

In [None]:
Y.load_weights("models/material_mapper.weights.h5")

In [None]:
def mapper_CNN(input_shape=(64, 64, 1), latent_dim=8):
    filters = [4, 8, 16, 32]
    kernels = [(4, 4) for _ in range(4)]
    strides = [(1, 1) for _ in range(4)]
    pools = [(2, 2) for _ in range(4)]

    print(f"{filters=}, kernels{kernels[0]}, strides={strides[0]}, pools={pools[0]}")

    x = tf.keras.layers.Input(shape=input_shape)
    mapper_input = x

    for f, k, s, p in zip(filters, kernels, strides, pools):
        x = tf.keras.layers.Conv2D(
            filters=f, kernel_size=k, strides=s, padding="valid", activation="relu"
        )(x)
        x = tf.keras.layers.MaxPooling2D(pool_size=p)(x)

    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(latent_dim, activation="relu")(x)
    mapper_output = tf.keras.layers.Dense(latent_dim)(x)

    return tf.keras.Model(mapper_input, mapper_output)


Ξ = mapper_CNN()
# Ξ.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), loss="mse")
Ξ.summary()

In [None]:
Ξ.load_weights("models/mappers/mapper_8.weights.h5")

## *iii)* Setup the final reconstruction network

In [None]:
gamma_test = np.squeeze(gamma_test, axis=4)

In [None]:
γ_hat = Ψ.predict(Ξ.predict(X_test))
γ_hat = np.squeeze(γ_hat, axis=4)
γ_hat = np.clip(γ_hat, a_min=0, a_max=1)
γ_hat[γ_hat != 0] = 1
print(γ_hat.shape)

m_pred = np.round(Y.predict(X_test))

In [None]:
plot_reconstruction_set(
    gamma_test,
    m_test,
    γ_hat,
    m_pred,
    save_fig="images/predicted_test_data_results.pdf",
    forced_sel=[1891, 3800, 1534, 7498, 7733],
)

In [None]:
custom_params = {"axes.spines.right": False, "axes.spines.top": False}
sns.set_theme(style="ticks", rc=custom_params)
sns.set_context(context="paper", font_scale=1.4)

In [None]:
volumen_error = compute_volume_error(γ_hat, gamma_test)

In [None]:
plt.figure(figsize=(6, 3))
sns.histplot(volumen_error, bins=25, kde=True, color="#756bb1")
plt.xlabel("Voxel element difference")
plt.xlim([-265, 400])
plt.ylabel("Count")
plt.tight_layout()
plt.savefig("images/voxel_element_deviation.pdf")
plt.savefig("images/voxel_element_deviation.png")
plt.show()

In [None]:
def compute_voxel_err(predicted_voxels, true_voxels):
    com_pred = center_of_mass(predicted_voxels)
    com_true = center_of_mass(true_voxels)

    return np.array(com_true) - np.array(com_pred)


axial_errors = list()
for pred_vxls, test_vxls in zip(γ_hat, gamma_test):
    axial_errors.append(compute_voxel_err(pred_vxls, test_vxls))
axial_errors = np.array(axial_errors)

In [None]:
save = False

voxel_val_max = 32
errors = axial_errors / voxel_val_max * 100
data = {"x-pos": errors[:, 0], "y-pos": errors[:, 1], "z-pos": errors[:, 2]}
df = pd.DataFrame(data)
plt.figure(figsize=(6, 3))
custom_palette = ["#a1d99b", "#a1d99b", "#a1d99b"]
sns.boxplot(data=df, showfliers=False, palette=custom_palette)
# plt.ylim([-1,3])
plt.ylabel("Error (%)")
plt.tight_layout()
plt.savefig("images/reconstruction_axis_error.pdf")
plt.savefig("images/reconstruction_axis_error.png")
plt.show()

___