In [5]:
import escnn
from escnn import gspaces
from escnn import nn
from escnn.nn import GeometricTensor
import torch
import torch.nn.functional as F
import numpy as np
import math

In [None]:


# -> [t1, vx1, vy1, vz1, t2, vx2, vy2, vz2, t3, vx3, vy3, vz3]   (shape height*channels)
# -> [0, 0, 0, 0, t1, vx1, vy1, vz1, t2, vx2, vy2, vz2,          (input for predicting h1)
#     t1, vx1, vy1, vz1, t2, vx2, vy2, vz2, t3, vx3, vy3, vz3,   (input for predicting h2)
#     t2, vx2, vy2, vz2, t3, vx3, vy3, vz3, 0, 0, 0, 0]          (input for predicting h3)
def concat_vertical_neighborhoods(x, ksize, stride=1, pad=True):
    """[batch, width, depth, height, channels] -> [batch, width, depth, out_height*ksize*channels]"""
    if pad:
        # pad height
        padding = required_same_padding(in_height=x.shape[-2], ksize=ksize, stride=stride)
        x = F.pad(x, (0, 0, math.floor(padding/2), math.ceil(padding/2))) # shape: ... x paddedHeight x channels
    
    # compute neighborhoods
    x = x.unfold(dimension=3, size=ksize, step=stride) # shape: ... x outHeight x channels x ksize
    out_height = x.shape[3]
    
    # concatenate neighboroods
    x = x.permute(0, 1, 2, 3, 5, 4) # shape: ... x outHeight x ksize x channels
    x = x.flatten(start_dim=3) # shape: ... x outHeight*kksize*channels i.e. ... x group*inputOfGroup
    
    return x, out_height

def required_same_padding(in_height, ksize, stride):
    out_height = math.ceil(in_height/stride)
    return (out_height-1) * stride - in_height + ksize

## Discrete

In [9]:

# TODO data augmentation (rotation of horizontal velocities)

BATCH_SIZE = 1
WIDTH, DEPTH, HEIGHT = 48, 48, 32
RB_CHANNELS = 4
OUT_CHANNELS = 10
KSIZE = 3
V_STRIDE = 1

ROTS = 4
ROTFLIPS = 2*ROTS

sim_data = torch.randn(BATCH_SIZE, WIDTH, DEPTH, HEIGHT, RB_CHANNELS)

x, out_height = concat_vertical_neighborhoods(sim_data, KSIZE, stride=V_STRIDE) # shape: batch,width,depth,outheight*ksize*channels
x = x.permute(0, 3, 1, 2) # batch, outheight*ksize*channels, width, depth

r2_act = gspaces.flipRot2dOnR2(N=ROTS)
feat_type_in  = nn.FieldType(r2_act,  out_height*KSIZE*[r2_act.trivial_repr, r2_act.irrep(1, 1), r2_act.trivial_repr])
feat_type_out = nn.FieldType(r2_act, out_height*OUT_CHANNELS*[r2_act.regular_repr])

x = feat_type_in(x)
conv_layer = nn.SequentialModule(
    nn.R2Conv(feat_type_in, feat_type_out, KSIZE, groups=out_height, padding=KSIZE//2, padding_mode='circular'), 
    nn.ReLU(feat_type_out)
)

y = conv_layer(x) # batch, outheight*outchannel*transformations, width, depth
y.reshape(BATCH_SIZE, out_height, -1, WIDTH, DEPTH)

#TODO concat_vertical_neighborhoods for this shape: height*outchannel*transform, width, depth

ValueError: not enough values to unpack (expected 2, got 1)

In [41]:
type(y.type)

escnn.nn.field_type.FieldType

In [None]:
def concat_vertical_neighborhoods2(tensor: torch.Tensor, type: nn.FieldType, ksize, fields, stride=1, pad=True):
    """
    tensor of shape [batch, height, channels*transformations, width, depth] -> [batch, out_height*ksize*channels*transformations, width, depth]
    
    geomTensor of shape [batch, height*channels*transformations, width, depth]"""
    in_height = tensor.shape[1]
    
    if pad:
        # pad height
        padding = required_same_padding(in_height=in_height, ksize=ksize, stride=stride)
        applied_padding = (math.floor(padding/2), math.ceil(padding/2))
        tensor = F.pad(tensor, (*([0,0]*3), *applied_padding)) # shape:(b,padH,c*t,w,d)
    
    # compute neighborhoods
    tensor = tensor.unfold(dimension=1, size=ksize, step=stride) # shape:(b,outH,c*t,w,d,ksize)
    out_height = tensor.shape[1]
    print(out_height)
    
    # concatenate neighboroods
    tensor = tensor.permute(0, 1, 5, 2, 3, 4) # shape:(b,outH,ksize,c*t,w,d)
    tensor = tensor.flatten(start_dim=1, end_dim=3) # shape:(b,outH*ksize*c,t,w,d)
    
    channel_representations = type.representations[:fields]
    assert in_height*channel_representations == type.representations
    newType = nn.FieldType(type.gspace, out_height*ksize*channel_representations)
    
    return GeometricTensor(tensor, newType), out_height

HEIGHT = 12
sim_data1 = torch.randn(BATCH_SIZE, WIDTH, DEPTH, HEIGHT, RB_CHANNELS) # b,w,d,h,c
sim_data_1 = sim_data1.permute(0, 3, 4, 1, 2) # b,h,c,w,d
sim_data2 = sim_data1.permute(0, 3, 4, 1, 2)

r2_act = gspaces.flipRot2dOnR2(N=ROTS)
sim_type  = nn.FieldType(r2_act,  HEIGHT*[r2_act.trivial_repr, r2_act.irrep(1, 1), r2_act.trivial_repr])

x1, out_height1 = concat_vertical_neighborhoods(sim_data1, KSIZE, stride=V_STRIDE) # [b,w,d,outH*ksize*c]
x1 = x1.permute(0, 3, 1, 2) # [b,outH*ksize*c*t,w,d]
x_1, out_height_1 = concat_vertical_neighborhoods_1(sim_data_1, ksize=KSIZE, stride=V_STRIDE) # [b,outH*ksize*c*t,w,d]
x2, out_height2 = concat_vertical_neighborhoods2(sim_data2, type=sim_type, fields=3, ksize=KSIZE, stride=V_STRIDE) # [b,outH*ksize*c*t,w,d]

assert out_height1 == out_height_1
assert x1.shape == x_1.shape
assert torch.equal(x1, x_1)


assert out_height1 == out_height2
assert x1.shape == x2.tensor.shape
assert torch.equal(x1, x2.tensor)

12


In [15]:
x = torch.Tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])

torch.equal(F.pad(x, [1, 1]).permute(0, 2, 1), F.pad(x.permute(0,2,1), [0, 0, 1, 1]))

(*([0,0]*3), *(1,1))

(0, 0, 0, 0, 0, 0, 1, 1)

In [12]:
x = torch.Tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
print('x', x) # a, b, c
x2 = x.permute(0, 2, 1) # a, c, b
print('x2', x2)

x_pad = F.pad(x, [1, 1])
x2_pad = F.pad(x2, [0, 0, 1, 1])
x2_pad_rev = x2_pad.permute(0, 2, 1)

print(x_pad)
print("---")
print(x2_pad_rev)
print("---")
print(x2_pad)

assert torch.equal(x_pad, x2_pad_rev)

x tensor([[[1., 2.],
         [3., 4.]],

        [[5., 6.],
         [7., 8.]]])
x2 tensor([[[1., 3.],
         [2., 4.]],

        [[5., 7.],
         [6., 8.]]])
tensor([[[0., 1., 2., 0.],
         [0., 3., 4., 0.]],

        [[0., 5., 6., 0.],
         [0., 7., 8., 0.]]])
---
tensor([[[0., 1., 2., 0.],
         [0., 3., 4., 0.]],

        [[0., 5., 6., 0.],
         [0., 7., 8., 0.]]])
---
tensor([[[0., 0.],
         [1., 3.],
         [2., 4.],
         [0., 0.]],

        [[0., 0.],
         [5., 7.],
         [6., 8.],
         [0., 0.]]])
