In [2]:
from utils import mppde_create_data, mppde_pushforward, mppde_test_rollout, GraphTemporalDataset, KEY_TO_INDEX
from coral.utils.data.load_data import get_dynamics_data, set_seed
from model import MP_PDE_Solver
from omegaconf import DictConfig, OmegaConf
import wandb
import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np
import hydra
import os
import yaml
from pathlib import Path
from pickletools import OpcodeInfo
from torch.nn.utils import weight_norm

from torch_geometric.loader import DataLoader
from torch_geometric.data import Data

In [7]:
cfg = DictConfig(yaml.safe_load(open("config.yaml")))
dataset_name = cfg.data.dataset_name

run_name = cfg.model.run_name
run_name= 'fanciful-leaf-5019'
dataset_name = 'navier-stokes-dino'
root_dir = Path(os.getenv("WANDB_DIR")) / dataset_name

tmp = torch.load(root_dir / "model" / f"{run_name}.pt")
cfg = tmp['cfg']

dataset_name = cfg.data.dataset_name
data_dir = cfg.data.dir
ntrain = cfg.data.ntrain
ntest = cfg.data.ntest
data_to_encode = cfg.data.data_to_encode
sub_tr = cfg.data.sub_tr
sub_from = cfg.data.sub_from
sub_te = cfg.data.sub_te
seed = cfg.data.seed
same_grid = cfg.data.same_grid
seq_inter_len = 20
seq_extra_len = 20

# optim
batch_size = cfg.optim.batch_size
time_window = 1

# model
hidden_features = 64

set_seed(seed)


(u_train, u_eval_extrapolation, u_test, grid_tr, grid_tr_extra, grid_te) = get_dynamics_data(
    data_dir,
    dataset_name,
    ntrain,
    ntest,
    seq_inter_len=seq_inter_len,
    seq_extra_len=seq_extra_len,
    sub_from=sub_from,
    sub_tr=sub_tr,
    sub_te=sub_te,
    same_grid=same_grid,
)

print(
    f"data: {dataset_name}, u_train: {u_train.shape}, u_test: {u_test.shape}")
print(f"grid: grid_tr: {grid_tr.shape}, grid_te: {grid_te.shape}")

# total frames = num_trajectories * sequence_length
T = u_train.shape[-1]

ntrain = u_train.shape[0]  # int(u_train.shape[0]*T)
ntest = u_test.shape[0]  # int(u_test.shape[0]*T)

testset = GraphTemporalDataset(
    u_test, grid_te
)

dt = 1
timestamps = torch.arange(0, T, dt).float().cuda()#0.1
T_in = 20
T_out = 40

# create torch dataset
test_loader = DataLoader(
    testset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=1,
)

if dataset_name == "navier-stokes-dino":
    pos_dim = 2
    input_dim = 1
    output_dim = 1
    time_window = 1

elif dataset_name == "shallow-water-dino":
    pos_dim = 3
    input_dim = 2 
    output_dim = 2
    time_window = 1

model = MP_PDE_Solver(pos_dim=pos_dim,
                        input_dim=input_dim,
                        output_dim=output_dim,
                        time_window=time_window,
                        hidden_features=hidden_features,
                        hidden_layer=6).cuda()

model.load_state_dict(tmp['model'])

pred_test_mse = 0
pred_test_out_mse = 0
pred_test_in_mse = 0
pred_test_mse = 0
code_test_mse = 0

for graph, idx in test_loader:
    model.eval()
    n_samples = len(graph)

    graph = graph.cuda()
    with torch.no_grad():
        loss = mppde_pushforward(model, graph)
    code_test_mse += loss.item() * n_samples

    with torch.no_grad():
        u_pred = mppde_test_rollout(model, graph, bundle_size=1)
        pred_test_mse += ((u_pred - graph.images) ** 2).mean() * n_samples
        pred_test_in_mse += ((u_pred[..., :T_in] - graph.images[..., :T_in]) ** 2).mean() * n_samples
        pred_test_out_mse += ((u_pred[..., T_in:] - graph.images[..., T_in:]) ** 2).mean() * n_samples

code_test_mse = code_test_mse / ntest
pred_test_mse = pred_test_mse / ntest
pred_test_in_mse = pred_test_in_mse / ntest
pred_test_out_mse = pred_test_out_mse / ntest

data: navier-stokes-dino, u_train: torch.Size([256, 819, 1, 20]), u_test: torch.Size([16, 819, 1, 40])
grid: grid_tr: torch.Size([256, 819, 2, 20]), grid_te: torch.Size([16, 819, 2, 40])
self values torch.Size([16, 819, 1, 40])
self grid torch.Size([16, 819, 2, 40])
data before torch.Size([13104, 1, 40])


In [8]:
print(sub_tr, sub_from, sub_te)

0.2 4 0.2


In [9]:
print("pred_test_in_mse : ", pred_test_in_mse)

pred_test_in_mse :  tensor(0.6107, device='cuda:0')


In [10]:
print("pred_test_out_mse : ", pred_test_out_mse)

pred_test_out_mse :  tensor(0.6100, device='cuda:0')


sub tr = 5 %

6.3e-3 & 2.74e-2

sub_te = 4
6.4e-3 & 2.49e-2

sub_te = 2
9.8e-3 & 4.89e-2

sub_te = 1
1.12e-2 & 6.4e-2

sub tr 20 %


20% : 0.0133 et 0.0713
5% : 0.0078 et 0.0276
