In [1]:
from typing import Sequence, Union
import os
from tqdm import tqdm
import joblib

import numpy as np
import pandas as pd
import xarray as xr
from scipy import linalg

import matplotlib.pyplot as plt
import cartopy
import cartopy.crs as ccrs

import torch

import src
from src.attrs import PATHS, GLOBALS
from src import utils

from src.data import loading
from src.train import datasets, losses
from src.models import base, koopman_autoencoder, cnn
from src.tools import plot

  from pandas.core.computation.check import NUMEXPR_INSTALLED


In [2]:
# Configs
plt.style.use('custom.mplstyle')
torch.set_grad_enabled(False);
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cuda')

In [73]:
# Other globals and functions
MAX_SAMPLES = None

def weighted_mse(X_true, X_pred, weights):
    mse = torch.mean(
        torch.sum(weights**2 * (X_pred - X_test)**2, dim=(2,3)) / weights.sum(),
        dim=0
    )
    return mse.item()

def make_latex_table(df_style):
    s = df_style.to_latex()
    s_split = s.split('\n')
    for i, line in enumerate(s_split):
        if i > 2:
            line_split = line.split('&')
            for j, cell in enumerate(line_split):
                line_split[j] = (
                    cell
                    .replace('\\background-color#', '\\cellcolor[HTML]{')
                    .replace(' \\color#f1f1f1 ', '} \\textcolor{white}{')
                    .replace(" \\\\", "} \\\\")
                    .replace(" \\color#000000 ", "} \\textcolor{black}{")
                )
                
                if j > 0 and j < len(line_split) - 1:
                    line_split[j] = line_split[j] + "}"
                    line_split[j] = line_split[j].replace(" }", "} ")
            line = "&".join(line_split)
            s_split[i] = line
    output = "\n".join(s_split)
    output = output.replace("_", "\_")
    print(output)

## Reconstruction metrics by subproject

In [5]:
# Data
D = 20
subprojects = [
    'cnn_pacific_daily_subsampled',
    'cnn_north_atlantic_daily_subsampled',
    'cnn_pacific_monthly',
    'cnn_north_atlantic_monthly'
]

mse = xr.DataArray(
    np.empty((4,3)),
    dims=('subproject', 'model'),
    coords={'subproject': subprojects, 'model': ['PCA', 'CAE', 'KAE']}
)

for j, subproject in enumerate(subprojects):
    X_test = xr.open_dataarray(os.path.join(PATHS[subproject], 'X_8.nc'))
    X_test = X_test[0:MAX_SAMPLES, 0:1, ...]
    X_test = torch.from_numpy(X_test.values)
    weights = torch.load(os.path.join(PATHS[subproject], 'weights.pt'))
    mask = xr.open_dataarray(os.path.join(PATHS[subproject], 'mask.nc'))
    areas = xr.open_dataarray(os.path.join(PATHS['grid'], 'areas.nc'))

    # Load models
    pca = joblib.load(os.path.join(PATHS[subproject], 'pca', f'pca_ssh_{D}.joblib'))
    cae = base.load_model_from_yaml(os.path.join(PATHS[subproject], 'cae', f'cae_ssh.{D}'))
    kae = base.load_model_from_yaml(os.path.join(PATHS[subproject], 'kae', f'kae_ssh.{D}'))
    cae = cae.to(DEVICE);
    kae = kae.to(DEVICE);
    
    # PCA MSE
    X_shape = X_test.shape
    X_weighted = X_test * weights
    X_weighted = X_weighted.view(X_weighted.shape[0], -1)
    z = pca.transform(X_weighted)
    X_pred = pca.inverse_transform(z)
    X_pred = torch.from_numpy(X_pred)
    X_pred = X_pred.view(*X_shape)
    X_pred = X_pred / weights
    X_pred = torch.where(weights != 0, X_pred, 0)
    mse[j, 0] = weighted_mse(X_test, X_pred, weights)

    # CAE MSE
    weights.to(DEVICE)
    X_pred = torch.zeros_like(X_test)
    for i, sample in enumerate(tqdm(X_test)):
        sample = sample.to(DEVICE)
        sample = sample.view(1, *sample.shape)
        X_pred[i, :] = cae(sample)
    mse[j, 1] = weighted_mse(X_test, X_pred, weights)

    # KAE MSE
    X_pred = torch.zeros_like(X_test)
    for i, sample in enumerate(tqdm(X_test)):
        sample = sample.to(DEVICE)
        sample = sample.view(1, *sample.shape)
        X_pred[i, :] = kae.autoencoder(sample)
    mse[j, 2] = weighted_mse(X_test, X_pred, weights)

100%|██████████| 91615/91615 [01:52<00:00, 811.45it/s]
100%|██████████| 91615/91615 [01:55<00:00, 794.12it/s]
100%|██████████| 91615/91615 [01:53<00:00, 810.59it/s]
100%|██████████| 91615/91615 [01:53<00:00, 807.09it/s]
100%|██████████| 3013/3013 [00:03<00:00, 812.99it/s]
100%|██████████| 3013/3013 [00:03<00:00, 833.89it/s]
100%|██████████| 3013/3013 [00:03<00:00, 836.65it/s]
100%|██████████| 3013/3013 [00:03<00:00, 852.42it/s]


In [8]:
mse

#### Save output to dataarray for later

In [83]:
subprojects = [
    'cnn_pacific_daily_subsampled',
    'cnn_north_atlantic_daily_subsampled',
    'cnn_pacific_monthly',
    'cnn_north_atlantic_monthly'
]

mse = xr.DataArray(
    np.array([
        [0.19147219, 0.06475546, 0.16704509, 0.0821541 ],
        [0.18455707, 0.06346036, 0.16064069, 0.08647706],
        [0.19773802, 0.07815582, 0.23116536, 0.13481773]
    ]),
    dims=('model', 'subproject'),
    coords={'model': ['PCA', 'CAE', 'KAE'], 'subproject': subprojects}
)
mse

In [84]:
df = mse.to_pandas()
df.columns = ['Pacific daily', 'Atlantic daily', 'Pacific monthly', 'Atlantic monthly']
df

Unnamed: 0_level_0,Pacific daily,Atlantic daily,Pacific monthly,Atlantic monthly
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
PCA,0.191472,0.064755,0.167045,0.082154
CAE,0.184557,0.06346,0.160641,0.086477
KAE,0.197738,0.078156,0.231165,0.134818


In [85]:
df_style = df.style.format("{:.3f}")
print(df_style.to_latex())
df_style

\begin{tabular}{lrrrr}
 & Pacific daily & Atlantic daily & Pacific monthly & Atlantic monthly \\
model &  &  &  &  \\
PCA & 0.191 & 0.065 & 0.167 & 0.082 \\
CAE & 0.185 & 0.063 & 0.161 & 0.086 \\
KAE & 0.198 & 0.078 & 0.231 & 0.135 \\
\end{tabular}



Unnamed: 0_level_0,Pacific daily,Atlantic daily,Pacific monthly,Atlantic monthly
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
PCA,0.191,0.065,0.167,0.082
CAE,0.185,0.063,0.161,0.086
KAE,0.198,0.078,0.231,0.135


Percent changes

In [72]:
pct_change = ((df - df.iloc[0]) / df.iloc[0])
pct_change.style.format("{:+4.2%}")

Unnamed: 0_level_0,Pacific daily,Atlantic daily,Pacific monthly,Atlantic monthly
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
PCA,+0.00%,+0.00%,+0.00%,+0.00%
CAE,-3.61%,-2.00%,-3.83%,+5.26%
KAE,+3.27%,+20.69%,+38.39%,+64.10%


## Reconstruction metrics by number of dimensions

In [23]:
# Data
SUBPROJECT = 'cnn_pacific_daily_subsampled'
X_test = xr.open_dataarray(os.path.join(PATHS[SUBPROJECT], 'X_8.nc'))
X_test = X_test[0:MAX_SAMPLES, 0:1, ...]
X_test = torch.from_numpy(X_test.values)
weights = torch.load(os.path.join(PATHS[SUBPROJECT], 'weights.pt'))
mask = xr.open_dataarray(os.path.join(PATHS[SUBPROJECT], 'mask.nc'))
areas = xr.open_dataarray(os.path.join(PATHS['grid'], 'areas.nc'))

---

In [24]:
D_list = [10, 20, 30, 40]

mse = xr.DataArray(
    np.empty((4,3)),
    dims=('D', 'model'),
    coords={'D': D_list, 'model': ['PCA', 'CAE', 'KAE']}
)

for j, D in enumerate(D_list):
    # Load models
    pca = joblib.load(os.path.join(PATHS[SUBPROJECT], 'pca', f'pca_ssh_{D}.joblib'))
    cae = base.load_model_from_yaml(os.path.join(PATHS[SUBPROJECT], 'cae', f'cae_ssh.{D}'))
    kae = base.load_model_from_yaml(os.path.join(PATHS[SUBPROJECT], 'kae', f'kae_ssh.{D}'))
    cae = cae.to(DEVICE);
    kae = kae.to(DEVICE);
    
    # PCA MSE
    X_shape = X_test.shape
    X_weighted = X_test * weights
    X_weighted = X_weighted.view(X_weighted.shape[0], -1)
    z = pca.transform(X_weighted)
    X_pred = pca.inverse_transform(z)
    X_pred = torch.from_numpy(X_pred)
    X_pred = X_pred.view(*X_shape)
    X_pred = X_pred / weights
    X_pred = torch.where(weights != 0, X_pred, 0)
    mse[j, 0] = weighted_mse(X_test, X_pred, weights)

    # CAE MSE
    weights.to(DEVICE)
    X_pred = torch.zeros_like(X_test)
    for i, sample in enumerate(tqdm(X_test)):
        sample = sample.to(DEVICE)
        sample = sample.view(1, *sample.shape)
        X_pred[i, :] = cae(sample)
    mse[j, 1] = weighted_mse(X_test, X_pred, weights)

    # KAE MSE
    X_pred = torch.zeros_like(X_test)
    for i, sample in enumerate(tqdm(X_test)):
        sample = sample.to(DEVICE)
        sample = sample.view(1, *sample.shape)
        X_pred[i, :] = kae.autoencoder(sample)
    mse[j, 2] = weighted_mse(X_test, X_pred, weights)

100%|██████████| 91615/91615 [01:59<00:00, 768.72it/s]
100%|██████████| 91615/91615 [01:56<00:00, 784.98it/s]
100%|██████████| 91615/91615 [01:58<00:00, 775.83it/s]
100%|██████████| 91615/91615 [01:57<00:00, 781.32it/s]
100%|██████████| 91615/91615 [01:56<00:00, 785.65it/s]
100%|██████████| 91615/91615 [01:59<00:00, 765.14it/s]
100%|██████████| 91615/91615 [01:52<00:00, 816.09it/s]
100%|██████████| 91615/91615 [01:56<00:00, 787.07it/s]


In [25]:
mse

In [26]:
df = mse.to_pandas().transpose()

D,10,20,30,40
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
PCA,0.308147,0.191472,0.136863,0.105674
CAE,0.300906,0.184557,0.131369,0.103439
KAE,0.310973,0.197738,0.148911,0.119119


In [76]:
mse = xr.DataArray(
    np.array([
        [0.30814709, 0.19147219, 0.13686294, 0.10567361],
        [0.30090588, 0.18455707, 0.13136891, 0.10343935],
        [0.31097326, 0.19773802, 0.14891145, 0.11911869]
    ]),
    dims=('model', 'D'),
    coords={'model': ['PCA', 'CAE', 'KAE'], 'D': [10, 20, 30, 40]}
)
mse

In [77]:
df = mse.to_pandas()
df

D,10,20,30,40
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
PCA,0.308147,0.191472,0.136863,0.105674
CAE,0.300906,0.184557,0.131369,0.103439
KAE,0.310973,0.197738,0.148911,0.119119


In [81]:
df_style = df.style.format(precision=3).background_gradient(vmin=0.05, vmax=0.35, cmap='magma_r')
df_style

D,10,20,30,40
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
PCA,0.308,0.191,0.137,0.106
CAE,0.301,0.185,0.131,0.103
KAE,0.311,0.198,0.149,0.119


In [80]:
pct_change = (df - df.iloc[0]) / df.iloc[0]
pct_change.style.format("{:.2%}")

D,10,20,30,40
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
PCA,0.00%,0.00%,0.00%,0.00%
CAE,-2.35%,-3.61%,-4.01%,-2.11%
KAE,0.92%,3.27%,8.80%,12.72%


In [68]:
make_latex_table(df_style)

\begin{tabular}{lrrrr}
D & 10 & 20 & 30 & 40 \\
model &  &  &  &  \\
PCA & \cellcolor[HTML]{21114e} \textcolor{white}{0.3081} & \cellcolor[HTML]{c23b75} \textcolor{white}{0.1915} & \cellcolor[HTML]{f8745c} \textcolor{white}{0.1369} & \cellcolor[HTML]{fea772} \textcolor{black}{0.1057} \\
CAE & \cellcolor[HTML]{2a115c} \textcolor{white}{0.3009} & \cellcolor[HTML]{cc3f71} \textcolor{white}{0.1846} & \cellcolor[HTML]{fa7d5e} \textcolor{white}{0.1314} & \cellcolor[HTML]{feaa74} \textcolor{black}{0.1034} \\
KAE & \cellcolor[HTML]{1e1149} \textcolor{white}{0.3110} & \cellcolor[HTML]{b83779} \textcolor{white}{0.1977} & \cellcolor[HTML]{f2625d} \textcolor{white}{0.1489} & \cellcolor[HTML]{fd9266} \textcolor{black}{0.1191} \\
\end{tabular}

