In [None]:
# install optimization toolkit
pip install -q tensorflow-model-optimization --yes

In [None]:
import tempfile
import os

import tensorflow as tf
import numpy as np

from tensorflow import keras

%load_ext tensorboard

In [None]:
model=tf.keras.models.load_model("./Trained_Model/resnet-Xinghuo-driving-1603168181.665796.h5")

In [None]:
# generate model for pruning

import tensorflow_model_optimization as tfmot

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(target_sparsity=0.5,
                                                                begin_step=0)
}

model_for_pruning = prune_low_magnitude(model, **pruning_params)

model_for_pruning.compile(optimizer='adam',
              loss=tf.keras.losses.MeanSquaredError(),
              metrics=['accuracy'])

model_for_pruning.summary()

In [None]:
# training and pruning the model

from tfrecord_utility import get_parsed_dataset

BATCH_SIZE = 32
EPOCHS = 2

train_dataset = get_parsed_dataset(record_file='。/tfrecord/train.record',
                                           batch_size=BATCH_SIZE,
                                           epochs=EPOCHS,
                                           shuffle=True)


logdir="./pruning_log"
callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
  tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]

model_for_pruning.fit(train_dataset,
                  epochs=EPOCHS,
                  callbacks=callbacks)

%tensorboard --logdir={logdir}

In [None]:
# save the model before evaluation
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

# _, pruned_keras_file = tempfile.mkstemp('.h5')
pruned_keras_file = "./Trained_Model/pruned_resnet50_model1.h5"
tf.keras.models.save_model(model_for_export, pruned_keras_file, include_optimizer=False)
print('Saved pruned Keras model to:', pruned_keras_file)

In [None]:
# evaluate
test_dataset = landmark.get_parsed_dataset(record_file="。/tfrecord/validation.record",
                                          batch_size=BATCH_SIZE,
                                          epochs=1,
                                          shuffle=False)
model=tf.keras.models.load_model("./saved_resnet.h5")
optimizer = keras.optimizers.Adam(lr=0.001)
model.compile(optimizer=optimizer, metrics=[keras.metrics.mean_squared_error], loss=keras.losses.mean_squared_error)

prun_model=tf.keras.models.load_model("./pruned_resnet50_model1.h5")
prun_model.compile(optimizer=optimizer, metrics=[keras.metrics.mean_squared_error], loss=keras.losses.mean_squared_error)

base_eval = model.evaluate(test_dataset)
pruned_eval = prun_model.evaluate(test_dataset)
print("baseline test accuracy:", base_eval)
print("pruned test accuracy:", pruned_eval)

In [None]:
# another attempt to launch tensorboard
%tensorboard --logdir={logdir}