In [None]:
%load_ext autoreload
%autoreload 2

import argparse
from collections import OrderedDict
import datetime
import gc
get_ipython().run_line_magic('matplotlib', 'inline')
import matplotlib
import matplotlib.pylab as plt
from numbers import Number
import numpy as np
import pandas as pd
pd.options.display.max_rows = 1500
pd.options.display.max_columns = 200
pd.options.display.width = 1000
pd.set_option('max_colwidth', 400)
import pdb
import pickle
import pprint as pp
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
from deepsnap.batch import Batch as deepsnap_Batch
import xarray as xr

import sys, os
sys.path.append(os.path.join(os.path.dirname("__file__"), '..'))
sys.path.append(os.path.join(os.path.dirname("__file__"), '..', '..'))
from le_pde.argparser import arg_parse
from le_pde.datasets.load_dataset import load_data
from le_pde.models import load_model
from le_pde.pytorch_net.util import groupby_add_keys, filter_df, get_unique_keys_df, Attr_Dict, Printer, get_num_params, get_machine_name, pload, pdump, to_np_array, get_pdict, reshape_weight_to_matrix, ddeepcopy as deepcopy, plot_vectors, record_data, filter_filename, Early_Stopping, str2bool, get_filename_short, print_banner, plot_matrices, get_num_params, init_args, filter_kwargs, to_string, COLOR_LIST
from le_pde.utils import update_legacy_default_hyperparam, EXP_PATH
from le_pde.utils import deepsnap_to_pyg, LpLoss, to_cpu, to_tuple_shape, parse_multi_step, loss_op, get_device

device = torch.device("cuda:0")
p = Printer()

## 0. Functions:

In [None]:
# Plotting:
def plot_learning_curve(data_record):
    plt.figure(figsize=(16,6))
    plt.subplot(1,2,1)
    plt.plot(data_record["epoch"], data_record["train_loss"], label="train")
    plt.plot(data_record["test_epoch"] if "test_epoch" in data_record else data_record["epoch"], data_record["val_loss"], label="val")
    plt.plot(data_record["test_epoch"] if "test_epoch" in data_record else data_record["epoch"], data_record["test_loss"], label="test")
    plt.title("Learning curve, linear scale")
    plt.legend()
    plt.subplot(1,2,2)
    plt.semilogy(data_record["epoch"], data_record["train_loss"], label="train")
    plt.semilogy(data_record["test_epoch"] if "test_epoch" in data_record else data_record["epoch"], data_record["val_loss"], label="val")
    plt.semilogy(data_record["test_epoch"] if "test_epoch" in data_record else data_record["epoch"], data_record["test_loss"], label="test")
    plt.title("Learning curve, log scale")
    plt.legend()
    plt.show()


def plot_colorbar(matrix, vmax=None, vmin=None, cmap="seismic", label=None):
    if vmax==None:
        vmax = matrix.max()
        vmin = matrix.min()
    im = plt.imshow(matrix,vmax=vmax,vmin=vmin,cmap=cmap)
    plt.title(label)
    im_ratio = matrix.shape[0]/matrix.shape[1]
    plt.colorbar(im,fraction=0.046*im_ratio,pad=0.04)


def visualize(pred, gt, animate=False):
    if torch.is_tensor(gt):
        gt = to_np_array(gt)
        pred = to_np_array(pred)
    mse_over_t = ((gt-pred)**2).mean(axis=0).mean(axis=-1)
     
    if not animate:
        vmax = gt.max()
        vmin = gt.min()
        plt.figure(figsize=[15,5])
        plt.subplot(1,4,1)
        plot_colorbar(gt[:,:,0].T,label="gt")
        plt.subplot(1,4,2)
        plot_colorbar(pred[:,:,0].T,label="pred")
        plt.subplot(1,4,3)
        plot_colorbar((pred-gt)[:,:,0].T,vmax=np.abs(pred-gt).max(),vmin=(-1*np.abs(pred-gt).max()),label="diff")
        plt.subplot(1,4,4)
        plt.plot(mse_over_t);plt.title("mse over t");plt.yscale('log');
        plt.tight_layout()
        plt.show()

def visualize_paper(pred, gt, is_save=False):
    idx = 6
    nx = pred.shape[0]

    fontsize = 14
    idx_list = np.arange(0, 200, 15)
    color_list = np.linspace(0.01, 0.9, len(idx_list))
    x_axis = np.linspace(0,16,nx)
    cmap = matplotlib.cm.get_cmap('jet')

    plt.figure(figsize=(16,5))
    plt.subplot(1,2,1)
    for i, idx in enumerate(idx_list):
        pred_i = to_np_array(pred[...,idx,:].squeeze())
        rgb = cmap(color_list[i])[:3]
        plt.plot(x_axis, pred_i, color=rgb, label=f"t={np.round(i*0.3, 1)}s")
    plt.ylabel("u(t,x)", fontsize=fontsize)
    plt.xlabel("x", fontsize=fontsize)
    plt.tick_params(labelsize=fontsize)
    # plt.legend(fontsize=10, bbox_to_anchor=[1,1])
    plt.xticks([0,8,16], [0,8,16])
    plt.ylim([-2.5,2.5])
    plt.title("Prediction")
    if is_save:
        plt.savefig(f"1D_E2-{nx}.pdf", bbox_inches='tight')

    plt.subplot(1,2,2)
    for i, idx in enumerate(idx_list):
        y_i = to_np_array(gt[...,idx,:])
        rgb = cmap(color_list[i])[:3]
        plt.plot(x_axis, y_i, color=rgb, label=f"t={np.round(i*0.3, 1)}s")
    plt.ylabel("u(t,x)", fontsize=fontsize)
    plt.xlabel("x", fontsize=fontsize)
    plt.tick_params(labelsize=fontsize)
    plt.legend(fontsize=10, bbox_to_anchor=[1,1])
    plt.xticks([0,8,16], [0,8,16])
    plt.ylim([-2.5,2.5])
    plt.title("Ground-truth")
    if is_save:
        plt.savefig(f"1D_gt-{nx}.pdf", bbox_inches='tight')
    plt.show()

In [None]:
# Analysis:
def get_results_1d(
    all_hash,
    mode="best",
    exclude_idx=(None,),
    n_rollout_steps=-1,
    dirname=None,
    suffix="",
):
    """
    Perform analysis on the 1D Burgers' benchmark.

    Args:
        all_hash: a list of hashes which indicates the experiments to load for analysis
        mode: choose from "best" (load the best model with lowest validation loss) or an integer, 
            e.g. -1 (last saved model), -2 (second last saved model)
        dirname: if not None, will use the dirnaem provided. E.g. user
        suffix: suffix for saving the analysis result.
    """

    isplot = True
    df_dict_list = []
    dirname_start = dirname
    for hash_str in all_hash:
        df_dict = {}
        df_dict["hash"] = hash_str
        # Load model:
        is_found = False
        for dirname_core in [
             dirname_start,
            ]:
            filename = filter_filename(EXP_PATH + dirname_core, include=hash_str)
            if len(filename) == 1:
                is_found = True
                break
        if not is_found:
            print(f"hash {hash_str} does not exist in {dirname}! Please pass in the correct dirname.")
            continue
        dirname = EXP_PATH + dirname_core
        if not dirname.endswith("/"):
            dirname += "/"

        try:
            data_record = pload(dirname + filename[0])
        except Exception as e:
            # p.print(f"Hash {hash_str}, best model at epoch {data_record['best_epoch']}:", banner_size=100)
            print(f"error {e} in hash_str {hash_str}")
            continue
        p.print(f"Hash {hash_str}, best model at epoch {data_record['best_epoch']}:", banner_size=160)
        if isplot:
            plot_learning_curve(data_record)
        args = init_args(update_legacy_default_hyperparam(data_record["args"]))
        args.filename = filename
        if mode == "best":
            model = load_model(data_record["best_model_dict"], device=device)
            print("Load the model with best validation loss.")
        else:
            assert isinstance(mode, int)
            print(f'Load the model at epoch {data_record["epoch"][mode]}')
            model = load_model(data_record["model_dict"][mode], device=device)
        model.eval()
        # pp.pprint(args.__dict__)
        kwargs = {}
        if data_record["best_model_dict"]["type"].startswith("GNNPolicy"):
            kwargs["is_deepsnap"] = True

        # Load test dataset:
        args_test = deepcopy(args)
        multi_step = (250 - 50) // args_test.temporal_bundle_steps
        args_test.multi_step = f"1^{multi_step}"
        args_test.is_test_only = True
        n_test_traj = 128
        (dataset_train_val, dataset_test), (train_loader, val_loader, test_loader) = load_data(args_test)
        nx = int(args.dataset.split("-")[2])
        time_stamps_effective = len(dataset_test) // n_test_traj
        for exclude_idx_ele in exclude_idx:
            loss_list = []
            pred_list = []
            y_list = []
            for i in range(n_test_traj):
                idx = i * time_stamps_effective + args_test.temporal_bundle_steps
                data = deepcopy(dataset_test[idx])
                data = data.to(device)
                preds, info = model(
                    data,
                    pred_steps=np.arange(1,n_rollout_steps+1) if n_rollout_steps != -1 else np.arange(1, max(parse_multi_step(args_test.multi_step).keys())+1),
                    latent_pred_steps=None,
                    is_recons=False,
                    use_grads=False,
                    use_pos=args.use_pos,
                    is_y_diff=False,
                    is_rollout=False,
                    **kwargs
                )
                y = data.node_label["n0"]
                if n_rollout_steps != -1:
                    y = y[:,:25*n_rollout_steps]
                pred = preds["n0"].reshape(y.shape)
                pred_list.append(pred.detach())
                y_list.append(y.detach())
                loss_ele = nn.MSELoss(reduction="sum")(pred, y) / nx
                loss_list.append(loss_ele.item())

            loss_mean = np.mean(loss_list)
            pred_list = torch.stack(pred_list).squeeze(-1)
            y_list = torch.stack(y_list).squeeze(-1)
            df_dict[f"loss_cumu_{exclude_idx_ele}"] = loss_mean 
            print("\nTest for {} for exclude_idx={} is: {:.9f} at epoch {}, for {}/{} epochs".format(hash_str, exclude_idx_ele, loss_mean, data_record['best_epoch'], len(data_record["train_loss"]), args.epochs))

            mse_full = nn.MSELoss(reduction="none")(pred_list, y_list)
            mse_time = to_np_array(mse_full.mean((0,1)))
            p.print("Learning curve:", is_datetime=False, banner_size=100)
            plt.figure(figsize=(12,5))
            plt.subplot(1,2,1)
            plt.plot(mse_time)
            plt.xlabel("rollout step")
            plt.ylabel("MSE")
            plt.title("MSE vs. rollout step (linear scale)")
            plt.subplot(1,2,2)
            plt.semilogy(mse_time)
            plt.xlabel("rollout step")
            plt.ylabel("MSE")
            plt.title("MSE vs. rollout step (log scale)")
            plt.show()
            plt.figure(figsize=(6,5))
            plt.plot(mse_time.cumsum())
            plt.title("cumulative MSE vs. rollout step")
            plt.xlabel("rollout step")
            plt.ylabel("cumulative MSE")
            plt.show()

            # Visualization:
            for idx in range(6,8):
                p.print(f"Example {idx*128}:", banner_size=100, is_datetime=False)
                data = deepcopy(dataset_test[idx*128]).to(device)
                preds, info = model(
                    data,
                    pred_steps=np.arange(1,max(parse_multi_step(args_test.multi_step).keys())+1),
                    latent_pred_steps=None,
                    is_recons=False,
                    use_grads=False,
                    use_pos=args.use_pos,
                    is_y_diff=False,
                    is_rollout=False,
                    **kwargs
                )
                y = data.node_label["n0"]
                pred = preds["n0"].reshape(y.shape)
                visualize(pred, y)
                visualize_paper(pred, y)

            p.print(f"Individual prediction at rollout step {y.shape[1]}:", banner_size=100, is_datetime=False)
            time_step = -1
            for idx in range(0, 20, 5):
                plt.figure(figsize=(6,4))
                plt.plot(to_np_array(pred_list[idx,:,time_step]), label="pred")
                plt.plot(to_np_array(y_list[idx,:,time_step]), "--", label="y")
                plt.legend()
                plt.show()
        df_dict["best_epoch"] = data_record['best_epoch']
        df_dict["epoch"] = len(data_record["train_loss"])
        df_dict.update(args.__dict__)
        df_dict_list.append(df_dict)
    df = pd.DataFrame(df_dict_list)
    pdump(df, f"df_1d{suffix}.p")
    return df

## 1. Analysis:

In [None]:
# all_hash is a list of hashes, each of which corresponds to one experiment.
# For example, if one experiment is saved under ./results/user4/mppde1d-E2-50_train_-1_algo_contrast_ebm_False_ebmt_cd_enc_cnn-s_evo_cnn_act_elu_hid_128_lo_rmse_recef_1.0_conef_1.0_nconv_4_nlat_1_clat_3_lf_True_reg_None_id_0_Hash_Un6ae7ja_turing2.p
# Then, the "Un6ae7ja_turing2" (located at the end of the filename) is the {hash}_{machine-name} of this file.
# The "user4" is the "{--exp_id}_{--date_time}" of the training command.
# all_hash can contain multiple hashes, and analyze them sequentially.
all_hash = [
    "Un6ae7ja_turing2",
]
get_results_1d(all_hash, dirname="user4")

## 2. Plotting:

In [None]:
hash_str = "Un6ae7ja_turing2" # E2-50
# hash_str = "nIa6UCdr_turing2" # E2-40
hash_str = "tdu+jfKw"  # E2-100
dirname = EXP_PATH + "user4/"
all_dict = {}

# Load model:
filename = filter_filename(dirname, include=hash_str)
if len(filename) == 0:
    print(f"hash {hash_str} does not exist!")
    raise

try:
    data_record = pload(dirname + filename[0])
except Exception as e:
    # p.print(f"Hash {hash_str}, best model at epoch {data_record['best_epoch']}:", banner_size=100)
    print(f"error {e} in hash_str {hash_str}")
    raise
p.print(f"Hash {hash_str}, best model at epoch {data_record['best_epoch']}:", banner_size=100)
if isplot:
    plot_learning_curve(data_record)
args = init_args(update_legacy_default_hyperparam(data_record["args"]))
args.filename = filename
args.is_test_only = True
model = load_model(data_record["best_model_dict"], device=device)
# model = load_model(data_record["model_dict"][-1], device=device)
model.eval()


# Load test dataset:
args_test = deepcopy(args)
multi_step = (250 - 50) // args_test.temporal_bundle_steps
args_test.multi_step = f"1^{multi_step}"
(dataset_train_val, dataset_test), (train_loader, val_loader, test_loader) = load_data(args_test)


In [None]:
idx = 6
nx = 100
data = dataset_test[idx*128]
data.to(device)
preds, info = model(
    data,
    pred_steps=np.arange(1,max(parse_multi_step(args_test.multi_step).keys())+1),
    latent_pred_steps=None,
    is_recons=False,
    use_grads=False,
    is_y_diff=False,
    is_rollout=False,
)

fontsize = 14
idx_list = np.arange(0, 200, 15)
color_list = np.linspace(0.01, 0.9, len(idx_list))
x_axis = np.linspace(0,16,nx)
cmap = matplotlib.cm.get_cmap('jet')

for i, idx in enumerate(idx_list):
    pred = to_np_array(preds["n0"][...,idx,:].squeeze())
    # y = to_np_array(data.node_label["n0"][...,idx,:])
    rgb = cmap(color_list[i])[:3]
    plt.plot(x_axis, pred, color=rgb, label=f"t={np.round(i*0.3, 1)}s")
plt.ylabel("u(t,x)", fontsize=fontsize)
plt.xlabel("x", fontsize=fontsize)
plt.tick_params(labelsize=fontsize)
plt.legend(fontsize=10, bbox_to_anchor=[1,1])
plt.xticks([0,8,16], [0,8,16])
plt.ylim([-2.2,2.])
plt.savefig(f"1D_E2-{nx}.pdf", bbox_inches='tight')
plt.show()

print("gt:")
for i, idx in enumerate(idx_list):
    pred = to_np_array(preds["n0"][...,idx,:].squeeze())
    y = to_np_array(data.node_label["n0"][...,idx,:])
    rgb = cmap(color_list[i])[:3]
    plt.plot(x_axis, y, color=rgb, label=f"t={np.round(i*0.3, 1)}s")
plt.ylabel("u(t,x)", fontsize=fontsize)
plt.xlabel("x", fontsize=fontsize)
plt.tick_params(labelsize=fontsize)
plt.legend(fontsize=10, bbox_to_anchor=[1,1])
plt.xticks([0,8,16], [0,8,16])
plt.ylim([-2.2,2.])
plt.savefig(f"1D_gt-{nx}.pdf", bbox_inches='tight')
plt.show()

## 3. Timing:

In [None]:
def get_timing_1d(all_hash, suffix=""):
    isplot = True

    dirname = EXP_PATH + "user4/"
    all_dict = {}
    hash_str = all_hash[0]

    # Load model:
    filename = filter_filename(dirname, include=hash_str)
    if len(filename) == 0:
        dirname = EXP_PATH + "user7/"
        filename = filter_filename(dirname, hash_str)
        if len(filename) == 0:
            print(f"hash {hash_str} does not exist!")
            raise

    try:
        data_record = pload(dirname + filename[0])
    except Exception as e:
        # p.print(f"Hash {hash_str}, best model at epoch {data_record['best_epoch']}:", banner_size=100)
        print(f"error {e} in hash_str {hash_str}")
        raise
    p.print(f"Hash {hash_str}, best model at epoch {data_record['best_epoch']}:", banner_size=100)
    if isplot:
        plot_learning_curve(data_record)
    args = init_args(update_legacy_default_hyperparam(data_record["args"]))
    args.filename = filename
    args.is_test_only = True
    model = load_model(data_record["best_model_dict"], device=device)
    # model = load_model(data_record["model_dict"][-1], device=device)
    model.eval()

    # Load test dataset:
    args_test = deepcopy(args)
    multi_step = (250 - 50) // args_test.temporal_bundle_steps
    args_test.multi_step = f"1^{multi_step}"

    args_test = deepcopy(args)
    args_test.multi_step = "1^8"
    (dataset_train_val, dataset_test), (train_loader, val_loader, test_loader) = load_data(args_test)
    idx_list = [i * 26 + 25 for i in range(128)]
    dataset_test_selected = dataset_test[idx_list]
    test_loader = DataLoader(dataset_test_selected, num_workers=0, collate_fn=deepsnap_Batch.collate(),
                             batch_size=len(dataset_test_selected), shuffle=False, drop_last=False)
    for data in test_loader:
        break
    data.to(device)

    for hash_str in all_hash:
        all_list_dict = {}
    for hash_str in all_hash:
        filename = filter_filename(dirname, hash_str)
        assert len(filename) == 1
        data_record = pload(dirname + filename[0])
        model = load_model(data_record["best_model_dict"], device=device)
        model.eval() 

        t_list = []
        for i in range(100):
            t_start = time.time()
            preds, info = model(
                data,
                pred_steps=np.arange(1,max(parse_multi_step(args_test.multi_step).keys())+1),
                latent_pred_steps=None,
                is_recons=False,
                use_grads=False,
                is_y_diff=False,
                is_rollout=False,
            )
            t_end = time.time()
            t_list.append(t_end - t_start)
            del preds
            gc.collect()
        full_time = np.mean(t_list)
        n_params = get_num_params(model)

        if model.__class__.__name__ == "Contrastive":
            t_list_evo = []
            for i in range(100):
                t_start = time.time()
                preds, info = model(
                    data,
                    pred_steps=[],
                    latent_pred_steps=np.arange(1,max(parse_multi_step(args_test.multi_step).keys())+1),
                    is_recons=False,
                    use_grads=False,
                    is_y_diff=False,
                    is_rollout=False,
                )
                t_end = time.time()
                t_list_evo.append(t_end - t_start)
                del preds
                gc.collect()
            evo_time = np.mean(t_list_evo)
            n_params_evo = get_num_params(model.evolution_op)
            print("hash {}, full time: {:.6f} +- {:.6f}  evo time: {:.6f} +- {:.6f}. #params: {}  #params_evo: {}.".format(
                hash_str, full_time, np.std(t_list),
                evo_time, np.std(t_list_evo),
                n_params, n_params_evo,
            ))
            all_list_dict[hash_str] = {
                "evo": t_list_evo,
                "full": t_list,
                "n_params": n_params,
                "n_params_evo": n_params_evo,
            }
        else:
            print("hash {}, full time: {:.6f} +- {:.6f}. #params: {}".format(
                hash_str, full_time, np.std(t_list), n_params))
            all_list_dict[hash_str] = {
                "full": t_list,
                "n_params": n_params,
            }
    pdump(all_list_dict, f"all_dict_1d_timing{suffix}.p")

In [None]:
# Ablation latent size @ turing3:
all_hash = [
    "ldDNKnog_turing2", # 512
    "JLf4tEYC_turing2", # 256
    "Un6ae7ja_turing2", # 128
    "5TFpW2r7", # 64
    "xuaWUuBJ", # 32
    "WO0JMG5U", # 16
    "9dhW7XBI_turing2", # 8
    "aKXmat5Z_turing2", # 4
]
get_timing_1d(all_hash, suffix="_ablation")

In [None]:
# FNO @ turing3:
all_hash = [
    "N+XmBTW+_turing3",
    "OlaHQmVh_turing3",
    "95ODTwFm_turing3",
    "k6pUtyDT_turing3",
    "tMtuBQbF_turing3",
    "6AEMxxdG_turing3",
    "94vcmAJH_turing3",
    "ajb+NVzF_turing3",
    "0di76Adz_turing3",
    "F0ge5kGj_turing3",
    "Jr2biWNP_turing3",
    "kxU74861_turing3",
]
get_timing_1d(all_hash, suffix="_fno")