In [None]:
from importlib import reload
import model
model = reload(model)

import os

import keras
import pandas as pd
import tensorflow as tf
import crystal_loader
from tqdm import tqdm
import tqdm.keras
import numpy as np
from symmetry import *
import dill
from sklearn.preprocessing import StandardScaler
import h5py
from sklearn.model_selection import train_test_split
tf.config.run_functions_eagerly(False)

import tensorflow as tf
from keras.losses import MeanSquaredError

import sys

if not sys.warnoptions:
    import warnings
    warnings.simplefilter("ignore")

try:
  physical_devices = tf.config.list_physical_devices('GPU')

  print("GPU:", tf.config.list_physical_devices('GPU'))
  print("Num GPUs:", len(physical_devices))
  tf.config.experimental.set_memory_growth(physical_devices[0], True)
except:
  print("No GPU")

In [None]:
print(tf.__version__)

In [None]:
dset_name = "TiO2_2015_angfixed_x3"

features_path = f"../pickles/{dset_name}_features.h5"
labels_path = f"../pickles/{dset_name}_labeldata.h5"


with h5py.File(features_path, "r") as f:
    features = [f[f"array_{i}"][:] for i in range(len(f))]

label_df = pd.read_hdf(labels_path, key="labels")
n_atoms = pd.read_hdf(labels_path, key="n_atoms").to_numpy().reshape(-1, 1)

print(label_df.columns)

In [None]:
# select label to use
labels = label_df["cohesive_energy"].to_numpy().reshape(-1, 1)


In [None]:
layers = [keras.layers.Dense(10, activation="relu"),
          keras.layers.Dense(10, activation="relu")]

MLP1 = model.MLPNet(layers=layers, N_features=70, ragged_processing=False, unitwise_loss_model=None)

MLP1.compile(
    optimizer = keras.optimizers.Adam(learning_rate=0.0004),
    metrics = [keras.metrics.RootMeanSquaredError(), tf.keras.metrics.MeanAbsoluteError()],
    loss="mse"
)


In [None]:
def scale_ragged(features):
    stacked = np.vstack(features)
    SSC = StandardScaler().fit(stacked)
    scaled_features = [SSC.transform(struct) for struct in features]

    return scaled_features

scaled_features = scale_ragged(features)

In [None]:
Xtrain, Xtest, y_train, y_test, c_train, c_test = train_test_split(scaled_features, labels, n_atoms, shuffle=True, random_state=12, test_size=0.2)
Xval, Xtest, y_val, y_test, c_val, c_test = train_test_split(Xtest, y_test, c_test, shuffle=True, random_state=12, test_size=0.5)

Xtrain = tf.ragged.constant(Xtrain, ragged_rank=1, inner_shape=(70,))
Xval = tf.ragged.constant(Xval, ragged_rank=1, inner_shape=(70,))
Xtest = tf.ragged.constant(Xtest, ragged_rank=1, inner_shape=(70,))

In [None]:
checkpoint_filepath = "./training_v2x3_weighted/cp-{epoch:04d}.weights.h5"
checkpoint_dir = os.path.dirname(checkpoint_filepath)
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_loss',
    mode='max',
    save_best_only=False)

In [None]:
MLP1.built=True
loaded_model_name = "x3v2_weighted"
mpath = f'./saved_models/{loaded_model_name}.weights.h5'
print(os.getcwd())
MLP1.load_weights(mpath)

In [None]:
res = MLP1.fit(
    Xtrain, y_train,
    batch_size = 32,
    epochs = 1200,
    validation_freq = 1,
    verbose = 0,
    validation_data=(Xval, y_val),
    callbacks=[model_checkpoint_callback]
)

In [None]:
# open a previous  model's training history
modname = "x3v1_weighted"
with open(f'./saved_models/{modname}_history.dill', 'rb') as file_pi:
    epochs, hist = dill.load(file_pi)

# or use current model's' history

In [None]:
import matplotlib.pyplot as plt

plt.close(1)
fig, ax = plt.subplots(1, 2, num=1, figsize=(10, 5))
ax[0].plot(res.epoch, res.history["loss"], label="training")
ax[0].plot(res.epoch, res.history["val_loss"], label="validation")

ax[0].set_ylabel("loss (MSE)", fontsize=16)

ax[1].plot(res.epoch, res.history["root_mean_squared_error"], label="training")
ax[1].plot(res.epoch, res.history["val_root_mean_squared_error"], label="validation")
ax[1].set_yscale("log")
ax[0].set_yscale("log")
ax[1].set_ylabel("loss (RMSE), eV", fontsize=16)



ax[0].set_xlabel("epoch", fontsize=16)
ax[1].set_xlabel("epoch", fontsize=16)

ax[0].legend()
ax[1].legend()

In [None]:
import MLPtools

y_pred = np.squeeze(MLP1.predict(Xtest))
y_exam = np.squeeze(y_test)
target_counts = np.squeeze(c_test)

y_pred_atomic = MLPtools.atomic_energies(y_pred, target_counts)
y_exam_atomic = MLPtools.atomic_energies(y_exam, target_counts)

In [None]:
print(Xtest[0])

In [None]:
# roughly bin atomic size classes
size_classes = np.unique(np.floor(5 * np.log(np.unique(target_counts))))

filters = []
for size_class in size_classes:
    class_filt = np.floor(5 * np.log(target_counts)) == size_class
    filters.append(class_filt.flatten())

In [None]:
fig, ax = plt.subplots(2, 4, num=2, figsize=(20, 10), tight_layout=True)

for axes, class_filt, size_class in zip(ax[0,:], filters, size_classes):
    pred_species = y_pred[class_filt]
    exam_species = y_exam[class_filt]
    axes.set_title(fr"$N_{{atoms}} \sim {np.round(np.exp(size_class / 5)):.0f}$")
    axes.set_xlabel("Cohesive energy (eV)")
    axes.hist(exam_species, bins="auto", label="true", alpha=0.6)
    axes.hist(pred_species, bins="auto", label="predicted", alpha=0.7)

    axes.legend()

for axes, class_filt, size_class in zip(ax[1,:], filters, size_classes):
    pred_species = y_pred_atomic[class_filt]
    exam_species = y_exam_atomic[class_filt]
    axes.set_title(fr"$N_{{atoms}} \sim {np.round(np.exp(size_class / 5)):.0f}$")
    axes.set_xlabel("Cohesive energy per atom (eV/atom)")
    axes.hist(exam_species, bins="auto", label="true", alpha=0.6)
    axes.hist(pred_species, bins="auto", label="predicted", alpha=0.7)

    axes.legend()


In [None]:
fig, ax = plt.subplots(num=3, figsize=(10, 6), tight_layout=True)

ax.scatter(y_exam, y_pred, marker="o", edgecolor="k")
ax.set_title("Prediction vs truth plot (total cohesive energy)")

fig, ax = plt.subplots(2, 4, num=2, figsize=(20, 10), tight_layout=True)
fig.suptitle("Predicted vs Actual plots in various energy ranges")

for axes, class_filt, size_class in zip(ax[0,:], filters, size_classes):
    pred_species = y_pred[class_filt]
    exam_species = y_exam[class_filt]
    species_counts = target_counts[class_filt]
    
    axes.set_title(fr"Total Cohesive Energy: $N_{{atoms}} \sim {np.round(np.exp(size_class / 5)):.0f}$")
    axes.set_xlabel(f"true (eV) (N={len(exam_species)} training instances)")
    axes.set_ylabel("predicted (eV)")
    axes.set_aspect("equal")

    axes.scatter(exam_species, pred_species, lw=0, marker=".", c=np.arange(np.size(exam_species)))
    
    axmin = np.min([np.min(exam_species), np.min(pred_species)])
    axmax = np.max([np.max(exam_species), np.max(pred_species)])
    x = np.linspace(axmin, axmax)
    axes.plot(x, x, c="r", linestyle="--", lw=0.5)
    axes.set_xlim(axmin, axmax)
    axes.set_ylim(axmin, axmax)

    error = np.sqrt(MLPtools.atomic_MSE(exam_species, pred_species, species_counts)) * 1000
    print(f"SC {np.round(np.exp(size_class / 5)):.0f}: RMSE {error:.4f} meV/atom")

for axes, class_filt, size_class in zip(ax[1,:], filters, size_classes):

    pred_species = y_pred_atomic[class_filt]
    exam_species = y_exam_atomic[class_filt]
    species_counts = target_counts[class_filt]
    
    axes.set_title(fr"Ev/atom: $N_{{atoms}} \sim {np.round(np.exp(size_class / 5)):.0f}$")
    axes.set_xlabel(f"true (eV/atom) (N={len(exam_species)} training instances)")
    axes.set_ylabel("predicted (eV/atom)")
    axes.set_aspect("equal")

    axes.scatter(exam_species, pred_species, lw=0, marker=".", c=np.arange(np.size(exam_species)))

    axmin = np.min([np.min(exam_species), np.min(pred_species)])
    axmax = np.max([np.max(exam_species), np.max(pred_species)])
    x = np.linspace(axmin, axmax)
    axes.plot(x, x, c="r", linestyle="--", lw=0.5)
    axes.set_xlim(axmin, axmax)
    axes.set_ylim(axmin, axmax)


In [None]:
MLP1.get_subnet().summary()

In [None]:
# prompt: save MLP1's weights and training history
import os
print(os.getcwd())

mname = "x3v2_weighted"
# Save MLP1's weights
MLP1.save_weights(f'./saved_models/{mname}.weights.h5')

# Save training history
with open(f'./saved_models/{mname}_history.dill', 'wb') as file_pi:
    dill.dump((res.epoch, res.history), file_pi)
