## Model registry

In [1]:
import torch.nn as nn
from segwork.registry import  backbones

In [2]:
@backbones.register
class NeuralNetworkDecorated(nn.Module):

    _register_name='Net'
    
    def __init__(self, size: int = 28):
        super(NeuralNetworkDecorated, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(size*size, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits
        

In [3]:
import json

# Load config (TODO)
config_path = './notebooks/config.json'
with open(config_path) as f:
    config = json.load(f)
print(config)
model_name = config['model_name']   # Net
model_args = config['model_args']

# Build model from registry
model = backbones.get_instance(model_name, **model_args)

{'model_name': 'Net', 'model_args': {'size': 28}}


In [4]:
model

NeuralNetworkDecorated(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)

In [6]:
print(backbones)

ConfigurableRegistry(attr_name=_register_name, unique=False)
            Number of registered classes: 114 
            Registered classes: ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x4d', 'resnext101_32x8d', 'resnext101_32x16d', 'resnext101_32x32d', 'resnext101_32x48d', 'dpn68', 'dpn68b', 'dpn92', 'dpn98', 'dpn107', 'dpn131', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn', 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152', 'se_resnext50_32x4d', 'se_resnext101_32x4d', 'densenet121', 'densenet169', 'densenet201', 'densenet161', 'inceptionresnetv2', 'inceptionv4', 'efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3', 'efficientnet-b4', 'efficientnet-b5', 'efficientnet-b6', 'efficientnet-b7', 'mobilenet_v2', 'xception', 'timm-efficientnet-b0', 'timm-efficientnet-b1', 'timm-efficientnet-b2', 'timm-efficientnet-b3', 'timm-efficientnet-b4', 'timm-efficientnet-b5', 'timm-effi

In [6]:
from segwork.utils.mermaid import TorchFXParser

parser = TorchFXParser(name = 'My graph name', module=model)
parser.display_graph()