# Pytorch 中的 décorator


## 注册 registry
注册器管理共享相似功能的不同模块， 比如说目标检测中的backbones, head, and necks，许多深度学习的工程使用注册管理数据集和模型模块， 比如说
MMDetection, detectron2,detection.pytorch,ProjectAo,MMDetection3D, MMClassification, MMEditing, 等等

### 什么是注册
注册器可以看作是完成了**string类型->类名**的一个映射。单个注册器包含的这些类通常具有相似的API，但是实现不同的算法。比如说目标检测中的主干网络。
使用注册器，用户可以通过其对应的字符串查找和实例化该类，并根据需要使用实例化的模块。


In [None]:
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
class Registry(object):
    """Registry Class to map modules.
    The registry that provides name -> object mapping, to support third-party users' custom modules.

    To create a registry (inside detectron2):
        BACKBONE_REGISTRY = Registry('BACKBONE')

    To register an object:

        @BACKBONE_REGISTRY.register()
        class MyBackbone():
            ...

        Or:

        BACKBONE_REGISTRY.register(MyBackbone)
    """

    def __init__(self, name):
        """
            name (str): the name of this registry
        """
        self._name = name

        self._obj_map = {}

    def _do_register(self, name, obj):
        upper_name = name.upper()
        assert (
            upper_name not in self._obj_map
        ), "An object named '{}' was already registered in '{}' registry!".format(upper_name, self._name)
        self._obj_map[upper_name] = obj

    def register(self, module_name=None, obj=None):
        """Register the given object under the name `obj.__name__`.
        Can be used as either a decorator or not. See docstring of this class for usage.

        Args:
            module_name (str, optional): name of module. Defaults to None.
            obj (obj, optional): the object to register. Defaults to None.
        """
        if obj is None:
            # used as a decorator
            def deco(func_or_class):
                name = module_name if module_name is not None else func_or_class.__name__
                self._do_register(name, func_or_class)
                return func_or_class

            return deco

        # used as a function call
        name = module_name if module_name is not None else obj.__name__
        self._do_register(name, obj)

    def get(self, name):
        """Get object with name.

        Args:
            name (str): registered object name.

        Returns:
            obj: The object.
        """
        ret = self._obj_map.get(name.upper())
        if ret is None:
            raise KeyError("No object named '{}' found in '{}' registry[{}]!".format(
                name.upper(), self._name, self._obj_map.keys()))
        return ret

    def __getitem__(self, name):
        """Get object with name.

        Args:
            name (str): registered object name.

        Returns:
            obj: The object.
        """
        return self.get(name)

    def __str__(self):
        """Format to string representation."""
        s = self._name + ':'
        s += str(self._obj_map)
        return s

BACKBONES = Registry('BACKBONES')


@BACKBONES.register('resnet18')
class resnet18():
    def __init__(self):
        pass
# @BACKBONES.register('resnet18') 相当于执行语句resnet18 = BACKBONES.register('resnet18')(resnet18)

@BACKBONES.register('resnet34')
class resnet34():
    def __init__(self):
        pass
# @BACKBONES.register('resnet34') 相当于执行语句resnet34 = BACKBONES.register('resnet34')(resnet34)

@BACKBONES.register('resnet50')
class resnet50():
    def __init__(self):
        pass
# @BACKBONES.register('resnet50') 相当于执行语句resnet50 = BACKBONES.register('resnet50')(resnet50)

def make_backbone(name):
    return BACKBONES.get(name)()


resnet = make_backbone('resnet18')


## Register

In [None]:
import logging

class Register:

    def __init__(self, registry_name):
        self._dict = {}
        self._name = registry_name

    def __setitem__(self, key, value):
        if not callable(value):
            raise Exception(f"Value of a Registry must be a callable!\nValue: {value}")
        if key is None:
            key = value.__name__
        if key in self._dict:
            logging.warning("Key %s already in registry %s." % (key, self._name))
        self._dict[key] = value

    def register(self, target):
        """Decorator to register a function or class."""

        def add(key, value):
            self[key] = value
            return value

        if callable(target): # functions and classes are all callable
            # @reg.register
            return add(None, target)
        # @reg.register('alias')
        return lambda x: add(target, x) # target is a string

    def __getitem__(self, key):
        return self._dict[key]

    def __contains__(self, key):
        return key in self._dict

    def keys(self):
        return self._dict.keys()


## 实际案例
在DeiT 代码中, 注册器如下

In [None]:
_module_to_models = defaultdict(set)  # dict of sets to check membership of model in module
_model_to_module = {}  # mapping of model names to module names
_model_entrypoints = {}  # mapping of model names to entrypoint fns
_model_has_pretrained = set()  # set of model names that have pretrained weight url present


def register_model(fn):
    # lookup containing module
    mod = sys.modules[fn.__module__]
    module_name_split = fn.__module__.split('.')
    module_name = module_name_split[-1] if len(module_name_split) else ''

    # add model to __all__ in module
    model_name = fn.__name__
    if hasattr(mod, '__all__'):
        mod.__all__.append(model_name)
    else:
        mod.__all__ = [model_name]

    # add entries to registry dict/sets
    _model_entrypoints[model_name] = fn
    _model_to_module[model_name] = module_name
    _module_to_models[module_name].add(model_name)
    has_pretrained = False  # check if model has a pretrained url to allow filtering on this
    if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
        # this will catch all models that have entrypoint matching cfg key, but miss any aliasing
        # entrypoints or non-matching combos
        has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url']
    if has_pretrained:
        _model_has_pretrained.add(model_name)
    return fn

In [None]:
@register_model
def deit_tiny_patch16_224(pretrained=False, **kwargs):
    model = VisionTransformer(
        patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"])
    return model

自定义新模型

In [None]:
@register_model
def deit_small_patch16_224(pretrained=False, **kwargs):
    model = VisionTransformer(
        patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"])
    return model
