Skip to content

MONAI_Network_Design_Discussion

Ben Murray edited this page Feb 14, 2020 · 1 revision

Introduction

Networks are a place in the design where a number of design decisions converge. They represent a choke point in the design where any issues around the design of layers and layer factories will manifest in issues when trying to find reusable elements to networks.

Network functionality represents a major design opportunity for MONAI. Pytorch is very much unopinionated in how networks are defined. It provides Module as a base class from which to create a network, and a few methods that must be implemented, but there is no prescribed pattern nor much helper functionality for initialising networks. This leaves a lot of room for defining some useful 'best practice' patterns for constructing new networks in MONAI. Although trivial, inflexible network implementations are easy enough, we can give users a toolset that makes it much easier to build well-engineered, flexible networks, and demonstrate their value by committing to use them in the networks that we build

Areas of functionality

  • layers (covered here)
    • layer factories (covered here)
  • network definition
    • compatibility with torchscript
    • reference and configurable versions
    • configurability
      • modules
        • blocks
        • structure
      • width
      • depth
      • rules-based configuration
    • network structure
      • dimension-agnosticism
      • recursion vs iteration
    • initialisation / model restore
  • network support
    • hyper-parameter search techniques
    • higher-order training techniques
      • adaptive loss schemes
      • dynamic curriculum learning techniques

Roadmap

  • MVP
    • provision of layers
    • provision of layer factories (if agreed upon)
    • provision of reference network implementations
    • provision of configurable network implementations sufficient for MVP
      • configuration of modular blocks
      • configuration of structure (layer counts etc.)
    • provision of network utilities
  • post-MVP
    • higher-order training techniques
    • refinement of configurable network implementations
    • model size calculation
  • longer term
    • higher-order training techniques
    • hyper-parameter search techniques

Topics

Layers and layer factories

This topic is covered in detail here and here.

Network definition

Compatibility with torchscript

Compatibility with torchscript is a key capabilities. All mechanisms that we crease for constructing networks must be torchscript compatible, and this imposes restrictions on how such mechanisms are implemented. This can impact models using features such as:

  • tied weights
  • densely-connected network blocks (Eric please elaborate)

We should do a full survey of torchscript-related issues that must be avoided in our network functionality

Reference Networks

Every network type that comes from a paper should have a plain, 'unconfigurable' implementation the purpose of which is to allow people to replicate results, and be clear from an understanding standpoint. Such networks can also be used as regression test sources for more configurable network implementations.

Configurability

Outside of reference networks, which are implemented in a way that serves their particular purpose, our network implementations should have strong configurability as a primary goal. We should set standards of configurability that all such network implementation should meet, each of which is covered in a subsection.

Restrictions of pytorch

Layers need to be attributes of a module for pytorch to be able to recognise them. This affects all aspects of configurability, as number of layers and number of downsamples / upsamples both require extra hoops to be jumped through beyond merely adding layers to a list. Pytorch provides (at least) the following mechanisms:

  • Module.__setattr__
  • Module.add_module
  • self.layers = nn.ModuleList()
  • nn.Sequential

Other patterns also exist, such as modular recursion (TODO: Eric already has example like this in the codebase; reference them here).

class RecursiveModule(nn.Module):

  def __init__(self, entries):
    self.cur = entries[0]
    self.next = RecursiveModule(entries[1:]) if len(entries) > 1 else None

  def forward(self, t_input):
    if self.next:
      return self.next(self.cur(t_input))
    return self.cur(t_input)

Modules

Many networks have a natural modularity to them in terms of what is considered a 'unit' of computation. The variants of ResNet block are a good example of this. Such blocks should be replaceable modular elements in an overall network structure, as it is often the case that innovations on a base network architecture tend to vary in the nature of the blocks. Allowing modules to be a configurable element of a network design gives developers the means to experiment readily with innovations on existing Networks.

Structure

Structural configurability can be separated into two subtypes:

  • Configuration of layer counts
  • Recursive network architectures

Structural configurability is the area most impacted by pytorch's need to have layers be attributes of nn.Module instances. Layer counts are affected by this and any of the solutions mentioned above are potential candidates to solve this problem.

Recursive network architectures are slightly more complex. Where layer counts represent a 'horizontal' configurability (aka network depth), down sample counts represent an example of 'vertical' configurability.

Example: unet

A Unet can be thought of as a recursive structure where each level of recursion is a given resolution.

Each resolution (achieved by downsampling) is a series of three concentric layers:

  • A convolutional block (CB) that contains the convolutional Modules
  • A convolutional layer (CL) that wraps the convolutional block with downsampling and upsampling modules
  • A skipped layer (SL) that has two paths, one that goes through a skip module (nn.Identity or some other module that does work on the skip connection) and the convolutional layer that does the work, along with a concatenation

Eric's UNet implementation that we have used for a baseline configurable UNet does layering along these lines, through the calling of recursive functions.

Ben has been experimenting with ways of achieving a similar design but through an iterative approach. This is not part of any PR at this point, but looks like this:

class SkippedLayer(nn.Module):

   def __init__(self, inner):
       super(SkippedLayer, self).__init__()
       self.skip = nn.Identity()
       self.inner = inner

   def forward(self, t_input):
       return self.skip(t_input) + self.inner(t_input)


class ConvolutionalLayer(nn.Module):

   def __init__(self, inner):
       super(ConvolutionalLayer, self).__init__()
       self.model([DownSample, inner, UpSample])

   def forward(self, t_input):
       return self.model(t_input)


class ConvolutionalBlock(nn.Module):

   def __init__(self, block_fn, enc_count, inner, dec_count):
       super(ConvolutionalBlock, self).__init__()

       if inner is not None:
           self.model = nn.Sequential([block_fn() * enc_count] + [inner] + [block_fn() * dec_count])
       else:
           self.model = nn.Sequential([block_fn() * enc_count] + [block_fn() * dec_count])

   def forward(self, t_input):
       return self.model(t_input)

   @staticmethod
   def factory(*args, **kwargs):
       return ConvolutionalBlock(*args, **kwargs)


class UnetFramework(nn.Module):

   def __init__(self, initial_fn, final_fn):
       super(UnetFramework, self).__init__()

       encoder_counts = [1, 2, 2, 4]
       decoder_counts = [1, 1, 1, 2]
       layers = []
       for i in len(encoder_counts):
           layers.append(SkippedLayer)
           layers.append(ConvolutionalLayer)
           layers.append(partial(ConvolutionalBlock.factory, ResNetBlock(), enc_count=encoder_counts[i], dec_count=decoder_counts[i]))

       # TODO: refactor out into a function that wires up the blocks recursively from the specified array
       inner = None
       for layer in reversed(layers):
           inner = layer(inner)
       self.model = inner

   def forward(self, t_input):
       self.model(t_input)

Model Initialisation

TODO

Model restore

TODO

Network support

Network support refers to everything around network design that supports sophisticalted use of networks. This includes:

  • hyperparameter exploration
  • adaptive loss techniques, especially those involving network outputs
  • adaptive sample selection techniques such as curriculum learning, especially those involving network outputs
Clone this wiki locally