In [1]:
from datetime import datetime

import os

from torch import nn
import matplotlib.pyplot as plt
import torch

from pathlib import Path
from omegaconf import DictConfig, OmegaConf
import einops
import yaml

from coral.utils.data.load_data import (set_seed, get_dynamics_data)
from coral.utils.data.dynamics_dataset import TemporalDatasetWithCode, KEY_TO_INDEX
from dynamics_modeling.train import DetailedMSE
from deeponet import DeepONet, AR_forward
from eval import eval_deeponet
from forwards_operator import forward_deeponet_up

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  @hydra.main(config_path="config/", config_name="ode.yaml")


In [12]:
cfg = DictConfig(yaml.safe_load(open("config/deeponet.yaml")))
dataset_name = cfg.data.dataset_name
dataset_name = 'navier-stokes-dino'
run_name = cfg.deeponet.run_name
run_name='lilac-star-5018'
root_dir = Path(os.getenv("WANDB_DIR")) / dataset_name

tmp = torch.load(root_dir / "deeponet" / f"{run_name}_tr.pt")
cfg = tmp['cfg']

In [13]:
# data
data_dir = cfg.data.dir
dataset_name = cfg.data.dataset_name
ntrain = cfg.data.ntrain
ntest = cfg.data.ntest
data_to_encode = cfg.data.data_to_encode
sub_from = 4
sub_tr = cfg.data.sub_tr
sub_te = cfg.data.sub_te
seed = cfg.data.seed
same_grid = cfg.data.same_grid
seq_inter_len = cfg.data.seq_inter_len
seq_extra_len = cfg.data.seq_extra_len

# deeponet
model_type = cfg.deeponet.model_type
code_dim = 1
branch_depth = cfg.deeponet.branch_depth
trunk_depth = cfg.deeponet.trunk_depth
width = cfg.deeponet.width

# optim
batch_size = cfg.optim.batch_size
batch_size_val = (
    batch_size if cfg.optim.batch_size_val == None else cfg.optim.batch_size_val
)

multichannel = False

In [14]:
print("sub_tr, sub_from, sub_te : ", sub_tr, sub_from, sub_te)

sub_tr, sub_from, sub_te :  0.2 4 0.2


In [15]:
set_seed(seed)

(u_train, u_eval_extrapolation, u_test, grid_tr, grid_tr_extra, grid_te) = get_dynamics_data(
    data_dir,
    dataset_name,
    ntrain,
    ntest,
    seq_inter_len=seq_inter_len,
    seq_extra_len=seq_extra_len,
    sub_from=sub_from,
    sub_tr=sub_tr,
    sub_te=sub_te,
    same_grid=same_grid,
)

# (_, _, u_test_2, _, _, grid_te_2) = get_dynamics_data(
#     data_dir,
#     dataset_name,
#     ntrain,
#     ntest,
#     seq_inter_len=seq_inter_len,
#     seq_extra_len=seq_extra_len,
#     sub_tr=sub_tr,
#     sub_te=sub_te_2,
#     same_grid=same_grid,
# )

# flatten spatial dims
u_train = einops.rearrange(u_train, 'B ... C T -> B (...) C T')
grid_tr = einops.rearrange(grid_tr, 'B ... C T -> B (...) C T')  # * 0.5
u_test = einops.rearrange(u_test, 'B ... C T -> B (...) C T')
grid_te = einops.rearrange(grid_te, 'B ... C T -> B (...) C T')  # * 0.5
u_eval_extrapolation = einops.rearrange(u_eval_extrapolation, 'B ... C T -> B (...) C T')
grid_tr_extra = einops.rearrange(grid_tr_extra, 'B ... C T -> B (...) C T')  # * 0.5

# u_test_2 = einops.rearrange(u_test_2, 'B ... C T -> B (...) C T')
# grid_te_2 = einops.rearrange(grid_te_2, 'B ... C T -> B (...) C T')  # * 0.5

print(f"data: {dataset_name}, u_train: {u_train.shape}, u_train_eval: {u_eval_extrapolation.shape}, u_test: {u_test.shape}")
print(f"grid: grid_tr: {grid_tr.shape}, grid_tr_extra: {grid_tr_extra.shape}, grid_te: {grid_te.shape}")

n_seq_train = u_train.shape[0]  # 512 en dur
n_seq_test = u_test.shape[0]  # 512 en dur
spatial_size = u_train.shape[1] * u_train.shape[2] # 64 en dur
state_dim = u_train.shape[2]  # N, XY, C, T
coord_dim = grid_tr.shape[2]  # N, XY, C, T
T = u_train.shape[-1]

ntrain = u_train.shape[0]  # int(u_train.shape[0]*T)
ntest = u_test.shape[0]  # int(u_test.shape[0]*T)


trainset_out = TemporalDatasetWithCode(
    u_eval_extrapolation, grid_tr_extra, code_dim, dataset_name, data_to_encode
)

testset = TemporalDatasetWithCode(
    u_test, grid_te, code_dim, dataset_name, data_to_encode
)

# testset_2 = TemporalDatasetWithCode(
#     u_test_2, grid_te_2, code_dim, dataset_name, data_to_encode
# )


valid_loader = torch.utils.data.DataLoader(
    trainset_out,
    batch_size=batch_size,
    shuffle=True, # TODO : here shuffle to False because error cuda (?!)
    num_workers=1,
    pin_memory=True,
)

test_loader = torch.utils.data.DataLoader(
    testset,
    batch_size=batch_size_val,
    shuffle=True, # TODO : here shuffle to False because error cuda (?!)
    num_workers=1,
)

# test_loader_2 = torch.utils.data.DataLoader(
#     testset_2,
#     batch_size=batch_size_val,
#     shuffle=True, # TODO : here shuffle to False because error cuda (?!)
#     num_workers=1,
# )

if multichannel:
    detailed_train_eval_mse = DetailedMSE(list(KEY_TO_INDEX[dataset_name].keys()),
                                            dataset_name,
                                            mode="train_extra",
                                            n_trajectories=n_seq_train)
    detailed_test_mse = DetailedMSE(list(KEY_TO_INDEX[dataset_name].keys()),
                                    dataset_name,
                                    mode="test",
                                    n_trajectories=n_seq_test)
else:
    detailed_train_eval_mse = None
    detailed_test_mse = None

T = u_train.shape[-1]
T_EXT = u_test.shape[-1]

dt = 1
timestamps_train = torch.arange(0, T, dt).float().cuda()
timestamps_ext = torch.arange(0, T_EXT, dt).float().cuda()

net_dyn_params = {
    'branch_dim': spatial_size,
    'branch_depth': branch_depth,
    'trunk_dim': coord_dim,
    'trunk_depth': trunk_depth,
    'width': width
}

data: navier-stokes-dino, u_train: torch.Size([256, 819, 1, 20]), u_train_eval: torch.Size([256, 819, 1, 40]), u_test: torch.Size([16, 819, 1, 40])
grid: grid_tr: torch.Size([256, 819, 2, 20]), grid_tr_extra: torch.Size([256, 819, 2, 40]), grid_te: torch.Size([16, 819, 2, 40])


In [16]:
deeponet = DeepONet(**net_dyn_params, logger=None, input_dataset=dataset_name)
deeponet.load_state_dict(tmp['deeponet_state_dict'])
deeponet = deeponet.to('cuda')
criterion = nn.MSELoss()

In [17]:
print("Evaluating train...")
(
    pred_train_mse,
    pred_train_inter_mse,
    pred_train_extra_mse
) = eval_deeponet(
    deeponet, 
    valid_loader,
    'cuda', 
    timestamps_ext,
    criterion, 
    n_seq_train, 
    seq_inter_len, 
    seq_extra_len, 
    detailed_train_eval_mse
)

# Out-of-domain evaluation
print("Evaluating test...")
(
    pred_test_mse,
    pred_test_inter_mse,
    pred_test_extra_mse
) = eval_deeponet(
    deeponet, 
    test_loader, 
    'cuda', 
    timestamps_ext,
    criterion, 
    n_seq_test, 
    seq_inter_len, 
    seq_extra_len, 
    detailed_test_mse
)

Evaluating train...
torch.Size([40, 40, 819, 1])
torch.Size([40, 40, 819, 1])
torch.Size([40, 40, 819, 1])
torch.Size([40, 40, 819, 1])
torch.Size([40, 40, 819, 1])
torch.Size([40, 40, 819, 1])
torch.Size([16, 40, 819, 1])
Evaluating test...
torch.Size([16, 40, 819, 1])


In [18]:
print("pred_train_inter_mse : ", pred_train_inter_mse)
print('pred_train_extra_mse :' , pred_train_extra_mse)

pred_train_inter_mse :  0.0313153974711895
pred_train_extra_mse : 0.031304750591516495


In [19]:
print("pred_test_inter_mse : ", pred_test_inter_mse)
print('pred_test_extra_mse :' , pred_test_extra_mse)

pred_test_inter_mse :  0.5221408009529114
pred_test_extra_mse : 0.5004575848579407


In [19]:
def eval_deeponet_up(deeponet, dataloader, testset, device, timestamps, criterion, n_seq, n_frames_in, n_frames_out, detailed_mse, multichannel=False):
    """def eval_dino(
    dataloader,
    net_dyn,
    net_dec,
    device,
    method,
    criterion,
    state_dim,
    code_dim,
    coord_dim,
    detailed_mse,
    timestamps,
    n_seq,
    n_frames_train=0,
    n_frames_test=0,
    states_params=None,
    lr_adapt=0.0,
    n_steps=300,
    multichannel=False,
    save_best=True,
):"""
    """
    In_t: loss within train horizon.
    Out_t: loss outside train horizon.
    In_s: loss within observation grid.
    Out_s: loss outside observation grid.
    loss: loss averaged across in_t/out_t and in_s/out_s
    loss_in_t: loss averaged across in_s/out_s for in_t.
    loss_in_t_in_s, loss_in_t_out_s: loss in_t + in_s / out_s
    """

    (
        loss,
        loss_out_t,
        loss_in_t,
    ) = (0.0, 0.0, 0.0)

    set_requires_grad(deeponet, False)

    for j, (images, _, coords, idx) in enumerate(dataloader):
        # flatten spatial dims
        t = timestamps.to(device)
        images_up = testset[idx][0] # 1, 4096, 1, 40
        coords_up = testset[idx][2] # 1, 4096, 2, 40
        ground_truth = einops.rearrange(images, 'B ... C T -> B (...) C T')
        model_input = einops.rearrange(coords, 'B ... C T -> B (...) C T')
        ground_truth_up = einops.rearrange(images_up, 'B ... C T -> B (...) C T')
        model_input_up = einops.rearrange(coords_up, 'B ... C T -> B (...) C T')

        # permute axis for forward
        ground_truth = torch.permute(
            ground_truth, (0, 3, 1, 2)).to(device)  # [B, XY, C, T] -> [B, T, XY, C]
        model_input = torch.permute(
            model_input, (0, 3, 1, 2))[:, 0, :, :].to(device)  # ([B, XY, C, T] -> -> [B, T, XY, C] -> [B, XY, C]
        ground_truth_up = torch.permute(
            ground_truth_up, (0, 3, 1, 2)).to(device)  # [B, XY, C, T] -> [B, T, XY, C]
        model_input_up = torch.permute(
            model_input_up, (0, 3, 1, 2))[:, 0, :, :].to(device)  # ([B, XY, C, T] -> -> [B, T, XY, C] -> [B, XY, C]
        
        # On prend que la première grille (c'est tjs la mm dans deeponet) 
        b_size, t_size, hw_size, channels = ground_truth.shape

        # t is T, model_input is B, T, XY, grid, ground_truth is B, T, XY, C

        model_output = forward_deeponet_up(deeponet, ground_truth, ground_truth_up, coords, coords_up, timestamps, device)
        print("model_output.shape, ground_truth_up.shape : ", model_output.shape, ground_truth_up.shape)
        # B, T, XY, C
        if n_frames_out == 0:
            loss += criterion(model_output, ground_truth_up).item() * b_size
        else : 
            loss = criterion(model_output, ground_truth_up).item() * b_size
            loss_in_t = criterion(model_output[:, :n_frames_in, :, :], ground_truth_up[:, :n_frames_in, :, :]).item() * b_size
            loss_out_t = criterion(model_output[:, n_frames_in:n_frames_in+n_frames_out, :, :], ground_truth_up[:, n_frames_in:n_frames_in+n_frames_out, :, :] ).item()* b_size
            
        if multichannel:
            detailed_mse.aggregate(model_output.detach(),
                                   ground_truth_up.detach())

    loss /= n_seq
    loss_in_t /= n_seq
    loss_out_t /= n_seq

    set_requires_grad(deeponet, True)

    return (
        loss,
        loss_in_t,
        loss_out_t,
    )


def set_requires_grad(module, tf=False):
    module.requires_grad = tf
    for param in module.parameters():
        param.requires_grad = tf

In [20]:
# Out-of-domain evaluation
print("Evaluating test...")
(
    pred_test_mse,
    pred_test_inter_mse,
    pred_test_extra_mse
) = eval_deeponet_up(
    deeponet, 
    test_loader,
    testset_2,
    'cuda', 
    timestamps_ext,
    criterion, 
    n_seq_test, 
    seq_inter_len, 
    seq_extra_len, 
    detailed_test_mse
)

Evaluating test...
model_output.shape, ground_truth_up.shape :  torch.Size([8, 40, 2048, 2]) torch.Size([8, 40, 2048, 2])


In [21]:
print(pred_test_inter_mse, pred_test_extra_mse)

0.01177526917308569 0.01664809323847294


In [28]:
sub_te_2

4

In [None]:
1.18e-2 & 1.66e-2

0.015738634392619133 0.019326908513903618