In [2]:
# Author: Decebal Constantin Mocanu et al. (legacy SET-MLP CIFAR-10)
# TensorFlow 2 / Keras 3 + Neuron-level EMA growth bias
#
# Method change:
# - Pruning: same as legacy SET (weight-based thresholds near 0)
# - Growth: biased toward neurons that are persistently active
#          using EMA of mean absolute activations (neuron-level importance)

import os
import numpy as np
import tensorflow as tf

from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Activation, Flatten
from tensorflow.keras import optimizers
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical

In [4]:
# ----------------------------
# SReLU
# ----------------------------
class SReLU(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def build(self, input_shape):
        self.tl = self.add_weight(
            name="tl", shape=(), initializer=tf.keras.initializers.Constant(-1.0)
        )
        self.tr = self.add_weight(
            name="tr", shape=(), initializer=tf.keras.initializers.Constant(1.0)
        )
        self.al = self.add_weight(
            name="al", shape=(), initializer=tf.keras.initializers.Constant(0.2)
        )
        self.ar = self.add_weight(
            name="ar", shape=(), initializer=tf.keras.initializers.Constant(0.2)
        )
        super().build(input_shape)

    def call(self, x):
        return tf.where(
            x < self.tl,
            self.tl + self.al * (x - self.tl),
            tf.where(
                x > self.tr,
                self.tr + self.ar * (x - self.tr),
                x
            )
        )

In [5]:
# ----------------------------
# Utility helpers (legacy logic)
# ----------------------------
def find_first_pos(array, value):
    idx = (np.abs(array - value)).argmin()
    return idx


def find_last_pos(array, value):
    idx = (np.abs(array - value))[::-1].argmin()
    return array.shape[0] - idx


def createWeightsMask(epsilon, noRows, noCols):
    """
    Generate an Erdos-Renyi sparse mask.
    """
    mask_weights = np.random.rand(noRows, noCols).astype(np.float32)
    prob = 1 - (epsilon * (noRows + noCols)) / (noRows * noCols)
    mask_weights[mask_weights < prob] = 0.0
    mask_weights[mask_weights >= prob] = 1.0
    noParameters = int(np.sum(mask_weights))
    print("Create Sparse Matrix: No parameters, NoRows, NoCols ", noParameters, noRows, noCols)
    return noParameters, mask_weights


def safe_probs(scores: np.ndarray) -> np.ndarray:
    """
    Convert nonnegative scores to a probability distribution.
    If all scores are ~0, return uniform.
    """
    s = np.asarray(scores, dtype=np.float64)
    s = np.maximum(s, 0.0)
    total = float(s.sum())
    if not np.isfinite(total) or total <= 1e-12:
        return np.ones_like(s, dtype=np.float64) / float(s.size)
    return s / total

In [6]:
# ----------------------------
# Batch-tracking generator
# ----------------------------
class TrackingFlow:
    """
    Wraps a Keras/Numpy iterator (e.g., datagen.flow) and stores last batch X
    so the callback can compute activations.
    """
    def __init__(self, iterator, parent):
        self.it = iterator
        self.p = parent

    def __iter__(self):
        while True:
            x, y = next(self.it)
            self.p.last_batch_x = x
            yield x, y

In [7]:
#----------------------------
# Callback implementing:
# - mask enforcement per batch
# - neuron activation EMA updates (every N batches)
# - rewiring at epoch end (growth biased by EMA)
# ----------------------------
class SETCallback(tf.keras.callbacks.Callback):
    def __init__(self, parent):
        super().__init__()
        self.p = parent

    @staticmethod
    def _apply_mask_to_layer(layer: tf.keras.layers.Dense, mask_np: np.ndarray):
        mask_tf = tf.convert_to_tensor(mask_np, dtype=layer.kernel.dtype)
        layer.kernel.assign(layer.kernel * mask_tf)

    def on_epoch_begin(self, epoch, logs=None):
        self.p.current_epoch = int(epoch)

    def on_train_batch_end(self, batch, logs=None):
        # 1) Enforce full masks (keeps forbidden connections at exactly 0)
        l1 = self.model.get_layer("sparse_1")
        l2 = self.model.get_layer("sparse_2")
        l3 = self.model.get_layer("sparse_3")

        self._apply_mask_to_layer(l1, self.p.wm1)
        self._apply_mask_to_layer(l2, self.p.wm2)
        self._apply_mask_to_layer(l3, self.p.wm3)

        # 2) Update neuron-activity EMA every N batches (performance)
        if (batch % self.p.act_update_every) != 0:
            return

        x = self.p.last_batch_x
        if x is None:
            return

        # Input "neurons" (flattened pixels/features)
        xf = x.reshape((x.shape[0], -1)).astype(np.float32)  # (B, 3072)
        inp_batch = np.mean(np.abs(xf), axis=0)              # (3072,)

        # Hidden neuron activations (after SReLU, before dropout)
        a1, a2, a3 = self.p.act_model(x, training=False)
        a1 = tf.reduce_mean(tf.abs(a1), axis=0).numpy()
        a2 = tf.reduce_mean(tf.abs(a2), axis=0).numpy()
        a3 = tf.reduce_mean(tf.abs(a3), axis=0).numpy()

        b = self.p.beta_act
        self.p.inp_ema = b * self.p.inp_ema + (1 - b) * inp_batch
        self.p.act1_ema = b * self.p.act1_ema + (1 - b) * a1
        self.p.act2_ema = b * self.p.act2_ema + (1 - b) * a2
        self.p.act3_ema = b * self.p.act3_ema + (1 - b) * a3

    def on_epoch_end(self, epoch, logs=None):
        # Store validation accuracy
        if logs is not None:
            va = logs.get("val_accuracy", None)
            if va is None:
                va = logs.get("val_acc", None)
            if va is not None:
                self.p.accuracies_per_epoch.append(float(va))

        # Rewire masks at epoch end
        self.p.weightsEvolution()

In [8]:
# ----------------------------
# Main model class
# ----------------------------
class SET_MLP_CIFAR10:
    def __init__(self):
        # SET parameters
        self.epsilon = 20
        self.zeta = 0.3
        self.batch_size = 100
        self.maxepoches = 1000
        self.learning_rate = 0.01
        self.num_classes = 10
        self.momentum = 0.9

        # Neuron-level EMA parameters
        self.beta_act = 0.99
        self.burn_in_epochs = 10       # uniform growth before this epoch
        self.act_update_every = 10     # update activation EMA every N batches
        self.current_epoch = 0
        self.last_batch_x = None

        # ER masks
        self.noPar1, self.wm1 = createWeightsMask(self.epsilon, 32 * 32 * 3, 4000)
        self.noPar2, self.wm2 = createWeightsMask(self.epsilon, 4000, 1000)
        self.noPar3, self.wm3 = createWeightsMask(self.epsilon, 1000, 4000)

        # Neuron activity EMA buffers
        self.inp_ema = np.zeros(3072, dtype=np.float32)
        self.act1_ema = np.zeros(4000, dtype=np.float32)
        self.act2_ema = np.zeros(1000, dtype=np.float32)
        self.act3_ema = np.zeros(4000, dtype=np.float32)

        # Acc tracking
        self.accuracies_per_epoch = []

        # Build model + activation probe
        self.create_model()
        self._apply_initial_masks()

        # Train
        self.train()

    def create_model(self):
        self.model = Sequential()
        self.model.add(Flatten(input_shape=(32, 32, 3)))

        self.model.add(Dense(4000, name="sparse_1"))
        self.model.add(SReLU(name="srelu1"))
        self.model.add(Dropout(0.3))

        self.model.add(Dense(1000, name="sparse_2"))
        self.model.add(SReLU(name="srelu2"))
        self.model.add(Dropout(0.3))

        self.model.add(Dense(4000, name="sparse_3"))
        self.model.add(SReLU(name="srelu3"))
        self.model.add(Dropout(0.3))

        self.model.add(Dense(self.num_classes, name="dense_4"))
        self.model.add(Activation("softmax"))

        # Build weights (needed before referencing layer outputs)
        dummy = tf.zeros((1, 32, 32, 3), dtype=tf.float32)
        _ = self.model(dummy, training=False)

        # Activation probe model (after SReLU layers)
        self.act_model = tf.keras.Model(
            inputs=self.model.input,
            outputs=[
                self.model.get_layer("srelu1").output,
                self.model.get_layer("srelu2").output,
                self.model.get_layer("srelu3").output,
            ]
        )

        # Compile once
        sgd = optimizers.SGD(learning_rate=self.learning_rate, momentum=self.momentum)
        self.model.compile(loss="categorical_crossentropy", optimizer=sgd, metrics=["accuracy"])

    def _apply_initial_masks(self):
        # Ensure sparse connectivity at initialization
        l1 = self.model.get_layer("sparse_1")
        l2 = self.model.get_layer("sparse_2")
        l3 = self.model.get_layer("sparse_3")

        l1.kernel.assign(l1.kernel * tf.convert_to_tensor(self.wm1, dtype=l1.kernel.dtype))
        l2.kernel.assign(l2.kernel * tf.convert_to_tensor(self.wm2, dtype=l2.kernel.dtype))
        l3.kernel.assign(l3.kernel * tf.convert_to_tensor(self.wm3, dtype=l3.kernel.dtype))

    def rewireMask(self, weights, noWeights, pre_scores=None, post_scores=None):
        """
        Prune weights near 0 (legacy SET-style thresholds),
        then add new edges. Growth is biased by neuron activity EMA after burn-in.
        """
        # ---- Pruning (legacy logic) ----
        values = np.sort(weights.ravel())
        firstZeroPos = find_first_pos(values, 0)
        lastZeroPos = find_last_pos(values, 0)

        firstZeroPos = int(np.clip(firstZeroPos, 0, values.shape[0] - 1))
        lastZeroPos = int(np.clip(lastZeroPos, 0, values.shape[0] - 1))

        largestNegative = values[int((1 - self.zeta) * firstZeroPos)]
        smallestPositive = values[int(
            min(values.shape[0] - 1, lastZeroPos + self.zeta * (values.shape[0] - lastZeroPos))
        )]

        rewired = weights.copy()
        rewired[rewired > smallestPositive] = 1
        rewired[rewired < largestNegative] = 1
        rewired[rewired != 1] = 0
        core_mask = rewired.copy().astype(np.float32)

        # ---- Growth (biased by neuron activity EMA) ----
        nrAdd = 0
        target_nonzeros = int(noWeights)
        current_nonzeros = int(np.sum(rewired))
        noRewires = int(target_nonzeros - current_nonzeros)

        # If something goes odd, protect
        if noRewires <= 0:
            return rewired.astype(np.float32), core_mask

        use_bias = (self.current_epoch >= self.burn_in_epochs) and (pre_scores is not None) and (post_scores is not None)
        if use_bias:
            p_pre = safe_probs(pre_scores)
            p_post = safe_probs(post_scores)

        rows, cols = rewired.shape

        # Add edges until we hit the original parameter budget
        while nrAdd < noRewires:
            if use_bias:
                i = np.random.choice(rows, p=p_pre)
                j = np.random.choice(cols, p=p_post)
            else:
                i = np.random.randint(0, rows)
                j = np.random.randint(0, cols)

            if rewired[i, j] == 0:
                rewired[i, j] = 1
                nrAdd += 1

        return rewired.astype(np.float32), core_mask

    def weightsEvolution(self):
        """
        Rewire each sparse layer:
        - compute new masks using pruning + biased growth
        - apply core mask to live weights (prunes + sets new edges to 0 initially)
        """
        l1 = self.model.get_layer("sparse_1")
        l2 = self.model.get_layer("sparse_2")
        l3 = self.model.get_layer("sparse_3")

        w1 = l1.kernel.numpy()
        w2 = l2.kernel.numpy()
        w3 = l3.kernel.numpy()

        # Layer1: pre=input neurons, post=hidden1 neurons
        self.wm1, wm1Core = self.rewireMask(
            w1, self.noPar1,
            pre_scores=self.inp_ema,
            post_scores=self.act1_ema
        )

        # Layer2: pre=hidden1, post=hidden2
        self.wm2, wm2Core = self.rewireMask(
            w2, self.noPar2,
            pre_scores=self.act1_ema,
            post_scores=self.act2_ema
        )

        # Layer3: pre=hidden2, post=hidden3
        self.wm3, wm3Core = self.rewireMask(
            w3, self.noPar3,
            pre_scores=self.act2_ema,
            post_scores=self.act3_ema
        )

        # Apply core masks to weights:
        # - removes pruned edges immediately
        # - new edges start at 0 (since core mask has 0 there)
        l1.kernel.assign(l1.kernel * tf.convert_to_tensor(wm1Core, dtype=l1.kernel.dtype))
        l2.kernel.assign(l2.kernel * tf.convert_to_tensor(wm2Core, dtype=l2.kernel.dtype))
        l3.kernel.assign(l3.kernel * tf.convert_to_tensor(wm3Core, dtype=l3.kernel.dtype))

    def read_data(self):
        (x_train, y_train), (x_test, y_test) = cifar10.load_data()
        y_train = to_categorical(y_train, self.num_classes)
        y_test = to_categorical(y_test, self.num_classes)

        x_train = x_train.astype("float32")
        x_test = x_test.astype("float32")

        # normalize (same as legacy)
        mean = np.mean(x_train, axis=0)
        std = np.std(x_train, axis=0) + 1e-8
        x_train = (x_train - mean) / std
        x_test = (x_test - mean) / std

        return x_train, x_test, y_train, y_test

    def train(self):
        x_train, x_test, y_train, y_test = self.read_data()

        datagen = ImageDataGenerator(
            featurewise_center=False,
            samplewise_center=False,
            featurewise_std_normalization=False,
            samplewise_std_normalization=False,
            zca_whitening=False,
            rotation_range=10,
            width_shift_range=0.1,
            height_shift_range=0.1,
            horizontal_flip=True,
            vertical_flip=False
        )
        datagen.fit(x_train)

        self.model.summary()

        steps_per_epoch = x_train.shape[0] // self.batch_size

        base_it = datagen.flow(x_train, y_train, batch_size=self.batch_size)
        tracked = TrackingFlow(base_it, self)

        set_cb = SETCallback(self)

        self.model.fit(
            iter(tracked),
            steps_per_epoch=steps_per_epoch,
            epochs=self.maxepoches,
            validation_data=(x_test, y_test),
            callbacks=[set_cb],
            verbose=1
        )

        self.accuracies_per_epoch = np.asarray(self.accuracies_per_epoch, dtype=np.float32)

In [9]:
if __name__ == "__main__":
    os.makedirs("results", exist_ok=True)

    model = SET_MLP_CIFAR10()

    np.savetxt(
        "results/set_mlp_neuronEMA_growth_cifar10_acc.txt",
        np.asarray(model.accuracies_per_epoch)
    )

Create Sparse Matrix: No parameters, NoRows, NoCols  141345 3072 4000
Create Sparse Matrix: No parameters, NoRows, NoCols  99914 4000 1000
Create Sparse Matrix: No parameters, NoRows, NoCols  99612 1000 4000
Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 flatten_1 (Flatten)         (None, 3072)              0         
                                                                 
 sparse_1 (Dense)            (None, 4000)              12292000  
                                                                 
 srelu1 (SReLU)              (None, 4000)              4         
                                                                 
 dropout_3 (Dropout)         (None, 4000)              0         
                                                                 
 sparse_2 (Dense)            (None, 1000)              4001000   
                                            

KeyboardInterrupt: 