In [1]:
import torch
from batchflow.models.torch import TorchModel
from batchflow.models.torch import ResNet, EncoderDecoder
from batchflow.models.torch.layers import ConvBlock

In [2]:
downsample_depth = 4
in_channels = 3
config = {
    'inputs/images/shape': (in_channels, 32, 32), # can be commented
    'initial_block/inputs': 'images', # can be commented    
    'device': 'gpu:0',

    'body/encoder': {
        'num_stages': downsample_depth,
        'order': ['skip', 'block', 'downsampling']
    },    
    'body/encoder/blocks': {
        'layout': 'cna cna',
        'filters': [in_channels*pow(2, i) for i in range(1, downsample_depth+1)]
    },
    'body/encoder/downsample': {
        'layout': 'p'
    },    
    
    'body/embedding': {
        'layout': 'cna cna', 
        'filters': in_channels*pow(2, downsample_depth)
    },   

    'body/decoder': {
        'num_stages': downsample_depth,
        'order': ['upsampling', 'combine', 'block'] # check        
    },
    'body/decoder/upsample': {
        'layout': 't',
        'filters': [in_channels*pow(2, i-1) for i in range(downsample_depth, -1, -1)]
    },
    'body/decoder/combine': {
        'op': 'concat',
        'leading_index': 1
    },
    'body/decoder/blocks': {
        'layout': 'cna cna',
        'filters': [in_channels*pow(2, i-1) for i in range(downsample_depth, -1, -1)]
    },
    'head':{
        'layout': 'ca',
        'filters': in_channels,
        'activation': 'Softmax'
    },
    
    'loss': 'mse',
    'optimizer': 'Adam'
}

model = EncoderDecoder(config)

  return self.activation(x, *self.args, **self.kwargs)


In [3]:
model.short_repr()

Sequential(
  (body): Sequential(
    (encoder): EncoderModule(
      (skip-0): Identity()
      (block-0): ConvBlock
      DefaultBlock
        layout=cnacna
          Layer 0,  letter "c": (None, 3, 32, 32) -> (None, 6, 32, 32)
          Layer 1,  letter "n": (None, 6, 32, 32) -> (None, 6, 32, 32)
          Layer 2,  letter "a": (None, 6, 32, 32) -> (None, 6, 32, 32)
          Layer 3,  letter "c": (None, 6, 32, 32) -> (None, 6, 32, 32)
          Layer 4,  letter "n": (None, 6, 32, 32) -> (None, 6, 32, 32)
          Layer 5,  letter "a": (None, 6, 32, 32) -> (None, 6, 32, 32)
          
      (downsample-0): ConvBlock
      layout=p
        Layer 0,  letter "p": (None, 6, 32, 32) -> (None, 6, 16, 16)
        
      (skip-1): Identity()
      (block-1): ConvBlock
      DefaultBlock
        layout=cnacna
          Layer 0,  letter "c": (None, 6, 16, 16) -> (None, 12, 16, 16)
          Layer 1,  letter "n": (None, 12, 16, 16) -> (None, 12, 16, 16)
          Layer 2,  letter "a": (None, 

In [4]:
model.model

Sequential(
  (body): Sequential(
    (encoder): EncoderModule(
      (skip-0): Identity()
      (block-0): ConvBlock(
        (0): DefaultBlock(
          (0): BaseConvBlock(
            layout=cnacna
            
            (Layer 0,  letter "c": (None, 3, 32, 32) -> (None, 6, 32, 32)): Conv(
              (layer): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1), bias=False)
            )
            (Layer 1,  letter "n": (None, 6, 32, 32) -> (None, 6, 32, 32)): BatchNorm(
              (layer): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
            (Layer 2,  letter "a": (None, 6, 32, 32) -> (None, 6, 32, 32)): Activation(
              (activation): ReLU(inplace=True)
            )
            (Layer 3,  letter "c": (None, 6, 32, 32) -> (None, 6, 32, 32)): Conv(
              (layer): Conv2d(6, 6, kernel_size=(3, 3), stride=(1, 1), bias=False)
            )
            (Layer 4,  letter "n": (None, 6, 32, 32) -> (None, 6, 32, 32)):