### Loading the model

In [2]:
import sys
sys.path.insert(0, '../')

from transformers import BertConfig
from model.model import MidiBert
import pickle
import torch


print("Loading Dictionary")
with open('/Users/tobiaslandgraf/code/HugoA45/music_project/music_project/model/CP.pkl', 'rb') as f:
    e2w, w2e = pickle.load(f)

# Define the configuration for the BERT model
configuration = BertConfig(
    max_position_embeddings=512,
    position_embedding_type='relative_key_query',
    hidden_size=768,
    num_attention_heads=12,
    num_hidden_layers=12
)

# Initialize the model
midibert = MidiBert(bertConfig=configuration, e2w=e2w, w2e=w2e)

# Define the path to your checkpoint here
ckpt_path = '/Users/tobiaslandgraf/code/HugoA45/music_project/music_project/model/pretrain_model.ckpt'

# Load the checkpoint
checkpoint = torch.load(ckpt_path, map_location='cpu')

# Remove the unexpected key from the state dictionary
embeddings_position_ids =  checkpoint['state_dict']["bert.embeddings.position_ids"]
if "bert.embeddings.position_ids" in checkpoint['state_dict']:
    del checkpoint['state_dict']["bert.embeddings.position_ids"]

# Load the state dictionary from the checkpoint into the model
midibert.load_state_dict(checkpoint['state_dict'])


Loading Dictionary


<All keys matched successfully>

In [3]:
import numpy as np

In [4]:
data = np.load('/Users/tobiaslandgraf/code/HugoA45/music_project/music_project/data/pianist8/composer_cp_test.npy')

In [5]:
data.shape

(126, 512, 4)

In [6]:
import pandas as pd

data[:,:,0]

array([[0, 0, 1, ..., 1, 1, 1],
       [1, 1, 1, ..., 1, 0, 1],
       [0, 1, 1, ..., 1, 1, 1],
       ...,
       [1, 1, 1, ..., 1, 1, 1],
       [1, 1, 1, ..., 1, 1, 0],
       [1, 1, 1, ..., 2, 2, 2]])

In [7]:
midibert

MidiBert(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
              (distance_embedding): Embedding(1023, 64)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (Layer

In [8]:
data.shape

(126, 512, 4)

In [9]:
tensor = torch.from_numpy(data)

In [10]:
tensor

tensor([[[ 0,  6, 32, 23],
         [ 0,  1, 44,  5],
         [ 1,  3, 32,  5],
         ...,
         [ 1,  1, 37,  4],
         [ 1,  2, 25,  2],
         [ 1,  2, 13,  2]],

        [[ 1,  4, 25,  4],
         [ 1,  4, 49,  4],
         [ 1,  4, 37,  4],
         ...,
         [ 1, 14, 20,  5],
         [ 0,  0,  8, 45],
         [ 1,  1, 20, 61]],

        [[ 0,  1, 36, 63],
         [ 1,  1, 63, 56],
         [ 1,  1, 27,  0],
         ...,
         [ 1, 10, 56,  2],
         [ 1, 11, 56,  1],
         [ 1, 11, 68,  1]],

        ...,

        [[ 1, 11, 43,  7],
         [ 1, 12, 24, 12],
         [ 1, 12, 48,  4],
         ...,
         [ 1,  8, 36,  9],
         [ 1,  8, 45, 17],
         [ 1,  8, 17,  7]],

        [[ 1,  8, 29,  9],
         [ 1,  8, 33,  4],
         [ 1, 11, 33,  4],
         ...,
         [ 1, 14, 16, 20],
         [ 1, 15, 24, 14],
         [ 0,  1, 28,  7]],

        [[ 1,  2, 31,  1],
         [ 1,  3, 31, 10],
         [ 1,  3, 36, 10],
         ...,
 

In [11]:
embeddings_position_ids

tensor([[  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
          14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
          28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
          42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
          56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
          70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
          84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
          98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
         112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
         126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
         140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,
         154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167,
         168, 169, 170, 171, 172, 173, 174, 175, 176

In [12]:
midibert.eval()


MidiBert(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
              (distance_embedding): Embedding(1023, 64)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (Layer

In [None]:
output = midibert(tensor)


### Bert model

In [50]:
midibert.bert


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
            (distance_embedding): Embedding(1023, 64)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
 

In [51]:
embeddings_position_ids.shape


torch.Size([1, 512])

### Number of parameters

In [52]:
model = midibert.bert.from_pretrained('bert-base-uncased')
num_parameters = sum(p.numel() for p in model.parameters())

print(f'The model has {num_parameters} parameters.')


Downloading config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

The model has 109482240 parameters.


 In a typical BERT model, the state_dict keys for the transformer layers usually follow this pattern:

bert.encoder.layer.{i}.{sub_layer}.{param}
where:

{i} is the layer index (from 0 to 11 for a 12-layer BERT model),
{sub_layer} is the sub-layer type, which can be attention.self, attention.output, intermediate, or output,
{param} is the parameter type, which can be weight or bias for linear layers, and gamma or beta for layer normalization.
To count the number of unique transformer layers in the state_dict, you can extract the layer index from each key and add it to a set (which automatically removes duplicates), then count the number of elements in the set.

In [53]:
state_dict = checkpoint['state_dict']

# Initialize an empty set to store the layer indices
layer_indices = set()

# Iterate over the keys in the state dictionary
for key in state_dict.keys():
    # Split the key into parts
    parts = key.split('.')
    # If this key corresponds to a transformer layer
    if parts[0] == 'bert' and parts[1] == 'encoder' and parts[2] == 'layer':
        # Extract the layer index and add it to the set
        layer_indices.add(int(parts[3]))

# Print the number of unique layers
print(f"The model has {len(layer_indices)} unique transformer layers.")


The model has 12 unique transformer layers.


In [54]:
# Get the names of the layers
layer_names = state_dict.keys()

# Get the unique keys from the layers
unique_keys = set()
for name in layer_names:
    # Split the name by '.' and get the first part
    key = name.split('.')[0]
    unique_keys.add(key)

# Print the unique keys
for key in unique_keys:
    print(key)


in_linear
bert
word_emb


the state_dict keys for the transformer layers usually follow this pattern:

bert.encoder.layer.{i}.{sub_layer}.{param}
where:

{i} is the layer index (from 0 to 11 for a 12-layer BERT model),
{sub_layer} is the sub-layer type, which can be attention.self, attention.output, intermediate, or output,
{param} is the parameter type, which can be weight or bias for linear layers, and gamma or beta for layer normalization.
To count the number of unique transformer layers in the state_dict, you can extract the layer index from each key and add it to a set (which automatically removes duplicates), then count the number of elements in the set.