In [1]:
import tensorflow as tf
from tensorflow import keras

import numpy as np

import rb_equivariant_cnn as conv
import rb_equivariant_gcnn as gconv
import rb_equivariant_se2ncnn as dn_conv

2024-08-05 07:42:17.608146: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


# Setup

In [2]:
RB_CHANNELS = 4
HORIZONTAL_SIZE = 64
HEIGHT = 32

BATCH_SIZE = 64

# Data Augmentation

In [3]:
# Rotate and Flip Vectors
class RandomRot(keras.layers.RandomRotation):
    def call(self, inputs: tf.Tensor, *args, **kwargs) -> tf.Tensor:
        in_shape = tf.shape(inputs)
        inputs = tf.reshape(inputs, in_shape[:-2]+[np.prod(in_shape[-2:])])
        outputs = super().call(inputs, *args, **kwargs)
        return tf.reshape(outputs, in_shape)
    
class RandomFlip(keras.layers.RandomFlip):
    def call(self, inputs: tf.Tensor, *args, **kwargs) -> tf.Tensor:
        in_shape = tf.shape(inputs)
        inputs = tf.reshape(inputs, in_shape[:-2]+[np.prod(in_shape[-2:])])
        outputs = super().call(inputs, *args, **kwargs)
        return tf.reshape(outputs, in_shape)

# 3D Rayleigh-Bénard Convolution
- Equivariant to horizontal translations
- __No vertical parameter sharing__
- Height dependend bias
- Supports horizontal wrap and same padding
    - Wrap makes sense when using peridoc boundary conditions for Rayleigh-Bénard
    - Attention: This may destroy exact rotation equivariance in our experiments (nevertheless WRAP will be preferable in practice)
- Also supports vertical same padding
- Supports stride (including vertical stride)
- Uses 2D convolutions under the hood

In [4]:
model = keras.Sequential([
            keras.layers.InputLayer(shape=(HORIZONTAL_SIZE, HORIZONTAL_SIZE, HEIGHT, RB_CHANNELS),
                                    batch_size=BATCH_SIZE),
            
            # Data Augmentation
            RandomRot(factor=1, fill_mode='wrap', value_range=(0,1)),
            RandomFlip(mode='horizontal_and_vertical'),
            
            conv.RB3D_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(2,2,2), name='Conv1'),
            conv.BatchNorm(name='BatchNorm1'),
            keras.layers.Activation('relu', name='NonLinearity1'),
            
            keras.layers.Dropout(rate=0.2),
            conv.RB3D_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(2,2,2), name='Conv2'),
            conv.BatchNorm(name='BatchNorm2'),
            keras.layers.Activation('relu', name='NonLinearity2'),
            
            conv.SpatialPooling(ksize=(2,2,2), strides=(2,2,2), pooling_type='MAX'),
            
            keras.layers.Dropout(rate=0.2),
            conv.RB3D_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(2,2,2), name='Conv3'),
            conv.BatchNorm(name='BatchNorm3'),
            keras.layers.Activation('relu', name='NonLinearity3'),
            
            keras.layers.Dropout(rate=0.2),
            conv.RB3D_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(2,2,2), name='Conv4'),
            conv.BatchNorm(name='BatchNorm4'),
            keras.layers.Activation('relu', name='NonLinearity4'),
        ])

# output shape: batch_size, width, depth, height, channels
model.summary()

# 3D Rayleigh-Bénard $D_4$ Group Equivariant Convolution
- Equivariant to all symmetries of 3D Rayleigh-Bénard:
    - __90° rotations around a vertical axis__
    - __reflections through a vertical plane__
    - __horizontal translations__

In [5]:
G = 'D4' # 'C4' for rotations or 'D4' for rotations and reflections
model = keras.Sequential([
            keras.layers.InputLayer(shape=(HORIZONTAL_SIZE, HORIZONTAL_SIZE, HEIGHT, RB_CHANNELS),
                                    batch_size=BATCH_SIZE),
            # add transformation dimension
            keras.layers.Reshape((HORIZONTAL_SIZE, HORIZONTAL_SIZE, 1, HEIGHT, RB_CHANNELS)), 
            
            gconv.RB3D_G_Conv('Z2', G, h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(2, 2, 2), 
                              name=f'Lift_{G}_Conv1'),
            gconv.RB3D_G_Conv(G,    G, h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(2, 2, 2), 
                              name=f'{G}_Conv2'),
            gconv.RB3D_G_Conv(G,    G, h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(2, 2, 2), 
                              name=f'{G}_Conv3'),
            gconv.RB3D_G_Conv(G,    G, h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(2, 2, 2), 
                              name=f'{G}_Conv4'),
        ])

# output shape: batch_size, width, depth, transformations, height, channels
model.summary()

# 3D Rayleigh-Bénard $D_N$ Group Equivariant Convolution
- Equivariant to all symmetries of 3D Rayleigh-Bénard:
    - __arbitrary discrete__ rotations around a vertical axis
    - reflections through a vertical plane
    - horizontal translations

In [6]:
ORIENTATIONS = 8

model = keras.Sequential([
            keras.layers.InputLayer(shape=(HORIZONTAL_SIZE, HORIZONTAL_SIZE, HEIGHT, RB_CHANNELS),
                                    batch_size=BATCH_SIZE),
            
            dn_conv.RB3D_LiftDN_Conv(orientations=ORIENTATIONS, h_ksize=5, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP',
                                     v_padding='SAME', strides=(2, 2, 2), name='Lift_DN_Conv1'),
            dn_conv.RB3D_DN_Conv(h_ksize=5, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', 
                                 v_padding='SAME', strides=(2, 2, 2), name='DN_Conv2'),
            dn_conv.RB3D_DN_Conv(h_ksize=5, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', 
                                 v_padding='SAME', strides=(2, 2, 2), name='DN_Conv3'),
            dn_conv.RB3D_DN_Conv(h_ksize=5, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', 
                                 v_padding='SAME', strides=(2, 2, 2), name='DN_Conv4'),
        ])

# output shape: batch_size, width, depth, transformations, height, channels
model.summary()

# Autoencoder

#### Convolutional Autoencoder

In [7]:
model = keras.Sequential([
            keras.layers.InputLayer(shape=(HORIZONTAL_SIZE, HORIZONTAL_SIZE, HEIGHT, RB_CHANNELS),
                                    batch_size=BATCH_SIZE),
            
            ###############
            #   Encoder   #
            ###############
            conv.RB3D_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1), name='En_Conv1'),
            conv.SpatialPooling(ksize=(2,2,2), pooling_type='MAX', strides=(2,2,2), padding='VALID', name='Pool1'),
            conv.RB3D_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1), name='En_Conv2'),
            conv.SpatialPooling(ksize=(2,2,2), pooling_type='MAX', strides=(2,2,2), padding='VALID', name='Pool2'),
            conv.RB3D_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1), name='En_Conv3'),
            conv.SpatialPooling(ksize=(2,2,2), pooling_type='MAX', strides=(2,2,2), padding='VALID', name='Pool3'),
            conv.RB3D_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1), name='En_Conv4'),
            conv.SpatialPooling(ksize=(2,2,2), pooling_type='MAX', strides=(2,2,2), padding='VALID', name='Pool4'),
            
            ###############
            #   Decoder   #
            ###############
            conv.UpSampling(size=(2,2,2), name='UpSampling1'),
            conv.RB3D_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1), name='De_Conv1'),
            conv.UpSampling(size=(2,2,2), name='UpSampling2'),
            conv.RB3D_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1), name='De_Conv2'),
            conv.UpSampling(size=(2,2,2), name='UpSampling3'),
            conv.RB3D_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1), name='De_Conv3'),
            conv.UpSampling(size=(2,2,2), name='UpSampling4'),
            conv.RB3D_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1), name='De_Conv4'),
        ])

# output shape: batch_size, width, depth, height, channels
model.summary()

### $D_4$ Group Equivariant Convolutional Autoencoder

In [8]:
model = keras.Sequential([
            keras.layers.InputLayer(shape=(HORIZONTAL_SIZE, HORIZONTAL_SIZE, HEIGHT, RB_CHANNELS),
                                    batch_size=BATCH_SIZE),
            
            # add transformation dimension
            keras.layers.Reshape((HORIZONTAL_SIZE, HORIZONTAL_SIZE, 1, HEIGHT, RB_CHANNELS)), 
            
            ###############
            #   Encoder   #
            ###############
            gconv.RB3D_G_Conv('Z2', G, h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1), 
                              name=f'En_Lift_{G}_Conv1'),
            gconv.SpatialPooling(ksize=(2,2,2), pooling_type='MAX', strides=(2,2,2), padding='VALID', name='SpatialPool1'),
            gconv.RB3D_G_Conv(G, G, h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1), 
                              name=f'En_{G}-Conv2'),
            gconv.SpatialPooling(ksize=(2,2,2), pooling_type='MAX', strides=(2,2,2), padding='VALID', name='SpatialPool2'),
            gconv.RB3D_G_Conv(G, G, h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1),
                              name=f'En_{G}-Conv3'),
            gconv.SpatialPooling(ksize=(2,2,2), pooling_type='MAX', strides=(2,2,2), padding='VALID', name='SpatialPool3'),
            gconv.RB3D_G_Conv(G, G, h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1),
                              name=f'En_{G}-Conv4'),
            gconv.SpatialPooling(ksize=(2,2,2), pooling_type='MAX', strides=(2,2,2), padding='VALID', name='SpatialPool4'),
            
            ###############
            #   Decoder   #
            ###############
            gconv.RB3D_G_Conv(G, G, h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1), 
                              name=f'De_{G}_Conv1'),
            gconv.UpSampling(size=(2,2,2), name='UpSampling1'),
            gconv.RB3D_G_Conv(G, G, h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1), 
                              name=f'De_{G}-Conv2'),
            gconv.UpSampling(size=(2,2,2), name='UpSampling2'),
            gconv.RB3D_G_Conv(G, G, h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1),
                              name=f'De_{G}-Conv3'),
            gconv.UpSampling(size=(2,2,2), name='UpSampling3'),
            gconv.RB3D_G_Conv(G, G, h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1),
                              name=f'De_{G}-Conv4'),
            gconv.UpSampling(size=(2,2,2), name='UpSampling4'),
        ])

# output shape: batch_size, width, depth, height, channels
model.summary()

### $D_N$ Group Equivariant Convolutional Autoencoder

In [9]:
ORIENTATIONS = 8

model = keras.Sequential([
            keras.layers.InputLayer(shape=(HORIZONTAL_SIZE, HORIZONTAL_SIZE, HEIGHT, RB_CHANNELS),
                                    batch_size=BATCH_SIZE),
            
            ###############
            #   Encoder   #
            ###############
            dn_conv.RB3D_LiftDN_Conv(orientations=ORIENTATIONS, h_ksize=3, v_ksize=5, channels=RB_CHANNELS, 
                                     h_padding='WRAP', v_padding='SAME', strides=(1,1,1), name='En_Lift_DN_Conv1'),
            dn_conv.SpatialPooling(ksize=(2,2,2), pooling_type='MAX', strides=(2,2,2), padding='VALID', name='SpatialPool1'),
            dn_conv.RB3D_DN_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, 
                                 h_padding='WRAP', v_padding='SAME', strides=(1,1,1), name='En_DN-Conv2'),
            dn_conv.SpatialPooling(ksize=(2,2,2), pooling_type='MAX', strides=(2,2,2), padding='VALID', name='SpatialPool2'),
            dn_conv.RB3D_DN_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, 
                                 h_padding='WRAP', v_padding='SAME', strides=(1,1,1), name='En_DN-Conv3'),
            dn_conv.SpatialPooling(ksize=(2,2,2), pooling_type='MAX', strides=(2,2,2), padding='VALID', name='SpatialPool3'),
            dn_conv.RB3D_DN_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, 
                                 h_padding='WRAP', v_padding='SAME', strides=(1,1,1), name='En_DN-Conv4'),
            dn_conv.SpatialPooling(ksize=(2,2,2), pooling_type='MAX', strides=(2,2,2), padding='VALID', name='SpatialPool4'),
            
            ###############
            #   Decoder   #
            ###############
            dn_conv.RB3D_DN_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, 
                                 h_padding='WRAP', v_padding='SAME', strides=(1,1,1), name='De_DN_Conv1'),
            dn_conv.UpSampling(size=(2,2,2), name='UpSampling1'),
            dn_conv.RB3D_DN_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, 
                                 h_padding='WRAP', v_padding='SAME', strides=(1,1,1), name='De_DN-Conv2'),
            dn_conv.UpSampling(size=(2,2,2), name='UpSampling2'),
            dn_conv.RB3D_DN_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, 
                                 h_padding='WRAP', v_padding='SAME', strides=(1,1,1), name='De_DN-Conv3'),
            dn_conv.UpSampling(size=(2,2,2), name='UpSampling3'),
            dn_conv.RB3D_DN_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, 
                                 h_padding='WRAP', v_padding='SAME', strides=(1,1,1), name='De_DN-Conv4'),
            dn_conv.UpSampling(size=(2,2,2), name='UpSampling4'),
        ])

# output shape: batch_size, width, depth, transformations, height, channels
model.summary()