In [1]:
from collections import OrderedDict
def get_model_params(sb, tb, sn, tn, e, f):
    group_dict = OrderedDict() # {name of parameter groups: number of parameters in the group}
    param_dict = {} # {name of parameters: shape of param matrix}

    # decoder_att
    for i in range(tn):
        param_dict['decoder_transformer_{0}_att_enc_h2o_weight'.format(i)] = (e, e)
        param_dict['decoder_transformer_{0}_att_enc_k2h_weight'.format(i)] = (e, e)
        param_dict['decoder_transformer_{0}_att_enc_pre_norm_beta'.format(i)] = (e,)
        param_dict['decoder_transformer_{0}_att_enc_pre_norm_gamma'.format(i)] = (e,)
        param_dict['decoder_transformer_{0}_att_enc_q2h_weight'.format(i)] = (e, e)
        param_dict['decoder_transformer_{0}_att_enc_v2h_weight'.format(i)] = (e, e)
        param_dict['decoder_transformer_{0}_att_self_h2o_weight'.format(i)] = (e, e)
        param_dict['decoder_transformer_{0}_att_self_i2h_weight'.format(i)] = (3*e, e)
        param_dict['decoder_transformer_{0}_att_self_pre_norm_beta'.format(i)] = (e,)
        param_dict['decoder_transformer_{0}_att_self_pre_norm_gamma'.format(i)] = (e, )

    group_dict['decoder_att'] = tn*4*e*(2*e+1)

    # decoder_ff
    for i in range(tn):
        param_dict['decoder_transformer_{0}_ff_h2o_bias'.format(i)] = (e,) 
        param_dict['decoder_transformer_{0}_ff_h2o_weight'.format(i)] = (e, f) 
        param_dict['decoder_transformer_{0}_ff_i2h_bias'.format(i)] = (f,) 
        param_dict['decoder_transformer_{0}_ff_i2h_weight'.format(i)] = (f, e) 
        param_dict['decoder_transformer_{0}_ff_pre_norm_beta'.format(i)] = (e,) 
        param_dict['decoder_transformer_{0}_ff_pre_norm_gamma'.format(i)] = (e,)

    group_dict['decoder_ff'] = tn*(2*e*f+3*e+f)

    # decoder_final
    param_dict['decoder_transformer_final_process_norm_beta'] = (e,)
    param_dict['decoder_transformer_final_process_norm_gamma'] = (e,)

    group_dict['decoder_final'] = 2*e

    # encoder_att
    for i in range(sn):
        param_dict['encoder_transformer_{0}_att_self_h2o_weight'.format(i)] = (e, e) 
        param_dict['encoder_transformer_{0}_att_self_i2h_weight'.format(i)] = (3*e, e) 
        param_dict['encoder_transformer_{0}_att_self_pre_norm_beta'.format(i)] = (e,) 
        param_dict['encoder_transformer_{0}_att_self_pre_norm_gamma'.format(i)] = (e,)

    group_dict['ecoder_att'] = sn*2*e*(2*e+1)

    # encoder_ff
    for i in range(sn):
        param_dict['encoder_transformer_{0}_ff_h2o_bias'.format(i)] = (e,) 
        param_dict['encoder_transformer_{0}_ff_h2o_weight'.format(i)] = (e, f) 
        param_dict['encoder_transformer_{0}_ff_i2h_bias'.format(i)] = (f,) 
        param_dict['encoder_transformer_{0}_ff_i2h_weight'.format(i)] = (f, e) 
        param_dict['encoder_transformer_{0}_ff_pre_norm_beta'.format(i)] = (e,) 
        param_dict['encoder_transformer_{0}_ff_pre_norm_gamma'.format(i)] = (e,)

    group_dict['encoder_ff'] = sn*(2*e*f+3*e+f)

    # encoder_final
    param_dict['encoder_transformer_final_process_norm_beta'] = (e,)
    param_dict['encoder_transformer_final_process_norm_gamma'] = (e,)

    group_dict['encoder_final'] = 2*e

    # io
    param_dict['source_embed_weight'] = (sb, e)
    param_dict['target_embed_weight'] = (tb, e)
    param_dict['target_output_bias'] = (tb,)
    param_dict['target_output_weight'] = (tb, e)

    group_dict['io'] = sb*e+tb*(2*e+1)

    return group_dict, param_dict

In [2]:
def get_num_params(sb, tb, sn, tn, e, f):
    io_nparam = sb*e + tb*(2*e+1)
    nparam = tn*(8*e*e+7*e+2*e*f+f) + sn*(4*e*e+5*e+2*e*f+f) + 4*e + io_nparam

    return nparam

In [4]:
results = []
for n in [1,2,4]: # num_layers
    for e in [256,512,1024]: # num_embed
        for f in [1024, 2048]: # transformer_feed_forward_num_hidden
            for b in [10000, 30000, 50000]: # bpe_symbols
                if b==10000:
                    sb = 10004
                    tb = 10004
                elif b==30000:
                    sb = 30004
                    tb = 28244
                else:
                    sb = 50004
                    tb = 41355
                group_dict, param_dict = get_model_params(sb, tb, n, n, e, f)
                nparam = get_num_params(sb, tb, n, n, e, f)
                res = []
                res += [n, e, f, b]
                res.append(nparam)
                res += group_dict.values()
                results.append(res)

In [7]:
res = ""
for r in results:
    res += "| "
    res += " | ".join([str(i) for i in r])
    res += " |\n"

In [9]:
print(res)

| 1 | 256 | 1024 | 10000 | 9534228 | 525312 | 526080 | 512 | 262656 | 526080 | 512 | 7693076 |
| 1 | 256 | 1024 | 30000 | 24011348 | 525312 | 526080 | 512 | 262656 | 526080 | 512 | 22170196 |
| 1 | 256 | 1024 | 50000 | 35857291 | 525312 | 526080 | 512 | 262656 | 526080 | 512 | 34016139 |
| 1 | 256 | 2048 | 10000 | 10584852 | 525312 | 1051392 | 512 | 262656 | 1051392 | 512 | 7693076 |
| 1 | 256 | 2048 | 30000 | 25061972 | 525312 | 1051392 | 512 | 262656 | 1051392 | 512 | 22170196 |
| 1 | 256 | 2048 | 50000 | 36907915 | 525312 | 1051392 | 512 | 262656 | 1051392 | 512 | 34016139 |
| 1 | 512 | 1024 | 10000 | 20629268 | 2099200 | 1051136 | 1024 | 1049600 | 1051136 | 1024 | 15376148 |
| 1 | 512 | 1024 | 30000 | 49565268 | 2099200 | 1051136 | 1024 | 1049600 | 1051136 | 1024 | 44312148 |
| 1 | 512 | 1024 | 50000 | 73244043 | 2099200 | 1051136 | 1024 | 1049600 | 1051136 | 1024 | 67990923 |
| 1 | 512 | 2048 | 10000 | 22728468 | 2099200 | 2100736 | 1024 | 1049600 | 2100736 | 1024 | 15376148 |
| 1