In [None]:
import argparse
import yaml
import torch
from torchsummary import summary
from thop import profile, clever_format

def import_class(name):
    components = name.split('.')
    mod = __import__(components[0])  # import return model
    for comp in components[1:]:
        mod = getattr(mod, comp)
    return mod

In [None]:
%cd ..

## Count Params and MACs 

In [None]:
# Load config
with open(r'D:\DATN\project\Pose-based-WLASL\configs\ctr-gcn\config.yaml') as f: 
    arg = yaml.load(f, Loader=yaml.FullLoader)
arg = argparse.Namespace(**arg)

num_keypoint = 27
T_frame = 150
arg.model_args['num_class'] = 2000
arg.model_args['graph_args']['layout'] = f'keypoint-{num_keypoint}'
vars(arg)

In [None]:
# Load model
Model = import_class(arg.model)
model = Model(**arg.model_args)
# Input size: (N, C, T, V, M)
# Batch size N = 64
# Channel C = 3 (stand for (X,Y,C) skeleton point data) 
# T = 150 (number of frames)
# V = 27 (number of keypoints)
# M = 1 (number of persons)
summary(model, input_size=(3, T_frame, num_keypoint, 1))

In [None]:
input = torch.randn(1, 3, T_frame, num_keypoint, 1)
macs, params = profile(model, inputs=(input, ))
macs, params = clever_format([macs, params], "%.3f")
print("MACs: {}, Params: {}".format(macs, params))