In [1]:
import torch
from torch import nn
import sys
sys.path.append("../")

from model.ae import AE
from model.unet import UNet2D
from model.wrapper import Frankenstein


In [2]:
# test inputs
x_in = torch.zeros((10,1,256,256))
feature_in = torch.zeros((10,32,64,64))
print(f"Inputs: \ntest input shape: {x_in.shape} \ntest feature shape: {feature_in.shape}")

# - init and test U-Net
unet = UNet2D(1, 4, n_filters_init=8)
tmp = unet(x_in)
print(f"Outputs: \nU-Net output shape: {tmp.shape}")

# - init and test AE
ae = AE(32, 64)
tmp = ae(feature_in)
print(f"AE output shape: {tmp.shape}")

# - init and test wrapper
# declare potential attachment points
layer_ids = ['shortcut0', 'shortcut1', 'shortcut2', 'up3']

# configure ae for specific layer(s)
#                         channel, spatial, latent,  depth, block 
ae_config   = {'up3': [        64,      32,    128,     2,      4]}


# set up module dict to pass to "transformations"
AEs = nn.ModuleDict({'up3': AE(in_channels = ae_config['up3'][0], 
                                in_dim      = ae_config['up3'][1],
                                latent_dim  = ae_config['up3'][2],
                                depth       = ae_config['up3'][3],
                                block_size  = ae_config['up3'][4])})

# for disabled ids, we attach identitdy functions. Since we populate the
# batch dimension, we need to make sure that features from different
# resolutions have matching batch dimensions even in cases where we do
# not alter them.
disabled_ids = [layer_id for layer_id in layer_ids if layer_id != 'up3']
for layer_id in disabled_ids:
    AEs[layer_id] = nn.Identity()

# instantiate wrapper class
model = Frankenstein(unet, 
                     AEs, 
                     disabled_ids=disabled_ids,
                     copy=True)

# test forward pass without any hooks
tmp = model(x_in)
print(f"wrapper output shape w/o hooks:           {tmp.shape}")

# test forward with training hooks
model.remove_all_hooks()       
model.hook_train_transformations(model.transformations)
tmp = model(x_in)
print(f"wrapper output shape w   training hooks:  {tmp.shape}")

# test forward with inference hooks. Sample argument can be used
# for transformations with stochastic elements (e.g. VAEs)
model.remove_all_hooks()       
model.hook_transformations(model.transformations,
                           n_samples=1)
tmp = model(x_in)
print(f"wrapper output shape w   inference hooks: {tmp.shape}")

Inputs: 
test input shape: torch.Size([10, 1, 256, 256]) 
test feature shape: torch.Size([10, 32, 64, 64])
Outputs: 
U-Net output shape: torch.Size([10, 4, 256, 256])
AE output shape: torch.Size([10, 32, 64, 64])
wrapper output shape w/o hooks:           torch.Size([10, 4, 256, 256])
wrapper output shape w   training hooks:  torch.Size([10, 4, 256, 256])
wrapper output shape w   inference hooks: torch.Size([20, 4, 256, 256])


In [3]:
ae

AE(
  (encoder): Sequential(
    (0): ConvBlock(
      (sample): Sequential(
        (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): LayerNorm((64, 32, 32), eps=1e-05, elementwise_affine=True)
        (2): LeakyReLU(negative_slope=0.01)
      )
    )
    (1): ConvBlock(
      (sample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): LayerNorm((128, 16, 16), eps=1e-05, elementwise_affine=True)
        (2): LeakyReLU(negative_slope=0.01)
      )
    )
    (2): ConvBlock(
      (sample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): LayerNorm((256, 8, 8), eps=1e-05, elementwise_affine=True)
        (2): LeakyReLU(negative_slope=0.01)
      )
    )
  )
  (intermediate_conv): ConvBlock(
    (sample): Sequential(
      (0): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1))
      (1): LayerNorm((32, 8, 8), eps=1e-05, elementwise_affine=True)