In [99]:

import escnn
from escnn import gspaces
from escnn import nn
from escnn.nn import GeometricTensor
import torch
import torch.nn.functional as F
import numpy as np
import math
torch.manual_seed(0)

<torch._C.Generator at 0x7f7a0e17e3d0>

In [38]:


# -> [t1, vx1, vy1, vz1, t2, vx2, vy2, vz2, t3, vx3, vy3, vz3]   (shape height*channels)
# -> [0, 0, 0, 0, t1, vx1, vy1, vz1, t2, vx2, vy2, vz2,          (input for predicting h1)
#     t1, vx1, vy1, vz1, t2, vx2, vy2, vz2, t3, vx3, vy3, vz3,   (input for predicting h2)
#     t2, vx2, vy2, vz2, t3, vx3, vy3, vz3, 0, 0, 0, 0]          (input for predicting h3)
def concat_vertical_neighborhoods(geom_tensor: GeometricTensor, height: int, ksize: int, stride: int = 1, pad: bool = True):
    """_summary_

    Args:
        tensor (torch.Tensor): _description_
        type (nn.FieldType): _description_
        ksize (int): _description_
        fields (int): _description_
        stride (int, optional): _description_. Defaults to 1.
        pad (bool, optional): _description_. Defaults to True.

    Returns:
        _type_: _description_
    """
    """
    tensor of shape [batch, height, channels*transformations, width, depth] -> [batch, out_height*ksize*channels*transformations, width, depth]
    
    geomTensor of shape [batch, height*channels*transformations, width, depth]"""
    batch_size, height_and_channels, width, depth = geom_tensor.tensor.shape
    tensor = geom_tensor.tensor.reshape(batch_size, height, height_and_channels//height, width, depth)
    
    in_height = tensor.shape[1]
    
    if pad:
        # pad height
        padding = required_same_padding(in_height=in_height, ksize=ksize, stride=stride)
        applied_padding = (math.floor(padding/2), math.ceil(padding/2))
        tensor = F.pad(tensor, (*([0,0]*3), *applied_padding)) # shape:(b,padH,c*t,w,d)
    
    # compute neighborhoods
    tensor = tensor.unfold(dimension=1, size=ksize, step=stride) # shape:(b,outH,c*t,w,d,ksize)
    out_height = tensor.shape[1]
    
    # concatenate neighboroods
    tensor = tensor.permute(0, 1, 5, 2, 3, 4) # shape:(b,outH,ksize,c*t,w,d)
    tensor = tensor.flatten(start_dim=1, end_dim=3) # shape:(b,outH*ksize*c,t,w,d)
    
    # compute new output type
    num_fields = len(geom_tensor.type.representations)//in_height
    fields = geom_tensor.type.representations[:num_fields]
    assert in_height*fields == geom_tensor.type.representations
    newType = nn.FieldType(geom_tensor.type.gspace, out_height*ksize*fields)
    
    return GeometricTensor(tensor, newType)


def required_same_padding(in_height, ksize, stride):
    out_height = math.ceil(in_height/stride)
    return (out_height-1) * stride - in_height + ksize


def output_height(in_height: int, ksize: int, stride: int, pad: bool):
    padding = 0 if not pad else required_same_padding(in_height, ksize, stride)
    return ((in_height-ksize+padding) // stride)+1

In [100]:

# TODO data augmentation (rotation of horizontal velocities)
BATCH_SIZE = 1
WIDTH, DEPTH, HEIGHT = 32, 32, 32
RB_CHANNELS = 4
KSIZE = 3
V_STRIDE = 2
H_STRIDE = 1
PAD = True

sim_data = torch.randn(BATCH_SIZE, WIDTH, DEPTH, HEIGHT, RB_CHANNELS)

tensor = sim_data.permute(0, 3, 4, 1, 2).reshape(BATCH_SIZE, HEIGHT*RB_CHANNELS, WIDTH, DEPTH)

r2_act = gspaces.flipRot2dOnR2(N=4)
type_simulation  = nn.FieldType(r2_act, HEIGHT*[r2_act.trivial_repr, r2_act.irrep(1, 1), r2_act.trivial_repr])

geom_tensor = type_simulation(tensor)

# Convolution 1
out_height1 = output_height(HEIGHT, KSIZE, V_STRIDE, PAD)
type_conv1_in = nn.FieldType(r2_act, out_height1*KSIZE*[r2_act.trivial_repr, r2_act.irrep(1, 1), r2_act.trivial_repr])
type_conv1_out = nn.FieldType(r2_act, out_height1*10*[r2_act.regular_repr])
conv1 = nn.SequentialModule(
    nn.R2Conv(type_conv1_in, type_conv1_out, KSIZE, groups=out_height1, padding=KSIZE//2, padding_mode='circular', stride=H_STRIDE), 
    nn.ReLU(type_conv1_out),
    nn.PointwiseMaxPool2D(type_conv1_out, kernel_size=2)
)

# Convolution 2
out_height2 = output_height(out_height1, KSIZE, V_STRIDE, PAD)
type_conv2_in = nn.FieldType(r2_act, out_height2*KSIZE*10*[r2_act.regular_repr])
type_conv2_out = nn.FieldType(r2_act, out_height2*15*[r2_act.regular_repr])
conv2 = nn.SequentialModule(
    nn.R2Conv(type_conv2_in, type_conv2_out, KSIZE, groups=out_height2, padding=KSIZE//2, padding_mode='circular', stride=H_STRIDE), 
    nn.ReLU(type_conv2_out),
    nn.PointwiseMaxPool2D(type_conv2_out, kernel_size=2)
)

# Convolution 3
out_height3 = output_height(out_height2, KSIZE, V_STRIDE, PAD)
type_conv3_in = nn.FieldType(r2_act, out_height3*KSIZE*15*[r2_act.regular_repr])
type_conv3_out = nn.FieldType(r2_act, out_height3*20*[r2_act.regular_repr])
conv3 = nn.SequentialModule(
    nn.R2Conv(type_conv3_in, type_conv3_out, KSIZE, groups=out_height3, padding=KSIZE//2, padding_mode='circular', stride=H_STRIDE),
    nn.ReLU(type_conv3_out),
    nn.PointwiseMaxPool2D(type_conv3_out, kernel_size=2)
)

# Convolution 4
out_height4 = output_height(out_height3, KSIZE, V_STRIDE, PAD)
type_conv4_in = nn.FieldType(r2_act, out_height4*KSIZE*20*[r2_act.regular_repr])
type_conv4_out = nn.FieldType(r2_act, out_height4*25*[r2_act.regular_repr])
conv4 = nn.SequentialModule(
    nn.R2Conv(type_conv4_in, type_conv4_out, KSIZE, groups=out_height4, padding=KSIZE//2, padding_mode='circular', stride=H_STRIDE),
    nn.ReLU(type_conv4_out),
    nn.PointwiseMaxPool2D(type_conv4_out, kernel_size=2)
)

    # Convolution 5
out_height5 = output_height(out_height4, KSIZE, V_STRIDE, PAD)
type_conv5_in = nn.FieldType(r2_act, out_height5*KSIZE*25*[r2_act.regular_repr])
type_conv5_out = nn.FieldType(r2_act, out_height5*30*[r2_act.regular_repr])
conv5 = nn.SequentialModule(
    nn.R2Conv(type_conv5_in, type_conv5_out, KSIZE, groups=out_height5, padding=KSIZE//2, padding_mode='circular', stride=H_STRIDE),
    nn.ReLU(type_conv5_out),
    nn.PointwiseMaxPool2D(type_conv5_out, kernel_size=2)
)

# Convolution 6
out_height6 = output_height(out_height5, KSIZE, V_STRIDE, PAD)
type_conv6_in = nn.FieldType(r2_act, out_height6*KSIZE*30*[r2_act.regular_repr])
type_conv6_out = nn.FieldType(r2_act, 
                            out_height6*1*[r2_act.trivial_repr, r2_act.irrep(1, 1), r2_act.trivial_repr])
conv6 = nn.SequentialModule(
    nn.R2Conv(type_conv6_in, type_conv6_out, KSIZE, groups=out_height6, padding=KSIZE//2, padding_mode='circular', stride=H_STRIDE),
)


def predict(geom_tensor):
    input = concat_vertical_neighborhoods(geom_tensor, height=HEIGHT, ksize=KSIZE, stride=V_STRIDE, pad=PAD)
    # Forward phase
    hidden1 = conv1(input)
    hidden1 = concat_vertical_neighborhoods(hidden1, height=out_height1, ksize=KSIZE, stride=V_STRIDE, pad=PAD)
    hidden2 = conv2(hidden1)
    hidden2 = concat_vertical_neighborhoods(hidden2, height=out_height2, ksize=KSIZE, stride=V_STRIDE, pad=PAD)
    hidden3 = conv3(hidden2)
    hidden3 = concat_vertical_neighborhoods(hidden3, height=out_height3, ksize=KSIZE, stride=V_STRIDE, pad=PAD)
    hidden4 = conv4(hidden3)
    hidden4 = concat_vertical_neighborhoods(hidden4, height=out_height4, ksize=KSIZE, stride=V_STRIDE, pad=PAD)
    hidden5 = conv5(hidden4)
    hidden5 = concat_vertical_neighborhoods(hidden5, height=out_height5, ksize=KSIZE, stride=V_STRIDE, pad=PAD)
    out = conv6(hidden5)

    return out

### Equivariance Check

In [None]:
y = predict(geom_tensor)

for g in r2_act.testing_elements:
    y_transformed_input = predict(geom_tensor.transform(g)).tensor.flatten()
    transformed_y = y.transform(g).tensor.flatten()
    
    assert torch.allclose(y_transformed_input, transformed_y, atol=1e-5), f'not equivariant for {g}'
    print(f'equivariant otuput for {g}:', transformed_y)
    
# -> temperature and vertical velocity remain constant
# -> horizontal velocities rotate/flip accordingly

equivariant otuput for (+, 0[2pi/4]): tensor([ 0.4489, -0.0158,  0.0138, -0.1530])
equivariant otuput for (+, 1[2pi/4]): tensor([ 0.4489, -0.0138, -0.0158, -0.1530])
equivariant otuput for (+, 2[2pi/4]): tensor([ 0.4489,  0.0158, -0.0138, -0.1530])
equivariant otuput for (+, 3[2pi/4]): tensor([ 0.4489,  0.0138,  0.0158, -0.1530])
equivariant otuput for (-, 0[2pi/4]): tensor([ 0.4489, -0.0158, -0.0138, -0.1530])
equivariant otuput for (-, 1[2pi/4]): tensor([ 0.4489,  0.0138, -0.0158, -0.1530])
equivariant otuput for (-, 2[2pi/4]): tensor([ 0.4489,  0.0158,  0.0138, -0.1530])
equivariant otuput for (-, 3[2pi/4]): tensor([ 0.4489, -0.0138,  0.0158, -0.1530])


In [102]:
def R(phi):
    return np.array([[np.cos(phi), -np.sin(phi)],
                     [np.sin(phi), np.cos(phi)]])
    
for phi in [0, 1/2*np.pi, np.pi, 3/2*np.pi]:
    print(f'{phi=:.2f}: {R(phi) @ np.array([-0.0158,  0.0138])}')

phi=0.00: [-0.0158  0.0138]
phi=1.57: [-0.0138 -0.0158]
phi=3.14: [ 0.0158 -0.0138]
phi=4.71: [0.0138 0.0158]
