In [None]:
import tensorflow as tf
from matplotlib import pyplot
from tensorflow.keras.datasets import fashion_mnist
from emnist import extract_training_samples
from tensorflow.keras import *
from tensorflow.keras.layers import *
from sklearn.model_selection import train_test_split
import numpy as np
import matplotlib.pyplot as plt
from pprint import pprint
import sys
import os
import zipfile
import tempfile
import tensorflow_model_optimization as tfmot
import json

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude


sys.path.insert(1, os.path.join(sys.path[0], '../../..'))

from utils import *

%matplotlib inline
%config Completer.use_jedi = False

tf.config.list_physical_devices('GPU')
tf.compat.v1.enable_eager_execution()

In [None]:
def get_gzipped_model_size(m):
    _, keras_file = tempfile.mkstemp('.h5')
    m.save(keras_file, include_optimizer=True)
    _, zipped_file = tempfile.mkstemp('.zip')
    with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
        f.write(keras_file)
    return os.path.getsize(zipped_file)

In [None]:
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
show_dataset(x_train)

In [None]:
num_classes = 10 

mean = np.mean(x_train,axis=(0, 1, 2))
std = np.std(x_train, axis=(0, 1, 2))

x_train = (x_train-mean)/(std+1e-7)
x_test = (x_test-mean)/(std+1e-7)

y_train = utils.to_categorical(y_train, num_classes) 
y_test = utils.to_categorical(y_test, num_classes) 

In [None]:
x_train = x_train.reshape(-1,28,28,1)
x_test = x_test.reshape(-1,28,28,1)

In [None]:
model = tf.keras.models.load_model("vanilla_fashion_mnist.h5")

In [None]:
pprint(full_evaluate(model, x_test, y_test))

In [None]:
model.summary()

# Weight prunning

In [None]:
results = {}

In [None]:
results['base'] = full_evaluate(model, x_test, y_test)
results['base']['size (kb)'] = os.stat('vanilla_fashion_mnist.h5').st_size//1024

In [None]:
batch_size = 256
epochs = 50

In [None]:
def train_sparse(model_for_pruning, epochs):
    model_for_pruning.compile(optimizer='adam',
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])

    logdir = tempfile.mkdtemp()

    callbacks = [
      tfmot.sparsity.keras.UpdatePruningStep(),
      tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
    ]

    model_for_pruning.fit(x_train, y_train,
                      batch_size=batch_size, epochs=epochs,
                      callbacks=callbacks,
                        validation_data=(x_test, y_test)
                         )


    model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
    return model_for_export

In [None]:
# for p in [0.5, 0.7, 0.8, 0.85, 0.9, 0.925, 0.95, 0.975, 0.99]:
#     pruning_params = {
#           'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.2,
#                             final_sparsity=p,
#                             begin_step=0,
#                             end_step=np.ceil(x_train.shape[0] / batch_size).astype(np.int32) * epochs)
#     }

#     model_for_pruning = prune_low_magnitude(tf.keras.models.clone_model(model), **pruning_params)
#     model_for_export = train_sparse(model_for_pruning, int(epochs*p))
#     results['weight_pruning_{}'.format(p)] = full_evaluate(model_for_export, x_test, y_test)
#     results['weight_pruning_{}'.format(p)]['size (kb)'] = get_gzipped_model_size(model_for_export)//1024

In [None]:
# pprint(results)

In [None]:
# with open('results_pruning.json', 'w', encoding ='utf8') as json_file:
#     json.dump(results, json_file, ensure_ascii = True)

In [None]:
tf.compat.v1.disable_eager_execution()

# Node prunning

In [None]:
model_tmp = tf.keras.models.clone_model(model)
# acc = []
# par_count =[]
models = [model_tmp]
model_tmp.compile(optimizer='adam',
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
model_tmp = prune_model(model_tmp, 0.3, opt='adam', method='l1')

for i in range(100):
    try:
        model_tmp = prune_model(model_tmp, 0.1, opt='adam', method='l1')
    except:
        break
    model_tmp.fit(x_train, y_train, batch_size=256, epochs=15, validation_data=(x_test, y_test), verbose=1)
    results['node_pruning_{}'.format((0.9)**(i+4))] = full_evaluate(model_tmp, x_test, y_test)
    results['node_pruning_{}'.format((0.9)**(i+4))]['params'] = model_tmp.count_params()
    results['node_pruning_{}'.format((0.9)**(i+4))]['size (kb)'] = get_gzipped_model_size(model_tmp)//1024
    with open('results_pruning.json', 'w', encoding ='utf8') as json_file:
        json.dump(results, json_file, ensure_ascii = True)

In [None]:
with open('results_pruning.json', 'r', encoding ='utf8') as json_file:
    results = json.load(json_file)

In [None]:
# for i in range(100):
#     results['node_pruning_{}'.format((0.9)**(i+1))] = results['node_pruning_{}'.format((0.1)**(i+1))]

In [None]:
# for i in range(100):
#     del results['node_pruning_{}'.format((0.1)**(i+1))]

In [None]:
# with open('results_pruning.json', 'w', encoding ='utf8') as json_file:
#         json.dump(results, json_file, ensure_ascii = True)