# Calculate the FLOPs and number of parameters

This notebook contains the implementation to calculate the FLOPs and number of parameters reported in our PFM experiments.

In [1]:
from graphs import my_vgg_bn_graph, my_resnet_graph, resnet_graph
from models import my_vgg_bn, my_vgg, my_resnet
from graphs.base_graph import NodeType
import numpy as np
import pandas as pd
import torchvision.models as models

# modify some functions in torchstat to get more detailed information
from torchstat.statistics import ModelHook, convert_leaf_modules_to_stat_tree
from torchstat.reporter import round_value
import torch.nn as nn

def report_format(collected_nodes):
    data = list()
    for node in collected_nodes:
        name = node.name
        input_shape = ' '.join(['{:>3d}'] * len(node.input_shape)).format(
            *[e for e in node.input_shape])
        output_shape = ' '.join(['{:>3d}'] * len(node.output_shape)).format(
            *[e for e in node.output_shape])
        parameter_quantity = node.parameter_quantity
        inference_memory = node.inference_memory
        MAdd = node.MAdd
        Flops = node.Flops
        mread, mwrite = [i for i in node.Memory]
        duration = node.duration
        data.append([name, input_shape, output_shape, parameter_quantity,
                     inference_memory, MAdd, duration, Flops, mread,
                     mwrite])
    df = pd.DataFrame(data)
    df.columns = ['module name', 'input shape', 'output shape',
                  'params', 'memory(MB)',
                  'MAdd', 'duration', 'Flops', 'MemRead(B)', 'MemWrite(B)']
    df['duration[%]'] = df['duration'] / (df['duration'].sum() + 1e-7)
    df['MemR+W(B)'] = df['MemRead(B)'] + df['MemWrite(B)']
    total_parameters_quantity = df['params'].sum()
    total_memory = df['memory(MB)'].sum()
    total_operation_quantity = df['MAdd'].sum()
    total_flops = df['Flops'].sum()
    total_duration = df['duration[%]'].sum()
    total_mread = df['MemRead(B)'].sum()
    total_mwrite = df['MemWrite(B)'].sum()
    total_memrw = df['MemR+W(B)'].sum()
    del df['duration']

    # Add Total row
    total_df = pd.Series([total_parameters_quantity, total_memory,
                          total_operation_quantity, total_flops,
                          total_duration, mread, mwrite, total_memrw],
                         index=['params', 'memory(MB)', 'MAdd', 'Flops', 'duration[%]',
                                'MemRead(B)', 'MemWrite(B)', 'MemR+W(B)'],
                         name='total')
    df = df.append(total_df)

    df = df.fillna(' ')
    df['memory(MB)'] = df['memory(MB)'].apply(
        lambda x: '{:.2f}'.format(x))
    df['duration[%]'] = df['duration[%]'].apply(lambda x: '{:.2%}'.format(x))
    df['MAdd'] = df['MAdd'].apply(lambda x: '{:,}'.format(x))
    df['Flops'] = df['Flops'].apply(lambda x: '{:,}'.format(x))

    summary = str(df) + '\n'
    summary += "=" * len(str(df).split('\n')[0])
    summary += '\n'
    summary += "Total params: {:,}\n".format(total_parameters_quantity)

    summary += "-" * len(str(df).split('\n')[0])
    summary += '\n'
    summary += "Total memory: {:.2f}MB\n".format(total_memory)
    summary += "Total MAdd: {}MAdd\n".format(round_value(total_operation_quantity))
    summary += "Total Flops: {}Flops\n".format(round_value(total_flops))
    summary += "Total MemR+W: {}B\n".format(round_value(total_memrw, True))
    return summary, df

class ModelStat(object):
    def __init__(self, model, input_size, query_granularity=1):
        assert isinstance(model, nn.Module)
        assert isinstance(input_size, (tuple, list)) and len(input_size) == 3
        self._model = model
        self._input_size = input_size
        self._query_granularity = query_granularity

    def _analyze_model(self):
        model_hook = ModelHook(self._model, self._input_size)
        leaf_modules = model_hook.retrieve_leaf_modules()
        stat_tree = convert_leaf_modules_to_stat_tree(leaf_modules)
        collected_nodes = stat_tree.get_collected_stat_nodes(self._query_granularity)
        return collected_nodes

    def show_report(self, print_report=True):
        collected_nodes = self._analyze_model()
        report, df = report_format(collected_nodes)
        if print_report:
            print(report)
        return report, df


def stat(model, input_size, query_granularity=1, print_report=False):
    ms = ModelStat(model, input_size, query_granularity)
    return ms.show_report(print_report)

  from .autonotebook import tqdm as notebook_tqdm


#### VGG16

In [2]:
def doubling_shape(shape_str):
    shape = shape_str.split(' ')
    for i, c in enumerate(shape):
        try:
            shape[i] = str(int(c)*2)
        except:
            pass
    return ' '.join(shape)

def report_zipit_at_layer_flops(report_df, layer, flops_multiply_factor=2):
    report_df = report_df.drop(['memory(MB)', 'MAdd', 'MemRead(B)', 'MemWrite(B)', 'duration[%]', 'MemR+W(B)'], axis=1).copy()
    
    report_df['Flops'] = report_df['Flops'].replace({',': ''}, regex=True).astype(float)

    input_shape_s = report_df['input shape'].apply(lambda x: x.strip().split(' ')[0]).to_numpy()
    input_shape_s = input_shape_s[input_shape_s != '']
    input_shape_s = input_shape_s.astype(int)
    
    if layer > 0:
        report_df.loc[:layer-1, 'Flops'] *= flops_multiply_factor
        report_df.loc[:layer-1, 'params'] *= 2
        report_df.loc[:layer-1, 'input shape'] = report_df.loc[:layer-1, 'output shape'].apply(doubling_shape)
        report_df.loc[:layer-1, 'output shape'] = report_df.loc[:layer-1, 'output shape'].apply(doubling_shape)
        
        zip_row = report_df.iloc[layer-1].copy()
        zip_row['module name'] = 'merge acts'
        zip_row['input shape'] = report_df['output shape'].iloc[layer-1]
        input_shape = int(zip_row['input shape'].strip().split(' ')[0])
        if layer < len(report_df) - 1:
            zip_row['output shape'] = report_df['input shape'].iloc[layer]
            output_shape = int(zip_row['output shape'].strip().split(' ')[0])
            assert(input_shape == output_shape * 2)
        
        zip_row['params'] = 0
        zip_row['Flops'] = input_shape
        report_df = pd.concat([report_df.iloc[:layer], pd.DataFrame([zip_row]), report_df.iloc[layer:]]).reset_index(drop=True)
        
        report_df.loc[report_df.index[-1], 'params'] = report_df['params'].iloc[:-1].sum()
        report_df.loc[report_df.index[-1], 'Flops'] = report_df['Flops'].iloc[:-1].sum()
    
    return report_df


model_new = my_vgg.my_vgg16().to('cpu')
num_features = len(model_new.features)

report, report_df = stat(model_new, (3, 32, 32))

merge_act_idx_s = []
# conv layer index
for i in range(num_features):
    if 'Conv2d' in model_new.features[i].__class__.__name__:
        merge_act_idx_s.append(i)
# for classifier
merge_act_idx_s.append(len(report_df)-2)
# for ensemble
merge_act_idx_s.append(len(report_df)-1)

report_df_zip_at_layer_s = []
for layer in merge_act_idx_s:
    report_df_zip_at_layer = report_zipit_at_layer_flops(report_df, layer)
    report_df_zip_at_layer_s.append(report_df_zip_at_layer)

param_s = [df.iloc[-1]['params'] for df in report_df_zip_at_layer_s]
flops_s = [df.iloc[-1]['Flops'] for df in report_df_zip_at_layer_s]
param_s = np.array(param_s)
flops_s = np.array(flops_s)

additioanl_param_pct = (param_s - param_s[0]) / param_s[0]
additioanl_flops_pct = (flops_s - flops_s[0]) / flops_s[0]

  df = df.append(total_df)
  report_df.loc[:layer-1, 'Flops'] *= flops_multiply_factor
  report_df.loc[:layer-1, 'params'] *= 2
  report_df.loc[:layer-1, 'input shape'] = report_df.loc[:layer-1, 'output shape'].apply(doubling_shape)
  report_df.loc[:layer-1, 'output shape'] = report_df.loc[:layer-1, 'output shape'].apply(doubling_shape)
  report_df.loc[:layer-1, 'Flops'] *= flops_multiply_factor
  report_df.loc[:layer-1, 'params'] *= 2
  report_df.loc[:layer-1, 'input shape'] = report_df.loc[:layer-1, 'output shape'].apply(doubling_shape)
  report_df.loc[:layer-1, 'output shape'] = report_df.loc[:layer-1, 'output shape'].apply(doubling_shape)
  report_df.loc[:layer-1, 'Flops'] *= flops_multiply_factor
  report_df.loc[:layer-1, 'params'] *= 2
  report_df.loc[:layer-1, 'input shape'] = report_df.loc[:layer-1, 'output shape'].apply(doubling_shape)
  report_df.loc[:layer-1, 'output shape'] = report_df.loc[:layer-1, 'output shape'].apply(doubling_shape)
  report_df.loc[:layer-1, 'Flops'] *=

In [3]:
print(f"additional params: {additioanl_param_pct}")
print(f"additional flops: {additioanl_flops_pct}")

additional params: [0.00000000e+00 1.21740636e-04 2.63046731e-03 7.64792065e-03
 1.76741316e-02 3.77265534e-02 7.78140056e-02 1.17901458e-01
 1.98076362e-01 3.58391388e-01 5.18706413e-01 6.79021439e-01
 8.39336465e-01 9.99651490e-01 1.00000000e+00]
additional flops: [0.         0.00605542 0.12694684 0.18728856 0.30786678 0.36810451
 0.48847395 0.60889558 0.66908193 0.78939916 0.9097425  0.93982181
 0.96990112 0.99998695 1.00000006]


In [4]:
report_df

Unnamed: 0,module name,input shape,output shape,params,memory(MB),MAdd,Flops,MemRead(B),MemWrite(B),duration[%],MemR+W(B)
0,features.0,3 32 32,64 32 32,1792.0,0.25,3538944.0,1835008.0,19456.0,262144.0,74.08%,281600.0
1,features.1,64 32 32,64 32 32,0.0,0.25,65536.0,65536.0,262144.0,262144.0,0.55%,524288.0
2,features.2,64 32 32,64 32 32,36928.0,0.25,75497472.0,37814272.0,409856.0,262144.0,6.74%,672000.0
3,features.3,64 32 32,64 32 32,0.0,0.25,65536.0,65536.0,262144.0,262144.0,0.18%,524288.0
4,features.4,64 32 32,64 16 16,0.0,0.06,49152.0,65536.0,262144.0,65536.0,0.47%,327680.0
5,features.5,64 16 16,128 16 16,73856.0,0.12,37748736.0,18907136.0,360960.0,131072.0,0.98%,492032.0
6,features.6,128 16 16,128 16 16,0.0,0.12,32768.0,32768.0,131072.0,131072.0,0.13%,262144.0
7,features.7,128 16 16,128 16 16,147584.0,0.12,75497472.0,37781504.0,721408.0,131072.0,2.64%,852480.0
8,features.8,128 16 16,128 16 16,0.0,0.12,32768.0,32768.0,131072.0,131072.0,0.13%,262144.0
9,features.9,128 16 16,128 8 8,0.0,0.03,24576.0,32768.0,131072.0,32768.0,0.17%,163840.0


#### VGG16-BN

In [5]:
# for vgg16 bn
def calculate_param_num(params, M):
    factor = 2
    total = params.sum()
    cur = 0
    if M == 14:
        cur = 2 * total
    elif M == 0:
        cur = total
    else:
        cur += (params[:factor*M]*2).sum()
        cur += params[factor*M:].sum()
    return cur, cur/total

def calculate_flop_num(flops, output_shape_s, preserve_until, M):
    total = flops.sum()
    
    idx = preserve_until[M]
    merge_flops = output_shape_s[idx]*2
    cur = 0
    if M == 14:
        cur = 2 * total + merge_flops
    elif M == 0:
        cur = total
    else:
        cur += (flops[:idx+1]*2).sum()
        cur += merge_flops
        cur += flops[idx+1:].sum()
    return cur/1e9, cur/total

Number of parameters

In [6]:
model_new = my_vgg_bn.my_vgg16_bn().to('cpu')
num_features = len(model_new.features)
report, report_df = stat(model_new, (3, 32, 32))

params = report_df['params'][:-1]
# remove all the zero; reset index
params = params[params != 0]
params = params.reset_index(drop=True)

print(calculate_param_num(params, 0))
print(calculate_param_num(params, 3))
print(calculate_param_num(params, 7))
print(calculate_param_num(params, 10))
print(calculate_param_num(params, 12))
print(calculate_param_num(params, 14))

(14728266.0, 1.0)
(14841354.0, 1.0076782969563423)
(16466058.0, 1.1179902644343875)
(22368906.0, 1.5187739004713794)
(27090570.0, 1.8393590935959467)
(29456532.0, 2.0)


  df = df.append(total_df)


FLOPs

In [7]:
relu_idx = []
for i, m in enumerate(model_new.features):
    if 'ReLU' in m.__class__.__name__:
        relu_idx.append(i)
relu_idx

[2, 5, 9, 12, 16, 19, 22, 26, 29, 32, 36, 39, 42]

In [8]:
flops = report_df['Flops'][:-1]
flops = flops.replace({',': ''}, regex=True).astype(float)
output_shape_s = report_df['output shape'].apply(lambda x: x.strip().split(' ')[0]).to_numpy()
output_shape_s = output_shape_s[output_shape_s != '']
output_shape_s = output_shape_s.astype(int)
preserve_until = [0] + relu_idx + [44]

print(calculate_flop_num(flops, output_shape_s, preserve_until, 0))
print(calculate_flop_num(flops, output_shape_s, preserve_until, 3))
print(calculate_flop_num(flops, output_shape_s, preserve_until, 7))
print(calculate_flop_num(flops, output_shape_s, preserve_until, 10))
print(calculate_flop_num(flops, output_shape_s, preserve_until, 12))
print(calculate_flop_num(flops, output_shape_s, preserve_until, 14))

(0.314432512, 1.0)
(0.37354624, 1.1880013222042383)
(0.50602752, 1.6093358691864537)
(0.60051456, 1.9098360922677105)
(0.619413504, 1.9699410218749898)
(0.628865044, 2.000000063606654)


In [9]:
report_df

Unnamed: 0,module name,input shape,output shape,params,memory(MB),MAdd,Flops,MemRead(B),MemWrite(B),duration[%],MemR+W(B)
0,features.0,3 32 32,64 32 32,1792.0,0.25,3538944.0,1835008.0,19456.0,262144.0,9.14%,281600.0
1,features.1,64 32 32,64 32 32,128.0,0.25,262144.0,131072.0,262656.0,262144.0,2.76%,524800.0
2,features.2,64 32 32,64 32 32,0.0,0.25,65536.0,65536.0,262144.0,262144.0,1.43%,524288.0
3,features.3,64 32 32,64 32 32,36928.0,0.25,75497472.0,37814272.0,409856.0,262144.0,8.29%,672000.0
4,features.4,64 32 32,64 32 32,128.0,0.25,262144.0,131072.0,262656.0,262144.0,0.88%,524800.0
5,features.5,64 32 32,64 32 32,0.0,0.25,65536.0,65536.0,262144.0,262144.0,1.19%,524288.0
6,features.6,64 32 32,64 16 16,0.0,0.06,49152.0,65536.0,262144.0,65536.0,1.86%,327680.0
7,features.7,64 16 16,128 16 16,73856.0,0.12,37748736.0,18907136.0,360960.0,131072.0,3.56%,492032.0
8,features.8,128 16 16,128 16 16,256.0,0.12,131072.0,65536.0,132096.0,131072.0,0.43%,263168.0
9,features.9,128 16 16,128 16 16,0.0,0.12,32768.0,32768.0,131072.0,131072.0,1.01%,262144.0


#### ResNet20

In [10]:
def calculate_num_flops(graph, report_df, prefix_idx_s, M):
    # M=1,...,13
    M = M - 1
    sum_param = 0
    sum_flops = 0
    total_layer = 0
    num_layers = 0
    prefix_id = prefix_idx_s[M]
    
    total_params = report_df['params']['total']
    total_flops = report_df['Flops']['total']

    for node in graph.G.nodes:
        node_name = graph.G.nodes[node]['layer']
        
        if node_name not in report_df['module name'].values:
            continue
        else:
            if report_df['params'][report_df['module name']==node_name].values[0] > 0:
                if 'norm' not in node_name and 'shortcut' not in node_name:
                    total_layer += 1
            if node <= prefix_id:
                sum_param += report_df['params'][report_df['module name']==node_name].values[0]*2
                sum_flops += report_df['Flops'][report_df['module name']==node_name].values[0]*2
                if report_df['params'][report_df['module name']==node_name].values[0] > 0:
                    if 'norm' not in node_name and 'shortcut' not in node_name:
                        num_layers += 1
            else:
                sum_param += report_df['params'][report_df['module name']==node_name].values[0]
                sum_flops += report_df['Flops'][report_df['module name']==node_name].values[0]
    return sum_param/1e6, sum_param/total_params, sum_flops/1e6, sum_flops/total_flops, num_layers, total_layer

In [11]:
model_new = my_resnet.ResNet.get_model_from_name('cifar_resnet20')
graph = my_resnet_graph.my_resnet20(model_new).graphify()

report, report_df = stat(model_new, (3, 32, 32))
report_df['Flops'] = report_df['Flops'].replace({',': ''}, regex=True).astype(float)

[MAdd]: Sequential is not supported!
[Flops]: Sequential is not supported!
[Memory]: Sequential is not supported!
[MAdd]: Sequential is not supported!
[Flops]: Sequential is not supported!
[Memory]: Sequential is not supported!
[MAdd]: Sequential is not supported!
[Flops]: Sequential is not supported!
[Memory]: Sequential is not supported!
[MAdd]: Sequential is not supported!
[Flops]: Sequential is not supported!
[Memory]: Sequential is not supported!
[MAdd]: Sequential is not supported!
[Flops]: Sequential is not supported!
[Memory]: Sequential is not supported!
[MAdd]: Sequential is not supported!
[Flops]: Sequential is not supported!
[Memory]: Sequential is not supported!
[MAdd]: Sequential is not supported!
[Flops]: Sequential is not supported!
[Memory]: Sequential is not supported!
[MAdd]: Sequential is not supported!
[Flops]: Sequential is not supported!
[Memory]: Sequential is not supported!
[MAdd]: Sequential is not supported!
[Flops]: Sequential is not supported!
[Memory]: Seq

  df = df.append(total_df)


In [12]:
report_df

Unnamed: 0,module name,input shape,output shape,params,memory(MB),MAdd,Flops,MemRead(B),MemWrite(B),duration[%],MemR+W(B)
0,conv1,3 32 32,16 32 32,432.0,0.06,868352.0,442368.0,14016.0,65536.0,4.87%,79552.0
1,norm1,16 32 32,16 32 32,32.0,0.06,65536.0,32768.0,65664.0,65536.0,1.38%,131200.0
2,avg_pool,64 8 8,64 1 1,0.0,0.0,0.0,0.0,0.0,0.0,1.31%,0.0
3,segments.0.0.conv1,16 32 32,16 32 32,2304.0,0.06,4702208.0,2359296.0,74752.0,65536.0,2.09%,140288.0
4,segments.0.0.norm1,16 32 32,16 32 32,32.0,0.06,65536.0,32768.0,65664.0,65536.0,0.99%,131200.0
5,segments.0.0.conv2,16 32 32,16 32 32,2304.0,0.06,4702208.0,2359296.0,74752.0,65536.0,1.75%,140288.0
6,segments.0.0.norm2,16 32 32,16 32 32,32.0,0.06,65536.0,32768.0,65664.0,65536.0,0.91%,131200.0
7,segments.0.0.shortcut,16 32 32,16 32 32,0.0,0.06,0.0,0.0,0.0,0.0,0.03%,0.0
8,segments.0.1.conv1,16 32 32,16 32 32,2304.0,0.06,4702208.0,2359296.0,74752.0,65536.0,2.00%,140288.0
9,segments.0.1.norm1,16 32 32,16 32 32,32.0,0.06,65536.0,32768.0,65664.0,65536.0,0.97%,131200.0


In [13]:
prefix_idx_s = []
util_layer_name_s = []
for node in graph.G.nodes:
    if graph.G.nodes[node]['type'] == NodeType.PREFIX:
        prefix_idx_s.append(node)
        if graph.G.nodes[node-1]['layer'] is None:
            util_layer_name_s.append(graph.G.nodes[node-2]['layer'])
        else:
            util_layer_name_s.append(graph.G.nodes[node-1]['layer'])
print(prefix_idx_s)
print(util_layer_name_s)

[5, 11, 17, 21, 26, 32, 38, 42, 47, 53, 59, 63]
['segments.0.0.norm1', 'segments.0.1.norm1', 'segments.0.2.norm1', 'segments.0.2.norm2', 'segments.1.0.norm1', 'segments.1.1.norm1', 'segments.1.2.norm1', 'segments.1.2.norm2', 'segments.2.0.norm1', 'segments.2.1.norm1', 'segments.2.2.norm1', 'segments.2.2.norm2']


In [14]:
print(calculate_num_flops(graph, report_df, prefix_idx_s, 3))
print(calculate_num_flops(graph, report_df, prefix_idx_s, 6))
print(calculate_num_flops(graph, report_df, prefix_idx_s, 8))
print(calculate_num_flops(graph, report_df, prefix_idx_s, 12))

(0.284618, 1.044569390106946, 53.650048, 1.301724593076161, 6, 20)
(0.310762, 1.1405198294149166, 62.13696, 1.5076446710912486, 10, 20)
(0.338602, 1.2426947158260973, 69.264, 1.6805698331309455, 13, 20)
(0.544298, 1.9976144512870952, 82.428544, 1.9999844715192134, 19, 20)


In [15]:
total_params = report_df['params']['total']
total_flops = report_df['Flops']['total']
print(total_params/1e6, total_flops/1e6)    

0.272474 41.214592


#### ResNet50

In [16]:
def calculate_num_flops(graph, report_df, prefix_idx_s, M):
    # M=1,...,13
    M = M - 1
    sum_param = 0
    sum_flops = 0
    total_layer = 0
    num_layers = 0
    prefix_id = prefix_idx_s[M]
    
    total_params = report_df['params']['total']
    total_flops = report_df['Flops']['total']

    for node in graph.G.nodes:
        node_name = graph.G.nodes[node]['layer']
        
        if node_name not in report_df['module name'].values:
            continue
        else:
            if report_df['params'][report_df['module name']==node_name].values[0] > 0:
                if 'bn' not in node_name and 'downsample' not in node_name:
                    total_layer += 1
            if node <= prefix_id:
                sum_param += report_df['params'][report_df['module name']==node_name].values[0]*2
                sum_flops += report_df['Flops'][report_df['module name']==node_name].values[0]*2
                if report_df['params'][report_df['module name']==node_name].values[0] > 0:
                    if 'bn' not in node_name and 'downsample' not in node_name:
                        num_layers += 1
            else:
                sum_param += report_df['params'][report_df['module name']==node_name].values[0]
                sum_flops += report_df['Flops'][report_df['module name']==node_name].values[0]
    return sum_param/1e6, sum_param/total_params, sum_flops/1e9, sum_flops/total_flops, num_layers, total_layer

In [17]:
model_new = models.resnet50()
graph = resnet_graph.resnet50(model_new).graphify()

report, report_df = stat(model_new, (3, 32, 32))
report_df['Flops'] = report_df['Flops'].replace({',': ''}, regex=True).astype(float)

[MAdd]: Sequential is not supported!
[Flops]: Sequential is not supported!
[Memory]: Sequential is not supported!
[MAdd]: Sequential is not supported!
[Flops]: Sequential is not supported!
[Memory]: Sequential is not supported!
[MAdd]: Sequential is not supported!
[Flops]: Sequential is not supported!
[Memory]: Sequential is not supported!
[MAdd]: Sequential is not supported!
[Flops]: Sequential is not supported!
[Memory]: Sequential is not supported!
[MAdd]: Sequential is not supported!
[Flops]: Sequential is not supported!
[Memory]: Sequential is not supported!
[MAdd]: Sequential is not supported!
[Flops]: Sequential is not supported!
[Memory]: Sequential is not supported!
[MAdd]: Sequential is not supported!
[Flops]: Sequential is not supported!
[Memory]: Sequential is not supported!
[MAdd]: Sequential is not supported!
[Flops]: Sequential is not supported!
[Memory]: Sequential is not supported!
[MAdd]: AdaptiveAvgPool2d is not supported!
[Flops]: AdaptiveAvgPool2d is not supported!

  df = df.append(total_df)


In [18]:
prefix_idx_s = []
util_layer_name_s = []
for node in graph.G.nodes:
    if graph.G.nodes[node]['type'] == NodeType.PREFIX:
        prefix_idx_s.append(node)
        if graph.G.nodes[node-1]['layer'] is None:
            util_layer_name_s.append(graph.G.nodes[node-2]['layer'])
        else:
            util_layer_name_s.append(graph.G.nodes[node-1]['layer'])
print(prefix_idx_s)
print(util_layer_name_s)

[3, 8, 11, 17, 20, 26, 29, 33, 38, 41, 47, 50, 56, 59, 65, 68, 72, 77, 80, 86, 89, 95, 98, 104, 107, 113, 116, 122, 125, 129, 134, 137, 143, 146, 152, 155, 159]
['bn1', 'layer1.0.bn1', 'layer1.0.bn2', 'layer1.1.bn1', 'layer1.1.bn2', 'layer1.2.bn1', 'layer1.2.bn2', 'layer1.2.bn3', 'layer2.0.bn1', 'layer2.0.bn2', 'layer2.1.bn1', 'layer2.1.bn2', 'layer2.2.bn1', 'layer2.2.bn2', 'layer2.3.bn1', 'layer2.3.bn2', 'layer2.3.bn3', 'layer3.0.bn1', 'layer3.0.bn2', 'layer3.1.bn1', 'layer3.1.bn2', 'layer3.2.bn1', 'layer3.2.bn2', 'layer3.3.bn1', 'layer3.3.bn2', 'layer3.4.bn1', 'layer3.4.bn2', 'layer3.5.bn1', 'layer3.5.bn2', 'layer3.5.bn3', 'layer4.0.bn1', 'layer4.0.bn2', 'layer4.1.bn1', 'layer4.1.bn2', 'layer4.2.bn1', 'layer4.2.bn2', 'layer4.2.bn3']


In [19]:
report_df

Unnamed: 0,module name,input shape,output shape,params,memory(MB),MAdd,Flops,MemRead(B),MemWrite(B),duration[%],MemR+W(B)
0,conv1,3 32 32,64 16 16,9408.0,0.06,4800512.0,2408448.0,49920.0,65536.0,4.60%,115456.0
1,bn1,64 16 16,64 16 16,128.0,0.06,65536.0,32768.0,66048.0,65536.0,0.97%,131584.0
2,relu,64 16 16,64 16 16,0.0,0.06,16384.0,16384.0,65536.0,65536.0,0.47%,131072.0
3,maxpool,64 16 16,64 8 8,0.0,0.02,32768.0,16384.0,65536.0,16384.0,0.63%,81920.0
4,layer1.0.conv1,64 8 8,64 8 8,4096.0,0.02,520192.0,262144.0,32768.0,16384.0,0.93%,49152.0
5,layer1.0.bn1,64 8 8,64 8 8,128.0,0.02,16384.0,8192.0,16896.0,16384.0,0.61%,33280.0
6,layer1.0.conv2,64 8 8,64 8 8,36864.0,0.02,4714496.0,2359296.0,163840.0,16384.0,0.92%,180224.0
7,layer1.0.bn2,64 8 8,64 8 8,128.0,0.02,16384.0,8192.0,16896.0,16384.0,0.52%,33280.0
8,layer1.0.conv3,64 8 8,256 8 8,16384.0,0.06,2080768.0,1048576.0,81920.0,65536.0,0.73%,147456.0
9,layer1.0.bn3,256 8 8,256 8 8,512.0,0.06,65536.0,32768.0,67584.0,65536.0,0.59%,133120.0


In [20]:
import numpy as np
import torch
preserverd_num = [0]
param_ratio = [1]
flops_ratio = [1]
for i in range(1, 37):
    arr = np.array(calculate_num_flops(graph, report_df, prefix_idx_s, i))
    param_ratio.append(arr[1]) 
    flops_ratio.append(arr[3])
    preserverd_num.append(arr[4])
    print(f"i: {i}", arr[[1, 3, 4]])

preserverd_num += [50]
param_ratio += [2]
flops_ratio += [2]

param_ratio = np.array(param_ratio)
flops_ratio = np.array(flops_ratio)
preserverd_num = np.array(preserverd_num)
res_dict = {'param_ratio': param_ratio, 'flops_ratio': flops_ratio, 'preserverd_num': preserverd_num}
# torch.save(res_dict, 'resnet50_pfm_details.pt')

i: 1 [1.00037313 1.02667746 1.        ]
i: 2 [1.00119951 1.04238407 2.        ]
i: 3 [1.00264694 1.06989446 3.        ]
i: 4 [1.00395414 1.09473947 5.        ]
i: 5 [1.00540157 1.12224985 6.        ]
i: 6 [1.00670876 1.14709487 8.        ]
i: 7 [1.00815619 1.17460525 9.        ]
i: 8 [ 1.0088173   1.18717055 10.        ]
i: 9 [ 1.01527814  1.23628943 11.        ]
i: 10 [ 1.02105784  1.26375222 12.        ]
i: 11 [ 1.02623654  1.28835925 14.        ]
i: 12 [ 1.03201624  1.31582204 15.        ]
i: 13 [ 1.03719493  1.34042908 17.        ]
i: 14 [ 1.04297463  1.36789186 18.        ]
i: 15 [ 1.04815332  1.3924989  20.        ]
i: 16 [ 1.05393302  1.41996169 21.        ]
i: 17 [ 1.05653739  1.4323366  22.        ]
i: 18 [ 1.0822806  1.4812651 23.       ]
i: 19 [ 1.10537937  1.50870408 24.        ]
i: 20 [ 1.12599397  1.53319213 26.        ]
i: 21 [ 1.14909274  1.56063112 27.        ]
i: 22 [ 1.16970734  1.58511917 29.        ]
i: 23 [ 1.19280611  1.61255816 30.        ]
i: 24 [ 1.21342071  1

In [21]:
# params, 40, 80
# flops 40, 80 

In [22]:
total_params = report_df['params']['total']
total_flops = report_df['Flops']['total']
print(total_params/1e6, total_flops/1e9)

25.557032 0.086057984
