In [8]:
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Activation, Flatten
from tensorflow.keras import optimizers

import numpy as np
from tensorflow.keras import backend as K
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical
from keras.constraints import Constraint

print(tf.config.list_physical_devices("GPU"))


[]


In [9]:
class SReLU(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def build(self, input_shape):
        param_shape = (1,) * (len(input_shape) - 1) + (input_shape[-1],)

        self.t_l = self.add_weight(
            name="t_l", shape=param_shape,
            initializer=tf.keras.initializers.Zeros(),
            trainable=True
        )
        self.a_l = self.add_weight(
            name="a_l", shape=param_shape,
            initializer=tf.keras.initializers.Constant(0.2),
            trainable=True
        )
        self.t_r = self.add_weight(
            name="t_r", shape=param_shape,
            initializer=tf.keras.initializers.Constant(1.0),
            trainable=True
        )
        self.a_r = self.add_weight(
            name="a_r", shape=param_shape,
            initializer=tf.keras.initializers.Constant(0.2),
            trainable=True
        )

    def call(self, x):
        left = self.t_l + self.a_l * (x - self.t_l)
        right = self.t_r + self.a_r * (x - self.t_r)
        return tf.where(
            x <= self.t_l,
            left,
            tf.where(x >= self.t_r, right, x)
        )


In [10]:
class MaskWeights(Constraint):
    def __init__(self, mask_var: tf.Variable):
        super().__init__()
        self.mask_var = mask_var  # tf.Variable, updated in-place

    def __call__(self, w):
        return w * tf.cast(self.mask_var, w.dtype)

    def get_config(self):
        return {}


In [11]:
def createWeightsMask(epsilon,noRows, noCols):
    # generate an Erdos Renyi sparse weights mask
    mask_weights = np.random.rand(noRows, noCols)
    prob = 1 - (epsilon * (noRows + noCols)) / (noRows * noCols)  # normal tp have 8x connections
    mask_weights[mask_weights < prob] = 0
    mask_weights[mask_weights >= prob] = 1
    noParameters = np.sum(mask_weights)
    print ("Create Sparse Matrix: No parameters, NoRows, NoCols ",noParameters,noRows,noCols)
    print(mask_weights)
    return [noParameters,mask_weights]

In [12]:
class SET_MLP_CIFAR10:
    def __init__(self):
        # set model parameters
        self.epsilon = 20 # control the sparsity level as discussed in the paper
        self.zeta = 0.3 # the fraction of the weights removed
        self.batch_size = 150 # batch size
        self.maxepoches = 1000 # number of epochs
        self.learning_rate = 0.01 # SGD learning rate
        self.num_classes = 10 # number of classes
        self.momentum=0.9 # SGD momentum

        # generate an Erdos Renyi sparse weights mask for each layer
        [self.noPar1, self.wm1] = createWeightsMask(self.epsilon,32 * 32 *3, 4000)
        [self.noPar2, self.wm2] = createWeightsMask(self.epsilon,4000, 1000)
        [self.noPar3, self.wm3] = createWeightsMask(self.epsilon,1000, 4000)

        # initialize layers weights
        self.w1 = None
        self.w2 = None
        self.w3 = None
        self.w4 = None

        self.wm1_var = tf.Variable(self.wm1.astype("float32"), trainable=False)
        self.wm2_var = tf.Variable(self.wm2.astype("float32"), trainable=False)
        self.wm3_var = tf.Variable(self.wm3.astype("float32"), trainable=False)

        # initialize weights for SReLu activation function
        self.wSRelu1 = None
        self.wSRelu2 = None
        self.wSRelu3 = None

        # create a SET-MLP model
        self.create_model()

        # train the SET-MLP model
        self.train()


    def create_model(self):

        # create a SET-MLP model for CIFAR10 with 3 hidden layers
        self.model = Sequential()
        self.model.add(Flatten(input_shape=(32, 32, 3)))
    
        self.model.add(Dense(4000, name="sparse_1", kernel_constraint=MaskWeights(self.wm1_var)))
        self.model.add(SReLU(name="srelu1"))
        self.model.add(Dropout(0.3))
    
        self.model.add(Dense(1000, name="sparse_2", kernel_constraint=MaskWeights(self.wm2_var)))        
        self.model.add(SReLU(name="srelu2"))
        self.model.add(Dropout(0.3))
    
        self.model.add(Dense(4000, name="sparse_3", kernel_constraint=MaskWeights(self.wm3_var)))
        self.model.add(SReLU(name="srelu3"))
        self.model.add(Dropout(0.3))
    
        self.model.add(Dense(self.num_classes, name="dense_4"))
        self.model.add(Activation("softmax"))

    def rewireMask(self, weights, noWeights):
        # rewire weight matrix
        values = np.sort(weights.ravel())
        zero_idx = np.where(values == 0)[0]
        firstZeroPos = zero_idx[0]
        lastZeroPos = zero_idx[-1]
        largestNegative = values[int((1-self.zeta) * firstZeroPos)]
        smallestPositive = values[int(min(values.shape[0] - 1, lastZeroPos +self.zeta * (values.shape[0] - lastZeroPos)))]
        rewiredWeights = weights.copy();
        rewiredWeights[rewiredWeights > smallestPositive] = 1;
        rewiredWeights[rewiredWeights < largestNegative] = 1;
        rewiredWeights[rewiredWeights != 1] = 0;
        weightMaskCore = rewiredWeights.copy()

        # add zeta random weights
        nrAdd = 0
        noRewires = noWeights - np.sum(rewiredWeights)
        zeros = np.argwhere(rewiredWeights == 0)
        need = int(noRewires)
        if need > 0:
            pick = zeros[np.random.choice(len(zeros), size=need, replace=False)]
            rewiredWeights[pick[:, 0], pick[:, 1]] = 1

        return [rewiredWeights, weightMaskCore]

    def weightsEvolution(self):
        # this represents the core of the SET procedure. It removes the weights closest to zero in each layer and add new random weights
        self.w1 = self.model.get_layer("sparse_1").get_weights()
        self.w2 = self.model.get_layer("sparse_2").get_weights()
        self.w3 = self.model.get_layer("sparse_3").get_weights()
        self.w4 = self.model.get_layer("dense_4").get_weights()

        self.wSRelu1 = self.model.get_layer("srelu1").get_weights()
        self.wSRelu2 = self.model.get_layer("srelu2").get_weights()
        self.wSRelu3 = self.model.get_layer("srelu3").get_weights()

        [self.wm1, self.wm1Core] = self.rewireMask(self.w1[0], self.noPar1)
        [self.wm2, self.wm2Core] = self.rewireMask(self.w2[0], self.noPar2)
        [self.wm3, self.wm3Core] = self.rewireMask(self.w3[0], self.noPar3)

        self.w1[0] = self.w1[0] * self.wm1Core
        self.w2[0] = self.w2[0] * self.wm2Core
        self.w3[0] = self.w3[0] * self.wm3Core

        # --- push pruned weights back into the model ---
        l1 = self.model.get_layer("sparse_1")
        l2 = self.model.get_layer("sparse_2")
        l3 = self.model.get_layer("sparse_3")
        
        l1.set_weights(self.w1)
        l2.set_weights(self.w2)
        l3.set_weights(self.w3)
        
        # --- update constraint masks (this is the critical part) ---
        self.wm1_var.assign(self.wm1.astype("float32"))
        self.wm2_var.assign(self.wm2.astype("float32"))
        self.wm3_var.assign(self.wm3.astype("float32"))
        
        # Optional but recommended: enforce immediately right now (not just after next optimizer step)
        l1.kernel.assign(l1.kernel_constraint(l1.kernel))
        l2.kernel.assign(l2.kernel_constraint(l2.kernel))
        l3.kernel.assign(l3.kernel_constraint(l3.kernel))

    def train(self):

        # read CIFAR10 data
        [x_train,x_test,y_train,y_test]=self.read_data()

        #data augmentation
        datagen = ImageDataGenerator(
            featurewise_center=False,  # set input mean to 0 over the dataset
            samplewise_center=False,  # set each sample mean to 0
            featurewise_std_normalization=False,  # divide inputs by std of the dataset
            samplewise_std_normalization=False,  # divide each input by its std
            zca_whitening=False,  # apply ZCA whitening
            rotation_range=10,  # randomly rotate images in the range (degrees, 0 to 180)
            width_shift_range=0.1,  # randomly shift images horizontally (fraction of total width)
            height_shift_range=0.1,  # randomly shift images vertically (fraction of total height)
            horizontal_flip=True,  # randomly flip images
            vertical_flip=False)  # randomly flip images
        datagen.fit(x_train)

        self.model.summary()

        # training process in a for loop
        self.accuracies_per_epoch = []
        patience = 10
        best_val = -np.inf
        epochs_no_improve = 0

        sgd = optimizers.SGD(learning_rate=self.learning_rate, momentum=self.momentum)
        self.model.compile(loss="categorical_crossentropy", optimizer=sgd, metrics=["accuracy"], jit_compile=True)
        
        for epoch in range(self.maxepoches):
            historytemp = self.model.fit(
                datagen.flow(x_train, y_train, batch_size=self.batch_size),
                steps_per_epoch=x_train.shape[0] // self.batch_size,
                initial_epoch=epoch,
                epochs=epoch + 1,
                validation_data=(x_test, y_test),
                verbose=1,
            )
            val_acc = historytemp.history["val_accuracy"][-1]
            self.accuracies_per_epoch.append(val_acc)

            if val_acc > best_val:
                best_val = val_acc
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1
        
            if epochs_no_improve >= patience:
                print(f"Stopping early at epoch {epoch}")
                break
            
            self.weightsEvolution()  # updates masks/weights for next epoch

        self.accuracies_per_epoch=np.asarray(self.accuracies_per_epoch)

    def read_data(self):

        #read CIFAR10 data
        (x_train, y_train), (x_test, y_test) = cifar10.load_data()
        y_train = to_categorical(y_train, self.num_classes)
        y_test = to_categorical(y_test, self.num_classes)
        x_train = x_train.astype('float32')
        x_test = x_test.astype('float32')

        #normalize data
        xTrainMean = np.mean(x_train, axis=0)
        xTtrainStd = np.std(x_train, axis=0)
        x_train = (x_train - xTrainMean) / xTtrainStd
        x_test = (x_test - xTrainMean) / xTtrainStd

        return [x_train, x_test, y_train, y_test]

In [14]:
class SET_IMPORTANCE_MLP_CIFAR10:
    def __init__(self):
        # set model parameters
        self.epsilon = 20 # control the sparsity level as discussed in the paper
        self.zeta = 0.3 # the fraction of the weights removed
        self.batch_size = 150 # batch size
        self.maxepoches = 1000 # number of epochs
        self.learning_rate = 0.01 # SGD learning rate
        self.num_classes = 10 # number of classes
        self.momentum=0.9 # SGD momentum

        # generate an Erdos Renyi sparse weights mask for each layer
        [self.noPar1, self.wm1] = createWeightsMask(self.epsilon,32 * 32 *3, 4000)
        [self.noPar2, self.wm2] = createWeightsMask(self.epsilon,4000, 1000)
        [self.noPar3, self.wm3] = createWeightsMask(self.epsilon,1000, 4000)

        # initialize layers weights
        self.w1 = None
        self.w2 = None
        self.w3 = None
        self.w4 = None

        self.wm1_var = tf.Variable(self.wm1.astype("float32"), trainable=False)
        self.wm2_var = tf.Variable(self.wm2.astype("float32"), trainable=False)
        self.wm3_var = tf.Variable(self.wm3.astype("float32"), trainable=False)

        # initialize weights for SReLu activation function
        self.wSRelu1 = None
        self.wSRelu2 = None
        self.wSRelu3 = None

        # create a SET-MLP model
        self.create_model()

        # train the SET-MLP model
        self.train()


    def create_model(self):

        # create a SET-MLP model for CIFAR10 with 3 hidden layers
        self.model = Sequential()
        self.model.add(Flatten(input_shape=(32, 32, 3)))
    
        self.model.add(Dense(4000, name="sparse_1", kernel_constraint=MaskWeights(self.wm1_var)))
        self.model.add(SReLU(name="srelu1"))
        self.model.add(Dropout(0.3))
    
        self.model.add(Dense(1000, name="sparse_2", kernel_constraint=MaskWeights(self.wm2_var)))        
        self.model.add(SReLU(name="srelu2"))
        self.model.add(Dropout(0.3))
    
        self.model.add(Dense(4000, name="sparse_3", kernel_constraint=MaskWeights(self.wm3_var)))
        self.model.add(SReLU(name="srelu3"))
        self.model.add(Dropout(0.3))
    
        self.model.add(Dense(self.num_classes, name="dense_4"))
        self.model.add(Activation("softmax"))

    def rewireMask(self, weights, noWeights, I_out, I_in, alpha=0.9, eps=1e-12):
        # rewire weight matrix
        values = np.sort(weights.ravel())
        zero_idx = np.where(values == 0)[0]
        firstZeroPos = zero_idx[0]
        lastZeroPos = zero_idx[-1]
        
        largestNegative = values[int((1-self.zeta) * firstZeroPos)]
        smallestPositive = values[int(min(values.shape[0] - 1, lastZeroPos +self.zeta * (values.shape[0] - lastZeroPos)))]
        
        rewiredWeights = weights.copy()
        rewiredWeights[rewiredWeights > smallestPositive] = 1
        rewiredWeights[rewiredWeights < largestNegative] = 1
        rewiredWeights[rewiredWeights != 1] = 0
        weightMaskCore = rewiredWeights.copy()

        noRewires = noWeights - np.sum(rewiredWeights)
        need = int(noRewires)
        
        if need <= 0:
            return [rewiredWeights, weightMaskCore]

        zeros = np.argwhere(rewiredWeights == 0)  # (N_zero, 2) with (i,j)
        N_zero = zeros.shape[0]
        need = min(need, N_zero)

        # Ensures non-negative importances
        I_out = np.asarray(I_out, dtype=np.float64)
        I_in  = np.asarray(I_in,  dtype=np.float64)
        I_out = np.clip(I_out, 0.0, None)
        I_in  = np.clip(I_in,  0.0, None)

        i = zeros[:, 0]
        j = zeros[:, 1]

        # importance product term: I_i * I_j
        imp_prod = I_out[i] * I_in[j]  # (N_zero,)

        # S_ij = alpha*(I_i I_j) + (1-alpha)*(1/N_zero)
        scores = alpha * imp_prod + (1.0 - alpha) * (1.0 / max(N_zero, 1))

        ssum = scores.sum()
        if not np.isfinite(ssum) or ssum <= eps:
            # fallback to uniform if scores are all ~0
            probs = None
        else:
            probs = scores / ssum
    
        chosen = np.random.choice(N_zero, size=need, replace=False, p=probs)
        pick = zeros[chosen]
        rewiredWeights[pick[:, 0], pick[:, 1]] = 1

        return [rewiredWeights, weightMaskCore]

    def neuron_importance_from_weights(W, mask=None, eps=1e-12):
        A = np.abs(W)
        if mask is not None:
            A = A * mask
    
        I_out = A.sum(axis=1)  # rows
        I_in  = A.sum(axis=0)  # cols
    
        I_out = I_out / (I_out.mean() + eps)
        I_in  = I_in  / (I_in.mean()  + eps)
        return I_out, I_in
    
    def weightsEvolution(self):
        # this represents the core of the SET procedure. It removes the weights closest to zero in each layer and add new random weights
        self.w1 = self.model.get_layer("sparse_1").get_weights()
        self.w2 = self.model.get_layer("sparse_2").get_weights()
        self.w3 = self.model.get_layer("sparse_3").get_weights()
        self.w4 = self.model.get_layer("dense_4").get_weights()

        self.wSRelu1 = self.model.get_layer("srelu1").get_weights()
        self.wSRelu2 = self.model.get_layer("srelu2").get_weights()
        self.wSRelu3 = self.model.get_layer("srelu3").get_weights()
        
        [self.wm1, self.wm1Core] = self.rewireMask(self.w1[0], self.noPar1, I_out, I_in)
        [self.wm2, self.wm2Core] = self.rewireMask(self.w2[0], self.noPar2, I_out, I_in)
        [self.wm3, self.wm3Core] = self.rewireMask(self.w3[0], self.noPar3, I_out, I_in)

        self.w1[0] = self.w1[0] * self.wm1Core
        self.w2[0] = self.w2[0] * self.wm2Core
        self.w3[0] = self.w3[0] * self.wm3Core

        # --- push pruned weights back into the model ---
        l1 = self.model.get_layer("sparse_1")
        l2 = self.model.get_layer("sparse_2")
        l3 = self.model.get_layer("sparse_3")
        
        l1.set_weights(self.w1)
        l2.set_weights(self.w2)
        l3.set_weights(self.w3)
        
        # --- update constraint masks ---
        self.wm1_var.assign(self.wm1.astype("float32"))
        self.wm2_var.assign(self.wm2.astype("float32"))
        self.wm3_var.assign(self.wm3.astype("float32"))
        
        l1.kernel.assign(l1.kernel_constraint(l1.kernel))
        l2.kernel.assign(l2.kernel_constraint(l2.kernel))
        l3.kernel.assign(l3.kernel_constraint(l3.kernel))
    
    def train(self):

        # read CIFAR10 data
        [x_train,x_test,y_train,y_test]=self.read_data()

        #data augmentation
        datagen = ImageDataGenerator(
            featurewise_center=False,  # set input mean to 0 over the dataset
            samplewise_center=False,  # set each sample mean to 0
            featurewise_std_normalization=False,  # divide inputs by std of the dataset
            samplewise_std_normalization=False,  # divide each input by its std
            zca_whitening=False,  # apply ZCA whitening
            rotation_range=10,  # randomly rotate images in the range (degrees, 0 to 180)
            width_shift_range=0.1,  # randomly shift images horizontally (fraction of total width)
            height_shift_range=0.1,  # randomly shift images vertically (fraction of total height)
            horizontal_flip=True,  # randomly flip images
            vertical_flip=False)  # randomly flip images
        datagen.fit(x_train)

        self.model.summary()

        # training process in a for loop
        self.accuracies_per_epoch = []
        patience = 10
        best_val = -np.inf
        epochs_no_improve = 0

        sgd = optimizers.SGD(learning_rate=self.learning_rate, momentum=self.momentum)
        self.model.compile(loss="categorical_crossentropy", optimizer=sgd, metrics=["accuracy"], jit_compile=True)
        
        for epoch in range(self.maxepoches):
            historytemp = self.model.fit(
                datagen.flow(x_train, y_train, batch_size=self.batch_size),
                steps_per_epoch=x_train.shape[0] // self.batch_size,
                initial_epoch=epoch,
                epochs=epoch + 1,
                validation_data=(x_test, y_test),
                verbose=1,
            )
            val_acc = historytemp.history["val_accuracy"][-1]
            self.accuracies_per_epoch.append(val_acc)

            if val_acc > best_val:
                best_val = val_acc
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1
        
            if epochs_no_improve >= patience:
                print(f"Stopping early at epoch {epoch}")
                break

            
            
            self.weightsEvolution()  # updates masks/weights for next epoch

        self.accuracies_per_epoch=np.asarray(self.accuracies_per_epoch)

    def read_data(self):

        #read CIFAR10 data
        (x_train, y_train), (x_test, y_test) = cifar10.load_data()
        y_train = to_categorical(y_train, self.num_classes)
        y_test = to_categorical(y_test, self.num_classes)
        x_train = x_train.astype('float32')
        x_test = x_test.astype('float32')

        #normalize data
        xTrainMean = np.mean(x_train, axis=0)
        xTtrainStd = np.std(x_train, axis=0)
        x_train = (x_train - xTrainMean) / xTtrainStd
        x_test = (x_test - xTrainMean) / xTtrainStd

        return [x_train, x_test, y_train, y_test]

In [15]:
model=SET_MLP_CIFAR10()

# save accuracies over for all training epochs
# in "results" folder you can find the output of running this file
np.savetxt("results/set_mlp_srelu_sgd_cifar10_acc.txt", np.asarray(model.accuracies_per_epoch))

Create Sparse Matrix: No parameters, NoRows, NoCols  141852.0 3072 4000
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
Create Sparse Matrix: No parameters, NoRows, NoCols  100125.0 4000 1000
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
Create Sparse Matrix: No parameters, NoRows, NoCols  100182.0 1000 4000
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]


  super().__init__(**kwargs)





[1m333/333[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m46s[0m 130ms/step - accuracy: 0.1111 - loss: 2.3011 - val_accuracy: 0.1469 - val_loss: 2.2985
Epoch 2/2
[1m122/333[0m [32m━━━━━━━[0m[37m━━━━━━━━━━━━━[0m [1m23s[0m 113ms/step - accuracy: 0.1287 - loss: 2.2988

KeyboardInterrupt: 