In [1]:
from openprompt.data_utils import InputExample
classes = [ # There are two classes in Sentiment Analysis, one for negative and one for positive
    "negative",
    "positive"
]
dataset = [ # For simplicity, there's only two examples
    # text_a is the input text of the data, some other datasets may have multiple input sentences in one example.
    InputExample(
        guid = 0,
        text_a = "Albert Einstein was one of the greatest intellects of his time.",
    ),
    InputExample(
        guid = 1,
        text_a = "The film was badly made.",
    ),
]

In [2]:
from openprompt.plms import load_plm
plm, tokenizer, model_config, WrapperClass = load_plm("bert", "bert-base-cased")

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
from openprompt.prompts import ManualTemplate, SoftVerbalizer
promptTemplate = ManualTemplate(
    text = '{"placeholder":"text_a"} It was {"mask"}',
    tokenizer = tokenizer,
)

In [4]:
promptVerbalizer = SoftVerbalizer(tokenizer, plm, num_classes=2)

In [5]:
freeze_verb_plm = True
if freeze_verb_plm:
    # now set the grouped_parameters_1 require grad to False
    for param in promptVerbalizer.group_parameters_1:
        param.requires_grad = False

In [6]:
from openprompt import PromptForClassification

# model with no freezing of the plm
promptModel = PromptForClassification(
    template = promptTemplate,
    plm = plm,
    verbalizer = promptVerbalizer,
    freeze_plm = True
)

In [7]:
# check number of params that require_grad

def get_n_trainable_params(model):    

    
    # all trainable
    num_total_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    # split into the plm and classisifcation head
    num_plm_trainable = sum(p.numel() for p in model.plm.parameters() if p.requires_grad)
    
    # template trainable
    try:
        num_template_trainable = sum(p.numel() for p in model.template.soft_embedding.parameters() if p.requires_grad)
    except:
        num_template_trainable = 0
    
    # verbalizer trainable 
    num_verbalizer_trainable = sum(p.numel() for p in model.verbalizer.parameters() if p.requires_grad)
    
    # assert sum of the two = total
    assert num_plm_trainable+num_template_trainable+num_verbalizer_trainable == num_total_trainable
    
    print(f"Number of trainable parameters of PLM: {num_plm_trainable}\n")
    print('#'*50)
    print(f"Number of trainable parameters of template: {num_template_trainable}\n")
    print('#'*50)
    print(f"Number of trainable parameters of verbalizer: {num_verbalizer_trainable}\n")
    print('#'*50)
    print(f"Total number of trainable parameters of whole model: {num_total_trainable}")

    print(f"Verbalizer grouped_parameters_1 require_grad: {model.verbalizer.group_parameters_1[0].requires_grad}")

In [8]:

get_n_trainable_params(promptModel)

Number of trainable parameters of PLM: 0

##################################################
Number of trainable parameters of template: 0

##################################################
Number of trainable parameters of verbalizer: 1536

##################################################
Total number of trainable parameters of whole model: 1536
Verbalizer grouped_parameters_1 require_grad: False


In [9]:
# now set the grouped_parameters_1 require grad to False
for param in promptModel.verbalizer.group_parameters_1:
    param.requires_grad = False

In [10]:

get_n_trainable_params(promptModel)

Number of trainable parameters of PLM: 0

##################################################
Number of trainable parameters of template: 0

##################################################
Number of trainable parameters of verbalizer: 1536

##################################################
Total number of trainable parameters of whole model: 1536
Verbalizer grouped_parameters_1 require_grad: False


In [11]:
# at this point the plm has been frozen and the PromptForClassification has been instantiated and the PLM itself is now frozen
# however the verbalizer had been instantiated prior to the freezing of the PLM. If you now instantiate a second model with that same frozen plm. The verbalizer also has 
# properly frozen PLM parameters as described in ?

promptVerbalizer_2 = SoftVerbalizer(tokenizer, plm, num_classes=2)
promptModel_2 = PromptForClassification(
    template = promptTemplate,
    plm = plm,
    verbalizer = promptVerbalizer_2,
    freeze_plm = True
)

get_n_trainable_params(promptModel_2)

Number of trainable parameters of PLM: 0

##################################################
Number of trainable parameters of template: 0

##################################################
Number of trainable parameters of verbalizer: 1536

##################################################
Total number of trainable parameters of whole model: 1536
Verbalizer grouped_parameters_1 require_grad: False


In [12]:
promptModel_2.verbalizer.group_parameters_2[0].requires_grad

True

In [13]:
promptModel_2.verbalizer.group_parameters_2

[Parameter containing:
 tensor([[ 0.0088,  0.0201,  0.0091,  ...,  0.0042, -0.0165,  0.0024],
         [ 0.0237, -0.0224, -0.0142,  ...,  0.0080, -0.0297, -0.0175]],
        requires_grad=True)]

In [14]:
promptModel_2.verbalizer

SoftVerbalizer(
  (head): BertOnlyMLMHead(
    (predictions): BertLMPredictionHead(
      (transform): BertPredictionHeadTransform(
        (dense): Linear(in_features=768, out_features=768, bias=True)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      )
      (decoder): Linear(in_features=768, out_features=2, bias=False)
    )
  )
)

In [16]:
plm.cls

BertOnlyMLMHead(
  (predictions): BertLMPredictionHead(
    (transform): BertPredictionHeadTransform(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    )
    (decoder): Linear(in_features=768, out_features=28996, bias=True)
  )
)