# Coral Experiments : up-sampling task

This notebook presents the up-sampling capabilities of the CORAL framework. **This notebook requires that models have already been trained.**

In [2]:
!nvidia-smi

import sys
sys.executable

Wed May 10 15:46:45 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.89.02    Driver Version: 525.89.02    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA TITAN Xp     On   | 00000000:06:00.0 Off |                  N/A |
| 23%   33C    P8     8W / 250W |      1MiB / 12288MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

'/data/lise.leboudec/conda/envs/coral/bin/python'

## Initialization 

In [3]:
import torch
import torch.nn as nn
import einops
from pathlib import Path
import os
from torchdiffeq import odeint
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
from omegaconf import DictConfig, OmegaConf
import omegaconf

from coral.utils.data.load_data import get_dynamics_data, set_seed
from expe.load_models.load_models_inr import load_coral, load_dino
from expe.forwards.forwards_inr import forward_coral, forward_dino
from expe.config.run_names import RUN_NAMES
from coral.utils.data.dynamics_dataset import TemporalDatasetWithCode

### Load config

In [4]:
cuda = torch.cuda.is_available()
if cuda:
    gpu_id = torch.cuda.current_device()
    device = torch.device(f"cuda:{gpu_id}")
else:
    device = torch.device("cpu")
print("device : ", device)

device :  cuda:0


In [5]:
data_dir = "/data/serrano/"
dataset_name = 'navier-stokes-dino'
root_dir = Path(os.getenv("WANDB_DIR")) / dataset_name
sub_from = 4
sub_tr = 0.2
sub_te = 0.2
inr_run_name =  RUN_NAMES[sub_from][sub_tr]["coral"]["inr"]
dyn_run_name =  RUN_NAMES[sub_from][sub_tr]["coral"]["dyn"]
dino_run_name =  RUN_NAMES[sub_from][sub_tr]["dino"]
n_baselines = (inr_run_name != None) + \
                 (dino_run_name != None)

print("inr_run_name : ", inr_run_name)
print("dyn_run_name : ", dyn_run_name)
print("dino_run_name : ", dino_run_name)
print(f"running on {n_baselines} baselines")

inr_run_name :  legendary-sky-4204
dyn_run_name :  jedi-parsec-4274
dino_run_name :  carbonite-droid-4311
running on 2 baselines


In [6]:
if dyn_run_name is not None: 
    cfg_coral_dyn = torch.load(root_dir / "model" / f"{dyn_run_name}.pt")['cfg']
if inr_run_name is not None: 
    cfg_coral_inr = torch.load(root_dir / "inr" / f"{inr_run_name}.pt")['cfg']

upsamplings = ['0', '1', '2', '4']

In [7]:
# data
ntrain = cfg_coral_dyn.data.ntrain
ntest = cfg_coral_dyn.data.ntest
data_to_encode = cfg_coral_dyn.data.data_to_encode
try:
    sub_from = cfg_coral_dyn.data.sub_from
except omegaconf.errors.ConfigAttributeError:
    sub_from = sub_tr # firsts runs don't have a sub_from attribute ie run 4 / 1-1
    sub_tr = 1
    sub_te = 1
seed = cfg_coral_dyn.data.seed
same_grid = cfg_coral_dyn.data.same_grid
setting = cfg_coral_dyn.data.setting
sequence_length_optim = None
sequence_length_in = cfg_coral_dyn.data.sequence_length_in
sequence_length_out = cfg_coral_dyn.data.sequence_length_out

print(f"running in setting {setting} with sampling {sub_from} / {sub_tr} - {sub_te}")

# dino
n_steps = 300
lr_adapt = 0.005

# coral
code_dim_coral = cfg_coral_inr.inr.latent_dim
width_dyn_coral = cfg_coral_dyn.dynamics.width
depth_dyn_coral = cfg_coral_dyn.dynamics.depth
inner_steps = cfg_coral_inr.optim.inner_steps

# optim
batch_size = 1
batch_size_val = 1
criterion = nn.MSELoss()

if dataset_name == 'shallow-water-dino':
    multichannel = True
else:
    multichannel = False

running in setting extrapolation with sampling 4 / 0.2 - 0.2


In [8]:
# experiments
sub_from1 = 4
sub_from2 = 2
sub_from3 = 1

### Load data

In [9]:
set_seed(seed)

(u_train, u_test, grid_tr, grid_te, _, _, _, _, u_train_ext, u_test_ext, grid_tr_ext, grid_te_ext) = get_dynamics_data(
    data_dir,
    dataset_name,
    ntrain,
    ntest,
    sequence_length=sequence_length_optim,
    sub_from=sub_from,
    sub_tr=sub_tr,
    sub_te=sub_te,
    same_grid=same_grid,
    setting=setting,
    sequence_length_in=sequence_length_in,
    sequence_length_out=sequence_length_out
)

(_, _, _, _, _, _, _, _, u_train_up1, u_test_up1, grid_tr_up1, grid_te_up1) = get_dynamics_data(
    data_dir,
    dataset_name,
    ntrain,
    ntest,
    sequence_length=sequence_length_optim,
    sub_from=sub_from1,
    sub_tr=1,
    sub_te=1,
    same_grid=same_grid,
    setting=setting,
    sequence_length_in=sequence_length_in,
    sequence_length_out=sequence_length_out
)

(_, _, _, _, _, _, _, _, u_train_up4, u_test_up4, grid_tr_up4, grid_te_up4) = get_dynamics_data(
    data_dir,
    dataset_name,
    ntrain,
    ntest,
    sequence_length=sequence_length_optim,
    sub_from=sub_from2,
    sub_tr=1,
    sub_te=1,
    same_grid=same_grid,
    setting=setting,
    sequence_length_in=sequence_length_in,
    sequence_length_out=sequence_length_out
)

(_, _, _, _, _, _, _, _, u_train_up16, u_test_up16, grid_tr_up16, grid_te_up16) = get_dynamics_data(
    data_dir,
    dataset_name,
    ntrain,
    ntest,
    sequence_length=sequence_length_optim,
    sub_from=sub_from3,
    sub_tr=1,
    sub_te=1,
    same_grid=same_grid,
    setting=setting,
    sequence_length_in=sequence_length_in,
    sequence_length_out=sequence_length_out
)

# flatten spatial dims
u_train = einops.rearrange(u_train, 'B ... C T -> B (...) C T')
grid_tr = einops.rearrange(grid_tr, 'B ... C T -> B (...) C T')  
u_test = einops.rearrange(u_test, 'B ... C T -> B (...) C T')
grid_te = einops.rearrange(grid_te, 'B ... C T -> B (...) C T')  
if u_train_ext is not None:
    u_train_ext = einops.rearrange(u_train_ext, 'B ... C T -> B (...) C T')
    grid_tr_ext = einops.rearrange(
        grid_tr_ext, 'B ... C T -> B (...) C T')  
    u_test_ext = einops.rearrange(u_test_ext, 'B ... C T -> B (...) C T')
    grid_te_ext = einops.rearrange(
        grid_te_ext, 'B ... C T -> B (...) C T')  
if u_train_up1 is not None:
    u_train_up1 = einops.rearrange(u_train_up1, 'B ... C T -> B (...) C T')
    grid_tr_up1 = einops.rearrange(
        grid_tr_up1, 'B ... C T -> B (...) C T')  
    u_test_up1 = einops.rearrange(u_test_up1, 'B ... C T -> B (...) C T')
    grid_te_up1 = einops.rearrange(
        grid_te_up1, 'B ... C T -> B (...) C T') 
if u_train_up4 is not None:
    u_train_up4 = einops.rearrange(u_train_up4, 'B ... C T -> B (...) C T')
    grid_tr_up4 = einops.rearrange(
        grid_tr_up4, 'B ... C T -> B (...) C T')  
    u_test_up4 = einops.rearrange(u_test_up4, 'B ... C T -> B (...) C T')
    grid_te_up4 = einops.rearrange(
        grid_te_up4, 'B ... C T -> B (...) C T') 
if u_train_up16 is not None:
    u_train_up16 = einops.rearrange(u_train_up16, 'B ... C T -> B (...) C T')
    grid_tr_up16 = einops.rearrange(
        grid_tr_up16, 'B ... C T -> B (...) C T')  
    u_test_up16 = einops.rearrange(u_test_up16, 'B ... C T -> B (...) C T')
    grid_te_up16 = einops.rearrange(
        grid_te_up16, 'B ... C T -> B (...) C T') 

print(
    f"data: {dataset_name}, u_train: {u_train.shape}, u_test: {u_test.shape}")
print(f"grid: grid_tr: {grid_tr.shape}, grid_te: {grid_te.shape}")
if u_train_ext is not None:
    print(
        f"data: {dataset_name}, u_train_ext: {u_train_ext.shape}, u_test_ext: {u_test_ext.shape}")
    print(
        f"grid: grid_tr_ext: {grid_tr_ext.shape}, grid_te_ext: {grid_te_ext.shape}")


: 

: 

In [None]:
n_seq_train = u_train.shape[0]  # 512 en dur
n_seq_test = u_test.shape[0]  # 512 en dur
spatial_size = u_train.shape[1]  # 64 en dur
state_dim = u_train.shape[2]  # N, XY, C, T
coord_dim = grid_tr.shape[2]  # N, XY, C, T
T = u_train.shape[-1]

ntrain = u_train.shape[0]  # int(u_train.shape[0]*T)
ntest = u_test.shape[0]  # int(u_test.shape[0]*T)


In [None]:
trainset = TemporalDatasetWithCode(
    u_train, grid_tr, code_dim_coral, dataset_name, data_to_encode
)
testset = TemporalDatasetWithCode(
    u_test, grid_te, code_dim_coral, dataset_name, data_to_encode
)
if u_train_ext is not None:
    trainset_ext = TemporalDatasetWithCode(
        u_train_ext, grid_tr_ext, code_dim_coral, dataset_name, data_to_encode)
if u_test_ext is not None:
    testset_ext = TemporalDatasetWithCode(
        u_test_ext, grid_te_ext, code_dim_coral, dataset_name, data_to_encode)
if u_train_up1 is not None:
    trainset_up1 = TemporalDatasetWithCode(
        u_train_up1, grid_tr_up1, code_dim_coral, dataset_name, data_to_encode)
if u_test_up1 is not None:
    testset_up1 = TemporalDatasetWithCode(
        u_test_up1, grid_te_up1, code_dim_coral, dataset_name, data_to_encode)
if u_train_up4 is not None:
    trainset_up4 = TemporalDatasetWithCode(
        u_train_up4, grid_tr_up4, code_dim_coral, dataset_name, data_to_encode)
if u_test_up4 is not None:
    testset_up4 = TemporalDatasetWithCode(
        u_test_up4, grid_te_up4, code_dim_coral, dataset_name, data_to_encode)
if u_train_up16 is not None:
    trainset_up16 = TemporalDatasetWithCode(
        u_train_up16, grid_tr_up16, code_dim_coral, dataset_name, data_to_encode)
if u_test_up16 is not None:
    testset_up16 = TemporalDatasetWithCode(
        u_test_up16, grid_te_up16, code_dim_coral, dataset_name, data_to_encode)


In [None]:
# create torch dataset
train_loader = torch.utils.data.DataLoader(
    trainset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=1,
    pin_memory=True,
    drop_last=True,
)
test_loader = torch.utils.data.DataLoader(
    testset,
    batch_size=batch_size_val,
    shuffle=True,
    num_workers=1,
    drop_last=True,
)
if u_train_ext is not None:
    train_loader_ext = torch.utils.data.DataLoader(
        trainset_ext,
        batch_size=batch_size_val,
        shuffle=True,
        num_workers=1,
        drop_last=True,
    )
if u_test_ext is not None:
    test_loader_ext = torch.utils.data.DataLoader(
        testset_ext,
        batch_size=batch_size_val,
        shuffle=True,
        num_workers=1,
        drop_last=True,
    )
if u_train_up1 is not None:
    train_loader_up1 = torch.utils.data.DataLoader(
        trainset_up1,
        batch_size=batch_size_val,
        shuffle=True,
        num_workers=1,
        drop_last=True,
    )
if u_test_up1 is not None:
    test_loader_up1 = torch.utils.data.DataLoader(
        testset_up1,
        batch_size=batch_size_val,
        shuffle=True,
        num_workers=1,
        drop_last=True,
    )
if u_train_up4 is not None:
    train_loader_up4 = torch.utils.data.DataLoader(
        trainset_up4,
        batch_size=batch_size_val,
        shuffle=True,
        num_workers=1,
        drop_last=True,
    )
if u_test_up4 is not None:
    test_loader_up4 = torch.utils.data.DataLoader(
        testset_up4,
        batch_size=batch_size_val,
        shuffle=True,
        num_workers=1,
        drop_last=True,
    )
if u_train_up16 is not None:
    train_loader_up16 = torch.utils.data.DataLoader(
        trainset_up16,
        batch_size=batch_size_val,
        shuffle=True,
        num_workers=1,
        drop_last=True,
    )
if u_test_up16 is not None:
    test_loader_up16 = torch.utils.data.DataLoader(
        testset_up16,
        batch_size=batch_size_val,
        shuffle=True,
        num_workers=1,
        drop_last=True,
    )

In [None]:
T = u_train.shape[-1]
if u_test_ext is not None:
    T_EXT = u_test_ext.shape[-1]

# trainset coords of shape (N, Dx, Dy, input_dim, T)
input_dim = grid_tr.shape[-2]
# trainset images of shape (N, Dx, Dy, output_dim, T)
output_dim = u_train.shape[-2]

dt = 1
timestamps_train = torch.arange(0, T, dt).float().cuda()
timestamps_ext = torch.arange(0, T_EXT, dt).float().cuda()


### Load models

In [None]:
inr, alpha, dyn, z_mean, z_std = load_coral(root_dir, inr_run_name, dyn_run_name, data_to_encode, input_dim, output_dim, trainset, testset, multichannel, code_dim_coral, width_dyn_coral, depth_dyn_coral, inner_steps)
net_dec, net_dyn, states_params, code_dim_dino = load_dino(root_dir, dino_run_name)

Train, average loss: 7.169906737658494e-05
Test, average loss: 0.027545505203306675
dict_keys(['net.bilinear.0.A', 'net.bilinear.0.B', 'net.bilinear.0.bias', 'net.bilinear.1.A', 'net.bilinear.1.B', 'net.bilinear.1.bias', 'net.bilinear.2.A', 'net.bilinear.2.B', 'net.bilinear.2.bias', 'net.bilinear.3.A', 'net.bilinear.3.B', 'net.bilinear.3.bias', 'net.output_bilinear.weight', 'net.output_bilinear.bias', 'net.filters.0.weight', 'net.filters.1.weight', 'net.filters.2.weight', 'net.filters.3.weight'])
dict_keys(['net.net.0.weight', 'net.net.0.bias', 'net.net.1.beta', 'net.net.2.weight', 'net.net.2.bias', 'net.net.3.beta', 'net.net.4.weight', 'net.net.4.bias', 'net.net.5.beta', 'net.net.6.weight', 'net.net.6.bias'])


## Forwards

In [None]:
idx = 0

batch = next(iter(test_loader_ext))
pred_coral0 = forward_coral(inr, dyn, batch, inner_steps, alpha, True, timestamps_ext, z_mean, z_std, dataset_name).cpu().detach().numpy()
pred_dino0 = forward_dino(net_dec, net_dyn, batch, n_seq_train, states_params, code_dim_dino, n_steps, lr_adapt, device, criterion, timestamps_ext, save_best=True, method="rk4").cpu().detach().numpy()
batch = next(iter(test_loader_up1))
pred_coral1 = forward_coral(inr, dyn, batch, inner_steps, alpha, True, timestamps_ext, z_mean, z_std, dataset_name).cpu().detach().numpy()
pred_dino1 = forward_dino(net_dec, net_dyn, batch, n_seq_train, states_params, code_dim_dino, n_steps, lr_adapt, device, criterion, timestamps_ext, save_best=True, method="rk4").cpu().detach().numpy()
batch = next(iter(test_loader_up4))
pred_coral4 = forward_coral(inr, dyn, batch, inner_steps, alpha, True, timestamps_ext, z_mean, z_std, dataset_name).cpu().detach().numpy()
pred_dino4 = forward_dino(net_dec, net_dyn, batch, n_seq_train, states_params, code_dim_dino, n_steps, lr_adapt, device, criterion, timestamps_ext, save_best=True, method="rk4").cpu().detach().numpy()
# batch = next(iter(test_loader_up16))
# pred_coral16 = forward_coral(inr, dyn, batch, inner_steps, alpha, True, timestamps_ext, z_mean, z_std, dataset_name).cpu().detach().numpy()
# pred_dino16 = forward_dino(net_dec, net_dyn, batch, n_seq_train, states_params, code_dim_dino, n_steps, lr_adapt, device, criterion, timestamps_ext, save_best=True, method="rk4").cpu().detach().numpy()


NameError: name 'test_loader_ext' is not defined

## Vizualiation

In [None]:
idx = [0]
time2show = 20
path = '/home/lise.leboudec/project/coral/xp/vis/'

b0 = trainset_ext[idx]
b1 = trainset_up1[idx]
b4 = trainset_up4[idx]
b16 = trainset_up16[idx]

x = b0[2][0, ..., 0, time2show]
y = b0[2][0, ..., 1, time2show]

print("pred_coral0.shape : ", pred_coral0.shape)
print("pred_coral1.shape : ", pred_coral1.shape)
print("pred_coral4.shape : ", pred_coral4.shape)
# print("pred_coral16.shape : ", pred_coral16.shape)

fig, axs = plt.subplots(3, 4, figsize=(16, 4))
axs[0, 0].scatter(y, -x, 50, b0[0][0, ..., time2show], edgecolor="w",
    lw=0.2,)
axs[0, 1].imshow(b1[0][0, ..., time2show].reshape(64, 64))
# axs[0, 0].set_title(f"Prediction, rel mse = {100*pred_test_mse:.2f}%", fontsize=8)
axs[0, 2].imshow(b4[0][0, ..., time2show].reshape(128, 128))
# axs[0, 1].set_title(f"Ground truth", fontsize=8)
axs[0, 3].imshow(b16[0][0, ..., time2show].reshape(256, 256))

axs[1, 0].scatter(x, -y, 50, pred_coral0[0][0, ..., time2show], edgecolor="w",
    lw=0.2,)
axs[1, 1].imshow(pred_coral1[0][0, ..., time2show])
# axs[1, 0].set_title(f"Prediction, rel mse = {100*pred_test_mse:.2f}%", fontsize=8)
axs[1, 2].imshow(pred_coral4[0][0, ..., time2show])
# axs[1, 1].set_title(f"Ground truth", fontsize=8)
# axs[1, 3].imshow(pred_coral16[0][0, ..., time2show])

axs[1, 0].scatter(x, -y, 50, pred_dino0[0][0, ..., time2show], edgecolor="w",
    lw=0.2,)
axs[1, 1].imshow(pred_dino1[0][0, ..., time2show])
# axs[1, 0].set_title(f"Prediction, rel mse = {100*pred_test_mse:.2f}%", fontsize=8)
axs[1, 2].imshow(pred_dino4[0][0, ..., time2show])
# axs[1, 1].set_title(f"Ground truth", fontsize=8)
# axs[1, 3].imshow(pred_dino16[0][0, ..., time2show])

# plt.savefig(os.path.join(plot_dir, 'ns-upsampling64to256.png'), bbox_inches='tight', dpi=300)

"""
fig, ax = plt.subplots(1, 2)

def animate(i):
    ax[0].imshow(b_te[0][0, ..., i])
    ax[1].imshow(pred_te_coral[0][0, ..., i])
    return ax

ani = FuncAnimation(fig, animate, interval=40, repeat=True, frames=len(T_EXT))
ani.save(path + "vis_test_exp.gif", dpi=300, writer=PillowWriter(fps=25))"""

NameError: name 'pred_coral0' is not defined