In [34]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import tensorflow as tf
import time
from tqdm import tqdm

sns.set()
tf.enable_eager_execution()
tf.set_random_seed(1867)


In [35]:
#Prepare MNIST Data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

#将数据切片，缓冲区1000，batch64
dataset_train = tf.data.Dataset.from_tensor_slices((
    tf.cast(x_train/255, tf.float32),
    tf.cast(y_train, tf.int64)
)).shuffle(1000).batch(64)

dateset_test = tf.data.Dataset.from_tensor_slices((
    tf.cast(x_test/255, tf.float32),
    tf.cast(y_test, tf.int64)
)).batch(64)


In [36]:
#Model Definition
model_orig = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(1000, activation=tf.nn.relu, use_bias=False),
    tf.keras.layers.Dense(1000, activation=tf.nn.relu, use_bias=False),
    tf.keras.layers.Dense(500, activation=tf.nn.relu, use_bias=False),
    tf.keras.layers.Dense(200, activation=tf.nn.relu, use_bias=False),
    tf.keras.layers.Dense(10, use_bias=False),
])

In [37]:
def train_model(model):
    optimizer = tf.train.AdamOptimizer()
    global_step = tf.train.get_or_create_global_step()

    training_losses = []
    training_accuracy = []

    for epoch in range(10):
        epoch_loss_avg = tf.contrib.eager.metrics.Mean()
        epoch_accuracy = tf.contrib.eager.metrics.Accuracy()
        for x, y in tqdm(dataset_train, total=round(len(x_train)/64)):
            with tf.GradientTape() as tape:
                outputs = model(x)
                loss = tf.losses.softmax_cross_entropy(tf.one_hot(y, 10), outputs)
            grads = tape.gradient(loss, model.trainable_weights)
            optimizer.apply_gradients(zip(grads, model.trainable_weights), global_step)
            epoch_loss_avg(loss)
            epoch_accuracy(tf.argmax(outputs, axis=1, output_type=tf.int64), y)
        training_losses.append(epoch_loss_avg.result())
        training_accuracy.append(epoch_accuracy.result())

    return training_losses, training_accuracy

In [43]:
def re_train(pruned_model):
    optimizer = tf.train.AdamOptimizer()
    global_step = tf.train.get_or_create_global_step()

    training_losses = []
    training_accuracy = []

    for epoch in range(1):
        epoch_loss_avg = tf.contrib.eager.metrics.Mean()
        epoch_accuracy = tf.contrib.eager.metrics.Accuracy()
        for x, y in tqdm(dataset_train, total=round(len(x_train)/64)):
            with tf.GradientTape() as tape:
                outputs = pruned_model(x)
                loss = tf.losses.softmax_cross_entropy(tf.one_hot(y, 10), outputs)
            grads = tape.gradient(loss, pruned_model.trainable_weights)
            optimizer.apply_gradients(zip(grads, pruned_model.trainable_weights), global_step)
            epoch_loss_avg(loss)
            epoch_accuracy(tf.argmax(outputs, axis=1, output_type=tf.int64), y)
        training_losses.append(epoch_loss_avg.result())
        training_accuracy.append(epoch_accuracy.result())

    return training_losses, training_accuracy

In [8]:
def test(model, dataset):
    epoch_loss_avg = tf.contrib.eager.metrics.Mean()
    epoch_accuracy = tf.contrib.eager.metrics.Accuracy()
    for x, y in dataset:
        outputs = model(x)
        loss = tf.losses.softmax_cross_entropy(tf.one_hot(y, 10), outputs)
        epoch_loss_avg(loss)
        epoch_accuracy(tf.argmax(outputs, axis=1, output_type=tf.int64), y)
    return epoch_loss_avg.result().numpy(), epoch_accuracy.result().numpy()



[<tf.Tensor: id=179433, shape=(), dtype=float64, numpy=0.9418833333333333>,
 <tf.Tensor: id=358650, shape=(), dtype=float64, numpy=0.97235>,
 <tf.Tensor: id=537867, shape=(), dtype=float64, numpy=0.9797166666666667>,
 <tf.Tensor: id=717084, shape=(), dtype=float64, numpy=0.985>,
 <tf.Tensor: id=896301, shape=(), dtype=float64, numpy=0.9876333333333334>,
 <tf.Tensor: id=1075518, shape=(), dtype=float64, numpy=0.98935>,
 <tf.Tensor: id=1254735, shape=(), dtype=float64, numpy=0.9901333333333333>,
 <tf.Tensor: id=1433952, shape=(), dtype=float64, numpy=0.9914833333333334>,
 <tf.Tensor: id=1613169, shape=(), dtype=float64, numpy=0.9926>,
 <tf.Tensor: id=1792386, shape=(), dtype=float64, numpy=0.9929833333333333>]

In [21]:
def unit_prune(dense_model, percentile):
    prev_kept_columns = None
    pruned_model = tf.keras.models.Sequential()
    pruned_model.add(tf.keras.layers.Flatten(input_shape=(28, 28)))
    num_layers = len(dense_model.trainable_weights)

    for i_layer, weights in enumerate(dense_model.trainable_weights):
        weights_np = weights.numpy()

        # Remove pruned columns
        if i_layer < num_layers - 1:  # Do not prune last layer
            column_norms = np.linalg.norm(weights_np, ord=2, axis=0)
            critical_value = np.percentile(column_norms, percentile)
            keep_mask = column_norms >= critical_value
            weights_np = weights_np[:, keep_mask]

        # Remove rows corresponding to previous layer's pruned columns
        if prev_kept_columns is not None:
            weights_np = weights_np[prev_kept_columns, :]

        # Record which columns were kept
        if i_layer < num_layers - 1:  # No pruned columns in last layer
            prev_kept_columns = np.argwhere(keep_mask).reshape(-1)

        # Add new layer to sparse model
        new_layer = tf.keras.layers.Dense(weights_np.shape[1], activation=tf.nn.relu, use_bias=False)
        pruned_model.add(new_layer)
        new_layer.set_weights([weights_np])

    return pruned_model

(0.10114042613807853, 0.9795)

In [38]:
losses, accuracies = train_model(model_orig)
print(losses)
print(accuracies)

100%|████████████████████████████████████████████████████████████████████████████████| 938/938 [00:24<00:00, 37.90it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 938/938 [00:22<00:00, 42.37it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 938/938 [00:21<00:00, 44.35it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 938/938 [00:21<00:00, 44.01it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 938/938 [00:21<00:00, 44.40it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 938/938 [00:20<00:00, 45.11it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 938/938 [00:20<00:00, 45.59it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 938/938 [00:21<00:00, 44.57it/s]
100%|███████████████████████████████████

[<tf.Tensor: id=4405096, shape=(), dtype=float64, numpy=0.19671311230956715>, <tf.Tensor: id=4584313, shape=(), dtype=float64, numpy=0.09328903942903839>, <tf.Tensor: id=4763530, shape=(), dtype=float64, numpy=0.06762130413989303>, <tf.Tensor: id=4942747, shape=(), dtype=float64, numpy=0.05142297952891147>, <tf.Tensor: id=5121964, shape=(), dtype=float64, numpy=0.039816764571163406>, <tf.Tensor: id=5301181, shape=(), dtype=float64, numpy=0.037109097587299036>, <tf.Tensor: id=5480398, shape=(), dtype=float64, numpy=0.03431974684613929>, <tf.Tensor: id=5659615, shape=(), dtype=float64, numpy=0.029289227981580998>, <tf.Tensor: id=5838832, shape=(), dtype=float64, numpy=0.025262298133229184>, <tf.Tensor: id=6018049, shape=(), dtype=float64, numpy=0.024636659453119374>]
[<tf.Tensor: id=4405102, shape=(), dtype=float64, numpy=0.9418833333333333>, <tf.Tensor: id=4584319, shape=(), dtype=float64, numpy=0.97235>, <tf.Tensor: id=4763536, shape=(), dtype=float64, numpy=0.9797166666666667>, <tf.Te

In [39]:
pruned_model_ = unit_prune(model_orig, 90)

In [41]:
losses_pruned, accuracies_pruned = test(pruned_model_, dateset_test)

In [42]:
print(losses_pruned)
print(accuracies_pruned)

2.2603668680616247
0.5594


In [51]:
losses_retrained, accuracies_retrained = re_train(pruned_model_)
print(losses_retrained)
print(accuracies_retrained)

100%|████████████████████████████████████████████████████████████████████████████████| 938/938 [00:14<00:00, 63.02it/s]


[<tf.Tensor: id=6994803, shape=(), dtype=float64, numpy=0.01900134881276799>]
[<tf.Tensor: id=6994809, shape=(), dtype=float64, numpy=0.9944666666666667>]


In [52]:
losses1, accuracies1 = test(pruned_model_, dateset_test)
print(losses1)
print(accuracies1)

0.09755854798758537
0.9777
