In [1]:
import os
import sys
from pathlib import Path
from pickletools import OpcodeInfo
import numpy as np
import torch
import torch.nn as nn
import einops
import yaml
from omegaconf import DictConfig, OmegaConf
from coral.utils.data.dynamics_dataset import (KEY_TO_INDEX, TemporalDatasetWithCode)
from coral.utils.models.load_inr import create_inr_instance, load_inr_model
from coral.utils.data.load_data import get_dynamics_data, set_seed
from utils import scheduling
from ode_model import Decoder, Derivative
from torchdiffeq import odeint
from eval_dino import *

In [17]:
cfg = DictConfig(yaml.safe_load(open("config.yaml")))
dataset_name = cfg.data.dataset_name
dataset_name = 'navier-stokes-dino'
run_name = cfg.inr.run_name

root_dir = Path(os.getenv("WANDB_DIR")) / dataset_name

In [18]:
# Load dyn_model 

tmp = torch.load(root_dir / "dino" / "model" / f"{run_name}.pt")
#cfg = tmp['cfg']
dec_state = tmp['dec_state_dict']
dyn_state = tmp['dyn_state_dict']

data_dir = cfg.data.dir
ntrain = cfg.data.ntrain
ntest = cfg.data.ntest
sub_from = cfg.data.sub_from
sub_tr = cfg.data.sub_tr
sub_te = cfg.data.sub_te
seed = cfg.data.seed
same_grid = cfg.data.same_grid
seq_inter_len = cfg.data.seq_inter_len
seq_extra_len = cfg.data.seq_extra_len

# optim
batch_size = cfg.optim.minibatch_size
lr = cfg.optim.lr

# inr
state_dim = cfg.inr.state_dim
code_dim = cfg.inr.code_dim
hidden_c_enc = cfg.inr.hidden_c_enc
n_layers = cfg.inr.n_layers
coord_dim = cfg.inr.coord_dim

# forecaster
hidden_c = cfg.forecaster.hidden_c

# Decoder
net_dec_params = {
    "state_c": state_dim,
    "code_c": code_dim,
    "hidden_c": hidden_c_enc,
    "n_layers": n_layers,
    "coord_dim": coord_dim,
}
# Forecaster
net_dyn_params = {
    "state_c": state_dim,
    "hidden_c": hidden_c,
    "code_c": code_dim,
}

In [20]:
batch_size = 1

In [21]:
print("sub_tr, sub_from, sub_te : ", sub_tr, sub_from, sub_te)

sub_tr, sub_from, sub_te :  0.2 2 0.2


In [22]:
set_seed(seed)

if dataset_name == 'shallow-water-dino':
    multichannel = True
else:
    multichannel = False

(u_train, u_eval_extrapolation, u_test, grid_tr, grid_tr_extra, grid_te) = get_dynamics_data(
    data_dir,
    dataset_name,
    ntrain,
    ntest,
    seq_inter_len=seq_inter_len,
    seq_extra_len=seq_extra_len,
    sub_tr=sub_tr,
    sub_te=sub_te,
    same_grid=same_grid,
)

u_train = einops.rearrange(u_train, 'N ... T -> N T ...')
u_eval_extrapolation = einops.rearrange(u_eval_extrapolation, 'N ... T -> N T ...')
u_test = einops.rearrange(u_test, 'N ... T -> N T ...')
grid_tr = einops.rearrange(grid_tr, 'N ... T -> N T ...')
grid_tr_extra = einops.rearrange(grid_tr_extra, 'N ... T -> N T ...')
grid_te = einops.rearrange(grid_te, 'N ... T -> N T ...')

trainset = TemporalDatasetWithCode(
    u_train, grid_tr, code_dim, dataset_name, None
)

trainset_extra = TemporalDatasetWithCode(
    u_eval_extrapolation, grid_tr_extra, code_dim, dataset_name, None
)
testset = TemporalDatasetWithCode(
    u_test, grid_te, code_dim, dataset_name, None
)

train_loader = torch.utils.data.DataLoader(
    trainset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=1,
    pin_memory=True,
)
train_extra_loader = torch.utils.data.DataLoader(
    trainset_extra,
    batch_size=batch_size,
    shuffle=True,
    num_workers=1,
    pin_memory=True,
)
test_loader = torch.utils.data.DataLoader(
    testset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=1,
)

n_seq_train = u_train.shape[0]
n_seq_test = u_test.shape[0]
T_train = u_train.shape[1]
T_test = u_test.shape[1]
dt = 1

timestamps_train = torch.arange(0, T_train, dt).float().cuda()
timestamps_test = torch.arange(0, T_test, dt).float().cuda()

method = "rk4"

if dataset_name == "shallow-water-dino":
    n_steps = 500
else:
    n_steps = 300

net_dec = Decoder(**net_dec_params)
net_dec_dict = net_dec.state_dict()
pretrained_dict = {
    k: v for k, v in dec_state.items() if k in net_dec_dict
}
net_dec_dict.update(pretrained_dict)
net_dec.load_state_dict(pretrained_dict)
print(dict(net_dec.named_parameters()).keys())

net_dyn = Derivative(**net_dyn_params)
net_dyn_dict = net_dyn.state_dict()
pretrained_dict = {
    k: v for k, v in dyn_state.items() if k in net_dyn_dict
}
net_dyn_dict.update(pretrained_dict)
net_dyn.load_state_dict(net_dyn_dict)
print(dict(net_dyn.named_parameters()).keys())

states_params = tmp["states_params"]
net_dec = net_dec.to('cuda')
net_dyn = net_dyn.to('cuda')

print(f"data: {dataset_name}, u_train: {u_train.shape}, u_train_eval: {u_eval_extrapolation.shape}, u_test: {u_test.shape}")
print(f"grid: grid_tr: {grid_tr.shape}, grid_tr_extra: {grid_tr_extra.shape}, grid_te: {grid_te.shape}")

dict_keys(['net.bilinear.0.A', 'net.bilinear.0.B', 'net.bilinear.0.bias', 'net.bilinear.1.A', 'net.bilinear.1.B', 'net.bilinear.1.bias', 'net.bilinear.2.A', 'net.bilinear.2.B', 'net.bilinear.2.bias', 'net.bilinear.3.A', 'net.bilinear.3.B', 'net.bilinear.3.bias', 'net.output_bilinear.weight', 'net.output_bilinear.bias', 'net.filters.0.weight', 'net.filters.1.weight', 'net.filters.2.weight', 'net.filters.3.weight'])
dict_keys(['net.net.0.weight', 'net.net.0.bias', 'net.net.1.beta', 'net.net.2.weight', 'net.net.2.bias', 'net.net.3.beta', 'net.net.4.weight', 'net.net.4.bias', 'net.net.5.beta', 'net.net.6.weight', 'net.net.6.bias'])
data: navier-stokes-dino, u_train: torch.Size([256, 20, 13107, 1]), u_train_eval: torch.Size([256, 40, 13107, 1]), u_test: torch.Size([16, 40, 13107, 1])
grid: grid_tr: torch.Size([256, 20, 13107, 2]), grid_tr_extra: torch.Size([256, 40, 13107, 2]), grid_te: torch.Size([16, 40, 13107, 2])


In [25]:
class DetailedMSE():
    def __init__(self, keys, dataset_name="shallow-water-dino", mode="train", n_trajectories=256):
        self.keys = keys
        self.mode = mode
        self.dataset_name = dataset_name
        self.n_trajectories = n_trajectories
        self.reset_dic()

    def reset_dic(self):
        dic = {}
        for key in self.keys:
            dic[f"{key}_{self.mode}_mse"] = 0
        self.dic = dic

    def aggregate(self, u_pred, u_true):
        n_samples = u_pred.shape[0]
        for key in self.keys:
            idx = KEY_TO_INDEX[self.dataset_name][key]
            self.dic[f"{key}_{self.mode}_mse"] += (
                (u_pred[..., idx, :] - u_true[..., idx, :])**2).mean()*n_samples

    def get_dic(self):
        dic = self.dic
        for key in self.keys:
            dic[f"{key}_{self.mode}_mse"] /= self.n_trajectories
        return self.dic 
    
if multichannel:
    detailed_train_mse = DetailedMSE(list(KEY_TO_INDEX[dataset_name].keys()),
                                        dataset_name,
                                        mode="train",
                                        n_trajectories=ntrain)
    detailed_train_eval_mse = DetailedMSE(list(KEY_TO_INDEX[dataset_name].keys()),
                                        dataset_name,
                                        mode="train_extra",
                                        n_trajectories=ntrain)
    detailed_test_mse = DetailedMSE(list(KEY_TO_INDEX[dataset_name].keys()),
                                    dataset_name,
                                    mode="test",
                                    n_trajectories=ntest)
else:
    detailed_test_mse = None
    detailed_train_eval_mse = None
    detailed_train_mse = None

In [26]:
criterion = nn.MSELoss()

print("Evaluating train...")
pred_train_mse, pred_train_inter_mse, pred_train_extra_mse, detailed_train_eval_mse = eval_dino(
    train_extra_loader, net_dyn, net_dec, 'cuda', method, criterion, state_dim, code_dim, coord_dim, 
    detailed_train_eval_mse, timestamps_test, n_seq_train, seq_inter_len, seq_extra_len, states_params, 
    multichannel=multichannel, n_steps=n_steps,
)

# Out-of-domain evaluation
print("Evaluating test...")
pred_test_mse, pred_test_inter_mse, pred_test_extra_mse, detailed_test_mse = eval_dino(
    test_loader, net_dyn, net_dec, 'cuda', method, criterion, state_dim, code_dim, coord_dim, 
    detailed_test_mse, timestamps_test, n_seq_test, seq_inter_len, seq_extra_len,
    states_params, lr, multichannel=multichannel, n_steps=n_steps
)

Evaluating train...
Evaluating test...


In [27]:
print("pred_train_inter_mse : ", pred_train_inter_mse)
print('pred_train_extra_mse :' , pred_train_extra_mse)

pred_train_inter_mse :  0.010746166815806646
pred_train_extra_mse : 0.0488120950612938


In [28]:
print("pred_test_inter_mse : ", pred_test_inter_mse)
print('pred_test_extra_mse :' , pred_test_extra_mse)

pred_test_inter_mse :  0.015077762713190168
pred_test_extra_mse : 0.0685730007244274
