In [5]:
import sys
sys.path.append("../../")
from superlayer.models import UNet, SLN_UNet
from torchsummary import summary
import torch
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
"""
Setup vanilla UNet
"""

enc_nf = [128, 128, 128, 128]
dec_nf = [128, 128, 128, 128]

model1 = UNet(input_ch=1, out_ch=6, enc_nf=enc_nf, dec_nf=dec_nf).to(device)
summary(model1, input_size=(1, 160, 192))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1        [-1, 128, 160, 192]           1,280
    InstanceNorm2d-2        [-1, 128, 160, 192]               0
              ReLU-3        [-1, 128, 160, 192]               0
      simple_block-4        [-1, 128, 160, 192]               0
         MaxPool2d-5          [-1, 128, 80, 96]               0
            Conv2d-6          [-1, 128, 80, 96]         147,584
    InstanceNorm2d-7          [-1, 128, 80, 96]               0
              ReLU-8          [-1, 128, 80, 96]               0
      simple_block-9          [-1, 128, 80, 96]               0
        MaxPool2d-10          [-1, 128, 40, 48]               0
           Conv2d-11          [-1, 128, 40, 48]         147,584
   InstanceNorm2d-12          [-1, 128, 40, 48]               0
             ReLU-13          [-1, 128, 40, 48]               0
     simple_block-14          [-1, 128,

In [11]:
model2 = SLN_UNet(input_ch=1, out_ch=6, superblock_size=256, depth=4, W=None, b=None, train_block=True).to(device)
summary(model2, input_size=(1, 160, 192))
print(model2.W.shape)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         MaxPool2d-1            [-1, 1, 80, 96]               0
            Conv2d-2          [-1, 128, 80, 96]           1,280
    InstanceNorm2d-3          [-1, 128, 80, 96]               0
              ReLU-4          [-1, 128, 80, 96]               0
      simple_block-5          [-1, 128, 80, 96]               0
         MaxPool2d-6          [-1, 128, 40, 48]               0
    InstanceNorm2d-7          [-1, 128, 40, 48]               0
              ReLU-8          [-1, 128, 40, 48]               0
      simple_block-9          [-1, 128, 40, 48]               0
        MaxPool2d-10          [-1, 128, 20, 24]               0
   InstanceNorm2d-11          [-1, 128, 20, 24]               0
             ReLU-12          [-1, 128, 20, 24]               0
     simple_block-13          [-1, 128, 20, 24]               0
        MaxPool2d-14          [-1, 128,

In [4]:
model = AEnet(input_ch=1, out_ch=15, use_bn=True, enc_nf=enc_nf, dec_nf=dec_nf).to(device)
summary(model, input_size=(1, 160, 192))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 160, 192]             640
    InstanceNorm2d-2         [-1, 64, 160, 192]               0
              ReLU-3         [-1, 64, 160, 192]               0
      simple_block-4         [-1, 64, 160, 192]               0
         MaxPool2d-5           [-1, 64, 80, 96]               0
            Conv2d-6           [-1, 64, 80, 96]          36,928
    InstanceNorm2d-7           [-1, 64, 80, 96]               0
              ReLU-8           [-1, 64, 80, 96]               0
      simple_block-9           [-1, 64, 80, 96]               0
        MaxPool2d-10           [-1, 64, 40, 48]               0
           Conv2d-11           [-1, 64, 40, 48]          36,928
   InstanceNorm2d-12           [-1, 64, 40, 48]               0
             ReLU-13           [-1, 64, 40, 48]               0
     simple_block-14           [-1, 64,

In [5]:
model = SL_AEnet(input_ch=1, out_ch=15, use_bn=True, superblock_size=64, depth=4).to(device)
summary(model, input_size=(1, 160, 192))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 160, 192]             640
    InstanceNorm2d-2         [-1, 64, 160, 192]               0
              ReLU-3         [-1, 64, 160, 192]               0
      simple_block-4         [-1, 64, 160, 192]               0
         MaxPool2d-5           [-1, 64, 80, 96]               0
            Conv2d-6           [-1, 64, 80, 96]          36,928
    InstanceNorm2d-7           [-1, 64, 80, 96]               0
              ReLU-8           [-1, 64, 80, 96]               0
      simple_block-9           [-1, 64, 80, 96]               0
        MaxPool2d-10           [-1, 64, 40, 48]               0
           Conv2d-11           [-1, 64, 40, 48]          36,928
   InstanceNorm2d-12           [-1, 64, 40, 48]               0
             ReLU-13           [-1, 64, 40, 48]               0
     simple_block-14           [-1, 64,

In [6]:
model = SuperNet(input_ch=1, out_ch=15, use_bn=True, superblock_size=64, depth=4).to(device)
summary(model, input_size=(1, 160, 192))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 160, 192]             640
    InstanceNorm2d-2         [-1, 64, 160, 192]               0
              ReLU-3         [-1, 64, 160, 192]               0
      simple_block-4         [-1, 64, 160, 192]               0
         MaxPool2d-5           [-1, 64, 80, 96]               0
    InstanceNorm2d-6           [-1, 64, 80, 96]               0
              ReLU-7           [-1, 64, 80, 96]               0
      simple_block-8           [-1, 64, 80, 96]               0
         MaxPool2d-9           [-1, 64, 40, 48]               0
   InstanceNorm2d-10           [-1, 64, 40, 48]               0
             ReLU-11           [-1, 64, 40, 48]               0
     simple_block-12           [-1, 64, 40, 48]               0
        MaxPool2d-13           [-1, 64, 20, 24]               0
   InstanceNorm2d-14           [-1, 64,

In [9]:
# Prepare the vm1 or vm2 model and send to device
nf_enc = [64, 64, 64, 64]
nf_dec = [64, 64, 64, 64, 64, 64, 64]

atlas_file = '/home/vib9/src/voxelmorph/data/atlas_norm.npz'

atlas_vol = np.load(atlas_file)['vol'][np.newaxis, ..., np.newaxis][:,:,:,100,:]
vol_size = atlas_vol.shape[1:-1]

model1 = cvpr2018_net(vol_size, nf_enc, nf_dec).unet_model
summary(model1, input_size=(2, 160, 192))

torch.Size([2, 64, 3, 3])
torch.Size([2])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 80, 96]           2,112
         LeakyReLU-2           [-1, 64, 80, 96]               0
        conv_block-3           [-1, 64, 80, 96]               0
            Conv2d-4           [-1, 64, 40, 48]          65,600
         LeakyReLU-5           [-1, 64, 40, 48]               0
        conv_block-6           [-1, 64, 40, 48]               0
            Conv2d-7           [-1, 64, 20, 24]          65,600
         LeakyReLU-8           [-1, 64, 20, 24]               0
        conv_block-9           [-1, 64, 20, 24]               0
           Conv2d-10           [-1, 64, 10, 12]          65,600
        LeakyReLU-11           [-1, 64, 10, 12]               0
       conv_block-12           [-1, 64, 10, 12]               0
           Conv2d-13           [-1, 64, 10, 12]          36,9

In [10]:
nf_enc = [64, 64, 64, 64]
nf_dec = [64, 64, 64, 64, 64, 64, 32]


atlas_file = '/home/vib9/src/voxelmorph/data/atlas_norm.npz'

atlas_vol = np.load(atlas_file)['vol'][np.newaxis, ..., np.newaxis][:,:,:,100,:]
vol_size = atlas_vol.shape[1:-1]

model4 = cvpr2018_net(vol_size, nf_enc, nf_dec, superblock_size=64).unet_model
summary(model4, input_size=(2, 160, 192))

torch.Size([2, 32, 3, 3])
torch.Size([2])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         MaxPool2d-1            [-1, 2, 80, 96]               0
            Conv2d-2           [-1, 32, 80, 96]             608
         LeakyReLU-3           [-1, 32, 80, 96]               0
        conv_block-4           [-1, 32, 80, 96]               0
         MaxPool2d-5           [-1, 32, 40, 48]               0
            Conv2d-6           [-1, 32, 40, 48]           9,248
         LeakyReLU-7           [-1, 32, 40, 48]               0
        conv_block-8           [-1, 32, 40, 48]               0
         MaxPool2d-9           [-1, 32, 20, 24]               0
           Conv2d-10           [-1, 32, 20, 24]           9,248
        LeakyReLU-11           [-1, 32, 20, 24]               0
       conv_block-12           [-1, 32, 20, 24]               0
        MaxPool2d-13           [-1, 32, 10, 12]              