In [2]:
import torch
from torch import nn
import sys
sys.path.append("../")

from model.ae import AE
from model.unet import UNet2D
from model.wrapper import Frankenstein


In [2]:
# test inputs
x_in = torch.zeros((10,1,256,256))
feature_in = torch.zeros((10,32,64,64))
print(f"Inputs: \ntest input shape: {x_in.shape} \ntest feature shape: {feature_in.shape}")

# - init and test U-Net
unet = UNet2D(1, 4, n_filters_init=8)
tmp = unet(x_in)
print(f"Outputs: \nU-Net output shape: {tmp.shape}")

# - init and test AE
ae = AE(32, 64)
tmp = ae(feature_in)
print(f"AE output shape: {tmp.shape}")

# - init and test wrapper
# declare potential attachment points
layer_ids = ['shortcut0', 'shortcut1', 'shortcut2', 'up3']

# configure ae for specific layer(s)
#                         channel, spatial, latent,  depth, block 
ae_config   = {'up3': [        64,      32,    128,     2,      4]}


# set up module dict to pass to "transformations"
AEs = nn.ModuleDict({'up3': AE(in_channels = ae_config['up3'][0], 
                                in_dim      = ae_config['up3'][1],
                                latent_dim  = ae_config['up3'][2],
                                depth       = ae_config['up3'][3],
                                block_size  = ae_config['up3'][4])})

# for disabled ids, we attach identitdy functions. Since we populate the
# batch dimension, we need to make sure that features from different
# resolutions have matching batch dimensions even in cases where we do
# not alter them.
disabled_ids = [layer_id for layer_id in layer_ids if layer_id != 'up3']
for layer_id in disabled_ids:
    AEs[layer_id] = nn.Identity()

# instantiate wrapper class
model = Frankenstein(unet, 
                     AEs, 
                     disabled_ids=disabled_ids,
                     copy=True)

# test forward pass without any hooks
tmp = model(x_in)
print(f"wrapper output shape w/o hooks:           {tmp.shape}")

# test forward with training hooks
model.remove_all_hooks()       
model.hook_train_transformations(model.transformations)
tmp = model(x_in)
print(f"wrapper output shape w   training hooks:  {tmp.shape}")

# test forward with inference hooks. Sample argument can be used
# for transformations with stochastic elements (e.g. VAEs)
model.remove_all_hooks()       
model.hook_inference_transformations(model.transformations,
                           n_samples=1)
tmp = model(x_in)
print(f"wrapper output shape w   inference hooks: {tmp.shape}")

Inputs: 
test input shape: torch.Size([10, 1, 256, 256]) 
test feature shape: torch.Size([10, 32, 64, 64])
Outputs: 
U-Net output shape: torch.Size([10, 4, 256, 256])
AE output shape: torch.Size([10, 32, 64, 64])
wrapper output shape w/o hooks:           torch.Size([10, 4, 256, 256])
wrapper output shape w   training hooks:  torch.Size([10, 4, 256, 256])
wrapper output shape w   inference hooks: torch.Size([20, 4, 256, 256])


In [3]:
ae

AE(
  (encoder): Sequential(
    (0): ConvBlock(
      (sample): Sequential(
        (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): LayerNorm((64, 32, 32), eps=1e-05, elementwise_affine=True)
        (2): LeakyReLU(negative_slope=0.01)
      )
    )
    (1): ConvBlock(
      (sample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): LayerNorm((128, 16, 16), eps=1e-05, elementwise_affine=True)
        (2): LeakyReLU(negative_slope=0.01)
      )
    )
    (2): ConvBlock(
      (sample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): LayerNorm((256, 8, 8), eps=1e-05, elementwise_affine=True)
        (2): LeakyReLU(negative_slope=0.01)
      )
    )
  )
  (intermediate_conv): ConvBlock(
    (sample): Sequential(
      (0): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1))
      (1): LayerNorm((32, 8, 8), eps=1e-05, elementwise_affine=True)

In [39]:
from monai.networks.nets import ResNet, UNet, SwinUNETR, SegResNet
from monai.networks.blocks import ResidualUnit

In [52]:
class ResDAE(nn.Module):

    def __init__(self, 
        in_channels, 
        depth, 
        residual: str = True
    ):
        super(ResDAE, self).__init__()
        self.on = True
        self.residual = residual

        self.model = nn.ModuleList(
            ResidualUnit(
                spatial_dims=2,
                in_channels=in_channels,
                out_channels=in_channels,
                act="PReLU",
                norm="BATCH",
                adn_ordering="AN"
            ) for _ in range(depth)
        )


    def turn_off(self) -> None:
        self.on = False
    

    def turn_on(self) -> None:
        self.on = True


    def forward(self, x):
        if self.on:
            for layer in self.model:
                x_out = layer(x)
            if self.residual:
                return x + x_out
            else:
                return x_out
        else:
            return x



m = ResDAE(64, 20)
print(sum(p.numel() for p in m.parameters()))
x_in = torch.zeros((10,64,32,32))
tmp = m(x_in)

print(f"ResNet output shape: {tmp.shape}")
m    

1482280
ResNet output shape: torch.Size([10, 64, 32, 32])


ResDAE(
  (model): ModuleList(
    (0): ResidualUnit(
      (conv): Sequential(
        (unit0): Convolution(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (adn): ADN(
            (A): PReLU(num_parameters=1)
            (N): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (unit1): Convolution(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (adn): ADN(
            (A): PReLU(num_parameters=1)
            (N): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
      )
      (residual): Identity()
    )
    (1): ResidualUnit(
      (conv): Sequential(
        (unit0): Convolution(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (adn): ADN(
            (A): PReLU(num_parameters=1)
            (N): BatchNorm2d(64, eps=1e-05, momentum=0.1, a

In [43]:
m = ResidualUnit(
    spatial_dims=2,
    in_channels=64,
    out_channels=64,
    adn_ordering="AN",
    act=("prelu", {"init": 0.2}),
    norm="BATCH"    # norm=("insta", {"normalized_shape": (10, 10)}),
)
print(sum(p.numel() for p in m.parameters()))
x_in = torch.zeros((10,64,32,32))
tmp = m(x_in)

print(f"ResNet output shape: {tmp.shape}")
m

74114
ResNet output shape: torch.Size([10, 64, 32, 32])


ResidualUnit(
  (conv): Sequential(
    (unit0): Convolution(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (adn): ADN(
        (A): PReLU(num_parameters=1)
        (N): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (unit1): Convolution(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (adn): ADN(
        (A): PReLU(num_parameters=1)
        (N): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (residual): Identity()
)

In [35]:
m = ResNet(
    block='basic',
    layers=(2,2,2,2),
    block_inplanes=(64, 64, 64, 64),
    spatial_dims=2,
    n_input_channels=64
)
print(sum(p.numel() for p in m.parameters()))
x_in = torch.zeros((10,64,32,32))
tmp = m(x_in)

print(f"ResNet output shape: {tmp.shape}")
m


831568
ResNet output shape: torch.Size([10, 400])


ResNet(
  (conv1): Conv2d(64, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): ResNetBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): ResNetBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)

In [38]:
m = SegResNet(
    spatial_dims=2,
    in_channels=64,
    out_channels=64,
    init_filters=64,
    # channels=(64, 64, 64, 64),
    # strides=(2, 2, 2),
)
print(sum(p.numel() for p in m.parameters()))
x_in = torch.zeros((10,64,32,32))
tmp = m(x_in)

print(f"ResNet output shape: {tmp.shape}")
m

25220288
ResNet output shape: torch.Size([10, 64, 32, 32])


SegResNet(
  (act_mod): ReLU(inplace=True)
  (convInit): Convolution(
    (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  )
  (down_layers): ModuleList(
    (0): Sequential(
      (0): Identity()
      (1): ResBlock(
        (norm1): GroupNorm(8, 64, eps=1e-05, affine=True)
        (norm2): GroupNorm(8, 64, eps=1e-05, affine=True)
        (act): ReLU(inplace=True)
        (conv1): Convolution(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (conv2): Convolution(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
    )
    (1): Sequential(
      (0): Convolution(
        (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      )
      (1): ResBlock(
        (norm1): GroupNorm(8, 128, eps=1e-05, affine=True)
        (norm2): GroupNorm(8, 128, eps=1e-05, affine=True)
        (act): ReLU(

In [23]:
net = SwinUNETR(
    img_size=(256, 256),
    in_channels=1,
    out_channels=1,
    spatial_dims=2,
    use_v2=True,
)
print(sum(p.numel() for p in net.parameters()))
2**5

8946043


32

In [20]:
net = UNet(
    spatial_dims=2,
    in_channels=1,
    out_channels=1,
    channels=(64, 128, 256, 512),
    strides=[2] * 3,
    num_res_units=4
)
print(sum(p.numel() for p in net.parameters()))

# x = torch.ones((1,1,256,256))
# out = net(x)

12672352


In [6]:
net = UNet2D(
    n_chans_in=1, 
    n_chans_out=1, 
    n_filters_init=8
)

print(sum(p.numel() for p in net.parameters()))

# x = torch.ones((1,1,256,256))
# out = net(x)

# print(out.min(), out.max())

610539


In [30]:
[2] * 4

[2, 2, 2, 2]