In [1]:
import argparse
import sys; sys.path.append('..')
from dual_attention_transformer import DualAttnTransformerLM
from language_models import TransformerLM
import torch
import torchinfo

In [125]:
# Create an argparse.Namespace object with the specified defaults
args = argparse.Namespace(
    vocab_size=50304,
    d_model=768,
    n_layers=12,
    sa=6,
    ra=6,
    n_kv_heads=1,
    n_relations=None,
    rel_activation='identity',
    symbol_type='symbolic_attention',
    # symbol_type='position_relative',
    sym_attn_n_symbols=512,
    trainable_symbols=1,
    symmetric_rels=0,
    dff=None,
    activation='gelu',
    dropout_rate=0.0,
    norm_first=1,
    norm_type='layernorm',
    max_block_size=1024,
    bias=0,
    pos_enc_type='RoPE',
    max_seq_len=1024,
    shared_symbol_retriever=0,
    weight_tie_symbol_library=1,
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [126]:
# Model configuration
max_seq_len = args.max_seq_len

vocab_size = args.vocab_size
d_model = args.d_model
n_layers = args.n_layers
sa, ra = args.sa, args.ra
dff = args.dff
ra_type = 'relational_attention'
symmetric_rels = bool(args.symmetric_rels) if args.symmetric_rels in (0,1) else None
n_relations = args.n_relations
rel_proj_dim = None if n_relations is None else int((d_model / (sa+ra)) * (ra / n_relations))
rel_activation = args.rel_activation
symbol_type = args.symbol_type
trainable_symbols = bool(args.trainable_symbols)
sym_attn_n_symbols = args.sym_attn_n_symbols # args.max_block_size # only applicable for symbol_type=sym_attn
activation = args.activation
dropout_rate = args.dropout_rate
norm_first = bool(args.norm_first)
norm_type = args.norm_type
max_block_size = args.max_block_size
bias = bool(args.bias)
pos_enc_type = args.pos_enc_type
n_kv_heads = args.n_kv_heads

ra_kwargs = dict(n_relations=n_relations, rel_activation=rel_activation, rel_proj_dim=rel_proj_dim, n_kv_heads=n_kv_heads) # FIXME
sa_kwargs = dict(n_kv_heads=n_kv_heads) # FIXME NOTE: only used for DAT-LM
if symbol_type == 'symbolic_attention':
    # NOTE: n_heads, n_symbols fixed for now
    symbol_retrieval_kwargs = dict(d_model=d_model, n_symbols=sym_attn_n_symbols, n_heads=4, trainable_symbols=trainable_symbols)
elif symbol_type == 'position_relative':
    symbol_retrieval_kwargs = dict(symbol_dim=d_model, max_rel_pos=max_seq_len)
    ra_kwargs['use_relative_positional_symbols'] = True # if using position-relative symbols, need to tell RA module
elif ra != 0:
    raise ValueError(f'`symbol_type` {symbol_type} not valid')

if ra_type == 'relational_attention':
    ra_kwargs['symmetric_rels'] = symmetric_rels

symbol_retriever_config = dict(shared_symbol_retriever=bool(args.shared_symbol_retriever), weight_tie_symbol_library=bool(args.weight_tie_symbol_library))

# if ra=0, use TransformerLM
if ra == 0:
    model_config = dict(
        vocab_size=vocab_size, d_model=d_model, n_layers=n_layers, n_heads=sa, dff=dff,
        pos_enc_type=pos_enc_type, dropout_rate=dropout_rate, activation=activation, norm_first=norm_first,
        max_block_size=max_seq_len, bias=bias, use_flash_attention=True)

    # model = TransformerLM(**model_args).to(device)
# otherwise, use DualAttnTransformerLM
else:
    model_config = dict(
        vocab_size=vocab_size, d_model=d_model, n_layers=n_layers, n_heads_sa=sa, n_heads_ra=ra, dff=dff,
        sa_kwargs=sa_kwargs, ra_kwargs=ra_kwargs, ra_type=ra_type, pos_enc_type=pos_enc_type, activation=activation,
        symbol_retrieval=symbol_type, symbol_retrieval_kwargs=symbol_retrieval_kwargs, symbol_retriever_config=symbol_retriever_config,
        dropout_rate=dropout_rate, norm_first=norm_first, max_block_size=max_seq_len, bias=bias)


In [127]:
if 'n_heads_ra' in model_config:
    model = DualAttnTransformerLM(**model_config)
else:
    model = TransformerLM(**model_config)

model = model.to(device)

In [128]:
torchinfo.summary(model)

Layer (type:depth-idx)                                  Param #
DualAttnTransformerLM                                   --
├─ModuleDict: 1-1                                       --
│    └─Embedding: 2-1                                   38,633,472
│    └─Dropout: 2-2                                     --
│    └─ModuleList: 2-3                                  --
│    │    └─SymbolicAttention: 3-1                      1,377,024
│    │    └─SymbolicAttention: 3-2                      1,377,024
│    │    └─SymbolicAttention: 3-3                      1,377,024
│    │    └─SymbolicAttention: 3-4                      1,377,024
│    │    └─SymbolicAttention: 3-5                      1,377,024
│    │    └─SymbolicAttention: 3-6                      1,377,024
│    │    └─SymbolicAttention: 3-7                      1,377,024
│    │    └─SymbolicAttention: 3-8                      1,377,024
│    │    └─SymbolicAttention: 3-9                      1,377,024
│    │    └─SymbolicAttention: 3-10    

In [129]:
print(f'{model.get_num_params():,}')

127,621,632


In [130]:
print(f'{sum(p.numel() for p in model.layers.symbol_retrievers.parameters() if p.requires_grad):,}')

12,198,912


In [131]:
print(f'{sum(p.numel() for p in model.parameters() if p.requires_grad):,}')

127,621,632


In [6]:
type(model_summary)

torchinfo.model_statistics.ModelStatistics

In [7]:
model_summary_dict = {
    'Input size (MB)': model_summary.to_megabytes(model_summary.total_input),
    'Params size (MB)': model_summary.to_megabytes(model_summary.total_param_bytes),
    'Forward/backward pass size  (MB)': model_summary.to_megabytes(model_summary.total_output_bytes),
    'Estimated total size (MB)': model_summary.to_megabytes(model_summary.total_output_bytes + model_summary.total_param_bytes + model_summary.total_input),
    'Total Mult-Adds': model_summary.total_mult_adds,

    'trainable_params': model_summary.trainable_params,
    'total_params': model_summary.total_params,
}

for k,v in model_summary_dict.items():
    print(f'{k}: {v}')

Input size (MB): 0.004176
Params size (MB): 665.662464
Forward/backward pass size  (MB): 1406.140416
Estimated total size (MB): 2071.807056
Total Mult-Adds: 172912128
trainable_params: 167033088
total_params: 167622912


In [8]:
print(dir(model_summary))

['__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', 'float_to_megabytes', 'format_output_num', 'formatting', 'input_size', 'summary_list', 'to_megabytes', 'to_readable', 'total_input', 'total_mult_adds', 'total_output_bytes', 'total_param_bytes', 'total_params', 'trainable_params']


In [9]:
model_summary.total_input

4176