# torch.no_grad() vs. param.requires_grad
- torch.no_grad()
    - 定义了一个上下文管理器，隐式地不进行梯度更新，不会改变requires_grad
    - 适用于eval阶段，或用于model forward的过程中某些不更新梯度的模块（此时这些模块仅进行特征提取[前向计算]，不反向更新）
- param.requires_grad
    -  显式地frozen一些module(layer)的参数更新
    - layer/module级别
    - 可能会更灵活
- 个人理解：torch.no_grad() >> param.requires_grad=False

In [1]:
from transformers import BertModel
import torch
from torch import nn

In [2]:
model_name = 'bert-base-uncased'
model = BertModel.from_pretrained(model_name)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
def calc_learnable_params(model):
    total_params = 0
    for name, param in model.named_parameters():
        if param.requires_grad:
            total_params += param.numel()
    return total_params

In [4]:
calc_learnable_params(model)

109482240

In [5]:
with torch.no_grad():
    print(calc_learnable_params(model))

109482240


In [7]:
for name, param in model.named_parameters():
    if param.requires_grad:
        param.requires_grad = False
calc_learnable_params(model)

0