In [1]:
import argparse
from collections import OrderedDict
import datetime
get_ipython().run_line_magic('matplotlib', 'inline')
import matplotlib.pylab as plt
from numbers import Number
import numpy as np
import pandas as pd
import gc
pd.options.display.max_rows = 1500
pd.options.display.max_columns = 200
pd.options.display.width = 1000
pd.set_option('max_colwidth', 400)
import pdb
import pickle
import pprint as pp
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
from deepsnap.batch import Batch as deepsnap_Batch

import sys, os
sys.path.append(os.path.join(os.path.dirname("__file__"), '..'))
sys.path.append(os.path.join(os.path.dirname("__file__"), '..', '..'))
from le_pde.argparser import arg_parse
from le_pde.datasets.load_dataset import load_data
from le_pde.models import load_model
from le_pde.pytorch_net.util import groupby_add_keys, filter_df, get_unique_keys_df, Attr_Dict, Printer, get_num_params, get_machine_name, pload, pdump, to_np_array, get_pdict, reshape_weight_to_matrix, ddeepcopy as deepcopy, plot_vectors, record_data, filter_filename, Early_Stopping, str2bool, get_filename_short, print_banner, plot_matrices, get_num_params, init_args, filter_kwargs, to_string, COLOR_LIST
from le_pde.utils import update_legacy_default_hyperparam, EXP_PATH
from le_pde.utils import deepsnap_to_pyg, LpLoss, to_cpu, to_tuple_shape, parse_multi_step, loss_op, get_device

device = torch.device("cuda:0")
p = Printer()

## 0. Functions:

In [2]:
def plot_learning_curve(data_record):
    x_axis = np.arange(len(data_record["train_loss"]))
    plt.figure(figsize=(16,6))
    plt.subplot(1,2,1)
    plt.plot(x_axis, data_record["train_loss"], label="train")
    plt.plot(x_axis, data_record["val_loss"], label="val")
    plt.plot(x_axis, data_record["test_loss"], label="test")
    plt.legend()
    plt.subplot(1,2,2)
    plt.semilogy(x_axis, data_record["train_loss"], label="train")
    plt.semilogy(x_axis, data_record["val_loss"], label="val")
    plt.semilogy(x_axis, data_record["test_loss"], label="test")
    plt.legend()
    plt.show()

In [3]:
def get_results_2d(all_hash, dirname, mode=-1, suffix=""):
    all_dict = {}
    df_dict_list = []
    isplot = True
    for hash_str in all_hash:
        # Load model:
        dirname = EXP_PATH + dirname
        df_dict = {}

        filename = filter_filename(dirname, include=hash_str)
        if len(filename) == 0:
            dirname = dirname = EXP_PATH + "user-2d_2022-7-27/"
            filename = filter_filename(dirname, include=hash_str)
            if len(filename) == 0:
                print(f"hash {hash_str} does not exist!")
                continue

        try:
            data_record = pload(dirname + filename[0])
        except Exception as e:
            print(f"error {e}")
            continue
        if "train_loss" not in data_record:
            continue
        if isplot:
            plot_learning_curve(data_record)
        args = init_args(update_legacy_default_hyperparam(data_record["args"]))
        args.filename = filename
        if mode == "best":
            model = load_model(data_record["best_model_dict"], device=device)
            print("Load the model with best validation loss.")
        else:
            assert isinstance(mode, int)
            print(f'Load the model at epoch {data_record["epoch"][mode]}')
            model = load_model(data_record["model_dict"][mode], device=device)
        model.eval()
        p.print(f"Hash {hash_str}, best model at epoch {data_record['best_epoch']}:", banner_size=100)

        # Load test dataset:
        args_test = deepcopy(args)
        if args.temporal_bundle_steps == 1:
            if args.dataset in ["fno", "fno-2", "fno-3"]:
                args_test.multi_step = "20"
            elif args.dataset in ["fno-1"]:
                args_test.multi_step = "40"
            elif args.dataset in ["fno-4"]:
                args_test.multi_step = "10"
            else:
                raise
        else:
            pass
        args_test.batch_size = 20
        (dataset_train_val, dataset_test), (train_loader, val_loader, test_loader) = load_data(args_test)
        myloss = LpLoss(size_average=False)

        # Compute loss:
        loss_list = []
        pred_list = []
        y_list = []
        for data in dataset_test:
            data.to(device)
            preds, info = model(
                data,
                pred_steps=np.arange(1,int(args_test.multi_step)+1),
                latent_pred_steps=None,
                is_recons=False,
                use_grads=False,
                is_y_diff=False,
                is_rollout=False,
                use_pos=args.use_pos,
            )
            pred_reshape = preds["n0"].reshape(1,-1)
            y_reshape = data.node_label["n0"].reshape(1,-1)
            pred_list.append(preds["n0"].detach())
            y_list.append(data.node_label["n0"].detach())
            loss_ele = myloss(pred_reshape, y_reshape)
            loss_list.append(loss_ele.item())
        loss_mean = np.mean(loss_list)
        pred_list = torch.stack(pred_list).squeeze(-1)
        y_list = torch.stack(y_list).squeeze(-1)
        all_dict[hash_str] = (data_record['best_epoch'], loss_mean, len(data_record["train_loss"]), args.epochs)
        print(all_dict[hash_str])
        print("Test for {} is: {:.6e} at epoch {}".format(hash_str, loss_mean, data_record['best_epoch']))

        # df_dict:
        df_dict["loss"] = loss_mean
        df_dict.update(args.__dict__)
        df_dict["hash_str"] = hash_str
        df_dict["train_loss"] = data_record["train_loss"][-1]
        df_dict["val_loss"] = data_record["val_loss"][-1]
        df_dict["test_loss"] = data_record["test_loss"][-1]
        df_dict["epoch"] = data_record["epoch"][-1]
        df_dict_list.append(df_dict)

        # Plotting:
        pred_reshape = preds["n0"].permute(1,0,2).squeeze(-1)
        y_reshape = data.node_label["n0"].permute(1,0,2).squeeze(-1)
        loss_ele = myloss(pred_reshape, y_reshape)
        num = 10
        print("loss_cumu:")
        mse_full = nn.MSELoss(reduction="none")(pred_list, y_list)
        mse_time = to_np_array(mse_full.mean((0,1)))
        plt.figure(figsize=(12,5))
        plt.subplot(1,2,1)
        plt.plot(mse_time)
        plt.xlabel("rollout step")
        plt.ylabel("MSE")
        plt.title("MSE vs. rollout step (linear scale)")
        plt.subplot(1,2,2)
        plt.semilogy(mse_time)
        plt.xlabel("rollout step")
        plt.ylabel("MSE")
        plt.title("MSE vs. rollout step (log scale)")
        plt.show()

        print("individual:")
        for i in range(len(pred_reshape) // num):
            print("pred:")
            plot_matrices(pred_reshape[num*i:num*(i+1)].reshape(num, 64, 64), images_per_row=num, scale_limit="auto")
            print("y:")
            plot_matrices(y_reshape[num*i:num*(i+1)].reshape(num, 64, 64), images_per_row=num, scale_limit="auto")
            print("diff:")
            plot_matrices((pred_reshape - y_reshape)[num*i:num*(i+1)].reshape(num, 64, 64), images_per_row=num, scale_limit="auto")
    pdump(all_dict, f"all_dict_2d{suffix}.p")
    df = pd.DataFrame(df_dict_list)
    # dff = filter_df(df, {"decoder_type": "neural-basis-mulr", "decoder_act_name": "gelu"})
    # dff[["hash_str", "epoch", "loss"]]

## 1. Analysis:

In [4]:
EXP_PATH = "/dfs/project/plasma/results/"

In [None]:
# all_hash is a list of hashes, each of which corresponds to one experiment.
# For example, if one experiment is saved under ./results/user-1d_2022-5-14/fno-4_train_-1_algo_contrast_ebm_False_ebmt_cd_enc_cnn-s_evo_cnn_act_elu_hid_256_lo_mse_recef_1.0_conef_1.0_nconv_4_nlat_1_clat_3_lf_True_reg_None_id_0_Hash_Lvdn+dA9_turing2.p
# Then, the "Lvdn+dA9_turing2" (located at the end of the filename) is the {hash}_{machine-name} of this file.
# The "user_2022-5-8" is the "{--exp_id}_{--date_time}" of the training command.
# all_hash can contain multiple hashes, and analyze them sequentially.
all_hash = ["Lvdn+dA9"]
get_results_2d(all_hash, dirname="user_2022-5-8/", mode=-1, suffix="")

## 2. Timing:

In [None]:
isplot = True
all_hash = [
    "nG2NBRJH_turing2", # latent_size: 512
    "Lvdn+dA9_turing2", # latent_size: 256
    "ECshE1x1_turing2", # latent_size: 128
    "2o9axpaz_turing2", # latent_size: 64
    "QLkuuTf9_turing2", # latent_size: 32
    "hXzEAQeu_turing2", # latent_size: 16
    "YkILg8wB_turing2", # latent_size: 8
    "ck3kAARM_turing2", # latent_size: 4
]
hash_str = all_hash[0]
dirname = EXP_PATH + "user_2022-5-8/"
filename = filter_filename(dirname, include=hash_str)
if len(filename) == 0:
    filename = filter_filename(EXP_PATH + "user_2022-5-8/", include=hash_str)
    if len(filename) == 0:
        print(f"hash {hash_str} does not exist!")
        raise

try:
    data_record = pload(dirname + filename[0])
except Exception as e:
    print(f"error {e}")
    # continue
    raise
if isplot:
    plot_learning_curve(data_record)
args = init_args(update_legacy_default_hyperparam(data_record["args"]))
args.filename = filename
# model = load_model(data_record["best_model_dict"], device=device)
model = load_model(data_record["model_dict"][-1], device=device)
model.eval()
p.print(filename, banner_size=100)

# Load test dataset:
args_test = deepcopy(args)
if args.temporal_bundle_steps == 1:
    if args.dataset in ["fno", "fno-2", "fno-3"]:
        args_test.multi_step = "20"
    elif args.dataset in ["fno-1"]:
        args_test.multi_step = "40"
    elif args.dataset in ["fno-4"]:
        args_test.multi_step = "10"
    else:
        raise
else:
    pass
args_test.batch_size = 20


(dataset_train_val, dataset_test), (train_loader, val_loader, test_loader) = load_data(args_test)
test_loader = DataLoader(dataset_test, num_workers=0, collate_fn=deepsnap_Batch.collate(),
                         batch_size=20, shuffle=False, drop_last=False)
for data in test_loader:
    break
data.to(device)

all_list_dict = {}
for hash_str in all_hash:
    t_list = []
    for i in range(100):
        t_start = time.time()
        preds, info = model(
            data,
            pred_steps=np.arange(1,max(parse_multi_step(args_test.multi_step).keys())+1),
            latent_pred_steps=None,
            is_recons=False,
            use_grads=False,
            is_y_diff=False,
            is_rollout=False,
        )
        t_end = time.time()
        t_list.append(t_end - t_start)
        del preds
        gc.collect()
    full_time = np.mean(t_list)

    t_list_evo = []
    for i in range(100):
        t_start = time.time()
        preds, info = model(
            data,
            pred_steps=[],
            latent_pred_steps=np.arange(1,max(parse_multi_step(args_test.multi_step).keys())+1),
            is_recons=False,
            use_grads=False,
            is_y_diff=False,
            is_rollout=False,
        )
        t_end = time.time()
        t_list_evo.append(t_end - t_start)
        del preds
        gc.collect()
    evo_time = np.mean(t_list_evo)
    print("hash {}, full time: {:.6f} +- {:.6f}  evo time: {:.6f} += {:.6f}".format(
        hash_str, full_time, np.std(t_list),
        evo_time, np.std(t_list_evo)))
    all_list_dict[hash_str] = {
        "evo": t_list_evo,
        "full": t_list,
    }

## 3. Plotting:

In [None]:
k = 4
cmap = "jet"
import matplotlib.colors as clr
import numpy as np


data = dataset_test[k]
data.to(device)
preds, info = model(
    data,
    pred_steps=np.arange(1,int(args_test.multi_step)+1),
    latent_pred_steps=None,
    is_recons=False,
    use_grads=False,
    is_y_diff=False,
    is_rollout=False,
)
pred_reshape = preds["n0"].permute(1,0,2).squeeze(-1)
y_reshape = data.node_label["n0"].permute(1,0,2).squeeze(-1)
pred_reshape = pred_reshape.reshape(pred_reshape.shape[0], 64, 64)
y_reshape = y_reshape.reshape(-1, 64, 64)

fig = plt.figure(figsize=(10,4))
fig.set_canvas(plt.gcf().canvas)

vmax = pred_reshape.max().item()
vmin = pred_reshape.min().item()

ax = fig.add_subplot(2, 5, 1)
ax.matshow(to_np_array(y_reshape[0]), cmap=cmap)
ax.set_title("Initial vorticity")
plt.xticks(np.array([]))
plt.yticks(np.array([]))

for i, idx in enumerate([5,10,15,20]):
    ax = fig.add_subplot(2, 5, i + 2)
    ax.matshow(to_np_array(y_reshape[idx-1]), cmap=cmap)
    ax.set_title(f"t={idx}")
    plt.xticks(np.array([]))
    plt.yticks(np.array([]))
for i, idx in enumerate([5,10,15,20]):
    ax = fig.add_subplot(2, 5, i + 3+4)
    ax.matshow(to_np_array(pred_reshape[idx-1]), cmap=cmap)
    
    plt.xticks(np.array([]))
    plt.yticks(np.array([]))
plt.savefig("2d_nv.pdf", bbox_inches='tight')
plt.show()