In [None]:
import argparse
import os
import copy
import sys
import time
from datetime import datetime
import torch
import pdb
import random
import numpy as np
from matplotlib import pyplot as plt
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
import pickle

import sys, os
sys.path.append(os.path.join(os.path.dirname("__file__"), '..'))
sys.path.append(os.path.join(os.path.dirname("__file__"), '..', '..'))
sys.path.append(os.path.join(os.path.dirname("__file__"), '..', '..', '..'))
sys.path.append(os.path.join(os.path.dirname("__file__"), '..', '..', '..', '..'))
sys.path.append(os.path.join(os.path.dirname("__file__"), '..', '..', '..', '..', '..'))
from MP_Neural_PDE_Solvers.equations.PDEs import *
from MP_Neural_PDE_Solvers.common.utils import HDF5Dataset, GraphCreator, p
from MP_Neural_PDE_Solvers.experiments.models_gnn import MP_PDE_Solver
from MP_Neural_PDE_Solvers.experiments.models_cnn import BaseCNN
from MP_Neural_PDE_Solvers.experiments.models_fno import FNO1d
from MP_Neural_PDE_Solvers.experiments.train_helper import *
from MP_Neural_PDE_Solvers.experiments.train import test
from lamp.pytorch_net.util import pload, get_machine_name, filter_filename, init_args, Interp1d_torch, to_np_array

def check_directory() -> None:
    """
    Check if log directory exists within experiments
    """
    if not os.path.exists(f'experiments/log'):
        os.mkdir(f'experiments/log')
    if not os.path.exists(f'models'):
        os.mkdir(f'models')

In [None]:
dirname = "models/"
# Choose one of the following lines to test the corresponding baseline:
model_name, uniform_sample = "BaseCNN", 2
model_name, uniform_sample = "BaseCNN", 4
model_name, uniform_sample = "BaseCNN", -1
model_name, uniform_sample = "FNO", 2
model_name, uniform_sample = "FNO", 4
model_name, uniform_sample = "FNO", -1
model_name, uniform_sample = "GNN", 2
model_name, uniform_sample = "GNN", 4
model_name, uniform_sample = "GNN", -1

# Other configurations:
is_full_eval = True
resolution = 100 // uniform_sample if uniform_sample != -1 and model_name == "GNN" else 100
filename_core = f"GNN_CE_E2_{model_name}_xresolution{resolution}-200_uni{uniform_sample}_n3_tw25_unrolling1_server"
# Find the saved model under "lamp/MP_Neural_PDE_Solvers/models/":
filenames = filter_filename("models/", include=[filename_core, ".pt"])
assert len(filenames) == 1, f"Find {len(filenames)} files satisfying the condition: {filenames}. Choose one of them."
filename = filenames[0]
device = "cuda:0"

In [None]:
args = init_args({
    "experiment": "E2",
    "model": model_name,
    "base_resolution": [250,100],
    "time_window": 25,
    "uniform_sample": uniform_sample,
    "batch_size": 4,
    "neighbors": 3,
    "parameter_ablation": False,
    "nr_gt_steps": 2,
})
if model_name == "GNN" and uniform_sample != -1:
    args.base_resolution = [250, 100//uniform_sample]
    if uniform_sample == 3:
        args.base_resolution = [250, 34]

In [None]:
test_dataset = HDF5Dataset(
    '../data/mppde1d_data/CE_test_E2.h5',
    pde="CE",
    mode='test',
    base_resolution=args.base_resolution,
    super_resolution=[250, 200],
    uniform_sample=uniform_sample,
    is_return_super=True,
)
test_loader = DataLoader(test_dataset,
                         batch_size=args.batch_size,
                         shuffle=False,
                         num_workers=0)
pde = CE(device=device)
pde.tmin = 0.0
pde.tmax = 4.0
pde.grid_size = [250, 100]

if model_name == "GNN" and uniform_sample != -1:
    args.base_resolution = [250, 100//uniform_sample]
    pde.grid_size = [250, 100//uniform_sample]
    if uniform_sample == 3:
        args.base_resolution = [250, 34]
        pde.grid_size = [250, 34]


eq_variables = {}
if not args.parameter_ablation:
    if args.experiment == 'E2':
        print(f'Beta parameter added to the GNN solver')
        eq_variables['beta'] = 0.2
    elif args.experiment == 'E3':
        print(f'Alpha, beta, and gamma parameter added to the GNN solver')
        eq_variables['alpha'] = 3.
        eq_variables['beta'] = 0.4
        eq_variables['gamma'] = 1.
    elif (args.experiment == 'WE3'):
        print('Boundary parameters added to the GNN solver')
        eq_variables['bc_left'] = 1
        eq_variables['bc_right'] = 1

graph_creator = GraphCreator(pde=pde,
                             neighbors=args.neighbors,
                             time_window=args.time_window,
                             t_resolution=args.base_resolution[0],
                             x_resolution=args.base_resolution[1]).to(device)

if args.model == 'GNN':
    model = MP_PDE_Solver(pde=pde,
                          time_window=graph_creator.tw,
                          eq_variables=eq_variables).to(device)
elif args.model == 'BaseCNN':
    model = BaseCNN(pde=pde,
                    time_window=args.time_window).to(device)
elif args.model == 'FNO':
    modes = min(16, (100 // uniform_sample) // 2 + 1)
    model = FNO1d(pde=pde,
                  modes=modes, width=64, input_size=args.time_window, output_size=args.time_window).to(device)
else:
    raise Exception("Wrong model specified")
model.load_state_dict(torch.load(dirname + filename))
model.to(device)
steps = [t for t in range(graph_creator.tw, 250-graph_creator.tw + 1)]
criterion = torch.nn.MSELoss(reduction="sum")
if uniform_sample == -1:
    is_full_eval = False

In [None]:
"""
Loss for full trajectory unrolling, we report this loss in the paper
Args:
    model (torch.nn.Module): neural network PDE solver
    steps (list): input list of possible starting (time) points
    nr_gt_steps (int): number of numerical input timesteps
    nx_base_resolution (int): spatial resolution of numerical baseline
    loader (DataLoader): dataloader [valid, test]
    graph_creator (GraphCreator): helper object to handle graph data
    criterion (torch.nn.modules.loss): criterion for training
    device (torch.cuda.device): device (cpu/gpu)
Returns:
    torch.Tensor: valid/test losses
"""

########################################################################
# Here we evaluated the loss on the full 100 nodes compared with ground-truth, 
# wheter the model was trained on the 25, 50 or 100-node dataset. In this way,
# All models are compared on the same ground-truth.
#########################################################################

model = model
steps = steps
batch_size = args.batch_size
nr_gt_steps = args.nr_gt_steps
nx_base_resolution = args.base_resolution[1]
loader = test_loader
graph_creator = graph_creator
criterion = criterion
device = device

losses = []
losses_base = []
for (u_base, u_super, u_ori, x, x_ori, variables) in loader:
    losses_tmp = []
    losses_base_tmp = []
    with torch.no_grad():
        same_steps = [graph_creator.tw * nr_gt_steps] * batch_size  # [50] * batch_size:16
        data, labels = graph_creator.create_data(u_super, same_steps)  # first time: data: from 25:50, label: from 50:75
        if f'{model}' == 'GNN':
            graph = graph_creator.create_graph(data, labels, x, variables, same_steps, uniform_sample=uniform_sample).to(device)
            pred = model(graph)
            loss = criterion(pred, graph.y) / nx_base_resolution
        else:
            data, labels = data.to(device), labels.to(device)
            pred = model(data)
            loss = criterion(pred, labels) / nx_base_resolution

        if is_full_eval:
            fnc = Interp1d_torch()
            x_expand = torch.repeat_interleave(x_ori, 25, dim=0).to(device)
            if model_name == "GNN":
                pred_permute = pred.reshape(batch_size, -1, pred.shape[-1]).permute(0,2,1)  # [B, time_step, n_nodes]
                pred_permute = pred_permute.reshape(-1, pred_permute.shape[-1])  # [B*time_step, n_nodes]
                pred_interp = fnc(x_expand[...,::uniform_sample], pred_permute.reshape(-1,pred_permute.shape[-1]), x_expand)
                pred_interp = pred_interp.reshape(batch_size, -1, pred_interp.shape[-1])
                labels_expand = u_ori[:,same_steps[0]:same_steps[0]+25].to(device)
                loss = criterion(pred_interp, labels_expand) / 100
            else:
                pred_interp = fnc(x_expand[...,::uniform_sample], pred.reshape(-1,pred.shape[-1]), x_expand)
                pred_interp = pred_interp.reshape(pred.shape[0], -1, pred_interp.shape[-1])
                # labels_narrow = u_super[:,same_steps[0]:same_steps[0]+25].to(device)
                labels_expand = u_ori[:,same_steps[0]:same_steps[0]+25].to(device)
                loss = criterion(pred_interp, labels_expand) / 100

        losses_tmp.append(loss / batch_size)

        # Unroll trajectory and add losses which are obtained for each unrolling
        # for step in list(range(50, 225, 25)):
        for step in list(range(graph_creator.tw * (nr_gt_steps + 1), graph_creator.t_res - graph_creator.tw, graph_creator.tw)):
            same_steps = [step] * batch_size
            _, labels = graph_creator.create_data(u_super, same_steps)
            if f'{model}' == 'GNN':
                graph = graph_creator.create_next_graph(graph, pred, labels, same_steps).to(device)
                pred = model(graph)
                loss = criterion(pred, graph.y) / nx_base_resolution  # pred/graph.y: [B*n_nodes, time_steps:25]
            else:
                labels = labels.to(device)
                pred = model(pred)
                loss = criterion(pred, labels) / nx_base_resolution  # pred, labels: [B, time_steps, n_nodes]

            if is_full_eval:
                fnc = Interp1d_torch()
                x_expand = torch.repeat_interleave(x_ori, 25, dim=0).to(device)
                if model_name == "GNN":
                    pred_permute = pred.reshape(batch_size, -1, pred.shape[-1]).permute(0,2,1)  # [B, time_step, n_nodes]
                    pred_permute = pred_permute.reshape(-1, pred_permute.shape[-1])  # [B*time_step, n_nodes]
                    pred_interp = fnc(x_expand[...,::uniform_sample], pred_permute.reshape(-1,pred_permute.shape[-1]), x_expand)
                    pred_interp = pred_interp.reshape(batch_size, -1, pred_interp.shape[-1])
                    labels_expand = u_ori[:,same_steps[0]:same_steps[0]+25].to(device)
                    loss = criterion(pred_interp, labels_expand) / 100
                else:
                    pred_interp = fnc(x_expand[...,::uniform_sample], pred.reshape(-1,pred.shape[-1]), x_expand)
                    pred_interp = pred_interp.reshape(pred.shape[0], -1, pred_interp.shape[-1])
                    # labels_narrow = u_super[:,same_steps[0]:same_steps[0]+25].to(device)
                    labels_expand = u_ori[:,same_steps[0]:same_steps[0]+25].to(device)
                loss = criterion(pred_interp, labels_expand) / 100
            losses_tmp.append(loss / batch_size)  # batch_size = 16

        # # Losses for numerical baseline
        # for step in range(graph_creator.tw * nr_gt_steps, graph_creator.t_res - graph_creator.tw + 1,
        #                   graph_creator.tw):
        #     same_steps = [step] * batch_size
        #     _, labels_super = graph_creator.create_data(u_super, same_steps)
        #     _, labels_base = graph_creator.create_data(u_base, same_steps)
        #     pdb.set_trace()
        #     loss_base = criterion(labels_super, labels_base) / nx_base_resolution
        #     losses_base_tmp.append(loss_base / batch_size)

    losses.append(torch.sum(torch.stack(losses_tmp)))
    losses_base.append(torch.tensor(0., dtype=torch.float32))
    # losses_base.append(torch.sum(torch.stack(losses_base_tmp)))

losses = torch.stack(losses)
losses_base = torch.stack(losses_base)
print(f'Unrolled forward losses {torch.mean(losses)}')
print(f'Unrolled forward base losses {torch.mean(losses_base)}')