TODOS:
- DataAugmentation (inkl. Rotation der Velocities)

In [1]:
import tensorflow as tf
from tensorflow import keras

import numpy as np
import math

import h5py
import os

import rb_equivariant_cnn as conv
import rb_equivariant_gcnn as gconv
import rb_equivariant_se2ncnn as dn_conv

os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

2024-08-15 02:50:30.650484: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-08-15 02:50:30.682059: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
gpus = tf.config.list_physical_devices('GPU')
if gpus:
  # Restrict TensorFlow to only allocate 1GB of memory on the first GPU
  try:
    tf.config.set_logical_device_configuration(
        gpus[0],
        [tf.config.LogicalDeviceConfiguration(memory_limit=16384)])
    logical_gpus = tf.config.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
  except RuntimeError as e:
    # Virtual devices must be set before GPUs have been initialized
    print(e)

1 Physical GPUs, 1 Logical GPUs


2024-08-15 02:50:32.449992: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:47] Overriding orig_value setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.
2024-08-15 02:50:32.451611: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1928] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 16384 MB memory:  -> device: 0, name: NVIDIA A100 80GB PCIe, pci bus id: 0000:01:00.0, compute capability: 8.0


# Setup

In [3]:
RB_CHANNELS = 4
HORIZONTAL_SIZE = 48
HEIGHT = 32

BATCH_SIZE = 64

SIMULATION_NAME = '48_48_32_2000_0.71_0.01_0.3_1000.2'

# Data

In [4]:
sim_file = os.path.join('data', f'{SIMULATION_NAME}.h5')

with h5py.File(sim_file, 'r') as hf:
    N_TRAIN = hf['train'].attrs['N']
    N_TEST = hf['test'].attrs['N']
    
TRAIN_BATCHES = math.ceil(N_TRAIN/BATCH_SIZE)
TEST_BATCHES = math.ceil(N_TEST/BATCH_SIZE)

class generator:
    def __init__(self, filename, dataset):
        self.filename = filename
        self.dataset = dataset


    def __call__(self):
        with h5py.File(self.filename, 'r') as hf:
            while True:
                for snap in hf[self.dataset]:
                    yield snap, snap

train_dataset = tf.data.Dataset.from_generator(
     generator(sim_file, 'train'),
     output_signature=(
         tf.TensorSpec(shape=(HORIZONTAL_SIZE, HORIZONTAL_SIZE, HEIGHT, RB_CHANNELS), dtype=tf.float64),
         tf.TensorSpec(shape=(HORIZONTAL_SIZE, HORIZONTAL_SIZE, HEIGHT, RB_CHANNELS), dtype=tf.float64)))

test_dataset = tf.data.Dataset.from_generator(
     generator(sim_file, 'test'),
     output_signature=(
         tf.TensorSpec(shape=(HORIZONTAL_SIZE, HORIZONTAL_SIZE, HEIGHT, RB_CHANNELS), dtype=tf.float64),
         tf.TensorSpec(shape=(HORIZONTAL_SIZE, HORIZONTAL_SIZE, HEIGHT, RB_CHANNELS), dtype=tf.float64)))

# dataset = dataset.shuffle(10, reshuffle_each_iteration=True)
train_dataset = train_dataset.batch(BATCH_SIZE, False)
test_dataset = test_dataset.batch(BATCH_SIZE, False)

# Data Augmentation

In [4]:

#TODO  Rotate and Flip Vectors
class RandomRot(keras.layers.RandomRotation):
    def call(self, inputs: tf.Tensor, *args, **kwargs) -> tf.Tensor:
        in_shape = tf.shape(inputs)
        inputs = tf.reshape(inputs, in_shape[:-2]+[np.prod(in_shape[-2:])])
        outputs = super().call(inputs, *args, **kwargs)
        return tf.reshape(outputs, in_shape)
    
class RandomFlip(keras.layers.RandomFlip):
    def call(self, inputs: tf.Tensor, *args, **kwargs) -> tf.Tensor:
        in_shape = tf.shape(inputs)
        inputs = tf.reshape(inputs, in_shape[:-2]+[np.prod(in_shape[-2:])])
        outputs = super().call(inputs, *args, **kwargs)
        return tf.reshape(outputs, in_shape)

# 3D Rayleigh-Bénard Convolution
- Equivariant to horizontal translations
- __No vertical parameter sharing__
- Height dependend bias
- Supports horizontal wrap and same padding
    - Wrap makes sense when using peridoc boundary conditions for Rayleigh-Bénard
    - Attention: This may destroy exact rotation equivariance in our experiments (nevertheless WRAP will be preferable in practice)
- Also supports vertical same padding
- Supports stride (including vertical stride)
- Uses 2D convolutions under the hood

In [6]:
model = keras.Sequential([
            keras.layers.InputLayer(shape=(HORIZONTAL_SIZE, HORIZONTAL_SIZE, HEIGHT, RB_CHANNELS),
                                    batch_size=BATCH_SIZE),
            
            # Data Augmentation
            RandomRot(factor=1, fill_mode='wrap', value_range=(0,1)),
            RandomFlip(mode='horizontal_and_vertical'),
            
            conv.RB3D_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(2,2,2), name='Conv1'),
            conv.BatchNorm(name='BatchNorm1'),
            keras.layers.Activation('relu', name='NonLinearity1'),
            
            keras.layers.Dropout(rate=0.2),
            conv.RB3D_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(2,2,2), name='Conv2'),
            conv.BatchNorm(name='BatchNorm2'),
            keras.layers.Activation('relu', name='NonLinearity2'),
            
            conv.SpatialPooling(ksize=(2,2,2), strides=(2,2,2), pooling_type='MAX'),
            
            keras.layers.Dropout(rate=0.2),
            conv.RB3D_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(2,2,2), name='Conv3'),
            conv.BatchNorm(name='BatchNorm3'),
            keras.layers.Activation('relu', name='NonLinearity3'),
            
            keras.layers.Dropout(rate=0.2),
            conv.RB3D_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(2,2,2), name='Conv4'),
            conv.BatchNorm(name='BatchNorm4'),
            keras.layers.Activation('relu', name='NonLinearity4'),
        ])

# output shape: batch_size, width, depth, height, channels
model.summary()

# 3D Rayleigh-Bénard $D_4$ Group Equivariant Convolution
- Equivariant to all symmetries of 3D Rayleigh-Bénard:
    - __90° rotations around a vertical axis__
    - __reflections through a vertical plane__
    - __horizontal translations__

In [7]:
G = 'D4' # 'C4' for rotations or 'D4' for rotations and reflections
model = keras.Sequential([
            keras.layers.InputLayer(shape=(HORIZONTAL_SIZE, HORIZONTAL_SIZE, HEIGHT, RB_CHANNELS),
                                    batch_size=BATCH_SIZE),
            # add transformation dimension
            keras.layers.Reshape((HORIZONTAL_SIZE, HORIZONTAL_SIZE, 1, HEIGHT, RB_CHANNELS)), 
            
            gconv.RB3D_G_Conv('Z2', G, h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(2, 2, 2), 
                              name=f'Lift_{G}_Conv1'),
            gconv.RB3D_G_Conv(G,    G, h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(2, 2, 2), 
                              name=f'{G}_Conv2'),
            gconv.RB3D_G_Conv(G,    G, h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(2, 2, 2), 
                              name=f'{G}_Conv3'),
            gconv.RB3D_G_Conv(G,    G, h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(2, 2, 2), 
                              name=f'{G}_Conv4'),
        ])

# output shape: batch_size, width, depth, transformations, height, channels
model.summary()

# 3D Rayleigh-Bénard $D_N$ Group Equivariant Convolution
- Equivariant to all symmetries of 3D Rayleigh-Bénard:
    - __arbitrary discrete__ rotations around a vertical axis
    - reflections through a vertical plane
    - horizontal translations

In [8]:
ORIENTATIONS = 8

model = keras.Sequential([
            keras.layers.InputLayer(shape=(HORIZONTAL_SIZE, HORIZONTAL_SIZE, HEIGHT, RB_CHANNELS),
                                    batch_size=BATCH_SIZE),
            
            dn_conv.RB3D_LiftDN_Conv(orientations=ORIENTATIONS, h_ksize=5, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP',
                                     v_padding='SAME', strides=(2, 2, 2), name='Lift_DN_Conv1'),
            dn_conv.RB3D_DN_Conv(h_ksize=5, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', 
                                 v_padding='SAME', strides=(2, 2, 2), name='DN_Conv2'),
            dn_conv.RB3D_DN_Conv(h_ksize=5, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', 
                                 v_padding='SAME', strides=(2, 2, 2), name='DN_Conv3'),
            dn_conv.RB3D_DN_Conv(h_ksize=5, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', 
                                 v_padding='SAME', strides=(2, 2, 2), name='DN_Conv4'),
        ])

# output shape: batch_size, width, depth, transformations, height, channels
model.summary()

# Autoencoder

#### Standard 3D Convolutional Autoencoder

In [5]:
ae = keras.Sequential([
            keras.layers.InputLayer(shape=(HORIZONTAL_SIZE, HORIZONTAL_SIZE, HEIGHT, RB_CHANNELS),
                                    batch_size=BATCH_SIZE),
            
            ###############
            #   Encoder   #
            ###############
            keras.layers.Conv3D(kernel_size=(3,3,5), filters=RB_CHANNELS, padding='same', kernel_regularizer=keras.regularizers.L2(0.01)),
            keras.layers.BatchNormalization(),
            keras.layers.Activation('relu'),
            keras.layers.MaxPooling3D(pool_size=(2,2,2)),
            keras.layers.Dropout(rate=0.2),
            keras.layers.Conv3D(kernel_size=(3,3,5), filters=RB_CHANNELS, padding='same', kernel_regularizer=keras.regularizers.L2(0.01)),
            keras.layers.BatchNormalization(),
            keras.layers.Activation('relu'),
            keras.layers.MaxPooling3D(pool_size=(2,2,2)),
            keras.layers.Dropout(rate=0.2),
            keras.layers.Conv3D(kernel_size=(3,3,5), filters=RB_CHANNELS, padding='same', kernel_regularizer=keras.regularizers.L2(0.01)),
            keras.layers.BatchNormalization(),
            keras.layers.Activation('relu'),
            keras.layers.MaxPooling3D(pool_size=(2,2,2)),
            keras.layers.Dropout(rate=0.2),
            keras.layers.Conv3D(kernel_size=(3,3,5), filters=RB_CHANNELS, padding='same', kernel_regularizer=keras.regularizers.L2(0.01)),
            keras.layers.BatchNormalization(),
            keras.layers.Activation('relu'),
            keras.layers.MaxPooling3D(pool_size=(2,2,2)),
            
            ###############
            #   Decoder   #
            ###############
            keras.layers.UpSampling3D(size=(2,2,2)),
            keras.layers.Dropout(rate=0.2),
            keras.layers.Conv3D(kernel_size=(3,3,5), filters=RB_CHANNELS, padding='same', kernel_regularizer=keras.regularizers.L2(0.01)),
            keras.layers.BatchNormalization(),
            keras.layers.Activation('relu'),
            keras.layers.UpSampling3D(size=(2,2,2)),
            keras.layers.Dropout(rate=0.2),
            keras.layers.Conv3D(kernel_size=(3,3,5), filters=RB_CHANNELS, padding='same', kernel_regularizer=keras.regularizers.L2(0.01)),
            keras.layers.BatchNormalization(),
            keras.layers.Activation('relu'),
            keras.layers.UpSampling3D(size=(2,2,2)),
            keras.layers.Dropout(rate=0.2),
            keras.layers.Conv3D(kernel_size=(3,3,5), filters=RB_CHANNELS, padding='same', kernel_regularizer=keras.regularizers.L2(0.01)),
            keras.layers.BatchNormalization(),
            keras.layers.Activation('relu'),
            keras.layers.UpSampling3D(size=(2,2,2)),
            keras.layers.Dropout(rate=0.2),
            keras.layers.Conv3D(kernel_size=(3,3,5), filters=RB_CHANNELS, padding='same', kernel_regularizer=keras.regularizers.L2(0.01)),
        ])

# output shape: batch_size, width, depth, height, channels
ae.summary()

In [6]:
ae.compile(
    loss=tf.keras.losses.MeanSquaredError,
    optimizer=keras.optimizers.Adam(learning_rate=0.0001),
)

hist = ae.fit(train_dataset, steps_per_epoch=TRAIN_BATCHES, validation_data=test_dataset, validation_steps=TEST_BATCHES, epochs=100)

Epoch 1/100


I0000 00:00:1723677235.303914   88675 service.cc:145] XLA service 0x7f0258009910 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1723677235.303989   88675 service.cc:153]   StreamExecutor device (0): NVIDIA A100 80GB PCIe, Compute Capability 8.0
2024-08-15 01:13:55.396663: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2024-08-15 01:13:55.598984: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:465] Loaded cuDNN version 8907


[1m 1/42[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m7:10[0m 10s/step - loss: 0.5997

I0000 00:00:1723677244.064649   88675 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m23s[0m 300ms/step - loss: 0.3453 - val_loss: 0.5966
Epoch 2/100
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 314ms/step - loss: 0.3307 - val_loss: 0.5636
Epoch 3/100
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 295ms/step - loss: 0.2960 - val_loss: 0.4639
Epoch 4/100
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 262ms/step - loss: 0.2251 - val_loss: 0.4004
Epoch 5/100
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 293ms/step - loss: 0.1703 - val_loss: 0.4046
Epoch 6/100
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 282ms/step - loss: 0.1337 - val_loss: 0.3972
Epoch 7/100
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 279ms/step - loss: 0.1104 - val_loss: 0.3948
Epoch 8/100
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 278ms/step - loss: 0.0973 - val_loss: 0.4000
Epoch 9/100
[1m 2/42[0m [37m━━━━━

KeyboardInterrupt: 

#### Height dependend Convolutional Autoencoder

In [8]:
h_ae = keras.Sequential([
            keras.layers.InputLayer(shape=(HORIZONTAL_SIZE, HORIZONTAL_SIZE, HEIGHT, RB_CHANNELS),
                                    batch_size=BATCH_SIZE),
            
            ###############
            #   Encoder   #
            ###############
            conv.RB3D_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1),
                           filter_regularizer=keras.regularizers.L2(0.01)),
            conv.BatchNorm(),
            keras.layers.Activation('relu'),
            conv.SpatialPooling(ksize=(2,2,2), pooling_type='MAX', strides=(2,2,2), padding='VALID'),
            keras.layers.Dropout(rate=0.2),
            conv.RB3D_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1),
                           filter_regularizer=keras.regularizers.L2(0.01)),
            conv.BatchNorm(),
            keras.layers.Activation('relu'),
            conv.SpatialPooling(ksize=(2,2,2), pooling_type='MAX', strides=(2,2,2), padding='VALID'),
            keras.layers.Dropout(rate=0.2),
            conv.RB3D_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1),
                           filter_regularizer=keras.regularizers.L2(0.01)),
            conv.BatchNorm(),
            keras.layers.Activation('relu'),
            conv.SpatialPooling(ksize=(2,2,2), pooling_type='MAX', strides=(2,2,2), padding='VALID'),
            keras.layers.Dropout(rate=0.2),
            conv.RB3D_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1),
                           filter_regularizer=keras.regularizers.L2(0.01)),
            conv.BatchNorm(),
            keras.layers.Activation('relu'),
            conv.SpatialPooling(ksize=(2,2,2), pooling_type='MAX', strides=(2,2,2), padding='VALID'),
            
            ###############
            #   Decoder   #
            ###############
            conv.UpSampling(size=(2,2,2)),
            keras.layers.Dropout(rate=0.2),
            conv.RB3D_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1),
                           filter_regularizer=keras.regularizers.L2(0.01)),
            conv.BatchNorm(),
            keras.layers.Activation('relu'),
            conv.UpSampling(size=(2,2,2)),
            keras.layers.Dropout(rate=0.2),
            conv.RB3D_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1),
                           filter_regularizer=keras.regularizers.L2(0.01)),
            conv.BatchNorm(),
            keras.layers.Activation('relu'),
            conv.UpSampling(size=(2,2,2)),
            keras.layers.Dropout(rate=0.2),
            conv.RB3D_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1),
                           filter_regularizer=keras.regularizers.L2(0.01)),
            conv.BatchNorm(),
            keras.layers.Activation('relu'),
            conv.UpSampling(size=(2,2,2)),
            keras.layers.Dropout(rate=0.2),
            conv.RB3D_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1),
                           filter_regularizer=keras.regularizers.L2(0.01)),
        ])

# output shape: batch_size, width, depth, height, channels
h_ae.summary()

In [9]:
h_ae.compile(
    loss=tf.keras.losses.MeanSquaredError,
    optimizer=keras.optimizers.Adam(learning_rate=0.0001),
)

hist = h_ae.fit(train_dataset, steps_per_epoch=TRAIN_BATCHES, validation_data=test_dataset, validation_steps=TEST_BATCHES, epochs=100)

Epoch 1/100


I0000 00:00:1723679090.605354   98967 service.cc:145] XLA service 0x7f355c00b260 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1723679090.605423   98967 service.cc:153]   StreamExecutor device (0): NVIDIA A100 80GB PCIe, Compute Capability 8.0
2024-08-15 01:44:50.784806: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2024-08-15 01:44:51.308987: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:465] Loaded cuDNN version 8907


[1m 2/42[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m2s[0m 66ms/step - loss: 0.7101  

I0000 00:00:1723679103.718369   98967 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m34/42[0m [32m━━━━━━━━━━━━━━━━[0m[37m━━━━[0m [1m1s[0m 236ms/step - loss: 0.4756

KeyboardInterrupt: 

### $D_4$ Group Equivariant Convolutional Autoencoder

In [5]:
G = 'D4' # 'C4' for rotations or 'D4' for rotations and reflections

d4_ae = keras.Sequential([
            keras.layers.InputLayer(shape=(HORIZONTAL_SIZE, HORIZONTAL_SIZE, HEIGHT, RB_CHANNELS),
                                    batch_size=BATCH_SIZE),
            
            # add transformation dimension
            keras.layers.Reshape((HORIZONTAL_SIZE, HORIZONTAL_SIZE, 1, HEIGHT, RB_CHANNELS)), 
            
            ###############
            #   Encoder   #
            ###############
            gconv.RB3D_G_Conv('Z2', G, h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1),
                              filter_regularizer=keras.regularizers.L2(0.01)),
            gconv.BatchNorm(),
            keras.layers.Activation('relu'),
            
            gconv.SpatialPooling(ksize=(2,2,2), pooling_type='MAX', strides=(2,2,2), padding='VALID'),
            keras.layers.Dropout(rate=0.2),
            gconv.RB3D_G_Conv(G, G, h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1),
                              filter_regularizer=keras.regularizers.L2(0.01)),
            gconv.BatchNorm(),
            keras.layers.Activation('relu'),
            
            gconv.SpatialPooling(ksize=(2,2,2), pooling_type='MAX', strides=(2,2,2), padding='VALID'),
            keras.layers.Dropout(rate=0.2),
            gconv.RB3D_G_Conv(G, G, h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1),
                              filter_regularizer=keras.regularizers.L2(0.01)),
            gconv.BatchNorm(),
            keras.layers.Activation('relu'),
            
            gconv.SpatialPooling(ksize=(2,2,2), pooling_type='MAX', strides=(2,2,2), padding='VALID'),
            keras.layers.Dropout(rate=0.2),
            gconv.RB3D_G_Conv(G, G, h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1),
                              filter_regularizer=keras.regularizers.L2(0.01)),
            gconv.BatchNorm(),
            keras.layers.Activation('relu'),
            
            gconv.SpatialPooling(ksize=(2,2,2), pooling_type='MAX', strides=(2,2,2), padding='VALID'),
            
            ###############
            #   Decoder   #
            ###############
            gconv.UpSampling(size=(2,2,2)),
            keras.layers.Dropout(rate=0.2),
            gconv.RB3D_G_Conv(G, G, h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1),
                              filter_regularizer=keras.regularizers.L2(0.01)),
            gconv.BatchNorm(),
            keras.layers.Activation('relu'),
            
            gconv.UpSampling(size=(2,2,2)),
            keras.layers.Dropout(rate=0.2),
            gconv.RB3D_G_Conv(G, G, h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1),
                              filter_regularizer=keras.regularizers.L2(0.01)),
            gconv.BatchNorm(),
            keras.layers.Activation('relu'),
            
            gconv.UpSampling(size=(2,2,2)),
            keras.layers.Dropout(rate=0.2),
            gconv.RB3D_G_Conv(G, G, h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1),
                              filter_regularizer=keras.regularizers.L2(0.01)),
            gconv.BatchNorm(),
            keras.layers.Activation('relu'),
            
            gconv.UpSampling(size=(2,2,2)),
            keras.layers.Dropout(rate=0.2),
            gconv.RB3D_G_Conv(G, G, h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1),
                              filter_regularizer=keras.regularizers.L2(0.01)),
            gconv.TransformationPooling(tf.reduce_mean, keepdims=False)
        ])

# output shape: batch_size, width, depth, height, channels
d4_ae.summary()

In [6]:
d4_ae.compile(
    loss=tf.keras.losses.MeanSquaredError,
    optimizer=keras.optimizers.Adam(learning_rate=0.001),
)

hist = d4_ae.fit(train_dataset, steps_per_epoch=TRAIN_BATCHES, validation_data=test_dataset, validation_steps=TEST_BATCHES, epochs=100)

Epoch 1/100


I0000 00:00:1723679957.488550  104125 service.cc:145] XLA service 0x7f110401c0b0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1723679957.488631  104125 service.cc:153]   StreamExecutor device (0): NVIDIA A100 80GB PCIe, Compute Capability 8.0
2024-08-15 01:59:17.721194: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2024-08-15 01:59:18.286564: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:465] Loaded cuDNN version 8907
2024-08-15 01:59:23.898404: E external/local_xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng0{} for conv (f32[64,1024,48,48]{3,2,1,0}, u8[0]{0}) custom-call(f32[64,144,50,50]{3,2,1,0}, f32[1024,144,3,3]{3,2,1,0}), window={size=3x3}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_co

[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m163s[0m 403ms/step - loss: 0.2047 - val_loss: 0.5262
Epoch 2/100
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 350ms/step - loss: 0.2691 - val_loss: 0.5341
Epoch 3/100
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 344ms/step - loss: 0.2732 - val_loss: 0.3797
Epoch 4/100
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 322ms/step - loss: 0.1845 - val_loss: 0.5462
Epoch 5/100
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 315ms/step - loss: 0.2611 - val_loss: 0.5243
Epoch 6/100
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 328ms/step - loss: 0.2790 - val_loss: 0.5118
Epoch 7/100
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 314ms/step - loss: 0.2523 - val_loss: 0.5440
Epoch 8/100
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 315ms/step - loss: 0.2777 - val_loss: 0.4435
Epoch 9/100
[1m42/42[0m [32m━━━━

### $D_N$ Group Equivariant Convolutional Autoencoder

In [18]:
ORIENTATIONS = 8

dn_ae = keras.Sequential([
            keras.layers.InputLayer(shape=(HORIZONTAL_SIZE, HORIZONTAL_SIZE, HEIGHT, RB_CHANNELS),
                                    batch_size=BATCH_SIZE),
            
            ###############
            #   Encoder   #
            ###############
            dn_conv.RB3D_LiftDN_Conv(orientations=ORIENTATIONS, h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1),
                                     filter_regularizer=keras.regularizers.L2(0.01)),
            gconv.BatchNorm(),
            keras.layers.Activation('relu'),
            dn_conv.SpatialPooling(ksize=(2,2,2), pooling_type='MAX', strides=(2,2,2), padding='VALID'),
            keras.layers.Dropout(rate=0.2),
            dn_conv.RB3D_DN_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1),
                                 filter_regularizer=keras.regularizers.L2(0.01)),
            gconv.BatchNorm(),
            keras.layers.Activation('relu'),
            dn_conv.SpatialPooling(ksize=(2,2,2), pooling_type='MAX', strides=(2,2,2), padding='VALID'),
            keras.layers.Dropout(rate=0.2),
            dn_conv.RB3D_DN_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1),
                                 filter_regularizer=keras.regularizers.L2(0.01)),
            gconv.BatchNorm(),
            keras.layers.Activation('relu'),
            dn_conv.SpatialPooling(ksize=(2,2,2), pooling_type='MAX', strides=(2,2,2), padding='VALID'),
            keras.layers.Dropout(rate=0.2),
            dn_conv.RB3D_DN_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1),
                                 filter_regularizer=keras.regularizers.L2(0.01)),
            gconv.BatchNorm(),
            keras.layers.Activation('relu'),
            dn_conv.SpatialPooling(ksize=(2,2,2), pooling_type='MAX', strides=(2,2,2), padding='VALID'),
            
            ###############
            #   Decoder   #
            ###############
            dn_conv.UpSampling(size=(2,2,2), name='UpSampling1'),
            keras.layers.Dropout(rate=0.2),
            dn_conv.RB3D_DN_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1),
                                 filter_regularizer=keras.regularizers.L2(0.01)),
            gconv.BatchNorm(),
            keras.layers.Activation('relu'),
            dn_conv.UpSampling(size=(2,2,2)),
            keras.layers.Dropout(rate=0.2),
            dn_conv.RB3D_DN_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1),
                                 filter_regularizer=keras.regularizers.L2(0.01)),
            gconv.BatchNorm(),
            keras.layers.Activation('relu'),
            dn_conv.UpSampling(size=(2,2,2)),
            keras.layers.Dropout(rate=0.2),
            dn_conv.RB3D_DN_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1),
                                 filter_regularizer=keras.regularizers.L2(0.01)),
            gconv.BatchNorm(),
            keras.layers.Activation('relu'),
            dn_conv.UpSampling(size=(2,2,2)),
            keras.layers.Dropout(rate=0.2),
            dn_conv.RB3D_DN_Conv(h_ksize=3, v_ksize=5, channels=RB_CHANNELS, h_padding='WRAP', v_padding='SAME', strides=(1,1,1),
                                 filter_regularizer=keras.regularizers.L2(0.01)),
            dn_conv.TransformationPooling(tf.reduce_mean, keepdims=False)
        ])

# output shape: batch_size, width, depth, transformations, height, channels
dn_ae.summary()

In [19]:
dn_ae.compile(
    loss=tf.keras.losses.MeanSquaredError,
    optimizer=keras.optimizers.Adam(learning_rate=0.0001),
)

hist = dn_ae.fit(train_dataset, steps_per_epoch=TRAIN_BATCHES, validation_data=test_dataset, validation_steps=TEST_BATCHES, epochs=100)

Epoch 1/100
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m50s[0m 842ms/step - loss: 0.7314 - val_loss: 0.8166
Epoch 2/100
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 771ms/step - loss: 0.5262 - val_loss: 0.6416
Epoch 3/100
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 773ms/step - loss: 0.4119 - val_loss: 0.5820
Epoch 4/100
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 770ms/step - loss: 0.3651 - val_loss: 0.5664
Epoch 5/100
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 769ms/step - loss: 0.3394 - val_loss: 0.5471
Epoch 6/100
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 771ms/step - loss: 0.3221 - val_loss: 0.5449
Epoch 7/100
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 771ms/step - loss: 0.3106 - val_loss: 0.5128
Epoch 8/100
[1m15/42[0m [32m━━━━━━━[0m[37m━━━━━━━━━━━━━[0m [1m19s[0m 721ms/step - loss: 0.4255