In [2]:
import torch
import torch.nn as nn

In [87]:
import torch.nn.functional as F

def concat_vertical_neighborhoods(x, ksize, stride=1):
    x_padded = F.pad(x, (0, 0, 0, 0, ksize // 2, ksize // 2 + (1-ksize%2)), mode='constant', value=0) # pad height

    neighborhoods = x_padded.unfold(dimension=1, size=ksize, step=stride) # unfold vertical neighborhoods

    return neighborhoods.permute(0, 1, 4, 2, 3).flatten(start_dim=1, end_dim=2) # concat neighborhoods

In [95]:
WIDTH, DEPTH, HEIGHT = 48, 48, 32
RB_CHANNELS = 4
KSIZE = 3 # only supports odd vertical sizes

V_STRIDE = 1


channels1 = 10
conv1 = nn.Conv2d(in_channels=RB_CHANNELS*HEIGHT*KSIZE, out_channels=channels1*HEIGHT, kernel_size=KSIZE, groups=HEIGHT,
                  padding='same')
channels2 = 20
conv2 = nn.Conv2d(in_channels=channels1*HEIGHT*KSIZE, out_channels=channels2*HEIGHT, kernel_size=KSIZE, groups=HEIGHT,
                  padding='same')

sim_data = torch.randn(1, RB_CHANNELS, WIDTH, DEPTH, HEIGHT)

x = sim_data.permute(0, 1, 4, 2, 3).reshape(1, RB_CHANNELS*HEIGHT, WIDTH, DEPTH)
x_transformed = concat_vertical_neighborhoods(x, KSIZE, stride=V_STRIDE)
y1 = conv1(x_transformed)
y1_transformed = concat_vertical_neighborhoods(y1, KSIZE, stride=V_STRIDE)
y2 = conv2(y1_transformed)

In [92]:
x_transformed[:, :RB_CHANNELS*HEIGHT*KSIZE//HEIGHT].shape # (vertical)ksize*in_channels, width, depth

torch.Size([1, 12, 48, 48])

In [93]:
conv1.weight.shape # out_height*out_channels, (vertical)ksize*in_channels, ksize, ksize

torch.Size([320, 12, 3, 3])

In [94]:
y1.shape # out_height*out_channels, width, depth

torch.Size([1, 320, 48, 48])