In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from se3cnn.image.gated_block import GatedBlock
from keymorph import utils

In [26]:
class GatedBlockPrint(GatedBlock):
    def forward(self, x):
        output = super().forward(x)
        print(output.shape)
        return output

class RXFM_Net(nn.Module):
    def __init__(self, n_in, output_chans, norm_type):
        super(RXFM_Net, self).__init__()

#         chan_config = [[32, 32, 4], [16, 16, 4], [16, 16, 4], [16, 16, 4]]
#         features = [[n_in]] + chan_config + [[output_chans]]
        chan_config = [[16, 16, 4]]
        features = [[n_in]] + chan_config


        common_block_params = {
            "size": 5,
            "stride": 2,
            "padding": 2,
            "normalization": norm_type,
            "capsule_dropout_p": None,
            "smooth_stride": False,
        }

        block_params = [{"activation": F.relu}] * (len(features) - 2) + [
            {"activation": F.relu}
        ]

        assert len(block_params) + 1 == len(features)
        
        for i in range(len(block_params)):
            print(features[i], features[i+1])
#         for i in range
        blocks = [
            GatedBlockPrint(
                features[i],
                features[i + 1],
                **common_block_params,
                **block_params[i],
                dyn_iso=True
            )
            for i in range(len(block_params))
        ]

        self.sequence = torch.nn.Sequential(*blocks)

    def forward(self, x):
        x = self.sequence(x)
        return x

In [27]:
num_keypoints = 128
norm_type = 'instance'
network = RXFM_Net(1, num_keypoints, norm_type=norm_type)
utils.summary(network)

[1] [16, 16, 4]

Model Summary
---------------------------------------------------------------
sequence.0.scalar_act.bias
sequence.0.gate_act.bias
sequence.0.conv.gn.weight
sequence.0.conv.gn.bias
sequence.0.conv.conv.kernel.weight
Total parameters: 186
---------------------------------------------------------------



In [14]:
X = torch.randn(1, 1, 256, 256, 256, requires_grad=False)
network(X)

torch.Size([1, 148, 128, 128, 128])
torch.Size([1, 84, 64, 64, 64])
torch.Size([1, 84, 32, 32, 32])
torch.Size([1, 84, 16, 16, 16])
torch.Size([1, 128, 8, 8, 8])


tensor([[[[[0.0537, 0.0000, 0.0527,  ..., 0.0000, 0.6348, 0.8341],
           [0.6479, 0.3225, 0.0000,  ..., 0.0000, 0.0203, 0.8004],
           [0.5407, 0.0000, 0.3152,  ..., 0.0000, 0.0000, 0.0000],
           ...,
           [0.5270, 0.0000, 1.0793,  ..., 0.0000, 0.2185, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.2058, 0.0000, 0.0000],
           [0.0000, 0.2443, 0.1662,  ..., 0.1826, 0.7695, 0.4777]],

          [[0.2279, 0.4449, 0.3591,  ..., 0.4409, 0.0000, 0.7959],
           [0.0000, 0.6100, 0.0000,  ..., 0.1646, 0.0000, 0.0000],
           [0.0067, 1.0561, 1.1851,  ..., 0.1220, 1.5646, 0.6398],
           ...,
           [0.0000, 0.0000, 0.0000,  ..., 0.7522, 1.2056, 0.6341],
           [1.2430, 0.2725, 0.5428,  ..., 0.4328, 0.0000, 0.0000],
           [0.6325, 0.0000, 0.0000,  ..., 0.3329, 0.3315, 0.0000]],

          [[0.4755, 0.1434, 0.0000,  ..., 1.6463, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.6103,  ..., 0.8111, 0.3757, 0.0000],
           [0.5298, 0.0000