In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
from unet import UNet
import torch


device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0
# Unet has skip connections removed
model = UNet(in_channels=3,
             out_channels=3,
             n_blocks=3,
             start_filters=32,
             activation='relu',
             normalization='batch',
             conv_mode='same',
             dim=2)
model = model.to(device)

input = torch.randn(size=(8, 3, 224, 224), dtype=torch.float32)
input = input.to(device)

with torch.no_grad():
  out = model(input, encode_only=False)
  features = model(input, encode_only=True)

print(f'Input: {input.shape}')
print(f'Out: {out.shape}')
print(f'Features: {features.shape}')

Input: torch.Size([8, 3, 224, 224])
Out: torch.Size([8, 3, 224, 224])
Features: torch.Size([8, 128, 56, 56])


In [2]:
from torchsummary import summary
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0
model = model.to(device)

summary = summary(model, (3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 224, 224]             896
              ReLU-2         [-1, 32, 224, 224]               0
       BatchNorm2d-3         [-1, 32, 224, 224]              64
            Conv2d-4         [-1, 32, 224, 224]           9,248
              ReLU-5         [-1, 32, 224, 224]               0
       BatchNorm2d-6         [-1, 32, 224, 224]              64
         MaxPool2d-7         [-1, 32, 112, 112]               0
         DownBlock-8  [[-1, 32, 112, 112], [-1, 32, 224, 224]]               0
            Conv2d-9         [-1, 64, 112, 112]          18,496
             ReLU-10         [-1, 64, 112, 112]               0
      BatchNorm2d-11         [-1, 64, 112, 112]             128
           Conv2d-12         [-1, 64, 112, 112]          36,928
             ReLU-13         [-1, 64, 112, 112]               0
      BatchNorm2d-14    

In [2]:
shape = 224


def compute_max_depth(shape, max_depth=10, print_out=True):
    shapes = []
    shapes.append(shape)
    for level in range(1, max_depth):
        if shape % 2 ** level == 0 and shape / 2 ** level > 1:
            shapes.append(shape / 2 ** level)
            if print_out:
                print(f'Level {level}: {shape / 2 ** level}')
        else:
            if print_out:
                print(f'Max-level: {level - 1}')
            break

    return shapes


out = compute_max_depth(shape, print_out=True, max_depth=10)

Level 1: 112.0
Level 2: 56.0
Level 3: 28.0
Level 4: 14.0
Level 5: 7.0
Max-level: 5
