In [None]:
from model import MLPNet
import tf_keras as keras
import pandas as pd
import tensorflow as tf
import crystal_loader
from tqdm import tqdm
import tqdm.keras

from symmetry import *
import dill
from sklearn.preprocessing import StandardScaler
import h5py
from sklearn.model_selection import train_test_split


In [None]:
features_path = "../pickles/TiO2_2015_features.h5"
labels_path = "../pickles/TiO2_2015_labels.h5"

with h5py.File(features_path, "r") as f:
    features = [f[f"array_{i}"][:] for i in range(len(f))]

with h5py.File(labels_path, "r") as f:
    labels = [f[f"array_{i}"][:] for i in range(len(f))]

labels = np.array(labels)

In [None]:
layers = [keras.layers.Dense(400, activation="relu"), keras.layers.Dense(200, activation="relu")] + list([keras.layers.Dense(50, activation="relu") for i in range(10)])

MLP1 = MLPNet(layers=layers, n_syms=70)

MLP1.compile(
    optimizer = keras.optimizers.Adam(learning_rate=0.0003),
    loss = keras.losses.MeanSquaredError(),
    metrics = [keras.metrics.RootMeanSquaredError()]
)


In [None]:
def scale_ragged(features):
    stacked = np.vstack(features)
    SSC = StandardScaler().fit(stacked)
    scaled_features = [SSC.transform(struct) for struct in features]

    return scaled_features

scaled_features = scale_ragged(features)

In [None]:
Xtrain, Xtest, y_train, y_test = train_test_split(scaled_features, labels, shuffle=True, random_state=12, test_size=0.4)
Xval, Xtest, y_val, y_test = train_test_split(Xtest, y_test, shuffle=True, random_state=12, test_size=0.5)

Xtrain = tf.ragged.constant(Xtrain, ragged_rank=1, inner_shape=(70,))
Xval = tf.ragged.constant(Xval, ragged_rank=1, inner_shape=(70,))
Xtest = tf.ragged.constant(Xtest, ragged_rank=1, inner_shape=(70,))

In [None]:
res = MLP1.fit(
    Xtrain, y_train,
    batch_size = 50,
    epochs = 8000,
    # validation_data = (Xval, yval),
    verbose = 0,
    validation_data=(Xval, y_val),
    callbacks=[tqdm.keras.TqdmCallback()]
)

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots()
ax.plot(res.epoch, res.history["loss"], label="training")
ax.plot(res.epoch, res.history["val_loss"], label="validation")

ax.legend()