# Pruning
imports

In [20]:
import os
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub
import tensorflow_model_optimization as tfmot
import numpy as np
import tempfile
from huggingface_hub import from_pretrained_keras

from tensorflow.python.framework import graph_util
from tensorflow.python.framework import graph_io
from tensorflow.python.tools import optimize_for_inference_lib
from tensorflow.python.platform import gfile

from src.models_download.model_download import download_tf_od_zoo
from src.models_download.model_download import download_tf_classification_zoo

paths

In [None]:
root_path = os.path.abspath(os.path.join(os.getcwd(), os.pardir))

In [None]:
object_detection_models = os.path.join(root_path, "pretrained_models", "object_detection")
classification_models = os.path.join(root_path, "pretrained_models", "classification")

download models

In [None]:
os.makedirs(object_detection_models)
download_tf_od_zoo(object_detection_models)

In [None]:
os.makedirs(classification_models)
download_tf_classification_zoo(classification_models)

prune test

In [None]:
frozen_graph_path = os.path.join(object_detection_models, "ssd_mobilenet_v2_320x320_coco17_tpu-8", "saved_model", "saved_model.pb")
saved_model_path = os.path.join(object_detection_models, "ssd_mobilenet_v2_320x320_coco17_tpu-8", "saved_model")

In [64]:
# model = tf.keras.applications.MobileNet(input_shape=(32,32,3), include_top=False)
model = tf.keras.applications.ResNet50(input_shape=(32,32,3), include_top=False)


cifar100 = tf.keras.datasets.cifar100
(train_images, train_labels), (test_images, test_labels) = cifar100.load_data()

train_images = tf.keras.applications.resnet50.preprocess_input(train_images)
test_images = tf.keras.applications.resnet50.preprocess_input(test_images)

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5


In [75]:
dataset, info = tfds.load('imagenette', split='train', with_info=True)
print(type(dataset))
train_dataset = dataset['train']
test_dataset = dataset['train']

def preprocess(image, label):
    image = tf.cast(image, tf.float32) / 255.0
    label = tf.one_hot(label, 1000)
    return image, label

train_dataset = train_dataset.map(preprocess)
test_dataset = test_dataset.map(preprocess)

train_dataset = train_dataset.shuffle(10000).batch(32)
test_dataset = test_dataset.batch(32)

train_images, train_labels = next(iter(train_dataset))
test_images, test_labels = next(iter(test_dataset))


train_images = tf.keras.applications.resnet50.preprocess_input(train_images)
test_images = tf.keras.applications.resnet50.preprocess_input(test_images)

<class 'tensorflow.python.data.ops.dataset_ops.PrefetchDataset'>


TypeError: 'PrefetchDataset' object is not subscriptable

In [65]:
# Train the digit classification model
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
  train_images,
  train_labels,
  epochs=4,
  validation_split=0.1,
)

Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


<keras.callbacks.History at 0x7fbde06ca5b0>

In [66]:
_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

Baseline test accuracy: 0.011537354439496994


In [67]:
# Compute end step to finish pruning after 2 epochs.
batch_size = 128
epochs = 2
validation_split = 0.1 # 10% of training set will be used for validation set.

num_images = train_images.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

# Define model for pruning.
pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                               final_sparsity=0.80,
                                                               begin_step=0,
                                                               end_step=end_step)
}

model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)

model_for_pruning.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model_for_pruning.summary()

Model: "resnet50"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_10 (InputLayer)          [(None, 32, 32, 3)]  0           []                               
                                                                                                  
 prune_low_magnitude_conv1_pad   (None, 38, 38, 3)   1           ['input_10[0][0]']               
 (PruneLowMagnitude)                                                                              
                                                                                                  
 prune_low_magnitude_conv1_conv  (None, 16, 16, 64)  18882       ['prune_low_magnitude_conv1_pad[0
  (PruneLowMagnitude)                                            ][0]']                           
                                                                                           

In [None]:
# os.environ["TF_DISABLE_JIT"] = "1"

logdir = tempfile.mkdtemp()

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
  tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]

model_for_pruning.fit(train_images, train_labels,
                  batch_size=batch_size, epochs=epochs, validation_split=validation_split,
                  callbacks=callbacks)

In [27]:
model = tf.keras.applications.MobileNetV2()

In [28]:
(train_ds, val_ds), info = tfds.load('imagenet2012', split=['train', 'validation'], with_info=True)

In [29]:
def preprocess(data):
    image = data['image']
    label = data['label']
    # Resize image to 224x224
    image = tf.image.resize(image, (224, 224))
    # Convert image to float32
    image = tf.cast(image, tf.float32)
    # Normalize image
    image = tf.keras.applications.mobilenet_v2.preprocess_input(image)
    return image, label


# train_ds = train_ds.map(preprocess).shuffle(10000).batch(32)
train_ds = train_ds.map(preprocess).batch(32)
val_ds = val_ds.map(preprocess).batch(32)

In [30]:
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

_, baseline_model_accuracy = model.evaluate(val_ds, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

  output, from_logits = _get_logits(


Baseline test accuracy: 0.7064200043678284


In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

with open('model.tflite', 'wb') as f:
    f.write(tflite_model)

In [31]:
# Compute end step to finish pruning after 2 epochs.
batch_size = 128
epochs = 2
validation_split = 0.1 # 10% of training set will be used for validation set.

num_images = train_ds.__len__().numpy() * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

# Define model for pruning.
pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                               final_sparsity=0.80,
                                                               begin_step=0,
                                                               end_step=end_step)
}

model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)

model_for_pruning.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)

model_for_pruning.summary()

Model: "mobilenetv2_1.00_224"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_3 (InputLayer)           [(None, 224, 224, 3  0           []                               
                                )]                                                                
                                                                                                  
 prune_low_magnitude_Conv1 (Pru  (None, 112, 112, 32  1730       ['input_3[0][0]']                
 neLowMagnitude)                )                                                                 
                                                                                                  
 prune_low_magnitude_bn_Conv1 (  (None, 112, 112, 32  129        ['prune_low_magnitude_Conv1[0][0]
 PruneLowMagnitude)             )                                ']            

In [34]:
logdir = tempfile.mkdtemp()

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
  tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]

model_for_pruning.fit(train_ds, validation_data=val_ds,
                  batch_size=batch_size, epochs=epochs,
                  callbacks=callbacks)

Epoch 1/2


  output, from_logits = _get_logits(
error: error: Can't find libdevice directory ${CUDA_DIR}/nvvm/libdevice
Can't find libdevice directory ${CUDA_DIR}/nvvm/libdeviceerror: 
Can't find libdevice directory ${CUDA_DIR}/nvvm/libdevice
error: Can't find libdevice directory ${CUDA_DIR}/nvvm/libdevice
error: error: Can't find libdevice directory ${CUDA_DIR}/nvvm/libdeviceCan't find libdevice directory ${CUDA_DIR}/nvvm/libdevice

error: error: Can't find libdevice directory ${CUDA_DIR}/nvvm/libdeviceCan't find libdevice directory ${CUDA_DIR}/nvvm/libdevice

error: error: Can't find libdevice directory ${CUDA_DIR}/nvvm/libdeviceCan't find libdevice directory ${CUDA_DIR}/nvvm/libdevice

error: error: Can't find libdevice directory ${CUDA_DIR}/nvvm/libdeviceerror: Can't find libdevice directory ${CUDA_DIR}/nvvm/libdevice
Can't find libdevice directory ${CUDA_DIR}/nvvm/libdevice
error: 
error: Can't find libdevice directory ${CUDA_DIR}/nvvm/libdeviceerror: 
Can't find libdevice directory ${CUDA_D

UnknownError: Graph execution error:

Detected at node 'mobilenetv2_1.00_224/prune_low_magnitude_Conv1/FloorMod' defined at (most recent call last):
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/runpy.py", line 197, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/traitlets/config/application.py", line 1043, in launch_instance
      app.start()
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 725, in start
      self.io_loop.start()
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/tornado/platform/asyncio.py", line 215, in start
      self.asyncio_loop.run_forever()
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/asyncio/base_events.py", line 601, in run_forever
      self._run_once()
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once
      handle._run()
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/asyncio/events.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 513, in dispatch_queue
      await self.process_one()
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 502, in process_one
      await dispatch(*args)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 409, in dispatch_shell
      await result
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 729, in execute_request
      reply_content = await reply_content
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 422, in do_execute
      res = shell.run_cell(
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/ipykernel/zmqshell.py", line 540, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 2961, in run_cell
      result = self._run_cell(
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3016, in _run_cell
      result = runner(coro)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3221, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3400, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3460, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/tmp/ipykernel_513/3588940814.py", line 8, in <module>
      model_for_pruning.fit(train_ds, validation_data=val_ds,
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/training.py", line 1564, in fit
      tmp_logs = self.train_function(iterator)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/training.py", line 1160, in train_function
      return step_function(self, iterator)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/training.py", line 1146, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/training.py", line 1135, in run_step
      outputs = model.train_step(data)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/training.py", line 993, in train_step
      y_pred = self(x, training=True)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/training.py", line 557, in __call__
      return super().__call__(*args, **kwargs)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/functional.py", line 510, in call
      return self._run_internal_graph(inputs, training=training, mask=mask)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/functional.py", line 667, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py", line 280, in call
      update_mask = utils.smart_cond(training, add_update, no_op)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow_model_optimization/python/core/keras/utils.py", line 50, in smart_cond
      if isinstance(pred, variables.Variable):
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow_model_optimization/python/core/keras/utils.py", line 54, in smart_cond
      pred, true_fn=true_fn, false_fn=false_fn, name=name)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py", line 268, in add_update
      with tf.control_dependencies(
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py", line 310, in conditional_mask_update
      return tf.distribute.get_replica_context().merge_call(
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py", line 307, in mask_update_distributed
      return tf.cond(maybe_update_masks(), update_distributed, no_update)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py", line 260, in maybe_update_masks
      if self._sparsity_m_by_n:
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py", line 264, in maybe_update_masks
      return self._pruning_schedule(self._step_fn())[0]
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_schedule.py", line 246, in __call__
      sparsity)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_schedule.py", line 61, in _should_prune_in_step
      is_pruning_turn = tf.math.equal(
Node: 'mobilenetv2_1.00_224/prune_low_magnitude_Conv1/FloorMod'
Detected at node 'mobilenetv2_1.00_224/prune_low_magnitude_Conv1/FloorMod' defined at (most recent call last):
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/runpy.py", line 197, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/traitlets/config/application.py", line 1043, in launch_instance
      app.start()
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 725, in start
      self.io_loop.start()
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/tornado/platform/asyncio.py", line 215, in start
      self.asyncio_loop.run_forever()
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/asyncio/base_events.py", line 601, in run_forever
      self._run_once()
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once
      handle._run()
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/asyncio/events.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 513, in dispatch_queue
      await self.process_one()
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 502, in process_one
      await dispatch(*args)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 409, in dispatch_shell
      await result
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 729, in execute_request
      reply_content = await reply_content
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 422, in do_execute
      res = shell.run_cell(
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/ipykernel/zmqshell.py", line 540, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 2961, in run_cell
      result = self._run_cell(
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3016, in _run_cell
      result = runner(coro)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3221, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3400, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3460, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/tmp/ipykernel_513/3588940814.py", line 8, in <module>
      model_for_pruning.fit(train_ds, validation_data=val_ds,
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/training.py", line 1564, in fit
      tmp_logs = self.train_function(iterator)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/training.py", line 1160, in train_function
      return step_function(self, iterator)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/training.py", line 1146, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/training.py", line 1135, in run_step
      outputs = model.train_step(data)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/training.py", line 993, in train_step
      y_pred = self(x, training=True)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/training.py", line 557, in __call__
      return super().__call__(*args, **kwargs)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/functional.py", line 510, in call
      return self._run_internal_graph(inputs, training=training, mask=mask)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/functional.py", line 667, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py", line 280, in call
      update_mask = utils.smart_cond(training, add_update, no_op)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow_model_optimization/python/core/keras/utils.py", line 50, in smart_cond
      if isinstance(pred, variables.Variable):
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow_model_optimization/python/core/keras/utils.py", line 54, in smart_cond
      pred, true_fn=true_fn, false_fn=false_fn, name=name)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py", line 268, in add_update
      with tf.control_dependencies(
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py", line 310, in conditional_mask_update
      return tf.distribute.get_replica_context().merge_call(
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py", line 307, in mask_update_distributed
      return tf.cond(maybe_update_masks(), update_distributed, no_update)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py", line 260, in maybe_update_masks
      if self._sparsity_m_by_n:
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py", line 264, in maybe_update_masks
      return self._pruning_schedule(self._step_fn())[0]
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_schedule.py", line 246, in __call__
      sparsity)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_schedule.py", line 61, in _should_prune_in_step
      is_pruning_turn = tf.math.equal(
Node: 'mobilenetv2_1.00_224/prune_low_magnitude_Conv1/FloorMod'
2 root error(s) found.
  (0) UNKNOWN:  JIT compilation failed.
	 [[{{node mobilenetv2_1.00_224/prune_low_magnitude_Conv1/FloorMod}}]]
	 [[mobilenetv2_1.00_224/prune_low_magnitude_block_13_expand_BN/assert_greater_equal/Assert/AssertGuard/pivot_f/_2199/_2563]]
  (1) UNKNOWN:  JIT compilation failed.
	 [[{{node mobilenetv2_1.00_224/prune_low_magnitude_Conv1/FloorMod}}]]
0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_948660]

In [None]:
model = from_pretrained_keras("keras-io/Object-Detection-RetinaNet")

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

with open('model.tflite', 'wb') as f:
    f.write(tflite_model)

In [29]:
data, info = tfds.load('coco/2017', split='train', with_info=True)

# Prepare the dataset for training
# def prepare_example(example):
#     # Normalize the image
#     image = tf.cast(example['image'], tf.float32) / 255.
#     # Resize the image to (256, 256)
#     image = tf.image.resize(image, (256, 256))
#     # Convert the labels and bboxes to dense tensors
#     label = tf.sparse.to_dense(example['objects']['label'])
#     bbox = tf.sparse.to_dense(example['objects']['bbox'])
#     return image, {'class_output': label, 'bbox_output': bbox}

def prepare_example(example):
    # Normalize the image
    image = tf.cast(example['image'], tf.float32) / 255.
    # Resize the image to (256, 256)
    image = tf.image.resize(image, (800, 800))
    # Convert the labels and bboxes to dense tensors
    label = example['objects']['label']
    bbox = example['objects']['bbox']
    label = tf.one_hot(label, depth=info.features['objects']['label'].num_classes)
    label = tf.reduce_max(label, axis=0)
    bbox = tf.reshape(bbox, [-1, 4])
    return image, {'class_output': label, 'bbox_output': bbox}

train_data = data.map(prepare_example).batch(32)

pruned_layers = []
for layer in model.layers:
    if isinstance(layer, tf.keras.layers.Conv2D) or isinstance(layer, tf.keras.layers.Dense):
        # Apply pruning to the layer if it is a supported layer type
        pruning_params = {
            'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(0.5, begin_step=0, end_step=100, frequency=10),
            'block_size': (1, 1),
            'block_pooling_type': 'AVG'
        }
        pruned_layer = tfmot.sparsity.keras.prune_low_magnitude(layer, **pruning_params)
        pruned_layers.append(pruned_layer)
    else:
        # Add the layer to the list of unpruned layers if it is not a supported layer type
        pruned_layers.append(layer)

pruned_model = tf.keras.models.Sequential(pruned_layers)

# Compile the pruned model
pruned_model.compile(loss={'class_output': 'binary_crossentropy', 'bbox_output': 'mse'},
                      optimizer='adam', metrics=['accuracy'])

pruned_model.fit(train_data, epochs=10)

converter = tf.lite.TFLiteConverter.from_keras_model(pruned_model)
tflite_model = converter.convert()

with open('pruned_model.tflite', 'wb') as f:
    f.write(tflite_model)

Epoch 1/10


ValueError: in user code:

    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/training.py", line 1160, in train_function  *
        return step_function(self, iterator)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/training.py", line 1146, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/training.py", line 1135, in run_step  **
        outputs = model.train_step(data)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/training.py", line 993, in train_step
        y_pred = self(x, training=True)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler
        raise e.with_traceback(filtered_tb) from None
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/sequential.py", line 349, in _build_graph_network_for_inferred_shape
        raise ValueError(SINGLE_LAYER_OUTPUT_ERROR_MSG)

    ValueError: Exception encountered when calling layer "sequential_3" "                 f"(type Sequential).
    
    All layers in a Sequential model should have a single output tensor. For multi-output layers, use the functional API.
    
    Call arguments received by layer "sequential_3" "                 f"(type Sequential):
      • inputs=tf.Tensor(shape=(None, 800, 800, 3), dtype=float32)
      • training=True
      • mask=None


In [30]:
# Define the pruning parameters
pruning_params = {
    'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50, final_sparsity=0.90, begin_step=0, end_step=1000)
}

pruned_model = model

# Prune only the supported layers of the model
for layer in pruned_model.layers:
    if isinstance(layer, (tf.keras.layers.Conv2D, tf.keras.layers.Dense)):
        tfmot.sparsity.keras.prune_low_magnitude(layer, **pruning_params)

# Compile and train the pruned model
pruned_model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
pruned_model.fit(train_data, epochs=10)

Epoch 1/10


ValueError: in user code:

    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/training.py", line 1160, in train_function  *
        return step_function(self, iterator)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/training.py", line 1146, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/training.py", line 1135, in run_step  **
        outputs = model.train_step(data)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/training.py", line 994, in train_step
        loss = self.compute_loss(x, y, y_pred, sample_weight)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/training.py", line 1052, in compute_loss
        return self.compiled_loss(
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/compile_utils.py", line 236, in __call__
        y_true = self._conform_to_outputs(y_pred, y_true)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/compile_utils.py", line 60, in _conform_to_outputs
        struct = map_to_output_names(outputs, self._output_names, struct)
    File "/home/marcinwsl/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/compile_utils.py", line 805, in map_to_output_names
        raise ValueError(

    ValueError: Found unexpected losses or metrics that do not correspond to any Model output: dict_keys(['class_output', 'bbox_output']). Valid mode output names: ['output_1']. Received struct is: {'class_output': <tf.Tensor 'IteratorGetNext:2' shape=(None, 80) dtype=float32>, 'bbox_output': <tf.Tensor 'IteratorGetNext:1' shape=(None, None, 4) dtype=float32>}.
