## Model registry

In [2]:
import torch.nn as nn
from segwork.registry import  models

In [3]:
@models.register
class NeuralNetworkDecorated(nn.Module):

    _register_name='Net'
    
    def __init__(self, size: int = 28):
        super(NeuralNetworkDecorated, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(size*size, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits
        

In [8]:
import json

# Load config (TODO)
config_path = './config.json'
with open(config_path) as f:
    config = json.load(f)
print(config)
model_name = config['model_name']   # Net
model_args = config['model_args']

# Build model from registry
model = models.get(model_name, **model_args)

{'model_name': 'Net', 'model_args': {'size': 28}}


In [9]:
model

NeuralNetworkDecorated(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)

In [5]:
print(models)

ClassRegistry(attr_name=_register_name, unique=False)
            Number of registered classes: 1 
            Registered classes: ['Net']


In [6]:
from segwork.utils.mermaid import TorchFXParser

parser = TorchFXParser(name = 'My graph name', module=model)
parser.display_graph()

In [2]:
from segwork.registry import print_compati
print_compati()

[b'adv_inception_v3\r\n',
 b'bat_resnext26ts\r\n',
 b'botnet26t_256\r\n',
 b'botnet50ts_256\r\n',
 b'convnext_base\r\n',
 b'convnext_base_384_in22ft1k\r\n',
 b'convnext_base_in22ft1k\r\n',
 b'convnext_base_in22k\r\n',
 b'convnext_large\r\n',
 b'convnext_large_384_in22ft1k\r\n',
 b'convnext_large_in22ft1k\r\n',
 b'convnext_large_in22k\r\n',
 b'convnext_small\r\n',
 b'convnext_tiny\r\n',
 b'convnext_tiny_hnf\r\n',
 b'convnext_xlarge_384_in22ft1k\r\n',
 b'convnext_xlarge_in22ft1k\r\n',
 b'convnext_xlarge_in22k\r\n',
 b'cspdarknet53\r\n',
 b'cspdarknet53_iabn\r\n',
 b'cspresnet50\r\n',
 b'cspresnet50d\r\n',
 b'cspresnet50w\r\n',
 b'cspresnext50\r\n',
 b'cspresnext50_iabn\r\n',
 b'darknet53\r\n',
 b'densenet121\r\n',
 b'densenet121d\r\n',
 b'densenet161\r\n',
 b'densenet169\r\n',
 b'densenet201\r\n',
 b'densenet264\r\n',
 b'densenet264d_iabn\r\n',
 b'densenetblur121d\r\n',
 b'dla34\r\n',
 b'dla46_c\r\n',
 b'dla46x_c\r\n',
 b'dla60\r\n',
 b'dla60_res2net\r\n',
 b'dla60_res2next\r\n',
 b'dla6