# 3D Rayleigh-Bénard: Equivariant Forecasting

In [5]:
%load_ext autoreload
%autoreload 2

import os
import sys

sys.path.append('..')

from IPython.display import Video, display
from ipywidgets import interact_manual, FloatSlider, IntSlider, BoundedIntText, Dropdown, SelectMultiple, Checkbox, Textarea
import glob
import json

from utils.model_building import build_autoencoder, build_forecaster, build_and_load_trained_model
from utils.evaluation import load_latent_sensitivity
from utils import visualization

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Hyperparameter Selection

### Autoencoder

In [None]:
conv_type_widget = Dropdown(options=['SteerableConv', 'SteerableConv3D', 'Conv', 'Conv3D'], value='SteerableConv')
simulation_widget = Dropdown(options=glob.glob('*.h5', root_dir='../data/datasets'))
rots_widget = IntSlider(min=4, max=16, step=4, value=4)
flips_widget = Checkbox(value=True)
encoder_channels_widget = Textarea()
pooling_layers_widget = Textarea(value='y,n,y,n,y')
latent_channels_widget = BoundedIntText(min=1, max=1000, value=32)
h_ksize_widget = Dropdown(options=[3, 5], value=5)
v_ksize_widget = Dropdown(options=[3, 5], value=5)
v_shares_widget = Textarea(value='1,1,1,1,1,1')
latent_h_ksize_widget = Dropdown(options=[3, 5], value=3)
latent_v_ksize_widget = Dropdown(options=[3, 5], value=3)
drop_rate_widget = FloatSlider(min=0, max=1, value=0.2)
nonlinearity_widget = Dropdown(options=['ReLU', 'ELU', 'LeakyReLU'], value='ELU')

def set_default_channels(*args):
    default_channels = {'SteerableConv': "6,12,12,24,24",
                        'SteerableConv3D': "20,40,40,80,80",
                        'Conv': "10,20,20,40,40",
                        'Conv3D': "30,60,60,118,118"}[conv_type_widget.value]
    encoder_channels_widget.value = default_channels
        
conv_type_widget.observe(set_default_channels)


@interact_manual.options(manual_name="Build model")(conv_type=conv_type_widget, simulation=simulation_widget,
                                                    rots=rots_widget, flips=flips_widget, encoder_channels=encoder_channels_widget,
                                                    pooling_layers=pooling_layers_widget, latent_channels=latent_channels_widget, 
                                                    h_ksize=h_ksize_widget, v_ksize=v_ksize_widget, 
                                                    latent_h_ksize=latent_h_ksize_widget, 
                                                    latent_v_ksize=latent_v_ksize_widget, 
                                                    v_shares=v_shares_widget,
                                                    drop_rate=drop_rate_widget, 
                                                    nonlinearity=nonlinearity_widget)
def show_patterns(conv_type, simulation, rots, flips, encoder_channels, pooling_layers, latent_channels, h_ksize, v_ksize, 
                  v_shares, latent_h_ksize, latent_v_ksize, drop_rate, nonlinearity):
    encoder_channels = [int(c.strip()) for c in encoder_channels.split(',')]
    v_shares = [int(c.strip()) for c in v_shares.split(',')]
    true_strings = ['1', 't', 'y', 'true', 'yes', 'p']
    pool_layers = [p.strip().lower() in true_strings for p in pooling_layers.split(',')]
    assert len(encoder_channels) == len(pool_layers)
    
    hps = {
        'simulation_name': simulation,
        'rots': rots,
        'flips': flips,
        'encoder_channels': encoder_channels,
        'latent_channels': latent_channels,
        'h_kernel_size': h_ksize,
        'v_kernel_size': v_ksize,
        'latent_h_kernel_size': latent_h_ksize,
        'latent_v_kernel_size': latent_v_ksize,
        'v_shares': v_shares,
        'drop_rate': drop_rate,
        'nonlinearity': nonlinearity,
        'pool_layers': pool_layers
    }
    
    model = build_autoencoder(conv_type, **hps)

    model.summary()

interactive(children=(Dropdown(description='model_type', options=('steerableCNN', 'steerable3DCNN', 'CNN', '3D…

### Forecaster

In [None]:
conv_type_widget = Dropdown(options=['SteerableConv', 'Conv3D'], value='SteerableConv')
autoencoder_widget = Dropdown(options=[])
rots_widget = IntSlider(min=4, max=16, step=4, value=4)
flips_widget = Checkbox(value=True)
lstm_channels_widget = Textarea(value='8,8')
residual_connection_widget = Checkbox(value=True)
h_ksize_widget = Dropdown(options=[3, 5], value=3)
v_ksize_widget = Dropdown(options=[3, 5], value=3)
drop_rate_widget = FloatSlider(min=0, max=1, value=0.2)
recurrent_drop_rate_widget = FloatSlider(min=0, max=1, value=0)
nonlinearity_widget = Dropdown(options=['ReLU', 'tanh', 'ELU'], value='tanh')
include_autoencoder_widget = Checkbox(value=False)
use_lstm_encoder_widget = Checkbox(value=True)
    
def update_autoencoders(*args):
    model_name = {'SteerableConv': f'{"D" if flips_widget.value else "C"}{rots_widget.value}cnn',
                  'SteerableConv3D': f'3D-{"D" if flips_widget.value else "C"}{rots_widget.value}cnn',
                  'Conv': 'cnn',
                  'Conv3D': '3Dcnn'}[conv_type_widget.value]
    trained_autoencoders = [os.path.join(*path.split('/')[:-1]) 
                          for path in glob.glob(f'{model_name}/**/epoch*.tar', root_dir='./trained_models/AE', recursive=True)]
    autoencoder_widget.options = sorted(trained_autoencoders)
    autoencoder_widget.value = autoencoder_widget.options[0]
        
conv_type_widget.observe(update_autoencoders)


@interact_manual.options(manual_name="Build model")(conv_type=conv_type_widget, autoencoder=autoencoder_widget,
                                                    rots=rots_widget, flips=flips_widget, lstm_channels=lstm_channels_widget,
                                                    h_ksize=h_ksize_widget, v_ksize=v_ksize_widget, 
                                                    drop_rate=drop_rate_widget, recurrent_drop_rate=recurrent_drop_rate_widget, 
                                                    nonlinearity=nonlinearity_widget, residual_connection=residual_connection_widget,
                                                    include_autoencoder=include_autoencoder_widget,
                                                    use_lstm_encoder=use_lstm_encoder_widget)
def show_patterns(conv_type, autoencoder, rots, flips, lstm_channels, h_ksize, v_ksize, drop_rate, recurrent_drop_rate,
                  nonlinearity, residual_connection, include_autoencoder, use_lstm_encoder):
    ae_model_name = os.path.join(*autoencoder.split('/')[:2])
    ae_train_name = os.path.join(*autoencoder.split('/')[2:])
    
    lstm_channels = [int(c.strip()) for c in lstm_channels.split(',')]
    
    hps = {
    'conv_type': conv_type,
    'ae_model_name': ae_model_name,
    'ae_train_name': ae_train_name,
    'h_kernel_size': h_ksize,
    'v_kernel_size': v_ksize,
    'drop_rate': drop_rate,
    'recurrent_drop_rate': recurrent_drop_rate,
    'nonlinearity': nonlinearity,
    'flips': flips,
    'rots': rots,
    'lstm_channels': lstm_channels,
    'parallel_ops': True,
    'residual_connection': residual_connection,
    'include_autoencoder': include_autoencoder,
    'use_lstm_encoder': use_lstm_encoder
}
    
    model = build_forecaster(models_dir='./trained_models', **hps)

    model.summary()

interactive(children=(Dropdown(description='model_type', options=('steerableCNN', '3DCNN'), value='steerableCN…

## Trained Model Summaries

In [5]:
trained_models_w_epoch = [os.path.join(*path.split('/')[:-1]) 
                          for path in glob.glob('**/epoch*.tar', root_dir='./trained_models', recursive=True)]
trained_models_widget = Dropdown(options=sorted(trained_models_w_epoch))

@interact_manual.options(manual_name="Build model")(trained_model=trained_models_widget)
def show_summary(trained_model, show_summary: bool = True, show_hps: bool = True):
    model_name = os.path.join(*trained_model.split('/')[:2])
    train_name = os.path.join(*trained_model.split('/')[2:])
    
    if show_summary:
        model = build_and_load_trained_model('trained_models', model_name, train_name, epoch=-1)
        model.summary()
    
    if show_hps:
        print('\nHyperparameters:')
        hp_file = os.path.join('trained_models', model_name, train_name, 'hyperparameters.json')
        if os.path.isfile(hp_file):
            with open(hp_file, 'r') as f:
                hps = json.load(f)
                print(json.dumps(hps, indent=4)[2:-2])

interactive(children=(Dropdown(description='trained_model', options=('AE/3D-D4cnn/hp_search/001_lrschedule_lr1…

## Evaluation

### Loss Evolution During Training

In [2]:
trained_models_w_log = [os.path.join(*path.split('/')[:-1]) 
                        for path in glob.glob('**/log.json', root_dir='./trained_models', recursive=True)]

trained_models_widget = SelectMultiple(options=sorted(trained_models_w_log))
smoothing_widget = FloatSlider(min=0, max=0.9, step=0.05)
x_axis_widget = Dropdown(options=['time', 'epochs'], value='epochs')

@interact_manual.options(manual_name="Visualize Loss")(trained_models=trained_models_widget, smoothing=smoothing_widget, x_axis=x_axis_widget)
def show_loss(trained_models, two_plots=True, log_scale=False, remove_outliers=True, x_axis='epochs', smoothing=0):
    if not trained_models: 
        print('Please select a trained model')
        return
    
    model_names = [os.path.join(*path.split('/')[:2]) for path in trained_models]
    train_names = [os.path.join(*path.split('/')[2:]) for path in trained_models]
    
    visualization.plot_loss_history('trained_models', model_names, train_names, two_plots, log_scale, 
                                    x_axis=='time', remove_outliers, smoothing)

interactive(children=(SelectMultiple(description='trained_models', options=('AE/3D-D4cnn/hp_search/001_lrsched…

### Test Performance

In [3]:
trained_models_w_performance = [os.path.join(*path.split('/')[:-1]) 
                                for path in glob.glob('**/performance.json', root_dir='./results', recursive=True)]
trained_models_widget = SelectMultiple(options=sorted(trained_models_w_performance))
metric_widget = Dropdown(options=['MSE', 'RMSE', 'MAE'], value='RMSE')
show_train_widget = Checkbox(value=False)
group_same_model_widget = Checkbox(value=False)

@interact_manual.options(manual_name="Plot performance")(trained_models=trained_models_widget, metric=metric_widget,
                                                         show_train=show_train_widget, 
                                                         group_same_model=group_same_model_widget)
def show_performance(trained_models, metric, show_train, group_same_model):
    if len(trained_models) == 0: 
        print('Please select a trained model')
        return
    
    model_names = [os.path.join(*path.split('/')[:2]) for path in trained_models]
    train_names = [os.path.join(*path.split('/')[2:]) for path in trained_models]
    
    visualization.plot_performance('./results', model_names, train_names, metric.lower(), group_same_model, show_train)

interactive(children=(SelectMultiple(description='trained_models', options=('AE/3Dcnn/hp_search/008_3layer_lc3…

### Autoregressive Test Performance

In [None]:
trained_models_w_performance = [os.path.join(*path.split('/')[:-1]) 
                                for path in glob.glob('**/autoregressive_performance.json', root_dir='./results', recursive=True)]
trained_models_widget = SelectMultiple(options=sorted(trained_models_w_performance))
metric_widget = Dropdown(options=['MSE', 'RMSE', 'MAE'], value='MAE')
show_train_widget = Checkbox(value=False)
show_bounds_widget = Checkbox(value=False)
median_widget = Checkbox(value=False)

@interact_manual.options(manual_name="Plot performance")(trained_models=trained_models_widget, metric=metric_widget,
                                                         show_train=show_train_widget, show_bounds=show_bounds_widget,
                                                         median=median_widget)
def show_performance(trained_models, metric, show_train, show_bounds, median):
    if len(trained_models) < 1: 
        print('Please select a trained model')
        return
    model_names = [os.path.join(*path.split('/')[:2]) for path in trained_models]
    train_names = [os.path.join(*path.split('/')[2:]) for path in trained_models]
    
    visualization.plot_autoregressive_performance('./results', model_names, train_names, metric.lower(), 
                                                  show_train, show_bounds, median=median)

interactive(children=(SelectMultiple(description='trained_models', options=('FC/3Dcnn/test/002_encoder_decoder…

### Test Performance per Simulation

In [None]:
trained_models_w_performance = [os.path.join(*path.split('/')[:-1]) 
                                for path in glob.glob('**/performance_per_sim.json', 
                                                      root_dir='./results', recursive=True)]
trained_models_widget = SelectMultiple(options=sorted(trained_models_w_performance))
metric_widget = Dropdown(options=['MSE', 'RMSE', 'MAE'], value='RMSE')

@interact_manual.options(manual_name="Plot performance")(trained_models=trained_models_widget, metric=metric_widget)
def show_performance(trained_models, metric):
    if not trained_models: 
        print('Please select a trained model')
        return
    
    model_names = [os.path.join(*path.split('/')[:2]) for path in trained_models]
    train_names = [os.path.join(*path.split('/')[2:]) for path in trained_models]
    
    visualization.plot_performance_per_sim('./results', model_names, train_names, metric.lower())

interactive(children=(SelectMultiple(description='trained_models', options=('AE/3Dcnn/parameters/07_200_000', …

### Test Performance per Channel

In [None]:
trained_models_w_performance = [os.path.join(*path.split('/')[:-1]) 
                                for path in glob.glob('**/performance_per_channel.json', 
                                                      root_dir='./results', recursive=True)]
trained_models_widget = SelectMultiple(options=sorted(trained_models_w_performance))
metric_widget = Dropdown(options=['MSE', 'RMSE', 'MAE'], value='RMSE')
show_train_widget = Checkbox(value=False)

@interact_manual.options(manual_name="Plot performance")(trained_models=trained_models_widget, metric=metric_widget,
                                                         show_train=show_train_widget)
def show_performance(trained_models, metric, show_train):
    if not trained_models: 
        print('Please select a trained model')
        return
    
    model_names = [os.path.join(*path.split('/')[:2]) for path in trained_models]
    train_names = [os.path.join(*path.split('/')[2:]) for path in trained_models]
    
    visualization.plot_performance_per_channel('./results', model_names, train_names, metric.lower(), show_train)

interactive(children=(SelectMultiple(description='trained_models', options=('AE/3Dcnn/parameters/07_200_000', …

### Test Performance per Height

In [None]:
trained_models_w_performance = [os.path.join(*path.split('/')[:-1]) 
                                for path in glob.glob('**/performance_per_height.json', 
                                                      root_dir='./results', recursive=True)]
trained_models_widget = SelectMultiple(options=sorted(trained_models_w_performance))
metric_widget = Dropdown(options=['MSE', 'RMSE', 'MAE'], value='RMSE')
channel_widget = Dropdown(options=['all', 't', 'u', 'v', 'w'], value='t')
show_train_widget = Checkbox(value=False)

@interact_manual.options(manual_name="Plot performance")(trained_models=trained_models_widget, metric=metric_widget,
                                                         channel=channel_widget, show_train=show_train_widget)
def show_performance(trained_models, metric, channel, show_train):
    if not trained_models: 
        print('Please select a trained model')
        return
    
    model_names = [os.path.join(*path.split('/')[:2]) for path in trained_models]
    train_names = [os.path.join(*path.split('/')[2:]) for path in trained_models]
    channel = "tuvw".index(channel) if channel != 'all' else None
    visualization.plot_performance_per_height('./results', model_names, train_names, metric.lower(), channel, show_train)

interactive(children=(SelectMultiple(description='trained_models', options=('AE/3Dcnn/parameters/07_200_000', …

### Test Performance per Latent Size

In [None]:
trained_models_w_performance = [os.path.join(*path.split('/')[:-1])
                                for path in glob.glob('AE/**/performance.json', 
                                                      root_dir='./results/', recursive=True)]
trained_models_widget = SelectMultiple(options=sorted(trained_models_w_performance))
metric_widget = Dropdown(options=['MSE', 'RMSE', 'MAE'], value='RMSE')

@interact_manual.options(manual_name="Plot performance")(trained_models=trained_models_widget, metric=metric_widget)
def show_performance(trained_models, metric):
    if len(trained_models) == 0: 
        print('Please select a trained models')
        return
    
    model_names = [os.path.join(*path.split('/')[:2]) for path in trained_models]
    train_names = [os.path.join(*path.split('/')[2:]) for path in trained_models]
    visualization.plot_performance_per_hp('./trained_models', './results', model_names, train_names, 
                                          metric.lower(), 'latent_size', 'relative latent size (%)')

### Test Performance per Parameter Count

In [None]:
trained_models_w_performance = [os.path.join(*path.split('/')[:-1]) 
                                for path in glob.glob('**/performance.json', 
                                                      root_dir='./results', recursive=True)]
trained_models_widget = SelectMultiple(options=sorted(trained_models_w_performance))
metric_widget = Dropdown(options=['MSE', 'RMSE', 'MAE'], value='RMSE')
rounding_widget = Dropdown(options=['1M', '100k', '10k', '1k'], value='100k')
show_train_widget = Checkbox(value=False)
fill_error_widget = Checkbox(value=True)

@interact_manual.options(manual_name="Plot performance")(trained_models=trained_models_widget, metric=metric_widget,
                                                         rounding=rounding_widget, show_train=show_train_widget,
                                                         fill_error=fill_error_widget)
def show_performance(trained_models, metric, rounding, show_train, fill_error):
    if len(trained_models) < 2: 
        print('Please select at least 2 trained models')
        return
    
    rounding = {'1M': -6, '100k': -5, '10k': -4, '1k': -3}[rounding]
    
    model_names = [os.path.join(*path.split('/')[:2]) for path in trained_models]
    train_names = [os.path.join(*path.split('/')[2:]) for path in trained_models]
    visualization.plot_performance_per_hp('./trained_models', './results', model_names, train_names, 
                                          metric.lower(), 'parameters', rounding=rounding, show_train=show_train,
                                          fill_error=fill_error)

interactive(children=(SelectMultiple(description='trained_models', options=('AE/3Dcnn/hp_search/008_3layer_lc3…

### Model Output Animation

In [14]:
trained_models_w_animation = [os.path.join(*path.split('/')[:-1]) 
                              for path in glob.glob('**/animations', root_dir='./results', recursive=True)]
trained_model_widget = Dropdown(options=sorted(trained_models_w_animation))
feature_widget = Dropdown(options=['t', 'u', 'v', 'w'])
dim_widget = Dropdown(options=['width', 'depth', 'height'])

@interact_manual.options(manual_name="Show animation")(trained_model=trained_model_widget, feature=feature_widget, dim=dim_widget)
def show_animation(trained_model, feature='t', dim='height'):    
    model_name = os.path.join(*trained_model.split('/')[:2])
    train_name = os.path.join(*trained_model.split('/')[2:])
    
    anim_file = os.path.join('.', 'results', model_name, train_name, 'animations', feature, f'{dim}.mp4')
    
    if not os.path.isfile(anim_file):
        print('There is no animation for this selection')
    else:
        display(Video(anim_file))

interactive(children=(Dropdown(description='trained_model', options=('AE/3Dcnn/latent_sizes/8channels', 'AE/3D…

### Latent Space Visualization (Input Sensitivity)

In [15]:
trained_models_w_sensitivity = [os.path.join(*path.split('/')[:-1]) 
                                for path in glob.glob('**/latent_sensitivity.pt', root_dir='./results/AE', recursive=True)]
trained_model_widget = Dropdown(options=sorted(trained_models_w_sensitivity))
feature_widget = Dropdown(options=['t', 'u', 'v', 'w'], value='t')
dim_widget = Dropdown(options=['width', 'depth', 'height'], value='height')
slice_widget = IntSlider(min=0, max=31, value=0)
auto_slice_widget = Checkbox(value=True)
num_patterns_widget = BoundedIntText(min=1, max=576, value=50)
cols_widget = IntSlider(min=1, max=10, value=10)
unified_cbar_widget = Checkbox(value=False)


def update_slice_range(*args):
    if dim_widget.value == 'height':
        slice_widget.max = 31
        slice_widget.value = min(31, slice_widget.value)
    else:
        slice_widget.max = 47
        
dim_widget.observe(update_slice_range)

cached, cached_avg_sensitivity, cached_avg_abs_sensitivity = None, None, None

@interact_manual.options(manual_name="Show patterns")(trained_model=trained_model_widget, feature=feature_widget, 
                                                      dim=dim_widget, slice=slice_widget, auto_slice=auto_slice_widget, 
                                                      num_patterns=num_patterns_widget, cols=cols_widget,
                                                      unified_cbar=unified_cbar_widget)
def show_patterns(trained_model, feature='t', dim='height', slice=16, auto_slice=False,
                  abs_sensitivity=False, contour=False, num_patterns=25, cols=5,
                  unified_cbar=True):
    global cached, cached_avg_sensitivity, cached_avg_abs_sensitivity
    
    model_name = trained_model.split('/')[0]
    train_name = os.path.join(*trained_model.split('/')[1:])
    
    if cached == trained_model:
        avg_sensitivity, avg_abs_sensitivity = cached_avg_sensitivity, cached_avg_abs_sensitivity
    else:
        sensitivity_dir = os.path.join('.', 'results/AE', model_name, train_name)
        avg_sensitivity, avg_abs_sensitivity, n = load_latent_sensitivity(sensitivity_dir, 'latent_sensitivity')
        cached = trained_model
        cached_avg_sensitivity, cached_avg_abs_sensitivity = avg_sensitivity, avg_abs_sensitivity

    sensitivity_data = avg_abs_sensitivity if abs_sensitivity else avg_sensitivity
    channel = 'tuvw'.index(feature)
    axis = ['width', 'depth', 'height'].index(dim)
    if auto_slice:
        slice = None
    
    visualization.show_latent_patterns(sensitivity_data, abs_sensitivity, num_patterns, channel, 
                                       slice, axis, cols=cols, contour=contour, unified_cbar=unified_cbar)

interactive(children=(Dropdown(description='trained_model', options=('3Dcnn/latent_sizes/8channels', '3Dcnn/pa…