In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np

In [2]:
# Load MNIST dataset
(mnist_train_images, mnist_train_labels), (mnist_test_images, mnist_test_labels) = tf.keras.datasets.fashion_mnist.load_data()

In [3]:
# Preprocess the data by adding a channel dimension and normalizing
mnist_train_images = mnist_train_images.reshape(-1, 28, 28, 1).astype('float32') / 255
mnist_test_images = mnist_test_images.reshape(-1, 28, 28, 1).astype('float32') / 255


In [4]:
class SparseConv2D(layers.Layer):
    def __init__(self, filters, kernel_size, p, **kwargs):
        super(SparseConv2D, self).__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.p = tf.Variable(float(p), trainable=False)  # 将 p 转换为浮点数类型
        self.counter = tf.Variable(0, trainable=False, dtype=tf.int32)  # 初始化计数器

    def build(self, input_shape):
        self.kernel = self.add_weight(name='kernel',
                                      shape=(self.kernel_size, self.kernel_size, input_shape[-1], self.filters),
                                      initializer='glorot_uniform',
                                      trainable=True)
        self.bias = self.add_weight(name='bias',
                                    shape=(self.filters,),
                                    initializer='zeros',
                                    trainable=True)

    @tf.function
    def call(self, inputs, training=None):
        if training:
            mask = tf.random.uniform(shape=(self.filters,), minval=0, maxval=1)
            mask = tf.cast(mask < self.p, dtype=tf.float32)
            mask = tf.reshape(mask, [1, 1, 1, self.filters])
            #self.counter.assign_add(1)  # 更新计数器
            #tf.print("\nP is", self.p)  # 使用 tf.print
        else:
            mask = tf.ones([1, 1, 1, self.filters], dtype=tf.float32) * self.p
    
        sparse_kernel = self.kernel * mask
        conv = tf.nn.conv2d(inputs, sparse_kernel, strides=[1, 1, 1, 1], padding='SAME')
        return tf.nn.bias_add(conv, self.bias)

    @tf.function
    def update_p(self, new_p):
        self.p.assign(float(new_p))  # 使用 assign 更新 tf.Variable 的值，并转换为浮点数
        #tf.print("\nEpoch counter is", self.counter)  # 使用 tf.print
        #tf.print("\nP is", self.p)  # 使用 tf.print

In [5]:
inputs = tf.keras.Input(shape=(28, 28, 1))  # Adjusted input shape for MNIST
x = SparseConv2D(filters=32, kernel_size=3, p=1, name='sparse_conv2d_1')(inputs)  # 第一个稀疏卷积层
x = layers.Activation('relu')(x)
x = layers.MaxPooling2D(pool_size=(2, 2))(x)

x = SparseConv2D(filters=64, kernel_size=3, p=1, name='sparse_conv2d_2')(x)
x = layers.Activation('relu')(x)
x = layers.MaxPooling2D(pool_size=(2, 2))(x)

x = SparseConv2D(filters=128, kernel_size=3, p=1, name='sparse_conv2d_3')(x)
x = layers.Activation('relu')(x)
x = layers.Flatten()(x)

x = layers.Dense(128, activation='relu')(x)
x = layers.Dense(64, activation='relu')(x)
outputs = layers.Dense(10, activation='softmax')(x)  # Adjusted for 10 classes of MNIST

model = models.Model(inputs=inputs, outputs=outputs)
model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
sparse_conv2d_1 (SparseConv2 (None, 28, 28, 32)        322       
_________________________________________________________________
activation (Activation)      (None, 28, 28, 32)        0         
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 14, 14, 32)        0         
_________________________________________________________________
sparse_conv2d_2 (SparseConv2 (None, 14, 14, 64)        18498     
_________________________________________________________________
activation_1 (Activation)    (None, 14, 14, 64)        0         
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 7, 7, 64)          0     

In [6]:
class UpdatePSparsity(tf.keras.callbacks.Callback):
    def __init__(self, model, sparsity_schedule):
        super(UpdatePSparsity, self).__init__()
        self.model = model
        self.sparsity_schedule = sparsity_schedule

    def on_epoch_end(self, epoch, logs=None):
        for layer_name, new_p in self.sparsity_schedule.items():
            layer = self.model.get_layer(name=layer_name)
            if epoch < len(new_p):
                p_value = new_p[epoch]
            else:
                p_value = new_p[-1]  # Use the last value for epochs beyond the predefined ones
            layer.update_p(p_value)

sparsity_schedule = {
    'sparse_conv2d_1': [1.0],
    'sparse_conv2d_2': [1.0],
    'sparse_conv2d_3': [0.9,0.8,0.7,0.6,0.5,0.4,0.3]
}

In [7]:
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
model.fit(mnist_train_images, mnist_train_labels, epochs=40, batch_size=128, validation_data=(mnist_test_images, mnist_test_labels), callbacks=[UpdatePSparsity(model, sparsity_schedule)])

Epoch 1/40
Epoch 2/40
Epoch 3/40
Epoch 4/40
Epoch 5/40
Epoch 6/40
Epoch 7/40
Epoch 8/40
Epoch 9/40
Epoch 10/40
Epoch 11/40
Epoch 12/40
Epoch 13/40
Epoch 14/40
Epoch 15/40
Epoch 16/40
Epoch 17/40
Epoch 18/40
Epoch 19/40
Epoch 20/40
Epoch 21/40
Epoch 22/40
Epoch 23/40
Epoch 24/40
Epoch 25/40
Epoch 26/40
Epoch 27/40
Epoch 28/40
Epoch 29/40
Epoch 30/40
Epoch 31/40
Epoch 32/40
Epoch 33/40
Epoch 34/40
Epoch 35/40
Epoch 36/40
Epoch 37/40
Epoch 38/40
Epoch 39/40
Epoch 40/40


<tensorflow.python.keras.callbacks.History at 0x17783fbcf08>