In [1]:
import sys
sys.path.insert(0,'..')

In [2]:
from model import EfConfClassifier as Model
from config import *

from torchsummary import summary
from prettytable import PrettyTable

In [3]:
def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        param = parameter.numel()
        table.add_row([name, param])
        total_params+=param
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params

model = Model(n_encoders=CONFIG["n_encoders"], n_decoders=CONFIG["n_decoders"])
n_params = count_parameters(model)

print(f"Model has {n_params/1e+6:.3f}M params")

+----------------------------------------------------+------------+
|                      Modules                       | Parameters |
+----------------------------------------------------+------------+
|              enc_proc.convs.0.weight               |   344064   |
|               enc_proc.convs.0.bias                |     64     |
|               dec_proc.lin.0.weight                |    2432    |
|                dec_proc.lin.0.bias                 |     64     |
|           encoder.lacs.0.lffn1.E1.weight           |   16384    |
|           encoder.lacs.0.lffn1.D1.weight           |   262144   |
|           encoder.lacs.0.lffn1.E2.weight           |   262144   |
|           encoder.lacs.0.lffn1.D2.weight           |   16384    |
|          encoder.lacs.0.mhlsa.W_O.weight           |    4096    |
|       encoder.lacs.0.conv_module.ln1.weight        |     64     |
|        encoder.lacs.0.conv_module.ln1.bias         |     64     |
| encoder.lacs.0.conv_module.pw_conv1.pw_conv.we