## Model registry - Custom model

In [1]:
import torch.nn as nn
from segwork.registry import models_reg

In [2]:
@models_reg.register
class NeuralNetworkDecorated(nn.Module):

    _register_name='Net'

    _default_kwargs = {
        'size' : 28
    }
    
    def __init__(self, size: int = 28):
        super(NeuralNetworkDecorated, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(size*size, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits
        
@models_reg.register
class NeuralNetworkDecoratedB(nn.Module):

    _register_name='NetBig'

    _default_kwargs = {
        'size' : 112
    }

In [3]:
models_reg

ConfigurableRegistry(attr_name=_register_name, unique=False)
            Number of registered classes: 3 
            Registered classes: ['unet', 'Net', 'NetBig']
	Class key: model
	Attribute args: _default_args
	Attribute kwargs: _default_kwargs
	Additional info from attributes: ['pretrained_settings']

In [3]:
# Read config file

model_args = {} # Read
model = models_reg.get_instance('Net', **model_args)

In [4]:
import json

# Load config (TODO)
config_path = './config.json'
with open(config_path) as f:
    config = json.load(f)
print(config)
model_name = config['model_name']   # Net
model_args = config['model_args']

# Build model from registry
print(models_reg[model_name])
model = models_reg.get_instance(model_name, **model_args)

{'model_name': 'Net', 'model_args': {'size': 28}}
{'model': <class '__main__.NeuralNetworkDecorated'>, '_default_args': [], '_default_kwargs': {'size': 28}, 'pretrained_settings': None}


In [5]:
from segwork.registry import backbones_reg
backbones_reg['resnet34']

## Backbones registry - Integration with smp

In [1]:
import typing
import copy

import torch
import torch.nn as nn
import segmentation_models_pytorch as smp

from segwork.model.layer import ConvBnAct
from segwork.registry import ConfigurableRegistry, smp_compatibility

In [2]:
# Create the registry
backbones_reg = ConfigurableRegistry(
    class_key = 'encoder',                                              # Key to the nn.module class
    initial_registry = smp.encoders.encoders,            # Initial registry. Default: None
    attr_name = '_register_name',
    attr_kwargs = 'params',
    unique=True,
    additional_args = ['pretrained_settings'],
    register_hook = smp_compatibility)   # Retrocompatibility

In [6]:
backbones_reg.add_additional_args('_description')

@backbones_reg.register
class DummyBackboneDecorated(nn.Module, smp.encoders._base.EncoderMixin):
    """Dummyy encoder to test compatibility with smp architectures
    
    Testing:
     - Custom attributes in registry
      - To be used in smp framework it is regquired to inherit from EncoderMixin
    """

    _register_name='Net'

    # Default params
    params = {
        'out_channels' : (3, 64, 256, 512),
        'depth': 3
    }

    # Additional settings
    pretrained_settings = None

    _description = 'Formal description of encoder'
    
    def __init__(self, out_channels: typing.List, depth:int):
        super(DummyBackboneDecorated, self).__init__()

        # A number of channels for each encoder feature tensor, list of integers
        self._out_channels: typing.Iterable[int] = out_channels

        # A number of stages in decoder (in other words number of downsampling operations), integer
        # use in in forward pass to reduce number of returning features
        self._depth: int = depth

        # Default number of input channels in first Conv2d layer for encoder (usually 3)
        self._in_channels: int = 3

        blocks = []

        for idx in range(len(out_channels) - 1):
            blocks.append(nn.Sequential(
            ConvBnAct(out_channels[idx], out_channels[idx + 1], 3),
            ConvBnAct(out_channels[idx + 1], out_channels[idx + 1], 3, stride=2),
        ))

        self.stages = nn.Sequential(*blocks)

    def forward(self, x):
        out = [x]

        for stage in self.stages:
            x = stage(x)
            out.append(x)

        return out

@backbones_reg.register
class BackboneDecoratedB(DummyBackboneDecorated):
    """New model with other defaults. Flexibility to configure this from within the module or with conig files"""
    
    _register_name = 'NetB'

    # Default params
    params = {
        'out_channels' : (3, 256, 256, 256),
        'output_stride': 16
    }



KeyError: 'Entry with key Net already exists. Set the unique attribute to false to overried items.'

In [7]:
encoder_name = 'Net'

# Framework entrypoint
backbone_fr = backbones_reg.get_instance(encoder_name)

# SMP entrypoint compatibility
backbone = smp.encoders.get_encoder(encoder_name)

# print(backbone)
print(list(backbones_reg['Net'].keys()))
print(list(backbones_reg['resnet34'].keys()))

['encoder', '_default_args', 'params', 'pretrained_settings', '_description']
['encoder', 'pretrained_settings', 'params']


### Output of registered backbone

In [8]:
x = torch.rand(1,3,224,224)

out = (backbone(x))

print('Features size...')
for idx, f in enumerate(out):
    print(f'Stage {idx:02d}: {f.size()}')

Features size...
Stage 00: torch.Size([1, 3, 224, 224])
Stage 01: torch.Size([1, 64, 112, 112])
Stage 02: torch.Size([1, 256, 56, 56])
Stage 03: torch.Size([1, 512, 28, 28])


### Using custom bakcbone

In [9]:
model_args = {
    'encoder_name' : 'Net',
    'encoder_depth' : 3,
    'encoder_weights' : None,
    'decoder_channels' : (512, 256, 64),
    'in_channels' : 3,
    'classes' : 20
}

model = smp.Unet(**model_args)           

In [10]:
out = model(x)
out.size()

torch.Size([1, 20, 224, 224])

In [16]:
model_fr = models_reg.get_instance('unet', **model_args)

In [17]:
out_fr = model_fr(x)
out_fr.size()

torch.Size([1, 20, 224, 224])

## Custom registry 
Create entrypoints for modular objects

In [18]:
from segwork.registry import ConfigurableRegistry
datasets = ConfigurableRegistry(
    class_key = 'dataset',
    attr_name = '_dataset_params'
)

In [19]:
datasets

ConfigurableRegistry(attr_name=_dataset_params, unique=False)
            Number of registered classes: 0 
            Registered classes: []