In [2]:
from pathlib import Path
from pickletools import OpcodeInfo
import numpy as np
import os
import torch.nn as nn
import einops
import yaml
from omegaconf import DictConfig, OmegaConf
from coral.utils.data.dynamics_dataset import (KEY_TO_INDEX, TemporalDatasetWithCode)
from coral.utils.models.load_inr import create_inr_instance, load_inr_model
from coral.utils.data.load_data import get_dynamics_data, set_seed
from coral.utils.data.load_modulations import load_dynamics_modulations
from coral.utils.models.get_inr_reconstructions import get_reconstructions
from coral.utils.models.scheduling import ode_scheduling
from torchdiffeq import odeint
from dynamics_modeling.eval import batch_eval_loop
from coral.mlp import MLP, Derivative

In [3]:
import matplotlib.pyplot as plt
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import torch

import matplotlib as mpl
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
from scipy.interpolate import interp2d
import numpy as np

fragrant-bee-4658 # 0.0125
iconic-firefly-4657 # 5%
atomic-surf-4656 # 100%

In [53]:
cfg = DictConfig(yaml.safe_load(open("config/ode.yaml")))
dataset_name = cfg.data.dataset_name
dataset_name = 'navier-stokes-dino'
# dyn
load_dyn_run = cfg.dynamics.run_name
root_dir = Path(os.getenv("WANDB_DIR")) / dataset_name
print(root_dir)
load_dyn_run = 'smooth-firebrand-5016'

/home/kassai/wandb_logs/navier-stokes-dino


In [54]:
# Load dyn_model 

tmp_dyn = torch.load(root_dir / "model" / f"{load_dyn_run}.pt")
cfg = tmp_dyn['cfg']
model_state = tmp_dyn['model']

data_dir = cfg.data.dir
ntrain = cfg.data.ntrain
ntest = cfg.data.ntest
data_to_encode = cfg.data.data_to_encode
sub_tr = cfg.data.sub_tr
sub_from = cfg.data.sub_from if cfg.data.sub_from is not None else 1
sub_te = cfg.data.sub_te
seed = cfg.data.seed
same_grid = cfg.data.same_grid
seq_inter_len = 20
seq_extra_len = 20

# optim
batch_size = cfg.optim.batch_size
batch_size_val = (
    batch_size if cfg.optim.batch_size_val == None else cfg.optim.batch_size_val
)
batch_size=1

In [55]:
print("sub_tr, sub_te : ", sub_tr, sub_te)

sub_tr, sub_te :  0.2 0.2


In [56]:
# inr
load_run_name = cfg.inr.run_name
inner_steps = cfg.inr.inner_steps

try:
    load_run_dict = dict(cfg.inr.run_dict)
except TypeError:
    load_run_dict = cfg.inr.run_dict

In [57]:
# loading

if data_to_encode is not None:
    inr_dir = Path(os.getenv("WANDB_DIR")) / \
        dataset_name / data_to_encode / "inr"
    modulations_dir = (
        Path(os.getenv("WANDB_DIR")) / dataset_name /
        data_to_encode / "modulations"
    )
    model_dir = (
        Path(os.getenv("WANDB_DIR")) /
        dataset_name / data_to_encode / "model"
    )

else:
    inr_dir = Path(os.getenv("WANDB_DIR")) / dataset_name / "inr"
    modulations_dir = Path(os.getenv("WANDB_DIR")) / \
        dataset_name / "modulations"
    model_dir = Path(os.getenv("WANDB_DIR")) / dataset_name / "model"

# we need the latent dim and the sub_tr used for training
if load_run_name is not None:
    multichannel = False
    tmp = torch.load(root_dir / "inr" / f"{load_run_name}.pt")
    latent_dim = tmp["cfg"].inr.latent_dim

elif load_run_dict is not None:
    multichannel = True
    tmp_data_to_encode = list(load_run_dict.keys())[0]
    tmp_run_name = list(load_run_dict.values())[0]
    tmp = torch.load(root_dir / tmp_data_to_encode /
                        "inr" / f"{tmp_run_name}.pt")
    latent_dim = tmp["cfg"].inr.latent_dim

In [58]:
random =  2 #np.random.randint(1, 8)
preds = []
truths = []
#sub_tes = [0.05, 4, 2, 1]
sub_tes = [sub_te]
set_seed(seed)
for i, sub_te in enumerate(sub_tes):
    (u_train, u_train_eval, u_test, grid_tr, grid_tr_extra, grid_te) = get_dynamics_data(
        data_dir,
        dataset_name,
        ntrain,
        ntest,
        seq_inter_len=seq_inter_len,
        seq_extra_len=seq_extra_len,
        sub_from = sub_from,
        sub_tr=sub_tr,
        sub_te=sub_te,
        same_grid=same_grid,
    )
    print(
        f"data: {dataset_name}, u_train: {u_train.shape}, u_train_eval: {u_train_eval.shape}, u_test: {u_test.shape}")
    print(f"grid: grid_tr: {grid_tr.shape}, grid_tr_extra: {grid_tr_extra.shape}, grid_te: {grid_te.shape}")

    trainset = TemporalDatasetWithCode(
        u_train, grid_tr, latent_dim, dataset_name, data_to_encode
    )
    trainset_extra = TemporalDatasetWithCode(
        u_train_eval, grid_tr_extra, latent_dim, dataset_name, data_to_encode
    )
    testset = TemporalDatasetWithCode(
        u_test, grid_te, latent_dim, dataset_name, data_to_encode
    )

    #total frames trainset
    ntrain = trainset.z.shape[0]

    #total frames testset
    ntest = testset.z.shape[0]

    # sequence length 
    T_train = u_train.shape[-1]
    T_test = u_test.shape[-1]

    dt = 1
    timestamps_train = torch.arange(0, T_train, dt).float().cuda()
    timestamps_test = torch.arange(0, T_test, dt).float().cuda()

    # trainset coords of shape (N, Dx, Dy, input_dim, T)
    input_dim = grid_tr.shape[-2]
    # trainset images of shape (N, Dx, Dy, output_dim, T)
    output_dim = u_train.shape[-2]

    if load_run_name is not None:
        inr, alpha = load_inr_model(
            root_dir / "inr",
            load_run_name,
            data_to_encode,
            input_dim=input_dim,
            output_dim=output_dim,
        )
        c = 1
        modulations = load_dynamics_modulations(
            trainset,
            trainset_extra,
            testset,
            inr,
            root_dir / "modulations",
            load_run_name,
            inner_steps=inner_steps,
            alpha=alpha,
            batch_size=2,
            data_to_encode=None,
            try_reload=False,
        )
        z_train = modulations["z_train"]
        z_train_extra = modulations["z_train_extra"]
        z_test = modulations["z_test"]
        z_mean = einops.rearrange(z_train, "b l t -> (b t) l").mean(0).reshape(1, latent_dim, 1)#
        z_std = einops.rearrange(z_train, "b l t -> (b t) l").std(0).reshape(1, latent_dim, 1)
        z_train = (z_train - z_mean) / z_std
        z_train_extra = (z_train_extra - z_mean) / z_std
        z_test = (z_test - z_mean) / z_std

    elif load_run_dict is not None:
        inr_dict = {}
        z_mean = {}
        z_std = {}
        c = len(list(load_run_dict.keys()))
        z_train = torch.zeros(ntrain, latent_dim, c, T_train)
        z_train_extra = torch.zeros(ntrain, latent_dim, c, T_test)
        z_test = torch.zeros(ntest, latent_dim, c, T_test)

        for to_encode in list(load_run_dict.keys()):
            tmp_name = load_run_dict[to_encode]
            output_dim = 1
            inr, alpha = load_inr_model(
                root_dir / to_encode / "inr",
                tmp_name,
                to_encode,
                input_dim=input_dim,
                output_dim=output_dim,
            )

            trainset.set_data_to_encode(to_encode)
            trainset_extra.set_data_to_encode(to_encode)
            testset.set_data_to_encode(to_encode)

            modulations = load_dynamics_modulations(
                trainset,
                trainset_extra,
                testset,
                inr,
                root_dir / to_encode / "modulations",
                tmp_name,
                inner_steps=inner_steps,
                alpha=alpha,
                batch_size=1,
                data_to_encode=to_encode,
                try_reload=False,
            )
            inr_dict[to_encode] = inr
            z_tr = modulations["z_train"]
            z_tr_extra = modulations["z_train_extra"]
            z_te = modulations["z_test"]
            z_m = einops.rearrange(z_tr, "b l t -> (b t) l").mean(0).reshape(1, latent_dim, 1)
            z_s = einops.rearrange(z_tr, "b l t -> (b t) l").std(0).reshape(1, latent_dim, 1)
            z_mean[to_encode] = z_m
            z_std[to_encode] = z_s
            z_train_extra[..., KEY_TO_INDEX[dataset_name]
                    [to_encode], :] = (z_tr_extra - z_m) / z_s
            z_test[..., KEY_TO_INDEX[dataset_name]
                    [to_encode], :] = (z_te - z_m) / z_s

        # concat the code
        trainset_extra.set_data_to_encode(None)
        testset.set_data_to_encode(None)
        # rename inr_dict <- inr
        inr = inr_dict

    trainset_extra.z = z_train_extra
    testset.z = z_test

    print('ztrain_extra', z_train_extra.shape, z_train_extra.mean(), z_train_extra.std())
    print('ztest', z_test.shape, z_test.mean(), z_test.std())

    train_extra_loader = torch.utils.data.DataLoader(
        trainset_extra,
        batch_size=batch_size,
        shuffle=True,
        num_workers=1,
        pin_memory=True,
    )
    test_loader = torch.utils.data.DataLoader(
        testset,
        batch_size=batch_size_val,
        shuffle=True,
        num_workers=1,
    )
    # Load model
    hidden = cfg.dynamics.width
    depth = cfg.dynamics.depth

    model = Derivative(c, z_train.shape[1], hidden, depth).cuda()
    model.load_state_dict(model_state)
    
    model.eval()
    ground_truth, modulations, coords = test_loader.dataset[random - 1: random][0], test_loader.dataset[random - 1: random][1], test_loader.dataset[random - 1: random][2]

    ground_truth = ground_truth.cuda()
    modulations = modulations.cuda()
    coords = coords.cuda()

    if multichannel:

        modulations = einops.rearrange(modulations, "b l c t -> b (l c) t")
        
    z_pred = ode_scheduling(odeint, model, modulations, timestamps_test, 0)
    pred = get_reconstructions(
        inr, coords, z_pred, z_mean, z_std, dataset_name
    )
    pred = pred.cpu()
    ground_truth = ground_truth.cpu()
    preds.append(pred)
    truths.append(ground_truth)

    if i == 0:
        coords_initial = coords

data: navier-stokes-dino, u_train: torch.Size([256, 3276, 1, 20]), u_train_eval: torch.Size([256, 3276, 1, 40]), u_test: torch.Size([16, 3276, 1, 40])
grid: grid_tr: torch.Size([256, 3276, 2, 20]), grid_tr_extra: torch.Size([256, 3276, 2, 40]), grid_te: torch.Size([16, 3276, 2, 40])
Train, average loss: 0.000693938372478442
Train extra, average loss: 0.00174077078463597
Test, average loss: 0.0018292279710294679
ztrain_extra torch.Size([256, 128, 40]) tensor(0.0050) tensor(1.0676)
ztest torch.Size([16, 128, 40]) tensor(0.0128) tensor(1.0706)


In [59]:
def gif_baselines(plot_dir, title, preds, truths, channel, view=(100., 0.)):
    
    T = preds[0].shape[-1]
    ims = []
    proj = ccrs.Orthographic(*view)

    # add subfigure per subplot
    fig = plt.figure(constrained_layout=True)
    subfigs = fig.subfigures(nrows=3, ncols=1)

    # clear subplots

    latitude1 = torch.linspace(90.0, -90.0, truths[0].shape[1])
    longitude1 = torch.linspace(0.0, 360.0 - (360.0 / truths[0].shape[2]), truths[0].shape[2])
    longitude_grid1, latitude_grid1 = torch.meshgrid(longitude1, latitude1, indexing="xy")

    latitude2 = torch.linspace(90.0, -90.0, truths[1].shape[1])
    longitude2 = torch.linspace(0.0, 360.0 - (360.0 / truths[1].shape[2]), truths[1].shape[2])
    longitude_grid2, latitude_grid2 = torch.meshgrid(longitude2, latitude2, indexing="xy")

    latitude3 = torch.linspace(90.0, -90.0, truths[2].shape[1])
    longitude3 = torch.linspace(0.0, 360.0 - (360.0 / truths[2].shape[2]), truths[2].shape[2])
    longitude_grid3, latitude_grid3 = torch.meshgrid(longitude3, latitude3, indexing="xy")
    
    axes = subfigs[0].subplots(1, 2, subplot_kw={'projection': proj})
    subfigs[1].suptitle('Sub_te = 4')
    axs = subfigs[1].subplots(1, 2, subplot_kw={'projection': proj})
    subfigs[1].suptitle('Sub_te = 2')
    ax = subfigs[2].subplots(1, 2, subplot_kw={'projection': proj})
    subfigs[2].suptitle('Sub_te = 1')

    for i in range(T):
        if i < 20:
            time = 'In-t'
            im0 = axes[0].pcolormesh(longitude_grid1, latitude_grid1, truths[0][0, ..., channel, i],
                                transform=ccrs.PlateCarree(), cmap='viridis', shading = 'auto')
            axes[0].set_title('Ground truth')
            axes[0].gridlines(linewidth=1, color='black', alpha=0.05)

            im1 = axes[1].pcolormesh(longitude_grid1, latitude_grid1, preds[0][0, ..., channel, i],
                                        transform=ccrs.PlateCarree(), cmap='viridis', shading = 'auto')
            axes[1].gridlines(linewidth=1, color='black', alpha=0.05)
            axes[1].set_title('Reconstruction')

            im2 = axs[0].pcolormesh(longitude_grid2, latitude_grid2, truths[1][0, ..., channel, i],
                                        transform=ccrs.PlateCarree(), cmap='viridis', shading = 'auto')
            axs[0].set_title('Ground truth')
            axs[0].gridlines(linewidth=1, color='black', alpha=0.05)

            im3 = axs[1].pcolormesh(longitude_grid2, latitude_grid2, preds[1][0, ..., channel, i],
                                        transform=ccrs.PlateCarree(), cmap='viridis', shading = 'auto')
            axs[1].gridlines(linewidth=1, color='black', alpha=0.05)
            axs[1].set_title('Reconstruction')

            im4 = ax[0].pcolormesh(longitude_grid3, latitude_grid3, truths[2][0, ..., channel, i],
                                        transform=ccrs.PlateCarree(), cmap='viridis', shading = 'auto')
            ax[0].set_title('Ground truth, {}'.format(time))
            ax[0].gridlines(linewidth=1, color='black', alpha=0.05)

            im5 = ax[1].pcolormesh(longitude_grid3, latitude_grid3, preds[2][0, ..., channel, i],
                                        transform=ccrs.PlateCarree(), cmap='viridis', shading = 'auto')
            ax[1].gridlines(linewidth=1, color='black', alpha=0.05)
            ax[1].set_title('Reconstruction')
            ims.append([im0, im1, im2, im3, im4, im5])
        elif i >= 20:
            time = 'Out-t'

            im0 = axes[0].pcolormesh(longitude_grid1, latitude_grid1, truths[0][0, ..., channel, i],
                                        transform=ccrs.PlateCarree(), cmap='plasma', shading = 'auto')
            axes[0].set_title('Ground truth')
            axes[0].gridlines(linewidth=1, color='black', alpha=0.05)

            im1 = axes[1].pcolormesh(longitude_grid1, latitude_grid1, preds[0][0, ..., channel, i],
                                        transform=ccrs.PlateCarree(), cmap='plasma', shading = 'auto')
            axes[1].gridlines(linewidth=1, color='black', alpha=0.05)
            axes[1].set_title('Reconstruction')

            im2 = axs[0].pcolormesh(longitude_grid2, latitude_grid2, truths[1][0, ..., channel, i],
                                        transform=ccrs.PlateCarree(), cmap='plasma', shading = 'auto')
            axs[0].set_title('Ground truth')
            axs[0].gridlines(linewidth=1, color='black', alpha=0.05)

            im3 = axs[1].pcolormesh(longitude_grid2, latitude_grid2, preds[1][0, ..., channel, i],
                                        transform=ccrs.PlateCarree(), cmap='plasma', shading = 'auto')
            axs[1].gridlines(linewidth=1, color='black', alpha=0.05)
            axs[1].set_title('Reconstruction')

            im4 = ax[0].pcolormesh(longitude_grid3, latitude_grid3, truths[2][0, ..., channel, i],
                                        transform=ccrs.PlateCarree(), cmap='plasma', shading = 'auto')
            ax[0].set_title('Ground truth')
            ax[0].gridlines(linewidth=1, color='black', alpha=0.05)

            im5 = ax[1].pcolormesh(longitude_grid3, latitude_grid3, preds[2][0, ..., channel, i],
                                        transform=ccrs.PlateCarree(), cmap='plasma', shading = 'auto')
            ax[1].gridlines(linewidth=1, color='black', alpha=0.05)
            ax[1].set_title('Reconstruction')
            ims.append([im0, im1, im2, im3, im4, im5])

    ani = animation.ArtistAnimation(fig, ims, interval=150, blit=True,
                                    repeat_delay=1000)

    ani.save(os.path.join(plot_dir, title),
            dpi=300)
    
def get_grid(mask_prob, title):

    num_points = 128*256
    coords = grid_te[0, ..., 0, 0].numpy()
    perm = torch.randperm(num_points)
    perm = perm[: int(mask_prob * len(perm))].clone().sort()[0]
    coords = coords.flatten()
    coords[perm] = 1
    coords = coords == 1
    coords = coords.reshape(128, 256)
    print(np.unique(coords, return_counts = True))
    
    proj = ccrs.Orthographic(*(-10, 30))
    fig, ax = plt.subplots(1, 1, figsize = (15, 8), subplot_kw={'projection': proj})
    latitude = torch.linspace(90.0, -90.0, 128)
    longitude = torch.linspace(0.0, 360.0 - (360.0 / 256), 256)
    longitude_grid, latitude_grid = torch.meshgrid(longitude, latitude, indexing="xy")

    img = ax.pcolormesh(longitude_grid, latitude_grid, coords,
                        transform=ccrs.PlateCarree(), shading = 'auto')
    ax.gridlines(linewidth=1, color='black', alpha=0.1)
    plt.show()
    fig.savefig(os.path.join('/home/kassai/code/coral/visualizations/table/', title), format='png')

def plot_globe(truth,  timestep, title, channel = 1):
    proj = ccrs.Orthographic(*(-10, 30))
    fig, ax = plt.subplots(1, 1, figsize = (15, 8), subplot_kw={'projection': proj})
    latitude = torch.linspace(90.0, -90.0, 128)
    longitude = torch.linspace(0.0, 360.0 - (360.0 / 256), 256)
    longitude_grid, latitude_grid = torch.meshgrid(longitude, latitude, indexing="xy")

    img = ax.pcolormesh(longitude_grid, latitude_grid, truth[0, ..., channel, timestep],
                        transform=ccrs.PlateCarree(), cmap='twilight', shading = 'auto')
    ax.gridlines(linewidth=1, color='black', alpha=0.1)

    plt.show()
    #fig.savefig(os.path.join('/home/kassai/code/coral/visualizations/table/5%/', title), format='png')

In [60]:
#plot_globe(preds[0], 39, 'timestep_39.png', 1)

In [61]:
#gif_baselines('/home/kassai/code/coral/visualizations/', 'predictions_coral_5%.gif', preds[1:], truths[1:], channel = 1, view = (-10, 45))

In [62]:
class DetailedMSE():
    def __init__(self, keys, dataset_name="shallow-water-dino", mode="train", n_trajectories=256):
        self.keys = keys
        self.mode = mode
        self.dataset_name = dataset_name
        self.n_trajectories = n_trajectories
        self.reset_dic()

    def reset_dic(self):
        dic = {}
        for key in self.keys:
            dic[f"{key}_{self.mode}_mse"] = 0
        self.dic = dic

    def aggregate(self, u_pred, u_true):
        n_samples = u_pred.shape[0]
        for key in self.keys:
            idx = KEY_TO_INDEX[self.dataset_name][key]
            self.dic[f"{key}_{self.mode}_mse"] += (
                (u_pred[..., idx, :] - u_true[..., idx, :])**2).mean()*n_samples

    def get_dic(self):
        dic = self.dic
        for key in self.keys:
            dic[f"{key}_{self.mode}_mse"] /= self.n_trajectories
        return self.dic 
    
if multichannel:
    detailed_train_mse = DetailedMSE(list(KEY_TO_INDEX[dataset_name].keys()),
                                        dataset_name,
                                        mode="train",
                                        n_trajectories=ntrain)
    detailed_train_eval_mse = DetailedMSE(list(KEY_TO_INDEX[dataset_name].keys()),
                                        dataset_name,
                                        mode="train_extra",
                                        n_trajectories=ntrain)
    detailed_test_mse = DetailedMSE(list(KEY_TO_INDEX[dataset_name].keys()),
                                    dataset_name,
                                    mode="test",
                                    n_trajectories=ntest)
else:
    detailed_train_mse = None
    detailed_train_eval_mse = None
    detailed_test_mse = None
    
# Eval Coral on train and test set

pred_train_inter_mse, code_train_inter_mse, pred_train_extra_mse, code_train_extra_mse, pred_train_mse, detailed_train_eval_mse = batch_eval_loop(
    model, inr, train_extra_loader,
    timestamps_test, detailed_train_eval_mse,
    ntrain, multichannel, z_mean, z_std,
    dataset_name, T_train
)
if T_train != T_test:
    pred_test_inter_mse, code_test_inter_mse, pred_test_extra_mse, code_test_extra_mse, pred_test_mse, detailed_test_mse = batch_eval_loop(
        model, inr, test_loader,
        timestamps_test, detailed_test_mse,
        ntest, multichannel, z_mean, z_std,
        dataset_name, T_train
    )
elif T_train == T_test:
    pred_test_mse, code_test_mse, detailed_test_mse = batch_eval_loop(
        model, inr, test_loader,
        timestamps_test, detailed_test_mse,
        ntest, multichannel, z_mean, z_std,
        dataset_name, None
    )

print("pred_train_inter_mse : ", pred_train_inter_mse.item())
print('pred_train_extra_mse :' , pred_train_extra_mse.item())

print("pred_test_inter_mse : ", pred_test_inter_mse.item())
print('pred_test_extra_mse :' , pred_test_extra_mse.item())