In [None]:

#? python train.py -h
#? python train.py steerable3DCNN 500 training1

#? python evaluate.py -h
#? python evaluate.py D4cnn training1 [-eval_performance | 
#?                                     -eval_performance_per_sim |
#?                                     -check_equivariance |
#?                                     -animate |
#?                                     -compute_latent_sensitivity]

#? tensorboard --logdir runs


# 3D Rayleigh-Bénard: Equivariant Convolutional Autoencoder

In [72]:
%load_ext autoreload
%autoreload 2

import os
import sys

sys.path.append('..')

from IPython.display import Video, display
from ipywidgets import interact_manual, FloatSlider, IntSlider, BoundedIntText, Dropdown, SelectMultiple, Checkbox
import glob
import json

from utils import training
from utils.model_building import build_model
from utils.evaluation import load_latent_sensitivity
from utils.visualization import show_latent_patterns, plot_loss_history

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Hyperparameter Selection

In [70]:
model_type_widget = Dropdown(options=['steerableCNN', 'steerable3DCNN', 'CNN', '3DCNN'], value='steerableCNN')
simulation_widget = Dropdown(options=os.listdir('../data/datasets'))
rots_widget = IntSlider(min=4, max=16, step=4, value=4)
flips_widget = Checkbox(value=True)
ec1_widget = BoundedIntText(min=1, max=1000, value=8)
ec2_widget = BoundedIntText(min=1, max=1000, value=16)
ec3_widget = BoundedIntText(min=1, max=1000, value=32)
ec4_widget = BoundedIntText(min=1, max=1000, value=64)
latent_channels_widget = BoundedIntText(min=1, value=32)
h_ksize_widget = Dropdown(options=[3, 5], value=3)
v_ksize_widget = Dropdown(options=[3, 5], value=5)
drop_rate_widget = FloatSlider(min=0, max=1, value=0.2)
nonlinearity_widget = Dropdown(options=['ReLU', 'ELU', 'LeakyReLU'], value='ELU')

def set_default_channels(*args):
    default_channels = {'steerableCNN': (8, 16, 32, 64),
                        'steerable3DCNN': (24, 48, 96, 192),
                        'CNN': (16, 32, 66, 160),
                        '3DCNN': (40, 80, 168, 320)}[model_type_widget.value]
    ec1_widget.value, ec2_widget.value, ec3_widget.value, ec4_widget.value = default_channels
        
model_type_widget.observe(set_default_channels)


@interact_manual.options(manual_name="Build model")(model_type=model_type_widget, simulation=simulation_widget,
                                                    rots=rots_widget, flips=flips_widget, ec1=ec1_widget, ec2=ec2_widget, 
                                                    ec3=ec3_widget, ec4=ec4_widget, latent_channels=latent_channels_widget, 
                                                    h_ksize=h_ksize_widget, v_ksize=v_ksize_widget, drop_rate=drop_rate_widget, 
                                                    nonlinearity=nonlinearity_widget)
def show_patterns(model_type, simulation, rots, flips, ec1, ec2, ec3, ec4, latent_channels, h_ksize, v_ksize, drop_rate, nonlinearity):
    encoder_channels = [ec1, ec2, ec3, ec4]
    hps = {
        'simulation_name': simulation,
        'rots': rots,
        'flips': flips,
        'encoder_channels': encoder_channels,
        'latent_channels': latent_channels,
        'h_kernel_size': h_ksize,
        'v_kernel_size': v_ksize,
        'drop_rate': drop_rate,
        'nonlinearity': nonlinearity
    }
    
    model = build_model(model_type, **hps)

    model.summary()

interactive(children=(Dropdown(description='model_type', options=('steerableCNN', 'steerable3DCNN', 'CNN', '3D…

## Trained Model Summaries

In [77]:
trained_models_widget = Dropdown(options=sorted(glob.glob('*/*', root_dir='./trained_models')))

@interact_manual.options(manual_name="Build model")(trained_model=trained_models_widget)
def show_summary(trained_model):
    model_name, train_name = trained_model.split('/')
    model = training.build_and_load_trained_model('trained_models', model_name, train_name, epoch=-1)

    model.summary()
    
    print('\nHyperparameters:')
    hp_file = os.path.join('trained_models', model_name, train_name, 'hyperparameters.json')
    with open(hp_file, 'r') as f:
        hps = json.load(f)
        print(json.dumps(hps, indent=4)[2:-2])

interactive(children=(Dropdown(description='trained_model', options=('3Dcnn/training1_lr1e-3', '3Dcnn/training…

## Evaluation

### Training Loss Evolution

In [60]:
trained_models_widget = SelectMultiple(options=sorted(glob.glob('*/*', root_dir='./trained_models')))
smoothing_widget = FloatSlider(min=0, max=0.9, step=0.1)

@interact_manual.options(manual_name="Visualize Loss")(trained_models=trained_models_widget, smoothing=smoothing_widget)
def show_loss(trained_models, two_plots=True, log_scale=False, time_x=False, remove_outliers=True, smoothing=0):
    if not trained_models: 
        print('Please select a trained model')
        return
    model_names, train_names = zip(*[path.split('/') for path in trained_models])
    plot_loss_history('trained_models', model_names, train_names, two_plots, log_scale, time_x, remove_outliers, smoothing)

interactive(children=(SelectMultiple(description='trained_models', options=('3Dcnn/training1_lr1e-3', '3Dcnn/t…

### Test Performance

### Encoder Output Animation

In [62]:
trained_models_w_animation = [path for path in glob.glob('*/*', root_dir='./results')
                              if os.path.isdir(os.path.join('results', path, 'animations'))]
trained_model_widget = Dropdown(options=sorted(trained_models_w_animation))
feature_widget = Dropdown(options=['t', 'u', 'v', 'w'])
dim_widget = Dropdown(options=['width', 'depth', 'height'])

@interact_manual.options(manual_name="Show animation")(trained_model=trained_model_widget, feature=feature_widget, dim=dim_widget)
def show_animation(trained_model, feature='t', dim='height'):
    model_name, train_name = trained_model.split('/')
    anim_file = os.path.join('.', 'results', model_name, train_name, 'animations', feature, f'{dim}.mp4')
    if not os.path.isfile(anim_file):
        print('There is no animation for this selection')
    else:
        display(Video(anim_file))

interactive(children=(Dropdown(description='trained_model', options=('D4cnn/training2_lr1e-4',), value='D4cnn/…

### Latent Space Visualization (Input Sensitivity)

In [67]:
trained_models_w_sensitivity = [path for path in glob.glob('*/*', root_dir='./results')
                              if os.path.isfile(os.path.join('results', path, 'latent_sensitivity.py'))]
trained_model_widget = Dropdown(options=sorted(trained_models_w_animation))
feature_widget = Dropdown(options=['t', 'u', 'v', 'w'], value='t')
dim_widget = Dropdown(options=['width', 'depth', 'height'], value='height')
slice_widget = IntSlider(min=0, max=31, value=0)
num_patterns_widget = BoundedIntText(min=1, max=576, value=50)
cols_widget = IntSlider(min=1, max=10, value=10)


def update_slice_range(*args):
    if dim_widget.value == 'height':
        slice_widget.max = 31
        slice_widget.value = min(31, slice_widget.value)
    else:
        slice_widget.max = 47
        
dim_widget.observe(update_slice_range)

cached, cached_avg_sensitivity, cached_avg_abs_sensitivity = None, None, None

@interact_manual.options(manual_name="Show patterns")(trained_model=trained_model_widget, feature=feature_widget, dim=dim_widget, 
                                                      slice=slice_widget, num_patterns=num_patterns_widget, cols=cols_widget)
def show_patterns(trained_model, feature='t', dim='height', slice=16, abs_sensitivity=False, contour=False, num_patterns=25, cols=5):
    global cached, cached_avg_sensitivity, cached_avg_abs_sensitivity
    
    model_name, train_name = trained_model.split('/')
    if cached == trained_model:
        avg_sensitivity, avg_abs_sensitivity = cached_avg_sensitivity, cached_avg_abs_sensitivity
    else:
        sensitivity_dir = os.path.join('.', 'results', model_name, train_name)
        avg_sensitivity, avg_abs_sensitivity, n = load_latent_sensitivity(sensitivity_dir, 'latent_sensitivity')
        cached = trained_model
        cached_avg_sensitivity, cached_avg_abs_sensitivity = avg_sensitivity, avg_abs_sensitivity

    sensitivity_data = avg_abs_sensitivity if abs_sensitivity else avg_sensitivity
    channel = 'tuvw'.index(feature)
    axis = ['width', 'depth', 'height'].index(dim)
    show_latent_patterns(sensitivity_data, abs_sensitivity, num_patterns, channel, slice, axis, cols=cols, contour=contour)

interactive(children=(Dropdown(description='trained_model', options=('D4cnn/training2_lr1e-4',), value='D4cnn/…