## 1.Preknowledge：[SAVE AND LOAD THE MODE](https://pytorch.org/tutorials/beginner/basics/saveloadrun_tutorial.html)
In this section we will look at how to persist model state with saving, loading and running model predictions.

In [1]:
import torch
import torchvision.models as models

### Saving and Loading Model Weights
PyTorch models store the learned parameters in an internal state dictionary, called **state_dict**. These can be persisted via the **torch.save** method:

In [None]:
model = models.vgg16(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth')

To load model weights, you need to create an instance of the same model first, and then load the parameters using **load_state_dict()** method.

In [None]:
model = models.vgg16() # we do not specify pretrained=True, i.e. do not load default weights
model.load_state_dict(torch.load('model_weights.pth'))
model.eval() # 模型进入评估模式，尤其对dropout和BN有用

- note：Be sure to call **model.eval()** method before inferencing to set the dropout and batch normalization layers to evaluation mode. Failing to do this will yield inconsistent inference results.

### Saving and Loading Models with Shapes
When loading model weights, we needed to instantiate the model class first, because the class defines the structure of a network. We might want to save the structure of this class together with the model, in which case we can pass **model** (and not **model.state_dict()**) to the saving function:

In [None]:
torch.save(model, 'model.pth')

We can then load the model like this:

In [None]:
model = torch.load('model.pth')

- note: This approach uses Python [pickle](https://docs.python.org/3/library/pickle.html) module when serializing the model, thus it relies on the actual class definition to be available when loading the model.

### [Saving and Loading a General Checkpoint](https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html)

#### Save the general checkpoint
Collect all relevant information and build your dictionary.

In [None]:
# Additional information
EPOCH = 5
PATH = "model.pt" # 也可用pth，pkl等后缀保存路径
LOSS = 0.4

torch.save({
            'epoch': EPOCH,
            'model_state_dict': net.state_dict(),   # net是之前定义网络类的实例化
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': LOSS,
            }, PATH)

#### Load the general checkpoint
Remember to first initialize the model and optimizer, then load the dictionary locally.

In [None]:
model = Net()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# - or -
model.train()

You must call **model.eval()** to set dropout and batch normalization layers to evaluation mode before running inference. Failing to do this will yield inconsistent inference results.

If you wish to resuming training, call **model.train()** to ensure these layers are in training mode.

## 2.Introduce nn.Module source code(Part II)
### Part of the class [nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module) source code is cited below

#### to()函数

In [None]:
    def to(self, *args, **kwargs):
        r"""Moves and/or casts the parameters and buffers.

        This can be called as

        .. function:: to(device=None, dtype=None, non_blocking=False)  # to()函数签名有多种
           :noindex:

        .. function:: to(dtype, non_blocking=False)
           :noindex:

        .. function:: to(tensor, non_blocking=False)
           :noindex:

        .. function:: to(memory_format=torch.channels_last)
           :noindex:
           
        ...... # 后面内容省略   
        
        .. note::
            This method modifies the module in-place.

        Examples::  # 例子

            >>> linear = nn.Linear(2, 2)
            >>> linear.weight
            Parameter containing:
            tensor([[ 0.1913, -0.3420],
                    [-0.5113, -0.2325]])
            >>> linear.to(torch.double)
            Linear(in_features=2, out_features=2, bias=True)
            >>> linear.weight
            Parameter containing:
            tensor([[ 0.1913, -0.3420],
                    [-0.5113, -0.2325]], dtype=torch.float64)
            >>> gpu1 = torch.device("cuda:1")
            >>> linear.to(gpu1, dtype=torch.half, non_blocking=True)
            Linear(in_features=2, out_features=2, bias=True)
            >>> linear.weight
            Parameter containing:
            tensor([[ 0.1914, -0.3420],
                    [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')
            >>> cpu = torch.device("cpu")
            >>> linear.to(cpu)
            Linear(in_features=2, out_features=2, bias=True)
            >>> linear.weight
            Parameter containing:
            tensor([[ 0.1914, -0.3420],
                    [-0.5112, -0.2324]], dtype=torch.float16)

            >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)
            >>> linear.weight
            Parameter containing:
            tensor([[ 0.3741+0.j,  0.2382+0.j],
                    [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128)
            >>> linear(torch.ones(3, 2, dtype=torch.cdouble))
            tensor([[0.6122+0.j, 0.1150+0.j],
                    [0.6122+0.j, 0.1150+0.j],
                    [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)

        """
        ...... # 后面内容省略

to()函数示例

In [2]:
class Test(torch.nn.Module): # 自定义一个module
    def __init__(self):
        super(Test, self).__init__() # 在子类的init函数定义时一般要调用父类的init函数
        self.linear1 = torch.nn.Linear(2, 3)
        self.linear2 = torch.nn.Linear(3, 4)
        self.batch_norm = torch.nn.BatchNorm2d(4)

test_module = Test()

In [3]:
test_module._modules #_modules属性返回一个有序字典

OrderedDict([('linear1', Linear(in_features=2, out_features=3, bias=True)),
             ('linear2', Linear(in_features=3, out_features=4, bias=True)),
             ('batch_norm',
              BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))])

In [4]:
test_module._modules['linear1'] # 得到linear1的module

Linear(in_features=2, out_features=3, bias=True)

In [5]:
test_module._modules['linear1'].weight

Parameter containing:
tensor([[ 0.4898,  0.2476],
        [ 0.4431, -0.5865],
        [ 0.4365,  0.5435]], requires_grad=True)

In [6]:
test_module._modules['linear1'].weight.dtype

torch.float32

In [7]:
test_module.to(torch.double)  # 将其中所有数据类型改为double（float64）

Test(
  (linear1): Linear(in_features=2, out_features=3, bias=True)
  (linear2): Linear(in_features=3, out_features=4, bias=True)
  (batch_norm): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [8]:
test_module._modules['linear1'].weight.dtype

torch.float64

In [9]:
test_module.to(torch.float32) # 改回默认数据类型

Test(
  (linear1): Linear(in_features=2, out_features=3, bias=True)
  (linear2): Linear(in_features=3, out_features=4, bias=True)
  (batch_norm): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

#### \_\_getattr__()函数

In [None]:
    def __getattr__(self, name: str) -> Union[Tensor, 'Module']:  # python的魔法函数之一，当访问不存在的属性时会抛出异常
        if '_parameters' in self.__dict__:  # 有以下3个属性
            _parameters = self.__dict__['_parameters']  # self为当前module自身，其返回的是当前model内定义的parameters
            if name in _parameters:                     # 而其子module中的parameters则无法返回
                return _parameters[name]  
        if '_buffers' in self.__dict__:
            _buffers = self.__dict__['_buffers']
            if name in _buffers:
                return _buffers[name]
        if '_modules' in self.__dict__:
            modules = self.__dict__['_modules']
            if name in modules:
                return modules[name]
        raise AttributeError("'{}' object has no attribute '{}'".format(
            type(self).__name__, name))


In [10]:
# 解释上面为什么test_module可以调用_modules（因为getattr函数有这样一个属性（nn.Module类也有这样一个成员变量））
test_module._modules

OrderedDict([('linear1', Linear(in_features=2, out_features=3, bias=True)),
             ('linear2', Linear(in_features=3, out_features=4, bias=True)),
             ('batch_norm',
              BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))])

In [11]:
test_module._parameters  # 返回空字典？因为调用_parameters并没有对test_module的所有子模块（linear1，2）进行遍历
                         # 其只对当前对象自身进行搜索，Test类中未定义任何nn.Parameter对象，但其子模块（linear层）中有定义
                         # 虽然输出空字典，不能理解为没有参数

OrderedDict()

In [12]:
test_module._buffers # 原因同上

OrderedDict()

#### _save_to_state_dict()函数

In [None]:
    def _save_to_state_dict(self, destination, prefix, keep_vars): # 当前module的parameters和buffers存放在一个destination字典中
        r"""Saves module state to `destination` dictionary, containing a state
        of the module, but not its descendants. This is called on every
        submodule in :meth:`~torch.nn.Module.state_dict`.

        In rare cases, subclasses can achieve class-specific behavior by
        overriding this method with custom logic.

        Args:
            destination (dict): a dict where state will be stored
            prefix (str): the prefix for parameters and buffers used in this
                module
        """
        for name, param in self._parameters.items(): # 对当前module(不包括其中子module)的parameters遍历
            if param is not None:
                destination[prefix + name] = param if keep_vars else param.detach() # 放到字典中
        for name, buf in self._buffers.items():     # 对当前module(不包括其中子module)的buffers遍历
            if buf is not None and name not in self._non_persistent_buffers_set:
                destination[prefix + name] = buf if keep_vars else buf.detach()
        extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
        if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state:
            destination[extra_state_key] = self.get_extra_state()

    # The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns
    # back that same object. But if they pass nothing, an `OrederedDict` is created and returned.
    
# 其在state_dict()中被调用

#### state_dict()函数

In [None]:
    def state_dict(self, destination=None, prefix='', keep_vars=False):
        r"""Returns a dictionary containing a whole state of the module.

        Both parameters and persistent buffers (e.g. running averages) are
        included. Keys are corresponding parameter and buffer names.
        Parameters and buffers set to ``None`` are not included.

        Returns:
            dict:
                a dictionary containing a whole state of the module

        Example::

            >>> module.state_dict().keys()
            ['bias', 'weight']

        """
        if destination is None:
            destination = OrderedDict()
            destination._metadata = OrderedDict()
        destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
        self._save_to_state_dict(destination, prefix, keep_vars) # 当前module的参数和buffer放入字典中
        for name, module in self._modules.items(): # 对当前module的子module进行遍历
            if module is not None: # 如果有子module
                module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars) # 子module调用state_dict,参数和buffer放入字典
        for hook in self._state_dict_hooks.values():
            hook_result = hook(self, destination, prefix, local_metadata)
            if hook_result is not None:
                destination = hook_result
        return destination  # 返回destination字典


state_dict()函数示例

In [13]:
test_module.state_dict() # 输出的最后三项为buffers，其余为parameters

OrderedDict([('linear1.weight',
              tensor([[ 0.4898,  0.2476],
                      [ 0.4431, -0.5865],
                      [ 0.4365,  0.5435]])),
             ('linear1.bias', tensor([ 0.0715, -0.1448, -0.1111])),
             ('linear2.weight',
              tensor([[-0.1690,  0.2995,  0.2924],
                      [-0.2389,  0.4674, -0.3988],
                      [-0.5677,  0.5025, -0.4594],
                      [-0.1568, -0.1094, -0.3007]])),
             ('linear2.bias', tensor([-0.2783, -0.2626,  0.3051, -0.0977])),
             ('batch_norm.weight', tensor([1., 1., 1., 1.])),
             ('batch_norm.bias', tensor([0., 0., 0., 0.])),
             ('batch_norm.running_mean', tensor([0., 0., 0., 0.])),
             ('batch_norm.running_var', tensor([1., 1., 1., 1.])),
             ('batch_norm.num_batches_tracked', tensor(0))])

In [14]:
test_module.state_dict()['linear1.weight']

tensor([[ 0.4898,  0.2476],
        [ 0.4431, -0.5865],
        [ 0.4365,  0.5435]])

#### _load_from_state_dict()函数

In [None]:
    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, # 从state_dict得到参数和buffer的值，
                              missing_keys, unexpected_keys, error_msgs):       # 然后赋给当前module的变量
        r"""Copies parameters and buffers from :attr:`state_dict` into only
        this module, but not its descendants. This is called on every submodule
        in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
        module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
        For state dicts without metadata, :attr:`local_metadata` is empty.
        Subclasses can achieve class-specific backward compatible loading using
        the version number at `local_metadata.get("version", None)`.

        .. note::
            :attr:`state_dict` is not the same object as the input
            :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
            it can be modified.

        Args:
            state_dict (dict): a dict containing parameters and
                persistent buffers.
            prefix (str): the prefix for parameters and buffers used in this
                module
            local_metadata (dict): a dict containing the metadata for this module.
                See
            strict (bool): whether to strictly enforce that the keys in
                :attr:`state_dict` with :attr:`prefix` match the names of
                parameters and buffers in this module
            missing_keys (list of str): if ``strict=True``, add missing keys to
                this list
            unexpected_keys (list of str): if ``strict=True``, add unexpected
                keys to this list
            error_msgs (list of str): error messages should be added to this
                list, and will be reported together in
                :meth:`~torch.nn.Module.load_state_dict`
        """
        for hook in self._load_state_dict_pre_hooks.values():
            hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)

        persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
        local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
        local_state = {k: v for k, v in local_name_params if v is not None} # 当前module中所有参数和buffer的键放入local_state中

        for name, param in local_state.items(): # 对local_state遍历
            key = prefix + name
            if key in state_dict: # 如果local_state的键在state_dict中
                input_param = state_dict[key] 
                if not torch.overrides.is_tensor_like(input_param):
                    error_msgs.append('While copying the parameter named "{}", '
                                      'expected torch.Tensor or Tensor-like object from checkpoint but '
                                      'received {}'
                                      .format(key, type(input_param)))
                    continue

                # This is used to avoid copying uninitialized parameters into
                # non-lazy modules, since they dont have the hook to do the checks
                # in such case, it will error when accessing the .shape attribute.
                is_param_lazy = torch.nn.parameter.is_lazy(param)
                # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
                if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
                    input_param = input_param[0]

                if not is_param_lazy and input_param.shape != param.shape:
                    # local shape should match the one in checkpoint
                    error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
                                      'the shape in current model is {}.'
                                      .format(key, input_param.shape, param.shape))
                    continue
                try:
                    with torch.no_grad():
                        param.copy_(input_param)  # 赋值的操作，input_param是从state_dict获得的
                except Exception as ex:
                    error_msgs.append('While copying the parameter named "{}", '
                                      'whose dimensions in the model are {} and '
                                      'whose dimensions in the checkpoint are {}, '
                                      'an exception occurred : {}.'
                                      .format(key, param.size(), input_param.size(), ex.args))
            elif strict:
                missing_keys.append(key)

        extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
        if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
            if extra_state_key in state_dict:
                self.set_extra_state(state_dict[extra_state_key])
            elif strict:
                missing_keys.append(extra_state_key)
        elif strict and (extra_state_key in state_dict):
            unexpected_keys.append(extra_state_key)

        if strict:
            for key in state_dict.keys():
                if key.startswith(prefix) and key != extra_state_key:
                    input_name = key[len(prefix):]
                    input_name = input_name.split('.', 1)[0]  # get the name of param/buffer/child
                    if input_name not in self._modules and input_name not in local_state:
                        unexpected_keys.append(key)


#### load_state_dict()函数

In [None]:
    def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]',    
                        strict: bool = True):
        r"""Copies parameters and buffers from :attr:`state_dict` into
        this module and its descendants. If :attr:`strict` is ``True``, then
        the keys of :attr:`state_dict` must exactly match the keys returned
        by this module's :meth:`~torch.nn.Module.state_dict` function.

        Args:
            state_dict (dict): a dict containing parameters and
                persistent buffers.
            strict (bool, optional): whether to strictly enforce that the keys
                in :attr:`state_dict` match the keys returned by this module's
                :meth:`~torch.nn.Module.state_dict` function. Default: ``True``

        Returns:
            ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
                * **missing_keys** is a list of str containing the missing keys
                * **unexpected_keys** is a list of str containing the unexpected keys

        Note:
            If a parameter or buffer is registered as ``None`` and its corresponding key
            exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a
            ``RuntimeError``.
        """
        missing_keys: List[str] = []
        unexpected_keys: List[str] = []
        error_msgs: List[str] = []

        # copy state_dict so _load_from_state_dict can modify it
        metadata = getattr(state_dict, '_metadata', None)
        state_dict = state_dict.copy()
        if metadata is not None:
            # mypy isn't aware that "_metadata" exists in state_dict
            state_dict._metadata = metadata  # type: ignore[attr-defined]

        def load(module, prefix=''):   # 调用了load()函数，函数中调用了_load_from_state_dict()函数
            local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
            module._load_from_state_dict(
                state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
            for name, child in module._modules.items():
                if child is not None:
                    load(child, prefix + name + '.')

        load(self)
        del load

        if strict:
            if len(unexpected_keys) > 0:
                error_msgs.insert(
                    0, 'Unexpected key(s) in state_dict: {}. '.format(
                        ', '.join('"{}"'.format(k) for k in unexpected_keys)))
            if len(missing_keys) > 0:
                error_msgs.insert(
                    0, 'Missing key(s) in state_dict: {}. '.format(
                        ', '.join('"{}"'.format(k) for k in missing_keys)))

        if len(error_msgs) > 0:
            raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
                               self.__class__.__name__, "\n\t".join(error_msgs)))
        return _IncompatibleKeys(missing_keys, unexpected_keys)

#### parameters()函数（ buffers()函数与之相似，仅以parameters()函数为例 ）

In [None]:
    def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
        r"""Returns an iterator over module parameters.

        This is typically passed to an optimizer.

        Args:
            recurse (bool): if True, then yields parameters of this module
                and all submodules. Otherwise, yields only parameters that
                are direct members of this module.

        Yields:
            Parameter: module parameter

        Example::

            >>> for param in model.parameters():
            >>>     print(type(param), param.size())
            <class 'torch.Tensor'> (20L,)
            <class 'torch.Tensor'> (20L, 1L, 5L, 5L)

        """
        for name, param in self.named_parameters(recurse=recurse): # 调用named_parameters()函数对参数遍历
            yield param    # 返回迭代器
            
# 要和_parameters区分开，_parameters为一个attribute(属性)，也可理解为一个成员变量；而parameters()为一函数

parameters()函数示例

In [15]:
for p in test_module.parameters():  # 因为其返回迭代器，我们对它进行遍历
    print(p)

Parameter containing:
tensor([[ 0.4898,  0.2476],
        [ 0.4431, -0.5865],
        [ 0.4365,  0.5435]], requires_grad=True)
Parameter containing:
tensor([ 0.0715, -0.1448, -0.1111], requires_grad=True)
Parameter containing:
tensor([[-0.1690,  0.2995,  0.2924],
        [-0.2389,  0.4674, -0.3988],
        [-0.5677,  0.5025, -0.4594],
        [-0.1568, -0.1094, -0.3007]], requires_grad=True)
Parameter containing:
tensor([-0.2783, -0.2626,  0.3051, -0.0977], requires_grad=True)
Parameter containing:
tensor([1., 1., 1., 1.], requires_grad=True)
Parameter containing:
tensor([0., 0., 0., 0.], requires_grad=True)


#### named_parameters()函数

In [None]:
    # 在parameters()函数中被调用，用于对module和子module中的parameters进行遍历
    def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:
        r"""Returns an iterator over module parameters, yielding both the
        name of the parameter as well as the parameter itself.

        Args:
            prefix (str): prefix to prepend to all parameter names.
            recurse (bool): if True, then yields parameters of this module
                and all submodules. Otherwise, yields only parameters that
                are direct members of this module.

        Yields:
            (string, Parameter): Tuple containing the name and parameter

        Example::

            >>> for name, param in self.named_parameters():
            >>>    if name in ['bias']:
            >>>        print(param.size())

        """
        gen = self._named_members(      # 调用了_name_members()函数
            lambda module: module._parameters.items(), # lambda函数返回传入的module自身的参数（包含名称和参数值信息）
            prefix=prefix, recurse=recurse)
        for elem in gen:
            yield elem     # 返回迭代器（名称和参数值）

#### _named_members()函数

In [None]:
    def _named_members(self, get_members_fn, prefix='', recurse=True): # 在被上面的函数调用时传入的get_members_fn就是lambda函数
        r"""Helper method for yielding various names + members of modules."""
        memo = set()
        modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)] # 对named_modules调用，返回所有子module
        for module_prefix, module in modules:
            members = get_members_fn(module)
            for k, v in members:
                if v is None or v in memo:
                    continue
                memo.add(v)
                name = module_prefix + ('.' if module_prefix else '') + k
                yield name, v

named_parameters()函数示例

In [16]:
for p in test_module.named_parameters():
    print(p) # 返回参数键和值组成的元组

('linear1.weight', Parameter containing:
tensor([[ 0.4898,  0.2476],
        [ 0.4431, -0.5865],
        [ 0.4365,  0.5435]], requires_grad=True))
('linear1.bias', Parameter containing:
tensor([ 0.0715, -0.1448, -0.1111], requires_grad=True))
('linear2.weight', Parameter containing:
tensor([[-0.1690,  0.2995,  0.2924],
        [-0.2389,  0.4674, -0.3988],
        [-0.5677,  0.5025, -0.4594],
        [-0.1568, -0.1094, -0.3007]], requires_grad=True))
('linear2.bias', Parameter containing:
tensor([-0.2783, -0.2626,  0.3051, -0.0977], requires_grad=True))
('batch_norm.weight', Parameter containing:
tensor([1., 1., 1., 1.], requires_grad=True))
('batch_norm.bias', Parameter containing:
tensor([0., 0., 0., 0.], requires_grad=True))


#### children()函数 和 named_children()函数

In [None]:
    def children(self) -> Iterator['Module']:
        r"""Returns an iterator over immediate children modules.

        Yields:
            Module: a child module
        """
        for name, module in self.named_children(): # 调用named_children
            yield module # 仅返回子module本身的迭代器

    def named_children(self) -> Iterator[Tuple[str, 'Module']]: # 返回迭代器
        r"""Returns an iterator over immediate children modules, yielding both
        the name of the module as well as the module itself.

        Yields:
            (string, Module): Tuple containing a name and child module

        Example::

            >>> for name, module in model.named_children():
            >>>     if name in ['conv4', 'conv5']:
            >>>         print(module)

        """
        memo = set() # 创建备忘录
        for name, module in self._modules.items(): # 遍历所有子module
            if module is not None and module not in memo: # 如果有子module且子module不在备忘录中
                memo.add(module) # 将子module加入备忘录
                yield name, module # 名字和子module的迭代器

children()函数 和 named_children()函数示例

In [17]:
for p in test_module.named_children():
    print(p) # 返回子module键和值组成的元组

('linear1', Linear(in_features=2, out_features=3, bias=True))
('linear2', Linear(in_features=3, out_features=4, bias=True))
('batch_norm', BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))


In [18]:
test_module._modules # 调用_modules的attribute返回的是一个有序字典，而.named_children()函数返回的是元组的迭代器

OrderedDict([('linear1', Linear(in_features=2, out_features=3, bias=True)),
             ('linear2', Linear(in_features=3, out_features=4, bias=True)),
             ('batch_norm',
              BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))])

In [19]:
for p in test_module.children():
    print(p) # 返回子module

Linear(in_features=2, out_features=3, bias=True)
Linear(in_features=3, out_features=4, bias=True)
BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)


#### modules()函数 和 named_modules()函数

In [None]:
    def modules(self) -> Iterator['Module']:
        r"""Returns an iterator over all modules in the network.

        Yields:
            Module: a module in the network

        Note:
            Duplicate modules are returned only once. In the following
            example, ``l`` will be returned only once.

        Example::

            >>> l = nn.Linear(2, 2)
            >>> net = nn.Sequential(l, l)
            >>> for idx, m in enumerate(net.modules()):
                    print(idx, '->', m)

            0 -> Sequential(
              (0): Linear(in_features=2, out_features=2, bias=True)
              (1): Linear(in_features=2, out_features=2, bias=True)
            )
            1 -> Linear(in_features=2, out_features=2, bias=True)

        """
        for _, module in self.named_modules(): # 调用named_modules()函数
            yield module  # 仅返回module的迭代器 

    def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = '', remove_duplicate: bool = True):
        r"""Returns an iterator over all modules in the network, yielding
        both the name of the module as well as the module itself.

        Args:
            memo: a memo to store the set of modules already added to the result
            prefix: a prefix that will be added to the name of the module
            remove_duplicate: whether to remove the duplicated module instances in the result
                or not

        Yields:
            (string, Module): Tuple of name and module

        Note:
            Duplicate modules are returned only once. In the following
            example, ``l`` will be returned only once.

        Example::

            >>> l = nn.Linear(2, 2)
            >>> net = nn.Sequential(l, l)
            >>> for idx, m in enumerate(net.named_modules()):
                    print(idx, '->', m)

            0 -> ('', Sequential(
              (0): Linear(in_features=2, out_features=2, bias=True)
              (1): Linear(in_features=2, out_features=2, bias=True)
            ))
            1 -> ('0', Linear(in_features=2, out_features=2, bias=True))

        """

        if memo is None:
            memo = set() # 创建备忘录 
        if self not in memo: # 自身不在备忘录中
            if remove_duplicate:
                memo.add(self)
            yield prefix, self # 返回perfix和自身module的迭代器
            for name, module in self._modules.items(): # 返回自身module后再对子module遍历
                if module is None:
                    continue
                submodule_prefix = prefix + ('.' if prefix else '') + name
                for m in module.named_modules(memo, submodule_prefix, remove_duplicate):
                    yield m # 返回子module的迭代器（名称和子module的tuple）


modules()函数 和 named_modules()函数示例

In [20]:
for p in test_module.named_modules():
    print(p) # 返回4个module，自身和其余3个子module

('', Test(
  (linear1): Linear(in_features=2, out_features=3, bias=True)
  (linear2): Linear(in_features=3, out_features=4, bias=True)
  (batch_norm): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
))
('linear1', Linear(in_features=2, out_features=3, bias=True))
('linear2', Linear(in_features=3, out_features=4, bias=True))
('batch_norm', BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))


In [21]:
test_module._modules # _modules的attribute只返回子module组成的有序字典

OrderedDict([('linear1', Linear(in_features=2, out_features=3, bias=True)),
             ('linear2', Linear(in_features=3, out_features=4, bias=True)),
             ('batch_norm',
              BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))])

In [22]:
for p in test_module.modules():
    print(p) # 返回4个module，自身和其余3个子module, 但不返回自身module的prefix和子module的名称
    print('\n') 

Test(
  (linear1): Linear(in_features=2, out_features=3, bias=True)
  (linear2): Linear(in_features=3, out_features=4, bias=True)
  (batch_norm): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)


Linear(in_features=2, out_features=3, bias=True)


Linear(in_features=3, out_features=4, bias=True)


BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)


