In [None]:
import argparse
from collections import OrderedDict
import datetime
get_ipython().run_line_magic('matplotlib', 'inline')
import matplotlib.pylab as plt
from numbers import Number
import numpy as np
import pandas as pd
pd.options.display.max_rows = 1500
pd.options.display.max_columns = 200
pd.options.display.width = 1000
pd.set_option('max_colwidth', 400)
# import pdb
import pickle
import pprint as pp
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
from deepsnap.batch import Batch as deepsnap_Batch

import sys, os
sys.path.append(os.path.join(os.path.dirname("__file__"), '..'))
sys.path.append(os.path.join(os.path.dirname("__file__"), '..', '..'))
from le_pde.argparser import arg_parse
from le_pde.datasets.load_dataset import load_data
from le_pde.models import load_model, rollout#, rollout_plasma, plot_vlasov
from le_pde.pytorch_net.util import to_np_array, get_pdict, reshape_weight_to_matrix, ddeepcopy as deepcopy, plot_vectors, record_data, filter_filename, Early_Stopping, str2bool, get_filename_short, print_banner, plot_matrices, get_num_params, init_args, filter_kwargs, to_string, COLOR_LIST
from le_pde.utils import update_legacy_default_hyperparam, EXP_PATH
from le_pde.utils import to_cpu, to_tuple_shape, parse_multi_step, loss_op, get_cholesky_inverse, get_device, get_data_comb

device = torch.device("cpu")

# Functions:

In [None]:
def evaluate(
    model,
    dataset,
    args,
    init_indices,
    n_rollout_steps=100,
    interval=20,
    n_plots_row=6,
    max_shift=50,
    dataset_name=None,
    isplot=0,
):
    def plot_loss_summary(loss_list_dict_all):
        fontsize = 16
        for key, loss_list in loss_list_dict_all.items():
            loss_list_mean = loss_list.mean(0)
            loss_list_std = loss_list.std(0)
            plt.figure(figsize=(8,6))
            for k in range(loss_list.shape[-1]):
                plt.errorbar(np.arange(loss_list_mean.shape[0]), loss_list_mean[:, k], yerr=loss_list_std[:, k], alpha=0.6, label="{}".format(k))
            plt.xlabel("rollout_step", fontsize=fontsize)
            plt.ylabel("MAE", fontsize=fontsize)
            plt.tick_params(labelsize=fontsize)
            plt.title("Species '{}' Summary MAE vs. rollout steps".format(key), fontsize=fontsize)
            plt.legend()
            plt.show()

    model.eval()
    for i, init_step in enumerate(init_indices):
        (preds, target), losses_rollout_all, info = rollout(
            dataset,
            init_step=int(init_step),
            model=model,
            device=device,
            algo=args.algo,
            n_rollout_steps=n_rollout_steps,
            interval=interval,
            n_plots_row=n_plots_row,
            use_grads=args.use_grads,
            is_y_diff=args.is_y_diff,
            loss_type=args.loss_type,
            dataset_name=dataset_name,
            isplot=isplot,
        )
        if i == 0:
            preds_all = {key: [] for key in preds}
            target_all = {key: [] for key in target}
            loss_list_dict_all = {key: [] for key in preds}
            MAE_list_dict_all = {key: [] for key in preds}
        for key in preds_all:
            preds_all[key].append(preds[key])
            target_all[key].append(target[key])
            loss_list_dict_all[key].append(losses_rollout_all["loss_list_dict"][key])
            MAE_list_dict_all[key].append(losses_rollout_all["MAE_list_dict"][key])

    for key in preds_all:
        preds_all[key] = np.stack(preds_all[key])  # [n_indices, n_rollout_steps, n_nodes, dyn_dims]
        target_all[key] = np.stack(target_all[key])   # [n_indices, n_rollout_steps, n_nodes, dyn_dims]
        loss_list_dict_all[key] = np.stack(loss_list_dict_all[key])  # [n_indices, n_rollout_steps, dyn_dims]
        MAE_list_dict_all[key] = np.stack(MAE_list_dict_all[key])  # [n_indices, n_rollout_steps, dyn_dims]

    # Plot summary:
    plot_loss_summary(MAE_list_dict_all)
    for key in preds_all:
        for dyn_dim in range(preds_all[key].shape[-1]):
            print("Species '{}' autocorrelation dim={}:".format(key, dyn_dim))
            get_auto_correlation(preds_all[key][..., dyn_dim], target_all[key][..., dyn_dim], max_shift=max_shift, isplot=True)
    return (preds_all, target_all, loss_list_dict_all, MAE_list_dict_all), info


def get_auto_correlation(array, array_gt, max_shift=30, isplot=True):
    def auto_correlation(array, shift):
        ac = (np.roll(array, shift, axis=-1) * array).mean(-1)
        return ac

    ac_list = []
    ac_gt_list = []
    for shift in range(max_shift):
        ac = auto_correlation(array, shift)
        ac_gt = auto_correlation(array_gt, shift)
        ac_list.append(ac)
        ac_gt_list.append(ac_gt)
    ac_list = np.stack(ac_list, -1)
    ac_gt_list = np.stack(ac_gt_list, -1)

    mean_idx = tuple(range(len(ac_list.shape)-1))
    info = {}
    info["ac_mean"] = ac_list.mean(mean_idx)
    info["ac_std"] = ac_list.std(mean_idx)
    info["ac_gt_mean"] = ac_gt_list.mean(mean_idx)
    info["ac_gt_std"] = ac_gt_list.std(mean_idx)

    if isplot:
        fontsize = 14
        plt.figure(figsize=(8,6))
        plt.errorbar(np.arange(len(info["ac_mean"])), info["ac_mean"], info["ac_std"], capsize=3, elinewidth=2, markeredgewidth=1, label='pred', alpha=0.5)
        plt.errorbar(np.arange(len(info["ac_gt_mean"])), info["ac_gt_mean"], info["ac_gt_std"], capsize=3, elinewidth=2, markeredgewidth=1, label='gt', alpha=0.5)
        plt.legend(bbox_to_anchor=[1,1], fontsize=fontsize)
        plt.xlabel("$\Delta x$", fontsize=fontsize)
        plt.ylabel("autocorrelation", fontsize=fontsize)
        plt.tick_params(labelsize=fontsize)
        plt.show()
    return info

def get_df(
    filenames,
    dataset,
    arg_list,
    init_indices,
    dataset_name,
    analysis_metrics=["loss", "recons", "svd-latent", "svd-evo", "rollout"],
    n_rollout_steps=100,
    interval=20,
    n_plots_row=6,
    isplot=1,
):
    df_dict_list = []
    info_all = {}
    original_shape = dataset[0].original_shape
    for filename in filenames:
        if isplot >= 1:
            print_banner(filename, banner_size=140)
        data_record = pickle.load(open(dirname + filename, "rb"))
        if "epoch" not in data_record:
            continue
        if data_record["epoch"][-1] < 5:
            print("The epoch {} of {} is smaller than 5. Skip.".format(data_record["epoch"][-1], filename))
            continue
        try:
            model = load_model(data_record["best_model_dict"], device, input_shape=original_shape)
        except Exception as e:
            raise
            print("Cannot load '{}'".format(filename))
            continue

        # kwargs:
        df_dict = data_record["args"]
        args = init_args(data_record["args"])
        df_dict = update_legacy_default_hyperparam(df_dict)
        if df_dict["disc_coef"] > 0:
            discriminator = load_model(data_record["best_disc_model_dict"], device)
        else:
            discriminator = None
        df_dict["filename"] = filename
        kwargs = filter_kwargs(data_record["args"], arg_list)
        if len(kwargs) < len(arg_list):
            print("{} in the arg_list are not in the data_record's args!".format(set(arg_list) - set(kwargs.keys())))
        if isplot >= 1:
            pp.pprint(kwargs)

        # Best epoch:
        best_epoch = df_dict["best_epoch"] = data_record["best_epoch"]
        if best_epoch != np.argmin(data_record["val_loss"]):
            print("saved best epoch {} does not equal the best epoch {} in val_loss!".format(best_epoch, np.argmin(data_record["val_loss"])))
        df_dict["epoch"] = len(data_record["train_loss"])
        print("\nbest_epoch:\t{}/{}".format(df_dict["best_epoch"], df_dict["epoch"]))

        # Loss:
        print("losses (best, last, median):")
        #pdb.set_trace()
        for key in data_record:
            if "loss" in key and "history" not in key and type(data_record[key])==list:
                df_dict["{}_l".format(key)] = data_record[key][-1]
                df_dict["{}_b".format(key)] = data_record[key][best_epoch]
                df_dict[key] = data_record[key]
                print("{0}  \t{1:.3e}     {2:.3e}     {3:.3e}".format("{}:".format(key).ljust(20), df_dict["{}_b".format(key)], df_dict["{}_l".format(key)], np.median(data_record[key])))
        # Plot learning curve:
        if isplot >= 1:
            #pdb.set_trace()
            plot_vectors(
                Dict=filter_kwargs(data_record, contains="loss"),
                x_range=data_record["epoch"],
                xlabel="epoch",
                ylabel=data_record["args"]["loss_type"],
                linestyle_dict={"tr": "--", "val": ":", "te": "-"},
                alpha=0.9,
            )

        # Plot reconstruction:
        if "recons" in analysis_metrics:
            data_4_recons = deepcopy(dataset[init_indices[0]]).to(device)
            try:
                preds, info = model(
                    data_4_recons,
                    pred_steps=[],
                    latent_pred_steps=[],
                    is_recons=True,
                    use_grads=args.use_grads,
                    is_y_diff=args.is_y_diff,
                    use_pos=args.use_pos,
                    latent_noise_amp=0,
                    reg_type=args.reg_type if args.reg_coef > 0 else "None",
                    is_rollout=True,
                )
            except Exception as e:
                raise
                print("Error {} happen when running {}.".format(e, filename))
                continue

        # Rollout:
        if "rollout" in analysis_metrics:
            dyn_dims_dict = dict(dataset[0].dyn_dims)
            (preds_all, target_all, loss_list_dict_all, MAE_list_dict_all), info_rollout = evaluate(
                model=model,
                dataset=dataset,
                args=args,
                init_indices=init_indices,
                n_rollout_steps=n_rollout_steps,
                interval=interval,
                n_plots_row=n_plots_row,
                isplot=isplot,
                dataset_name=dataset_name,
            )
            info_rollout["model"] = model
            info_rollout["discriminator"] = discriminator
            info_rollout["rollout_MSE_list"] = loss_list_dict_all
            info_rollout["rollout_MAE_list"] = MAE_list_dict_all
            info_rollout["rollout_preds_all"] = preds_all
            info_rollout["rollout_target_all"] = target_all
            info_all[filename] = info_rollout

            for key, dyn_dims in dyn_dims_dict.items():
                for i in range(interval, n_rollout_steps+interval, interval):
                    for metric in ["MAE", "MSE"]:
                        df_dict["{}_{}_rollout_{}".format(key, metric, i)] = np.mean(info_rollout["rollout_{}_list".format(metric)][key].mean(0)[i-1])
                        df_dict["{}_{}_rollout_{}_std".format(key, metric, i)] = np.mean(info_rollout["rollout_{}_list".format(metric)][key].std(0)[i-1])
                        for k in range(dyn_dims):
                            df_dict["{}_{}_rollout_{}_{}".format(key, metric, i, k)] = info_rollout["rollout_{}_list".format(metric)][key].mean(0)[i-1][k]

        # Save:
        df_dict_list.append(df_dict)
    df = pd.DataFrame(df_dict_list)
    return df, info_all

# Analysis:

### Load data:

In [None]:
args = arg_parse()
args.exp_id = "smoke"
args.date_time = "2022-3-11"
args.dataset="movinggas"
args.n_train= "-1"
args.time_interval= 1
args.save_interval= 10
args.algo= "contrast"
args.reg_type= None
args.reg_coef= 0
args.is_reg_anneal= True
args.no_latent_evo= False
args.encoder_type= "cnn-s"
args.evolution_type= "mlp-3-elu-2"
args.decoder_type= "cnn-tr"
args.encoder_n_linear_layers= "0"
args.n_conv_blocks= 4
args.n_latent_levs= 2
args.n_conv_layers_latent= 3
args.channel_mode= "exp-16"
args.is_latent_flatten= False
args.evo_groups= 1
args.recons_coef= 1
args.consistency_coef= 1
args.contrastive_rel_coef= 0
args.hinge= 0
args.density_coef= 0.001
args.latent_noise_amp= "1e-5"
args.normalization_type= "gn"
args.latent_size= 16
args.kernel_size= 4
args.stride= 2
args.padding= 1
args.padding_mode= "zeros"
args.act_name= "elu"
args.multi_step= "1"
args.latent_multi_step= "1"
args.use_grads= False
args.use_posargs= False
args.is_y_diff= False
args.loss_type= "mse"
args.loss_type_consistency= "mse"
args.batch_size= 2
args.val_batch_size= 8
args.epochs= 52
args.opt= "adam"
args.weight_decay= 0
args.disc_coef= 0
args.seed= 0
args.gpuid= 1
args.id= 0
args.is_train=False
args.is_test_only=True

### Run analysis:

In [None]:
# Settings:
exp_id = "smoke"
date_time = "2022-3-28"
analsys_metrics = [
    "loss",
    "recons",
    "rollout",
]
# include is a list of strings that the filename must include 
include = [
    ".p",
    "movinggas_train_-1_algo_contrast_ebm_False_ebmt_cd_enc_cnn-s_evo_cnn_act_elu_hid_128_lo_mse_recef_1.0_conef_1.0_nconv_4_nlat_1_clat_3_lf_True_reg_None_id_0_Hash_TgzjEJou_turing4.p",
]

dataset_name = "movinggas"
(dataset_train_val, dataset_test), (train_loader, val_loader, test_loader) = load_data(args)

init_indices = [89]  # starting t
n_rollout_steps = 86  # number of rollout steps
interval = 1         # Interval to visualize
is_test = True      # Whether to use the test dataset

# Run analysis:
dirname = EXP_PATH + "{}_{}/".format(exp_id, date_time)
filenames = filter_filename(dirname, include=include)
args_list = list(arg_parse().__dict__.keys())
dataset = deepcopy(dataset_test[:init_indices[0]+n_rollout_steps+2]) if is_test else deepcopy(dataset_train_val[:init_indices[0]+n_rollout_steps+2])
df, info_all = get_df(
    filenames,
    dataset,
    arg_list=args_list,
    analysis_metrics=analsys_metrics,
    init_indices=init_indices,
    n_rollout_steps=n_rollout_steps,
    dataset_name=dataset_name,
    interval=interval,
    isplot=2,
)
# pickle.dump({"df": df, "info_all": info_all}, open(dirname + "df_analysis_{}-{}.p".format(datetime.datetime.now().month, datetime.datetime.now().day), "wb"))