# Parameter Freezing related methods:  

This is a discussion on freezing BatchNorm:  
<https://discuss.pytorch.org/t/how-to-train-with-frozen-batchnorm/12106>

This is a good module blocks version of Pytorch graph define:  
<https://discuss.pytorch.org/t/train-nn-by-freezing-last-n-layers/9432/3>


### 1. Use parameter group.  

This is the [official given sample](https://gist.github.com/L0SG/2f6d81e4ad119c4f798ab81fa8d62d3f). 

```python
# let's unfreeze the fc2 layer this time for extra tuning
net.fc2.weight.requires_grad = True
net.fc2.bias.requires_grad = True

# add the unfrozen fc2 weight to the current optimizer
optimizer.add_param_group({'params': net.fc2.parameters()})
```

### 2. Simply use require_grade by default inputs and reset the requires_grad flag.

**Take care of the spreading mechasim**
```python
x = torch.randn(5, 5)  # requires_grad=False by default
y = torch.randn(5, 5)  # requires_grad=False by default
z = torch.randn((5, 5), requires_grad=True)
a = x + y
a.requires_grad # False
b = a + z
b.requires_grad # True
```

**A straight-forward example.**
```python
model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
    param.requires_grad = False
# Replace the last fully-connected layer
# Parameters of newly constructed modules have requires_grad=True by default
model.fc = nn.Linear(512, 100)

# Optimize only the classifier
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)
```

### 3. Use the filter on parameters().  

```python
optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=0.1)
```
Note that the above snippet assumed a common “train => save => load => freeze parts” scenario.

### 4. Use children() interface.
```python
model_ft = models.resnet50(pretrained=True)
ct = 0
for child in model_ft.children():
ct += 1
if ct < 7:
    for param in child.parameters():
        param.requires_grad = False
```

In [3]:
import torch, torchvision
current_model = torchvision.models.resnet101(True)

In [57]:
# List all components' __repr__.
for child in current_model.children():
    print(child)
    print()

Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

ReLU(inplace)

MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)

Sequential(
  (0): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (1): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn

In [90]:
def parse(current_model, depth=0):
    'This module will list all the parameters recursively.'
    current_depth = ''.join(['    ' for i in range(depth) ]) if depth!=0 else ''
    for module_name in current_model._modules:
        print(current_depth + module_name)
        current_module = current_model._modules[module_name]
        parameter_list = list(current_module._parameters.keys())
        if not parameter_list:
            parse(current_module, depth+1)
        else:
            for param_name in parameter_list:
                try:
                    print(current_depth + '|-' + param_name, 'requries_grad is', str(current_module._parameters[param_name].requires_grad)+'.')
                except:
                    print(current_depth + '|-' + param_name, 'is not exist(NoneType).')
parse(current_model)

conv1
|-weight requries_grad is True.
|-bias is not exist(NoneType).
bn1
|-weight requries_grad is True.
|-bias requries_grad is True.
relu
maxpool
layer1
    0
        conv1
        |-weight requries_grad is True.
        |-bias is not exist(NoneType).
        bn1
        |-weight requries_grad is True.
        |-bias requries_grad is True.
        relu
        conv2
        |-weight requries_grad is True.
        |-bias is not exist(NoneType).
        bn2
        |-weight requries_grad is True.
        |-bias requries_grad is True.
    1
        conv1
        |-weight requries_grad is True.
        |-bias is not exist(NoneType).
        bn1
        |-weight requries_grad is True.
        |-bias requries_grad is True.
        relu
        conv2
        |-weight requries_grad is True.
        |-bias is not exist(NoneType).
        bn2
        |-weight requries_grad is True.
        |-bias requries_grad is True.
layer2
    0
        conv1
        |-weight requries_grad is True.
        