In [8]:
from collections import OrderedDict
def get_model_params(cell, sb, tb, sn, tn, se, te, h):
    group_dict = OrderedDict() # {name of parameter groups: number of parameters in the group}
    param_dict = {} # {name of parameters: shape of param matrix}

    # enc2decinit parameters
    if cell == 'lstm':
        x = 2*tn
        group_dict["enc2decinit"] = 2*tn*h*(h+1)
    else:
        x = tn
        group_dict["enc2decinit"] = tn*h*(h+1)
    for i in range(x):
        param_dict['decoder_rnn_enc2decinit_{0}_bias'.format(i)] = (h,)
        param_dict['decoder_rnn_enc2decinit_{0}_weight'.format(i)] = (h,h)


    # hidden
    param_dict['decoder_rnn_hidden_bias'] = (h,)
    param_dict['decoder_rnn_hidden_weight'] = (h, 2*h)

    group_dict["hidden"] = h*(2*h+1)

    # decoder_lx
    if cell == 'lstm':
        y = 4*h
        group_dict["decoder_lx"] = 4*h*(se+2*tn*(h+1))
    else:
        y = 3*h
        group_dict["decoder_lx"] = 3*h*(se+2*tn*(h+1))
    for i in range(tn):
        if i == 0:
            z = h+se
        else:
            z = h
        param_dict['decoder_rnn_l{0}_h2h_bias'.format(i)] = (y,)
        param_dict['decoder_rnn_l{0}_h2h_weight'.format(i)] = (y,h)
        param_dict['decoder_rnn_l{0}_i2h_bias'.format(i)] = (y,)
        param_dict['decoder_rnn_l{0}_i2h_weight'.format(i)] = (y,z)

    # birnn
    if cell == 'lstm':
        y = 2*h
        group_dict["birnn"] = 2*h*(4+h+2*se)
    else:
        y = int(1.5 * h)
        group_dict["birnn"] = int(1.5*h*(4+h+2*se))
    param_dict['encoder_birnn_forward_l0_h2h_bias'] = (y,)
    param_dict['encoder_birnn_forward_l0_h2h_weight'] = (y, h/2)
    param_dict['encoder_birnn_forward_l0_i2h_bias'] = (y,)
    param_dict['encoder_birnn_forward_l0_i2h_weight'] = (y, se)
    param_dict['encoder_birnn_reverse_l0_h2h_bias'] = (y,)
    param_dict['encoder_birnn_reverse_l0_h2h_weight'] = (y, h/2)
    param_dict['encoder_birnn_reverse_l0_i2h_bias'] = (y,)
    param_dict['encoder_birnn_reverse_l0_i2h_weight'] = (y, se)

    # encoder_lx
    if sn > 1:
        if cell == 'lstm':
            y = 4*h
            group_dict["encoder_lx"] = 4*h*(sn-1)*(2+2*h)
        else:
            y = 3*h
            group_dict["encoder_lx"] = 3*h*(sn-1)*(2+2*h)
        for i in range(sn-1):
            param_dict['encoder_rnn_l{0}_h2h_bias'.format(i)] = (y,)
            param_dict['encoder_rnn_l{0}_h2h_weight'.format(i)] = (y, h)
            param_dict['encoder_rnn_l{0}_i2h_bias'.format(i)] = (y,)
            param_dict['encoder_rnn_l{0}_i2h_weight'.format(i)] = (y, h)
    else:
        group_dict["encoder_lx"] = 0

    # io
    param_dict['source_embed_weight'] = (sb, se)
    param_dict['target_embed_weight'] = (tb, se)
    param_dict['target_output_bias'] = (tb,)
    param_dict['target_output_weight'] = (tb, h)

    group_dict["io"] = sb*se+tb*(1+te+h)

    return group_dict, param_dict

In [9]:
def get_num_params(cell, sb, tb, sn, tn, se, te, h):
    io_nparam = sb*se + tb*(1+te+h)
    if cell == 'lstm':
        nparam = h*(-4*h + 8*se + (8*sn+10*tn)*(1+h) + 1) + io_nparam
    else:
        nparam = h*(-int(2.5*h) + 6*se + (6*sn+7*tn)*(1+h) + 1) + io_nparam
    return nparam

In [10]:
results = []
for n in [1,2,4]: # num_layers
    for cell in ["lstm", "gru"]: # rnn_cell_type
        for e in [256, 512, 1024]: # num_embed
            for h in [256, 512, 1024]: # rnn_num_hidden
                for b in [10000, 30000, 50000]: # bpe_symbols
                    if b==10000:
                        sb = 10004
                        tb = 10004
                    elif b==30000:
                        sb = 30004
                        tb = 28244
                    else:
                        sb = 50004
                        tb = 41355
                    group_dict, param_dict = get_model_params(cell, sb, tb, n, n, e, e, h)
                    nparam = get_num_params(cell, sb, tb, n, n, e, e, h)
                    res = []
                    res += [n, cell, e, h, b]
                    res.append(nparam)
                    res += group_dict.values()
                    results.append(res)

In [11]:
res = ""
for r in results:
    res += "| "
    res += " | ".join([str(i) for i in r])
    res += " |\n"

In [12]:
print(res)

| 1 | lstm | 256 | 256 | 10000 | 9139732 | 131584 | 131328 | 788480 | 395264 | 0 | 7693076 |
| 1 | lstm | 256 | 256 | 30000 | 23616852 | 131584 | 131328 | 788480 | 395264 | 0 | 22170196 |
| 1 | lstm | 256 | 256 | 50000 | 35462795 | 131584 | 131328 | 788480 | 395264 | 0 | 34016139 |
| 1 | lstm | 256 | 512 | 10000 | 14982420 | 525312 | 524800 | 2625536 | 1052672 | 0 | 10254100 |
| 1 | lstm | 256 | 512 | 30000 | 34128980 | 525312 | 524800 | 2625536 | 1052672 | 0 | 29400660 |
| 1 | lstm | 256 | 512 | 50000 | 49331339 | 525312 | 524800 | 2625536 | 1052672 | 0 | 44603019 |
| 1 | lstm | 256 | 1024 | 10000 | 32172820 | 2099200 | 2098176 | 9445376 | 3153920 | 0 | 15376148 |
| 1 | lstm | 256 | 1024 | 30000 | 60658260 | 2099200 | 2098176 | 9445376 | 3153920 | 0 | 43861588 |
| 1 | lstm | 256 | 1024 | 50000 | 82573451 | 2099200 | 2098176 | 9445376 | 3153920 | 0 | 65776779 |
| 1 | lstm | 512 | 256 | 10000 | 14786068 | 131584 | 131328 | 1050624 | 657408 | 0 | 12815124 |
| 1 | lstm | 512 | 256 | 30000