`torch.nn` is the neural network API in pytorch, and `torch.nn.Module` is the base class for all neural network mudules in pytorch.

Here is the inheritance structure for the `torch.nn.Module` class:
![image](./noteimg/torch_nn_Module_inheritance_structure.jpeg)

# `nn.Module`  Implementation

## 1. `nn.Module.__init__`
`nn.Module.__init__` basically does:
1. call `torch._C._log_api_usage_once("python.nn_module")` to **monitor and rocord the usage of APIs**.
2. **initialize important member variables**.

Here is an detailed list of parameters initialized:
```python
self.training=True               # control whether it is training/testing
self._parameters = OrderedDict() # save parameters chaning with BP
self._buffers = OrderedDict()    # save parameters not changing with BP
self._non_persistent_buffers_set = set()
self._backward_hooks = OrderedDict() # hooks to be called after BP
self._forward_hooks = OrderedDict() # hooks to be called after forward
self._forward_prehooks = OrderedDict() # hooks to be called before forward
self._state_dict_hooks = OrderedDict() # hooks to be called after getting state_dict
self._load_state_dict_pre_hooks = OrderedDict() # hooks to be called before loading state_dict
self._modules = OrderedDict()    # sub modules
```

Note:
- when we are initializing our self-defined module, we need to call `super().__init__` first, or else the above member variables are not created, which will give us error when we call other methods/functions/attributes.

## 2. `nn.Module.train` and `nn.Module.eval`

`nn.Module` use its member variable `self.training` to determine whether it is traning session or testing session.
The method `nn.Module.train` set `self.training` to `mode` for this module and all its sub-modules.
```python
def train(self: T, mode: bool = True) -> T:
    self.training = mode
    for module in self.children():
        module.train(mode)
    return self
```

`nn.Module.eval` simply calls `self.train(False)`
```python
def eval(self: T) -> T:
        return self.train(False)
```

## 3. `nn.Module.requires_grad_`  and `nn.Module.zero_grad_`
These two methods are used to modify the status of the gradients of the parameters or simply clear the gradients.
```python
    def requires_grad_(self: T, requires_grad: bool = True) -> T:
        r"""
        This method sets the parameters' :attr:`requires_grad` attributes
        in-place.
        """
        for p in self.parameters():
            p.requires_grad_(requires_grad)
        return self
```
   
```python
def zero_grad(self, set_to_none: bool = False) -> None:
    if getattr(self, '_is_replica', False):
        warnings.warn(
            "Calling .zero_grad() from a module created with nn.DataParallel() has no effect. "
            "The parameters are copied (in a differentiable manner) from the original module. "
            "This means they are not leaf nodes in autograd and so don't accumulate gradients. "
            "If you need gradients in your forward method, consider using autograd.grad instead.")

    for p in self.parameters():
        if p.grad is not None:
            if set_to_none:
                p.grad = None
            else:
                if p.grad.grad_fn is not None:
                    p.grad.detach_()
                else:
                    p.grad.requires_grad_(False)
                p.grad.zero_()
```

## 4. Parameter Conversion / Transfer
We can convert all the **parameters** and **buffers** of a module to another data type or transfer them to another device.
`nn.Module` provides 8 such methods, which are:
1. `nn.Module.cpu`: transfer all parameters and buffers to CPU.
2. `nn.Module.type`: convert all parameters and buffers to a certain type.
3. `nn.Module.cuda`: tansfer all parameters and buffers to GPU.
4. `nn.Module.float`: convert all **floating point** paramaters and buffers to `float32`.
5. `nn.Module.double`: convert all **floating point** parameters and buffers to `double`.
6. `nn.Module.half`: convert all **floating point** parameters and buffers to `float16`.
7. `nn.Module.bfloat16`: convert all **floating point** parameters and buffers to `bfloat16`.
8. `nn.Module.to`: Moves and/or casts the parameters and buffers.

### 4.1 `nn.Module._apply`
These functions are implemented by calling `nn.Module._apply(fn)`.
```python
def _apply(self, fn):
    # 对子模块进行递归调用
    for module in self.children():
        module._apply(fn)

    # 为了 BC-breaking 而新增了一个 tensor 类型判断
    def compute_should_use_set_data(tensor, tensor_applied):
        if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
            return not torch.__future__.get_overwrite_module_params_on_conversion()
        else:
            return False

    # 处理参数及其gradint
    for key, param in self._parameters.items():
        if param is not None:
            # Tensors stored in modules are graph leaves, and we don't want to
            # track autograd history of `param_applied`, so we have to use
            # `with torch.no_grad():`
            with torch.no_grad():
                param_applied = fn(param)
            should_use_set_data = compute_should_use_set_data(param, param_applied)
            if should_use_set_data:
                param.data = param_applied
            else:
                assert isinstance(param, Parameter)
                assert param.is_leaf
                self._parameters[key] = Parameter(param_applied, param.requires_grad)
            if param.grad is not None:
                with torch.no_grad():
                    grad_applied = fn(param.grad)
                should_use_set_data = compute_should_use_set_data(param.grad, grad_applied)
                if should_use_set_data:
                    param.grad.data = grad_applied
                else:
                    assert param.grad.is_leaf
                    self._parameters[key].grad = grad_applied.requires_grad_(param.grad.requires_grad)

    # 处理 buffers
    for key, buf in self._buffers.items():
        if buf is not None:
            self._buffers[key] = fn(buf)
    return self
```

### 4.2 `nn.Module.apply`
`nn.Module.apply` wraps the private method `nn.Module._apply` for public use.

It simply calls the `nn.Module._apply` method.

The definition is as below:
```python
def apply(self: T, fn: Callable[['Module'], None]) -> T:
    for module in self.children():
        module.apply(fn)
    fn(self)
    return self
```

## 5. Methods to modify attributes
There are 3 methods to modify the member variables of a Module.
1. `nn.Module.add_module`: it add sub modules to `nn.Module._modules`;
2. `nn.Module.register_parameter`:it updates `nn.Module._parameters` and the added parameter can be updated via BP.
3. `nn.Module.register_buffer`: it updates `nn.Module._buffers`. If the buffer is not persistant, it will be added to `self._non_persistant_buffers_set`.


Usually we modify attrutes of a module using the expession `self.attr = val`, and it calls the method `self.__setattr__` that is overwritten by `nn.Module`.
The signature of this method is:
```python
def __setattr__(self, name: str, value: Union[Tensor, 'Module']):
    def remove_from(*dicts_or_sets):
        for d in dicts_or_sets:
            if name in d:
                if isinstance(d, dict):
                    del d[name]
                else:
                    d.discard(name)

    params = self.__dict__.get('_parameters')
    # if the value is a Parameter
    if isinstance(value, Parameter):
        if params is None:
            raise AttributeError(
                "cannot assign parameters before Module.__init__() call")
        remove_from(self.__dict__, self._buffers, self._modules, self._non_persistent_buffers_set)
        self.register_parameter(name, value)
        elif params is not None and name in params:
        if value is not None:
            raise TypeError("cannot assign '{}' as parameter '{}' "
                            "(torch.nn.Parameter or None expected)"
                            .format(torch.typename(value), name))
        self.register_parameter(name, value)
    else:
        modules = self.__dict__.get('_modules')
        if isinstance(value, Module):
            if modules is None:
                raise AttributeError(
                    "cannot assign module before Module.__init__() call")
            remove_from(self.__dict__, self._parameters, self._buffers, self._non_persistent_buffers_set)
            modules[name] = value
        elif modules is not None and name in modules:
            if value is not None:
                raise TypeError("cannot assign '{}' as child module '{}' "
                                "(torch.nn.Module or None expected)"
                                .format(torch.typename(value), name))
            modules[name] = value
        else:
            buffers = self.__dict__.get('_buffers')
            if buffers is not None and name in buffers:
                if value is not None and not isinstance(value, torch.Tensor):
                    raise TypeError("cannot assign '{}' as buffer '{}' "
                                    "(torch.Tensor or None expected)"
                                    .format(torch.typename(value), name))
                buffers[name] = value
            else:
                object.__setattr__(self, name, value)
```
It accepts only `torch.Tensor` or `torch.nn.Module` as input of the attribute value. And it will do the following things:
1. check if this module is initialized properly;
2. if the value is a `Parameter`, it will:
    1. check if this module is initialized properly.
    2. remove `attr` from `self.__dict__, self._buffers, self._modules, self._non_persistent_buffers_set`;
    3. register this parameter by calling `self.register_parameter`.
3. if the value if not a `Parameter`, which means it's a `Module`:
    - check if  `name` is in `params`
