In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
from pathlib import Path
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.losses import MSE
import utils
from social_dynamics.autoencoder_utils import create_dataset

os.environ['CUDA_VISIBLE_DEVICES'] = '-1'


In [None]:
if "maler" in str(Path(".").resolve()):
    absolute_path = Path("C:/Users/maler/Federico/Lavoro/ZIB/experiments_results") 
elif "luziehel" in str(Path(".").resolve()):
    absolute_path = Path("C:/Users/luziehel/Code/experiments_results")

experiment_series_folder = absolute_path.joinpath("2_opt-h_luzie-gamma_delta_expl")
experiment_series_folder = absolute_path.joinpath("2_opt-h_luzie-alpha_beta_gamma_delta_expl")
experiment_series_folder = absolute_path.joinpath("2_opt-h_luzie-alpha_beta_gamma_delta_expl-0.0001t")

experiment_runs_folders = [folder for folder in experiment_series_folder.iterdir()]

model_type = "dnn"
model_path = Path("C:/Users/maler/Federico/Lavoro/ZIB/autoencoder_clustering/autoencoder_model/model.h5")

In [None]:
model = tf.keras.models.load_model(model_path)
autoencoder = Model(model.input, model.get_layer('embedding').output)

In [None]:
folder = np.random.choice(experiment_runs_folders)
state_metric = utils.load_metrics(folder)["StateMetric"][::4]
inputs = state_metric.flatten()
assert np.all(np.reshape(inputs, state_metric.shape) == state_metric)

preds = model.predict(inputs[np.newaxis])

In [None]:
utils.plot_agents_option(state_metric)

In [None]:
utils.plot_agents_option(np.reshape(preds, state_metric.shape))

# Random plots

In [None]:
folders = np.random.choice(experiment_runs_folders, size=10)
state_metrics = [utils.load_metrics(folder)["StateMetric"][::4] for folder in folders]
inputs = np.array([state_metric.flatten() for state_metric in state_metrics])

preds = model.predict(inputs)

n_agents, n_options = state_metrics[0].shape[1:]

plt.figure(figsize=(20, 60))
for i in range(len(state_metrics)):

    for option in range(n_options):
        plt.subplot(n_options*len(state_metrics), 2, 1 + i*4 + option*2)
        for agent in range(n_agents):
            plt.plot(state_metrics[i][:, agent, option], label=str(agent))
            #plt.legend()
            #plt.title("Option "+ str(option+1))
    for option in range(n_options):
        plt.subplot(n_options*len(state_metrics), 2, 2 + i*4 + option*2)
        for agent in range(n_agents):
            to_plot = np.reshape(preds[i], state_metrics[0].shape)
            plt.plot(to_plot[:, agent, option], label=str(agent))

plt.tight_layout()


plt.show()

# Testing

In [None]:
dataset = create_dataset(experiment_series_folder, model_type=model_type, downsampling=4)
file_pattern = str(experiment_series_folder) + "/*/StateMetric/results_t200000.npy"
file_dataset = tf.data.Dataset.list_files(file_pattern=file_pattern, shuffle=False)

In [None]:
y_preds = model.predict(dataset.batch(128).prefetch(tf.data.experimental.AUTOTUNE))
y_true = np.array(list(dataset.as_numpy_iterator()))[:, 1, :]
mses = MSE(y_true, y_preds)

In [None]:
worse_preds = np.argsort(mses)[-100:]
worse_files = np.array(list(file_dataset.as_numpy_iterator()))[worse_preds]

In [None]:
files = np.random.choice(worse_files, size=10)
state_metrics = [np.load(file)[::4] for file in files]
inputs = np.array([state_metric.flatten() for state_metric in state_metrics])

preds = model.predict(inputs)

n_agents, n_options = state_metrics[0].shape[1:]

plt.figure(figsize=(20, 60))
for i in range(len(state_metrics)):

    for option in range(n_options):
        plt.subplot(n_options*len(state_metrics), 2, 1 + i*4 + option*2)
        for agent in range(n_agents):
            plt.plot(state_metrics[i][:, agent, option], label=str(agent))
            #plt.legend()
            #plt.title("Option "+ str(option+1))
    for option in range(n_options):
        plt.subplot(n_options*len(state_metrics), 2, 2 + i*4 + option*2)
        for agent in range(n_agents):
            to_plot = np.reshape(preds[i], state_metrics[0].shape)
            plt.plot(to_plot[:, agent, option], label=str(agent))

plt.tight_layout()


plt.show()