In [None]:
import sys
sys.path.append('..')
from pathlib import Path
import numpy as np
import h5py
import jax
import jax.numpy as jnp
import yaml
import matplotlib.pyplot as plt
plt.style.use('../flowrec/utils/ppt.mplstyle')

import flowrec.training_and_states as state_utils
import flowrec.data as data_utils
import flowrec.physics_and_derivatives as derivatives

from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable
from scipy.interpolate import RBFInterpolator
from flowrec.utils import simulation
from flowrec import losses
from flowrec.utils.py_helper import slice_from_tuple
from flowrec.utils.system import set_gpu
from flowrec.utils.myplots import truegrey, create_custom_colormap
cmap = create_custom_colormap('trafficlight')
set_gpu(0,0.5)
# jax.config.update('jax_platform_name', 'cpu')

In [None]:
# results_dir = Path('../local_results/3dkol/kmplffmse250217185002') 
results_dir = Path('../local_results/3dkol/kmplffmse250218181248') 

with open(Path(results_dir,'config.yml'),'r') as f:
    cfg = yaml.load(f, Loader=yaml.UnsafeLoader)

cfg.data_config.update({'data_dir':'.'+cfg.data_config.data_dir})
pressure_inlet_plane = cfg.data_config['pressure_inlet_slice'][:-1] + ((None,None,None),)
# cfg.data_config.update({'data_dir':'../local_data/kolmogorov/dim2_re34_k32_f4_dt1_grid128_25619.h5'})
print(cfg.data_config.data_dir)
_result_file = Path(results_dir,'results.h5')
if _result_file.exists(): 
    inprogress = False
    with h5py.File(_result_file,'r') as hf:
        loss_train = np.array(hf.get("loss_train"))
        loss_val = np.array(hf.get("loss_val"))
        loss_div = np.array(hf.get("loss_div"))
        loss_momentum = np.array(hf.get("loss_momentum"))
        loss_sensors = np.array(hf.get("loss_sensors"))
else:
    inprogress = True

datacfg = cfg.data_config
traincfg = cfg.train_config

In [None]:
print(traincfg.to_dict())
print(datacfg.to_dict())
print(cfg.model_config.to_dict())

In [None]:
if not inprogress:
    plt.figure()
    fig, axes = plt.subplots(2,1)
    axes[0].plot(loss_train, label='Training loss')
    axes[0].plot(loss_val, label='Validation loss')
    axes[0].set_yscale('log')
    # axes[0].set_ylim([0.001,10])
    axes[0].legend()
    axes[1].plot(loss_momentum, label='Momentum loss')
    axes[1].plot(loss_sensors, label='Sensor loss')
    axes[1].set_yscale('log')
    # axes[1].set_ylim([0.0001,2])
    axes[1].legend()
    fig.show()

# Load Data

## Load training and validation data

In [None]:
data, datainfo = cfg.case.dataloader(datacfg)
print(data.keys())
if datacfg.shuffle:
    idx_shuffle, idx_unshuffle = data_utils.shuffle_with_idx(np.sum(datacfg.train_test_split), rng = np.random.default_rng(datacfg.randseed))

In [None]:
_keys_to_exclude = [
    'u_train_clean',
    'u_val_clean',
    'train_minmax',
    'val_minmax',
    'u_train',
    'u_val',
    'inn_train',
    'inn_val'
]
observe_kwargs = {key: value for key, value in data.items() if key not in _keys_to_exclude}
take_observation, insert_observation = cfg.case.observe(
    datacfg,
    example_pred_snapshot = data['u_train'][0,...],
    example_pin_snapshot = data['inn_train'][0,...],
    **observe_kwargs
)
observed_train, train_minmax = take_observation(data['u_train'], init=True)
observed_val, val_minmax = take_observation(data['u_val'], init=True)
data.update({
    'y_train':observed_train,
    'y_val':observed_val,
    'train_minmax':train_minmax,
    'val_minmax':val_minmax 
})
print(data.keys())

In [None]:
print(np.squeeze(data['inn_train']).ndim, data['u_train'].shape, data['y_train'].shape)

## Load test data from the last 


In [None]:
from flowrec.utils.simulation import read_data_kolsol

In [None]:
with open(datacfg.data_dir) as f:
    datasets_path = [Path(datacfg.data_dir).parent/line.rstrip() for line in f]
    testdata_path = datasets_path[-1]
print(testdata_path.exists(), testdata_path)
u_test, _, _ = read_data_kolsol(testdata_path)

In [None]:
y_test = take_observation(u_test)
measured_shape = (-1,)+y_test.shape[1:]
inn_loc = slice_from_tuple(datacfg.pressure_inlet_slice)
s_pressure = (np.s_[:],) + inn_loc + (np.s_[-1],)
inn_test = u_test[s_pressure].reshape((y_test.shape[0],-1))
print(inn_test.shape)

In [None]:
def reorganise(train, val, idx):
    return np.concatenate([train,val], axis=0)[idx]

In [None]:
## shuffle
if datacfg.shuffle:

    # data_reorganised = np.concatenate([data['y_train'],data['y_val']],axis=0)[idx_unshuffle]
    data_reorganised = reorganise(data['y_train'],data['y_val'],idx_unshuffle)
    fig, axes = plt.subplots(3,2,figsize=(10,5),width_ratios=[0.7,0.3],sharey=True)

    for i in range(3):

        axes[i,0].plot(datacfg.dt*np.arange(len(idx_shuffle)),data_reorganised[:,10,10,i],color=cmap(i),zorder=1)
        axes[i,0].scatter(datacfg.dt*idx_shuffle[datacfg.train_test_split[0]:np.sum(datacfg.train_test_split[:2])], data['y_val'][:,10,10,i],color='k',s=3,zorder=2)
        axes[i,0].vlines(datacfg.dt*np.cumsum(data['sets_index'][:-1]),-2,2,'r',linestyle=':')
        axes[i,0].set_xlim([0,datacfg.dt*len(idx_shuffle)])

        axes[i,1].plot(datacfg.dt*np.arange(y_test.shape[0]), y_test[:,10,10,i], color=cmap(i))

    # axes[1,0].plot(datacfg.dt*np.arange(len(idx_shuffle)),data_reorganised[:,10,10,1],color=cmap(1),zorder=1)
    # axes[1,0].scatter(datacfg.dt*idx_shuffle[datacfg.train_test_split[0]:np.sum(datacfg.train_test_split[:2])], data['y_val'][:,10,10,1],color='k',s=3,zorder=2)

    # axes[2,0].plot(datacfg.dt*np.arange(len(idx_shuffle)),data_reorganised[:,10,10,2],color=cmap(2),zorder=1)
    # axes[2,0].scatter(datacfg.dt*idx_shuffle[datacfg.train_test_split[0]:np.sum(datacfg.train_test_split[:2])], data['y_val'][:,10,10,2],color='b',s=3,zorder=2)

In [None]:
fig, axes = plt.subplots(2,3,sharex=True, sharey=True)
axes = axes.flatten()
for i in range(5):
    i1 = int(np.sum(data['sets_index'][:i]))
    i2 = int(np.sum(data['sets_index'][:i+1]))
    axes[i].plot(datacfg.dt*np.arange(i2-i1),data['y_train'][i1:i2,10,10,0],color=cmap(0))
    axes[i].plot(datacfg.dt*np.arange(i2-i1),data['y_train'][i1:i2,10,10,1],color=cmap(1))
    axes[i].plot(datacfg.dt*np.arange(i2-i1),data['y_train'][i1:i2,10,10,2],color=cmap(2))
axes[4].plot(datacfg.dt*np.arange(i2-i1,i2-i1+data['y_val'].shape[0]), data['y_val'][:,10,10,0], linestyle=':', color=cmap(0))
axes[4].plot(datacfg.dt*np.arange(i2-i1,i2-i1+data['y_val'].shape[0]), data['y_val'][:,10,10,1], linestyle=':', color=cmap(1))
axes[4].plot(datacfg.dt*np.arange(i2-i1,i2-i1+data['y_val'].shape[0]), data['y_val'][:,10,10,2], linestyle=':', color=cmap(2))
axes[5].plot(datacfg.dt*np.arange(y_test.shape[0]), y_test[:,10,10,0], color=cmap(0), linestyle='--')
axes[5].plot(datacfg.dt*np.arange(y_test.shape[0]), y_test[:,10,10,1], color=cmap(1), linestyle='--')
axes[5].plot(datacfg.dt*np.arange(y_test.shape[0]), y_test[:,10,10,2], color=cmap(2), linestyle='--')
plt.show()

In [None]:
_empty_data = jnp.zeros_like(data['u_train'][[0],...])
_empty_pressure = _empty_data.at[s_pressure].set(inn_test[[0],jnp.newaxis,:,jnp.newaxis])[0,...,-1]
z_plane = int(np.arange(64)[inn_loc[-1]])
_empty_data = insert_observation(_empty_data, y_test[[0],...])[0,...,0]

x, y, z = np.indices(_empty_data.shape)

# Flatten the arrays to use in scatter plot
x, y, z, values = x.flatten(), y.flatten(), z.flatten(), _empty_data.flatten()

# Create 3D scatter plot
fig = plt.figure(figsize=(8, 4))
ax = fig.add_subplot(121, projection='3d')
ax.view_init(elev=30, azim=120)

# Scatter plot where color is based on cube values
sc = ax.scatter(x, y, z, c=values, marker='o',)
# Add colorbar to show value scale
cbar = plt.colorbar(sc, ax=ax)
cbar.set_label("u1")
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Z")

ax = fig.add_subplot(122, projection='3d')
ax.view_init(elev=30, azim=140)
sc = ax.scatter(x, y, z, c=_empty_pressure.flatten(), marker='o')
cbar = plt.colorbar(sc, ax=ax)
cbar.set_label("p")


# Labels and title
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Z")

fig.suptitle('Sensor location')
plt.show()


# Load model

In [None]:
from flowrec.training_and_states import restore_trainingstate
from flowrec.data import unnormalise_group, normalise

In [None]:
prep_data, make_model = cfg.case.select_model(datacfg=datacfg, mdl=cfg.model_config, traincfg=traincfg)
data = prep_data(data, datainfo)
inn_train = data['inn_train']
inn_val = data['inn_val']
y_train = data['y_train']
y_val = data['y_val']
_mdl_output_shape = y_train.shape[1:]
y_test = np.reshape(y_test, (-1,)+_mdl_output_shape)
mdl = make_model(cfg.model_config)
state = restore_trainingstate(results_dir,'state')
jax.tree_util.tree_map(lambda x: print(x.shape),state.params)
param_count = sum(x.size for x in jax.tree_util.tree_leaves(state.params))
print(f'Total number of parameters {param_count}')
print(list(state.params))

# Results

In [None]:
pred_train = []
_t = 0
while _t<inn_train.shape[0]:
    if (_t + 500) < inn_train.shape[0]:
        pred_train.append(
            mdl.predict(state.params, inn_train[_t:_t+500,:])
        )
    else:
        pred_train.append(
            mdl.predict(state.params, inn_train[_t:,:])
        )
    _t = _t + 500
pred_train = jnp.concatenate(pred_train, axis=0)

if cfg.data_config.normalise:
    raise NotImplementedError
print(pred_train.shape)
pred_train = pred_train.reshape(measured_shape)
y_train = y_train.reshape(measured_shape)

In [None]:
pred_val = mdl.predict(state.params, inn_val)
if cfg.data_config.normalise:
    raise NotImplementedError
print(pred_val.shape)
pred_val = pred_val.reshape(measured_shape)
y_val = y_val.reshape(measured_shape)

pred_test = mdl.predict(state.params, inn_test)
if cfg.data_config.normalise:
    raise NotImplementedError
print(pred_test.shape)
pred_test = pred_test.reshape(measured_shape)
y_test = y_test.reshape(measured_shape)

In [None]:
l_mse = losses.mse(pred_train, y_train)
print(f'MSE of training slice: {l_mse:.3e}')
l_mse_val = losses.mse(pred_val, y_val)
print(f'MSE of validation slice: {l_mse_val:.3e}')
l_mse_test = losses.mse(pred_test, y_test)
print(f'MSE of testing slice: {l_mse_test:.3e}')

In [None]:
plt_step = 560
component = 0
vmin = y_train[::plt_step,:,:,component].min()
vmax = y_train[::plt_step,:,:,component].max()
ylims = [inn_train.min(), inn_train.max()]
fig, axes = plt.subplots(3,5,figsize=(12,6),height_ratios=(0.4,0.4,0.2))
fig.suptitle(f'Training Ref (top), model output (middle) and inlet pressure on observed plane z={z_plane}')
for i in range(5):
    im0 = axes[0,i].imshow(y_train[i*plt_step,:,:,component].T, vmin=vmin, vmax=vmax)
    divider = make_axes_locatable(axes[0,i])
    ax0 = divider.append_axes("right", size="5%", pad=0.0) 
    cbar = plt.colorbar(im0,cax=ax0) 
    axes[0,i].set(xlabel='x')
    im1 = axes[1,i].imshow(pred_train[i*plt_step,:,:,component].T, vmin=vmin, vmax=vmax)
    divider = make_axes_locatable(axes[1,i])
    ax1 = divider.append_axes("right", size="5%", pad=0.0) 
    cbar = plt.colorbar(im1,cax=ax1) 
    axes[1,i].set(xlabel='x')
    axes[2,i].plot(inn_train[i*plt_step,:])
    axes[2,i].set(xlabel=f'y at x=0 t={i*plt_step}',ylim=ylims)
axes[0,0].set(ylabel='y')
axes[1,0].set(ylabel='y')
# fig.tight_layout()
plt.show()

In [None]:
plt_step = 45
# vmin = y_val[::plt_step,:,:,component].min()
# vmax = y_val[::plt_step,:,:,component].max()
# ylims = [inn_val.min(), inn_val.max()]
fig, axes = plt.subplots(3,5,figsize=(12,6),height_ratios=(0.4,0.4,0.2))
fig.suptitle(f'Validation ref (top), model output (middle) and inlet pressure on observed plane z={z_plane}')
for i in range(5):
    im0 = axes[0,i].imshow(y_val[i*plt_step,:,:,component].T, vmin=vmin, vmax=vmax)
    divider = make_axes_locatable(axes[0,i])
    ax0 = divider.append_axes("right", size="5%", pad=0.0) 
    cbar = plt.colorbar(im0,cax=ax0) 
    axes[0,i].set(xlabel='x')
    im1 = axes[1,i].imshow(pred_val[i*plt_step,:,:,component].T, vmin=vmin, vmax=vmax)
    divider = make_axes_locatable(axes[1,i])
    ax1 = divider.append_axes("right", size="5%", pad=0.0) 
    cbar = plt.colorbar(im1,cax=ax1) 
    axes[1,i].set(xlabel='x')
    axes[2,i].plot(inn_val[i*plt_step,:])
    axes[2,i].set(xlabel=f'y at x=0 t={i*plt_step}',ylim=ylims)
axes[0,0].set(ylabel='y')
axes[1,0].set(ylabel='y')
# fig.tight_layout()
plt.show()

fig, axes = plt.subplots(3,5,figsize=(12,6),height_ratios=(0.4,0.4,0.2))
fig.suptitle(f'Testing ref (top), model output (middle) and inlet pressure on observed plane z={z_plane}')
for i in range(5):
    im0 = axes[0,i].imshow(y_test[i*plt_step,:,:,component].T, vmin=vmin, vmax=vmax)
    divider = make_axes_locatable(axes[0,i])
    ax0 = divider.append_axes("right", size="5%", pad=0.0) 
    cbar = plt.colorbar(im0,cax=ax0) 
    axes[0,i].set(xlabel='x')
    im1 = axes[1,i].imshow(pred_test[i*plt_step,:,:,component].T, vmin=vmin, vmax=vmax)
    divider = make_axes_locatable(axes[1,i])
    ax1 = divider.append_axes("right", size="5%", pad=0.0) 
    cbar = plt.colorbar(im1,cax=ax1) 
    axes[1,i].set(xlabel='x')
    axes[2,i].plot(inn_test[i*plt_step,:])
    axes[2,i].set(xlabel=f'y at x=0 t={i*plt_step}',ylim=ylims)
axes[0,0].set(ylabel='y')
axes[1,0].set(ylabel='y')
# fig.tight_layout()
plt.show()

## Test on volume

In [None]:
inn_volume_index = slice_from_tuple(pressure_inlet_plane)
u_train = data['u_train'].reshape((-1,64,64,64,4))
inn_volume = np.squeeze(u_train[:,*inn_volume_index,-1])
print(inn_volume.shape)

In [None]:
mdl_predict_overz = jax.vmap(mdl.predict, (None,2),2)
pred_train_volume = mdl_predict_overz(state.params, inn_volume)
pred_train_volume = pred_train_volume.reshape([*measured_shape,64])
pred_train_volume = np.einsum('txyuz -> txyzu', pred_train_volume)

In [None]:
print(f'volume mse {losses.mse(pred_train_volume,u_train[...,:-1])}, measured plane {losses.mse(pred_train_volume[:,:,:,z_plane,:],u_train[:,:,:,z_plane,:-1])}')
print(f'volume relative loss {losses.relative_error(pred_train_volume,u_train[...,:-1])}, measured plane {losses.relative_error(pred_train_volume[:,:,:,z_plane,:],u_train[:,:,:,z_plane,:-1])}')
for r in [10,20,30]:
    print(f'volume relative loss for planes z_plane+-{r} {losses.relative_error(pred_train_volume[...,z_plane-r:z_plane+r,:],u_train[...,z_plane-r:z_plane+r,:-1])}')

In [None]:
plt_z = [10,25,30,34,42,60]
plt_t = 100
component = 0
vmin = u_train[::plt_step,:,:,plt_z,component].min()
vmax = u_train[::plt_step,:,:,plt_z,component].max()
ylims = [inn_volume[:,:,plt_z].min(), inn_volume[:,:,plt_z].max()]
fig, axes = plt.subplots(3,5,figsize=(12,6),height_ratios=(0.4,0.4,0.2))
fig.suptitle('Ref, prediction and inlet pressure')
for i in range(5):
    axes[0,i].imshow(u_train[plt_t,:,:,plt_z[i],component].T, vmin=vmin, vmax=vmax)
    axes[0,i].set(xlabel='x', ylabel='y')
    axes[1,i].imshow(pred_train_volume[plt_t,:,:,plt_z[i],component].T, vmin=vmin, vmax=vmax)
    axes[1,i].set(xlabel='x', ylabel='y')
    axes[2,i].plot(inn_volume[plt_t,:,plt_z[i]])
    axes[2,i].set(xlabel=f'y at x=0 t={plt_t}, z={plt_z[i]}',ylim=ylims)
fig.tight_layout()
plt.show()

### Statistics over volume

In [None]:
def plot_stats(pred, true):
    fig1 = plt.figure(figsize=(4,2))
    g = ImageGrid(fig1, 111, (1,2), cbar_mode='single')
    imref = g.axes_all[0].imshow(np.mean(true, axis=(0,3))[:,:,0].T)
    vmin, vmax = imref.get_clim()
    impred = g.axes_all[1].imshow(np.mean(pred, axis=(0,3))[:,:,0].T, vmin=vmin, vmax=vmax)
    g.cbar_axes[0].colorbar(imref)
    fig1.suptitle('ref and reconstructed averaged over z & time')

    fig2, axes = plt.subplots(1,4,figsize=(8,2))
    for i,ax in enumerate(axes[:3]):
        counts_true,bins_true = np.histogram(true[...,i].flatten()-np.mean(true[...,i].flatten()), density=True, bins=1000)
        ax.stairs(counts_true,bins_true,label='true',linewidth=3, color=truegrey,alpha=0.5)
        counts,bins= np.histogram(pred[...,i].flatten()-np.mean(pred[...,i].flatten()), density=True, bins=1000)
        ax.stairs(counts,bins,label='recons')
    spectrum_true, kbins = derivatives.get_tke(true-np.mean(true,axis=0), datainfo)
    spectrum, _ = derivatives.get_tke(pred-np.mean(pred,axis=0), datainfo)
    axes[3].loglog(kbins,spectrum_true,label='true',linewidth=3, color=truegrey,alpha=0.5)
    axes[3].loglog(kbins,spectrum,label='recons')
    axes[3].grid(which='both',axis='x')
    axes[3].legend()
    
    return (fig1,g), (fig2,axes)

def print_losses(pred, true):
    forcing = data['forcing']
    with jax.default_device(jax.devices('cpu')[0]):
        _momfield = [losses.momentum_residual_field(pred[i*80:(i+1)*80,...],datainfo,forcing=forcing) for i in range(10)]
        _momfield = jnp.concatenate(_momfield, axis=0)
        _momfield_ref = [losses.momentum_residual_field(true[i*80:(i+1)*80,...],datainfo,forcing=forcing) for i in range(10)]
        _momfield_ref = jnp.concatenate(_momfield_ref, axis=0)
        l_momentum = losses.mse(_momfield)
        l_momentum_ref = losses.mse(_momfield_ref)
    l_div = losses.divergence(pred[...,:-1],datainfo)
    l_div_ref = losses.divergence(true[...,:-1],datainfo)
    l_rel = losses.relative_error(pred, true)
    l_mse_slice = losses.mse(pred[...,z_plane,:-1],true[...,z_plane,:-1])

    print(f'ref momentum loss: {l_momentum_ref:.5f}, ref divergence loss: {l_div_ref:.5f}')
    print(f'pred momentum loss: {l_momentum:.5f}, pred divergence loss: {l_div:.5f}')
    print(f'Relative error of the domain {l_rel*100:.3f}%')
    print(f'Relative error of the domain close to the measured plane from z=20 to z=50 {100*losses.relative_error(pred[...,20:50,:], true[...,20:50,:]):.5f}%')
    print(f'MSE of the domain close to the measured plane from z=20 to z=50 {100*losses.mse(pred[...,20:50,:], true[...,20:50,:]):.5f}')
    print(f'MSE of the measured plane {l_mse_slice:5f}')

In [None]:
(fig1,_), (fig2,_) = plot_stats(pred_train_volume, u_train[...,:-1])
plt.show()