In [1]:
import tensorflow as tf
from tensorflow.keras import models
from tensorflow.python.profiler.model_analyzer import profile
from tensorflow.python.profiler.option_builder import ProfileOptionBuilder
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph

In [2]:
# Función para obtener el nº de FLOPs de una red neuronal
def get_flops(model):
    concrete = tf.function(lambda inputs: model(inputs))
    concrete_func = concrete.get_concrete_function(
        [tf.TensorSpec([1, *inputs.shape[1:]]) for inputs in model.inputs])
    frozen_func, graph_def = convert_variables_to_constants_v2_as_graph(concrete_func)
    with tf.Graph().as_default() as graph:
        tf.graph_util.import_graph_def(graph_def, name='')
        run_meta = tf.compat.v1.RunMetadata()
        opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
        flops = tf.compat.v1.profiler.profile(graph=graph, run_meta=run_meta, cmd="op", options=opts)
        # The //2 is necessary since `profile` counts multiply and accumulate
        # as two flops, here we report the total number of multiply accumulate ops
        return flops.total_float_ops //2

In [3]:
# Importar modelo Keras
model = models.load_model('../../models/MNIST_model/mnistnetKeras.h5')
# Verificar modelo
model.summary()
 
# Importar modelo podado
model_pruned = models.load_model('mnistnetKerasPruned.h5')
# Verificar modelo podado
model_pruned.summary()



Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 24, 24, 20)        520       
_________________________________________________________________
activation (Activation)      (None, 24, 24, 20)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 20, 20, 50)        25050     
_________________________________________________________________
activation_1 (Activation)    (None, 20, 20, 50)        0         
_________________________________________________________________
flatten (Flatten)            (None, 20000)             0         
_________________________________________________________________
dense (Dense)                (None, 100)               2000100   
_________________________________________________________________
activation_2 (Activation)    (None, 100)               0

In [4]:
print('Nº de FLOPs del modelo original: {}'.format(get_flops(model)))
print('Nº de FLOPs del modelo podado: {}'.format(get_flops(model_pruned)))

Instructions for updating:
Use `tf.compat.v1.graph_util.tensor_shape_from_node_def_name`
Nº de FLOPs del modelo original: 12304840
Nº de FLOPs del modelo podado: 10623440
