In [1]:
from pathlib import Path
import sys

ROOT = Path().resolve().parents[0]  # one level up from the notebook folder
sys.path.insert(0, str(ROOT))


In [2]:
from pathlib import Path

import torch
from src.textclf_transformer.training.training_loop import TrainingLoop
from src.textclf_transformer import *

EXP_BASE = Path(f'{ROOT}/experiments/pretraining/')
name = 'E1_pretraining_wikipedia_bertsmall_mha'


exp_dir, cfg = read_experiment_config(EXP_BASE, name)
set_global_seed(cfg["experiment"].get("seed", 42))
wrapper = load_tokenizer_wrapper_from_cfg(cfg["tokenizer"])
arch_kw = arch_kwargs_from_cfg(cfg, wrapper.tokenizer) 
head = cfg["mlm_head"]        
model = TransformerForMaskedLM(
    **arch_kw,
    tie_mlm_weights=head["tie_mlm_weights"]
    )

In [3]:
from torchinfo import summary
import torch

model.eval()

B, N = 3, 100
vocab_size, pad_id = arch_kw['vocab_size'], arch_kw['pad_token_id']
device = "cpu"  

input_ids = torch.randint(1, vocab_size, (B, N), dtype=torch.long)
input_ids[:, -4:] = pad_id  
inputs = {
    "input_ids": input_ids,
    "attention_mask": (input_ids == pad_id),            
}

info = summary(
    model,
    input_data=inputs,     
    device=device,
    depth=3,
    col_names=("input_size", "output_size", "num_params", "mult_adds"),
    verbose=1,
    return_sequence=False,  
)

Layer (type:depth-idx)                             Input Shape               Output Shape              Param #                   Mult-Adds
TransformerForMaskedLM                             --                        [3, 100, 30522]           --                        --
├─TransformerTextEmbeddings: 1-1                   [3, 100]                  [3, 100, 512]             --                        --
│    └─Embedding: 2-1                              [3, 100]                  [3, 100, 512]             15,627,264                46,881,792
│    └─LayerNorm: 2-2                              [3, 100, 512]             [3, 100, 512]             1,024                     3,072
│    └─Dropout: 2-3                                [3, 100, 512]             [3, 100, 512]             --                        --
├─ModuleList: 1-2                                  --                        --                        --                        --
│    └─TransformerEncoderBlock: 2-4                [3, 100

In [6]:
for p , t in model.named_parameters():
    print(p)

embeddings.word_embeddings.weight
embeddings.layer_norm.weight
embeddings.layer_norm.bias
layers.0.attention_block.attention_mechanism.Uqkv.weight
layers.0.attention_block.attention_mechanism.Uqkv.bias
layers.0.attention_block.attention_mechanism.Uout.weight
layers.0.attention_block.attention_mechanism.Uout.bias
layers.0.attention_block.layer_norm.weight
layers.0.attention_block.layer_norm.bias
layers.0.mlp_block.mlp.0.weight
layers.0.mlp_block.mlp.0.bias
layers.0.mlp_block.mlp.2.weight
layers.0.mlp_block.mlp.2.bias
layers.0.mlp_block.layer_norm.weight
layers.0.mlp_block.layer_norm.bias
layers.1.attention_block.attention_mechanism.Uqkv.weight
layers.1.attention_block.attention_mechanism.Uqkv.bias
layers.1.attention_block.attention_mechanism.Uout.weight
layers.1.attention_block.attention_mechanism.Uout.bias
layers.1.attention_block.layer_norm.weight
layers.1.attention_block.layer_norm.bias
layers.1.mlp_block.mlp.0.weight
layers.1.mlp_block.mlp.0.bias
layers.1.mlp_block.mlp.2.weight
laye