In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import tensorflow as tf
import time
import os
from tqdm import tqdm

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
#tf.device('/gpu:1')

sns.set()
tf.enable_eager_execution()
tf.set_random_seed(1867)

In [2]:
#Prepare MNIST Data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

#将数据切片，缓冲区1000，batch64
dataset_train = tf.data.Dataset.from_tensor_slices((
    tf.cast(x_train/255, tf.float32),
    tf.cast(y_train, tf.int64)
)).shuffle(2000).batch(64)

dataset_test = tf.data.Dataset.from_tensor_slices((
    tf.cast(x_test/255, tf.float32),
    tf.cast(y_test, tf.int64)
)).batch(64)

In [3]:
#Model Definition
model_orig = tf.keras.Sequential([
    tf.keras.layers.Conv2D(64, 3, activation=tf.nn.relu, input_shape=(32, 32, 3), padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.001)),
    tf.keras.layers.Conv2D(64, 3, activation=tf.nn.relu, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.001)),
    tf.keras.layers.MaxPool2D(2, 2),
    tf.keras.layers.Dropout(0.25),
    tf.keras.layers.Conv2D(128, 3, activation=tf.nn.relu, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.001)),
    tf.keras.layers.Conv2D(128, 3, activation=tf.nn.relu, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.001)),
    tf.keras.layers.MaxPool2D(2, 2),
    tf.keras.layers.Dropout(0.25),
    tf.keras.layers.Conv2D(256, 3, activation=tf.nn.relu, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.001)),
    tf.keras.layers.Conv2D(256, 3, activation=tf.nn.relu, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.001)),
    tf.keras.layers.MaxPool2D(2, 2),
    tf.keras.layers.Dropout(0.25),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(4096, activation=tf.nn.relu),
    tf.keras.layers.Dense(4096, activation=tf.nn.relu),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(10, use_bias=False),
])



In [4]:
def train_model(model):
    optimizer = tf.train.AdamOptimizer()
    global_step = tf.train.get_or_create_global_step()

    training_losses = []
    training_accuracy = []

    for epoch in range(20):
        epoch_loss_avg = tf.contrib.eager.metrics.Mean()
        epoch_accuracy = tf.contrib.eager.metrics.Accuracy()
        for x, y in tqdm(dataset_train, total=round(len(x_train)/64)):
            x = tf.reshape(x, [-1, 32, 32, 3])
            y = y[:, 0]
            with tf.GradientTape() as tape:
                outputs = model(x)
                loss = tf.losses.softmax_cross_entropy(tf.one_hot(y, 10), outputs)
            grads = tape.gradient(loss, model.trainable_weights)
            optimizer.apply_gradients(zip(grads, model.trainable_weights), global_step)
            epoch_loss_avg(loss)
            epoch_accuracy(tf.argmax(outputs, axis=1, output_type=tf.int64), y)
        training_losses.append(epoch_loss_avg.result())
        training_accuracy.append(epoch_accuracy.result())

    return training_losses, training_accuracy


In [5]:
def test(model, dataset):
    epoch_loss_avg = tf.contrib.eager.metrics.Mean()
    epoch_accuracy = tf.contrib.eager.metrics.Accuracy()
    for x, y in dataset:
        x = tf.reshape(x, [-1, 32, 32, 3])
        y = y[:, 0]
        outputs = model(x)
        loss = tf.losses.softmax_cross_entropy(tf.one_hot(y, 10), outputs)
        epoch_loss_avg(loss)
        epoch_accuracy(tf.argmax(outputs, axis=1, output_type=tf.int64), y)
    return epoch_loss_avg.result().numpy(), epoch_accuracy.result().numpy()


In [6]:
#Train Model
def unit_prune(dense_model, percentile):
    conv_layer_cnt = 6
    pruned_model = tf.keras.models.Sequential()
    filter_norms = []
    weights_thislayer_np = dense_model.trainable_weights[0].numpy()

    # conv layer
    for i_conv_layer in range(conv_layer_cnt):
        weights_nextlayer_np = dense_model.trainable_weights[i_conv_layer * 2 + 2].numpy()
        for i_filter in range(weights_thislayer_np.shape[3]):
            filter_norms.append(np.mean(np.fabs(weights_thislayer_np[:, :, :, i_filter])))
        critical_value = np.percentile(filter_norms, percentile)
        keep_mask = filter_norms >= critical_value
        weights_thislayer_np = weights_thislayer_np[:, :, :, keep_mask]
        if i_conv_layer < conv_layer_cnt - 1:
            weights_nextlayer_np = weights_nextlayer_np[:, :, keep_mask, :]
        else:
            flatten_mask = np.zeros(shape=(len(keep_mask), len(weights_nextlayer_np) // len(keep_mask)))
            flatten_mask[:] = np.array(keep_mask).reshape(-1, 1)
            flatten_mask = flatten_mask.reshape(1, -1)
            flatten_mask = flatten_mask[0]
            weights_nextlayer_np = weights_nextlayer_np[np.argwhere(flatten_mask)[:, 0], :]
        bias = dense_model.trainable_weights[i_conv_layer * 2 + 1].numpy()[keep_mask]
        if i_conv_layer == 0:
            new_layer = tf.keras.layers.Conv2D(weights_thislayer_np.shape[3], 3, activation=tf.nn.relu, padding='same',
                                               input_shape=(32, 32, 3))
        else:
            new_layer = tf.keras.layers.Conv2D(weights_thislayer_np.shape[3], 3, activation=tf.nn.relu, padding='same')
        pruned_model.add(new_layer)
        new_layer.set_weights([weights_thislayer_np, bias])
        weights_thislayer_np = weights_nextlayer_np
        filter_norms.clear()
        if i_conv_layer%2 == 1:
            pruned_model.add(tf.keras.layers.MaxPool2D(2, 2))
            pruned_model.add(tf.keras.layers.Dropout(0.25))

    # flatten_layer
    new_layer = tf.keras.layers.Flatten()
    pruned_model.add(new_layer)

    new_layer = tf.keras.layers.Dense(4096, activation=tf.nn.relu)
    pruned_model.add(new_layer)
    new_layer.set_weights([weights_thislayer_np, dense_model.trainable_weights[13].numpy()])
    
    new_layer = tf.keras.layers.Dense(4096, activation=tf.nn.relu)
    pruned_model.add(new_layer)
    new_layer.set_weights([dense_model.trainable_weights[14].numpy(), dense_model.trainable_weights[15].numpy()])
    # fc layer
    pruned_model.add(tf.keras.layers.Dropout(0.5))
    new_layer = tf.keras.layers.Dense(10, use_bias=False)
    pruned_model.add(new_layer)
    new_layer.set_weights(dense_model.layers[16].get_weights())

    return pruned_model

In [7]:
def fine_tuning(pruned_model, epoch_num):
    optimizer = tf.train.AdamOptimizer()
    global_step = tf.train.get_or_create_global_step()

    training_losses = []
    training_accuracy = []

    for epoch in range(epoch_num):
        epoch_loss_avg = tf.contrib.eager.metrics.Mean()
        epoch_accuracy = tf.contrib.eager.metrics.Accuracy()
        for x, y in tqdm(dataset_train, total=round(len(x_train)/64)):
            with tf.GradientTape() as tape:
                x = tf.reshape(x, [-1, 32, 32, 3])
                y = y[:, 0]
                outputs = pruned_model(x)
                loss = tf.losses.softmax_cross_entropy(tf.one_hot(y, 10), outputs)
            grads = tape.gradient(loss, pruned_model.trainable_weights)
            optimizer.apply_gradients(zip(grads, pruned_model.trainable_weights), global_step)
            epoch_loss_avg(loss)
            epoch_accuracy(tf.argmax(outputs, axis=1, output_type=tf.int64), y)
        training_losses.append(epoch_loss_avg.result())
        training_accuracy.append(epoch_accuracy.result())

    return pruned_model

In [8]:
train_model(model_orig)
model_orig.save('mdoel_orig.h5')

782it [00:33, 23.20it/s]                         
782it [00:29, 26.45it/s]                         
782it [00:29, 26.60it/s]                         
782it [00:29, 26.46it/s]                         
782it [00:29, 26.46it/s]                         
782it [00:29, 26.44it/s]                         
782it [00:29, 26.26it/s]                         
782it [00:29, 26.20it/s]                         
782it [00:29, 26.12it/s]                         
782it [00:30, 26.02it/s]                         
782it [00:31, 25.18it/s]                         
782it [00:32, 24.11it/s]                         
782it [00:29, 26.26it/s]                         
782it [00:29, 26.32it/s]                         
782it [00:29, 26.40it/s]                         
782it [00:30, 25.95it/s]                         
782it [00:30, 25.82it/s]                         
782it [00:30, 25.90it/s]                         
782it [00:30, 25.89it/s]                         
782it [00:30, 25.77it/s]                         


In [13]:
model_pruned = unit_prune(model_orig, 75)

In [10]:
model_ft = model_pruned

In [14]:
model_ft = fine_tuning(model_pruned, 1)
print(test(model_ft, dataset_test))
for i in range(9):
    model_ft = fine_tuning(model_ft, 1)
    print(test(model_ft, dataset_test))

782it [00:17, 44.21it/s]                         
  0%|          | 0/781 [00:00<?, ?it/s]

(0.9695024600454197, 0.6608)


782it [00:17, 44.51it/s]                         
  0%|          | 0/781 [00:00<?, ?it/s]

(0.8523989029371055, 0.7057)


782it [00:17, 44.49it/s]                         
  0%|          | 0/781 [00:00<?, ?it/s]

(0.8754185764652909, 0.6975)


782it [00:17, 44.35it/s]                         
  0%|          | 0/781 [00:00<?, ?it/s]

(0.9700069533791512, 0.7015)


782it [00:17, 44.42it/s]                         
  0%|          | 0/781 [00:00<?, ?it/s]

(1.0234562838153474, 0.7052)


782it [00:17, 44.49it/s]                         
  0%|          | 0/781 [00:00<?, ?it/s]

(1.067143757252177, 0.7037)


782it [00:17, 44.60it/s]                         
  0%|          | 0/781 [00:00<?, ?it/s]

(1.188450474268312, 0.7046)


782it [00:17, 44.47it/s]                         
  0%|          | 0/781 [00:00<?, ?it/s]

(1.193146335471208, 0.7071)


782it [00:17, 44.63it/s]                         
  0%|          | 0/781 [00:00<?, ?it/s]

(1.5332773999803384, 0.7022)


782it [00:17, 44.60it/s]                         


(1.4204845117155913, 0.6995)


In [15]:
model_pruned_ = unit_prune(model_orig, 3)
model_ft_ = fine_tuning(model_pruned_, 1)
print(test(model_ft_, dataset_test))

for i in range(29):
    model_pruned_ = unit_prune(model_ft_, 3//np.power(0.97, i+1))
    model_ft_ = fine_tuning(model_pruned_, 1)
    print(test(model_ft_, dataset_test))



782it [00:29, 26.17it/s]                         


(0.8037976108159229, 0.7386)


782it [00:29, 26.67it/s]                         


(0.7309401130220693, 0.765)


782it [00:29, 26.76it/s]                         


(0.7650851362450107, 0.7527)


782it [00:28, 27.01it/s]                         


(0.810775093971544, 0.75)


782it [00:28, 27.73it/s]                         


(0.7606347003939805, 0.7573)


782it [00:27, 27.96it/s]                         


(0.9036116677864342, 0.7448)


782it [00:27, 28.70it/s]                         


(0.8937102776424141, 0.7519)


782it [00:27, 26.20it/s]                         


(0.9062892209952045, 0.7464)


782it [00:26, 25.84it/s]                         


(0.8908408950468537, 0.751)


782it [00:26, 29.76it/s]                         


(0.9724811845144649, 0.7432)


782it [00:25, 31.22it/s]                         


(0.9642699296307412, 0.7526)


782it [00:24, 31.53it/s]                         


(0.9426583099137446, 0.7573)


782it [00:24, 32.07it/s]                         


(0.9314216014685904, 0.7494)


782it [00:24, 32.50it/s]                         


(0.8936986051926947, 0.7502)


782it [00:23, 33.26it/s]                         


(0.8492751153791027, 0.7503)


782it [00:22, 34.46it/s]                         


(0.7932155722645438, 0.7581)


782it [00:22, 35.18it/s]                         


(0.8141366095299933, 0.752)


782it [00:21, 36.07it/s]                         


(0.8129005576394925, 0.7467)


782it [00:21, 36.76it/s]                         


(0.7694999588902589, 0.7538)


782it [00:20, 37.95it/s]                         


(0.7996066413867245, 0.7392)


782it [00:19, 39.24it/s]                         


(0.7565303416388809, 0.7499)


782it [00:19, 36.09it/s]                         


(0.7573758348537858, 0.7496)


782it [00:19, 41.14it/s]                         


(0.8227326215072802, 0.7247)


782it [00:18, 42.05it/s]                         


(0.8016463351097836, 0.7272)


782it [00:18, 42.65it/s]                         


(0.8007880415126776, 0.7333)


782it [00:18, 42.43it/s]                         


(0.8253344411303283, 0.721)


782it [00:17, 43.69it/s]                         


(0.8255365149230715, 0.7226)


782it [00:18, 44.94it/s]                         


(0.816090854679703, 0.7247)


782it [00:17, 44.99it/s]                         


(0.8061904159321147, 0.7229)


782it [00:17, 45.53it/s]                         


(0.826677365667501, 0.7202)


In [29]:
for i in range(20):
    print(model_ft_5.trainable_weights[i].shape)

(3, 3, 3, 19)
(19,)
(3, 3, 19, 19)
(19,)
(3, 3, 19, 41)
(41,)
(3, 3, 41, 41)
(41,)
(3, 3, 41, 84)
(84,)
(3, 3, 84, 84)
(84,)
(1344, 4096)
(4096,)
(4096, 4096)
(4096,)
(4096, 10)


IndexError: list index out of range

In [9]:
print(test(model_orig, dataset_test))

(1.6923481115869656, 0.735)
