# Bert Model - Configuration, Parameter count

In [1]:
from transformers import AutoModel, AutoConfig

## 1. Check configuration for Bert

In [2]:
# Create the Bert model object by passing the configuration
model_name = 'bert-base-uncased'

# model_name = 'bert-large-uncased'

config = AutoConfig.from_pretrained(model_name)

config

BertConfig {
  "_name_or_path": "bert-base-uncased",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.35.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

## 2. Print model info &  num of parameters

* (embeddings)   : represents the input embedding. 
  * All of these NN have weights that are learned during training
  * (word_embedding) : a vector of size 728 for each of the word in the vocabulary
  * (LayerNorm)  : output from the embedding layer has a dimension of 768
  
* (encoder) There are 12 encoder blocks or encoders marked with index of 0 to 11
    * (attention) : Each encoder block has 12 "attention heads"
    *               Within each attention heads - weights are learned fo query, key and value vectors
    *               Output from the multi-head attention blocks is concatenated
    * (intermediate): represents the FFNN. Using GELU as the activation function. Output dimension is 3072
* (output): is the last hidden state. Scaled down to dimension of 768.
  

In [3]:
# Create an instance of the model
model = AutoModel.from_pretrained(model_name)

# Print model info
print(model)

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
  

## 3. Parameter count, Iterate through the list

The PreTrainedModel class in Transformers is a subclass of pytorch.nn.Module

The pytorch.nn.Module.parameters returns an iterator over module parameters.

Use the dimensionality of the tensor to calculate the number of weights that are learned in each layer or network.

https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module

In [4]:
# Check the number of parameters
num_parameters = model.num_parameters()

# print number of parameters in millions
print("num_parameters (millions) : ", (num_parameters/(1e6)))

num_parameters (millions) :  109.48224


In [6]:
# get the list of layers
LP=list(model.parameters())

lp_count = len(LP)

# iterate through list
num_params_calc = 0
for p in range(0,lp_count):
    print(p, " : ", LP[p].shape)
    dim_1 = LP[p].shape[0]
    try:
        dim_2 = LP[p].shape[1]
    except:
        dim_2 = 1

    num_params_calc = num_params_calc + (dim_1*dim_2)

print("Exact count of parameters : ", num_params_calc)

0  :  torch.Size([30522, 768])
1  :  torch.Size([512, 768])
2  :  torch.Size([2, 768])
3  :  torch.Size([768])
4  :  torch.Size([768])
5  :  torch.Size([768, 768])
6  :  torch.Size([768])
7  :  torch.Size([768, 768])
8  :  torch.Size([768])
9  :  torch.Size([768, 768])
10  :  torch.Size([768])
11  :  torch.Size([768, 768])
12  :  torch.Size([768])
13  :  torch.Size([768])
14  :  torch.Size([768])
15  :  torch.Size([3072, 768])
16  :  torch.Size([3072])
17  :  torch.Size([768, 3072])
18  :  torch.Size([768])
19  :  torch.Size([768])
20  :  torch.Size([768])
21  :  torch.Size([768, 768])
22  :  torch.Size([768])
23  :  torch.Size([768, 768])
24  :  torch.Size([768])
25  :  torch.Size([768, 768])
26  :  torch.Size([768])
27  :  torch.Size([768, 768])
28  :  torch.Size([768])
29  :  torch.Size([768])
30  :  torch.Size([768])
31  :  torch.Size([3072, 768])
32  :  torch.Size([3072])
33  :  torch.Size([768, 3072])
34  :  torch.Size([768])
35  :  torch.Size([768])
36  :  torch.Size([768])
37  