In [None]:
import glob

import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
from src.classes import HitBox, TankProperties32x2
from src.dataprocessing import get_measured_potential
from src.vae_model import vae_model
from src.visualization import (
    plot_latent_space_with_tsne,
    plot_loss_history,
    plot_meas_coords,
    plot_meas_coords_wball,
    plot_mesh,
    plot_rball,
    plot_voxel,
)
from src.voxel_util import (
    gen_voxel_ball_data,
    gen_voxel_brick_data,
    random_voxel_ball,
    random_voxel_brick,
    read_json_file,
    scale_realworld_to_intdomain,
    voxel_ball,
)
from tensorflow.keras.optimizers import Adam
from tqdm import tqdm

In [None]:
# sources
# https://github.com/IsaacGuan/3D-VAE/blob/master/train.py
# https://github.com/ffriese/voxel_vae/tree/master
# https://github.com/ajbrock/Generative-and-Discriminative-Voxel-Modeling/blob/master/Generative/VAE.py

In [None]:
ball = True
if ball:
    d = 4
    plot_voxel(random_voxel_ball(d=d))
else:
    d_xyz = [5, 5, 5]
    plot_voxel(random_voxel_brick(d_xyz=d_xyz))

In [None]:
n_gammas = 500
if ball:
    gamma_train = gen_voxel_ball_data(num=n_gammas, d=d)
else:
    gamma_train = gen_voxel_brick_data(num=n_gammas, d_xyz=d_xyz)

print(gamma_train.shape)

In [None]:
ns = 100
for idx in np.random.randint(0, ns, size=3):
    plot_voxel(gamma_train[idx, :, :, :, 0])

In [None]:
vae = vae_model(input_shape=(32, 32, 32, 1), beta=1.02)
vae.compile(optimizer=Adam())
epoch_num = 250
batch_size = 128
# vae.build(input_shape=(32, 32, 32, 1))
# vae.summary()
history = vae.fit(gamma_train, epochs=epoch_num, batch_size=batch_size)
# callbacks = [LearningRateScheduler(learning_rate_scheduler)]

In [None]:
plot_loss_history(history)

plt.figure(figsize=(6, 4))
plt.title("Kl Loss History")
plt.plot(history.history["kl_loss"])
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.grid()
plt.tight_layout()
plt.show()

In [None]:
if ball:
    gamma_test = gen_voxel_ball_data(num=100, d=d)
else:
    gamma_test = gen_voxel_brick_data(num=100, d_xyz=d_xyz)

In [None]:
z_mean, z_log_var, z = vae.encoder.predict(gamma_test)
plot_latent_space_with_tsne(z_log_var)

In [None]:
if ball:
    sgl_data = gen_voxel_ball_data(1, d=d)
else:
    sgl_data = gen_voxel_brick_data(1, d_xyz=d_xyz)
plot_voxel(sgl_data[0, :, :, :, 0], azim=0, elev=30)
z_mean, z_log_var, z = vae.encoder.predict(sgl_data)

In [None]:
sgl_data.shape

In [None]:
sgl_pred = vae.decoder.predict(z)
sgl_pred = np.squeeze(sgl_pred, axis=4)
sgl_pred = np.squeeze(sgl_pred, axis=0)
# sgl_pred = np.abs(np.round(sgl_pred))
sgl_pred = np.clip(sgl_pred, a_min=0, a_max=1)

plot_voxel(sgl_pred, azim=0, elev=30)

In [None]:
vae.save_weights("models/vae_weights_beta.h5")
vae.save("models/vae_beta.keras")

## Load VAE

## Build Mapper

In [None]:
l_path = "measurements/acryl_skip_8_d_30/"

json_data = read_json_file(l_path + "info.json")
hitbox = HitBox(**json_data["HitBox"])
tank = TankProperties32x2(**json_data["TankProperties32x2"])

In [None]:
pots = list()
perm = list()

for path in tqdm(np.sort(glob.glob(l_path + "data/*"))):
    tmp = np.load(path, allow_pickle=True)
    pots.append(np.abs(get_measured_potential(tmp, shape_type="vector")))
    ball = tmp["anomaly"].tolist()
    coordinate = [ball.y, ball.x, ball.z]
    d = 4
    x0, y0, z0 = scale_realworld_to_intdomain(coordinate, hitbox, d=d)
    perm.append(voxel_ball(y0, x0, z0, d=d))
perm = np.array(perm)
_, _, z = vae.encoder.predict(np.expand_dims(perm, axis=4))
perm = z
del z
pots = np.array(pots)

In [None]:
pots_train, pots_test, perm_train, perm_test = train_test_split(
    pots, perm, test_size=0.05, random_state=42
)

In [None]:
print(pots_train.shape, pots_test.shape, perm_train.shape, perm_test.shape)

In [None]:
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam


def mapper(input_shape=(4096,), latent_dim=10):
    mapper_inputs = layers.Input(shape=input_shape)
    x = layers.Reshape((4096, 1))(mapper_inputs)
    x = layers.Conv1D(32, strides=2, kernel_size=9)(x)
    x = layers.Conv1D(16, strides=4, kernel_size=9)(x)
    x = layers.Conv1D(8, strides=4, kernel_size=9)(x)
    x = layers.Conv1D(4, strides=4, kernel_size=9)(x)
    x = layers.Flatten()(x)
    x = layers.Dense(latent_dim, activation="relu")(x)

    return Model(mapper_inputs, x)


mapper = mapper()
mapper.compile(Adam(), loss=tf.keras.losses.mean_squared_error)
mapper.summary()

In [None]:
history_mapper = mapper.fit(pots_train, perm_train, epochs=90, batch_size=64)

In [None]:
plt.plot(history_mapper.history["loss"])
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.grid()
plt.show()