In [1]:
import os
import gc
import time
import numpy as np
import tensorflow as tf

os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")

# GPU memory growth
gpus = tf.config.list_physical_devices("GPU")
for gpu in gpus:
    try:
        tf.config.experimental.set_memory_growth(gpu, True)
    except Exception:
        pass

2025-12-09 13:41:37.986623: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-12-09 13:41:38.023001: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-12-09 13:41:38.023025: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-12-09 13:41:38.023063: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-12-09 13:41:38.033373: I tensorflow/core/platform/cpu_feature_g

In [2]:
class SReLU(tf.keras.layers.Layer):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def build(self, input_shape):
        
        self.tl = self.add_weight("tl", shape=(), initializer=tf.keras.initializers.Constant(-1.0))
        self.tr = self.add_weight("tr", shape=(), initializer=tf.keras.initializers.Constant(1.0))
        self.al = self.add_weight("al", shape=(), initializer=tf.keras.initializers.Constant(0.2))
        self.ar = self.add_weight("ar", shape=(), initializer=tf.keras.initializers.Constant(0.2))
        super().build(input_shape)

    def call(self, x):
        tl = self.tl
        tr = self.tr
        al = self.al
        ar = self.ar

        left = tl + al * (x - tl)
        right = tr + ar * (x - tr)
        mid = x

        return tf.where(x <= tl, left, tf.where(x >= tr, right, mid))

In [3]:
class MaskedDense(tf.keras.layers.Layer):
    def __init__(self, units, eps=20, use_bias=True, **kwargs):

        super().__init__(**kwargs)
        self.units = int(units)
        self.eps = float(eps)
        self.use_bias = use_bias

    def build(self, input_shape):
        n_in = int(input_shape[-1])
        n_out = self.units
        self.n_in = n_in
        self.n_out = n_out

        self.kernel = self.add_weight(
            name="kernel",
            shape=(n_in, n_out),
            initializer="he_uniform",
            trainable=True,
        )

        if self.use_bias:
            self.bias = self.add_weight(
                name="bias",
                shape=(n_out,),
                initializer="zeros",
                trainable=True,
            )
        else:
            self.bias = None

        # mask
        p = self.eps * (n_in + n_out) / (n_in * n_out)
        p_eff = min(p, 1.0)

        M = (np.random.rand(n_in, n_out) < p_eff).astype(np.float32)

        for j in range(n_out):
            if M[:, j].sum() == 0:
                i = np.random.randint(0, n_in)
                M[i, j] = 1.0

        self.mask = tf.Variable(
            initial_value=M,
            trainable=False,
            dtype=tf.float32,
            name="mask",
        )

        # initial zeros
        self.apply_mask()
        super().build(input_shape)

    def call(self, inputs):
        w_eff = self.kernel * self.mask
        out = tf.linalg.matmul(inputs, w_eff)
        if self.bias is not None:
            out = out + self.bias
        return out

    def apply_mask(self):
        self.kernel.assign(self.kernel * self.mask)

    def prune_and_regrow(self, zeta=0.3, strategy="random", alpha=0.5):

        k = self.kernel.numpy().copy()
        m = self.mask.numpy().copy()

        nz = np.argwhere(m > 0)
        num_nz = nz.shape[0]
        if num_nz == 0:
            return

        n_prune = int(zeta * num_nz)
        if n_prune <= 0:
            return

        #PRUNE SMALLEST ABS(W) 
        w_nz = np.abs(k[m > 0])
        thresh = np.partition(w_nz, n_prune - 1)[n_prune - 1]

        prune_candidates = np.argwhere((m > 0) & (np.abs(k) <= thresh))
        if prune_candidates.shape[0] > n_prune:
            sel = np.random.choice(prune_candidates.shape[0], size=n_prune, replace=False)
            prune_idx = prune_candidates[sel]
        else:
            prune_idx = prune_candidates

        for i, j in prune_idx:
            m[i, j] = 0.0
            k[i, j] = 0.0

        # REGROW 
        zeros = np.argwhere(m == 0)
        if zeros.shape[0] == 0:
            self.mask.assign(m.astype(np.float32))
            self.kernel.assign(k.astype(np.float32))
            return

        n_prune = min(n_prune, zeros.shape[0])

        #TRYING MAGNITUDE BASED
        if strategy == "importance":
            
            w_eff = k * m
            I_in = np.sum(np.abs(w_eff), axis=1)  
            I_out = np.sum(np.abs(w_eff), axis=0) 

            base = I_in[zeros[:, 0]] * I_out[zeros[:, 1]]
            mean_base = base.mean() if base.size else 0.0
            scores = alpha * base + (1.0 - alpha) * mean_base

            # Robust sampling to avoid alpha=1.0 failure
            pos = scores > 0
            pos_idx = np.where(pos)[0]

            if len(pos_idx) >= n_prune:
                pos_scores = scores[pos_idx]
                ssum = pos_scores.sum()
                if ssum <= 0:
                    grow_pick = np.random.choice(len(zeros), n_prune, replace=False)
                else:
                    pos_probs = pos_scores / ssum
                    grow_pick = np.random.choice(pos_idx, n_prune, replace=False, p=pos_probs)
            else:
                
                chosen = list(pos_idx)

               
                remaining = n_prune - len(pos_idx)
                rest_idx = np.where(~pos)[0]
                if remaining > 0 and len(rest_idx) > 0:
                    extra = np.random.choice(rest_idx, remaining, replace=False)
                    chosen = np.concatenate([np.array(chosen, dtype=int), extra]) if len(chosen) else extra
                grow_pick = np.array(chosen, dtype=int)

        else:
            # RANDOM REGROWTH
            grow_pick = np.random.choice(len(zeros), n_prune, replace=False)

        grow_idx = zeros[grow_pick]

        # He-uniform init for new edges
        limit = np.sqrt(6.0 / self.n_in)
        for i, j in grow_idx:
            m[i, j] = 1.0
            k[i, j] = np.random.uniform(-limit, limit)

        self.mask.assign(m.astype(np.float32))
        self.kernel.assign(k.astype(np.float32))

In [4]:
class SETModel(tf.keras.Model):
    def __init__(self, *args, weight_decay=0.0, **kwargs):
        super().__init__(*args, **kwargs)
        self.weight_decay = float(weight_decay)

    def train_step(self, data):
        x, y = data
        y = tf.cast(y, tf.int32)

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            loss = self.compiled_loss(y, y_pred)

            # L2 on ACTIVE EDGES ONLY (WEIGHT DECAY 0.002 IN PAPER!!)
            if self.weight_decay > 0:
                wd = 0.0
                for layer in self.layers:
                    if isinstance(layer, MaskedDense):
                        wd += tf.reduce_sum(tf.square(layer.kernel * layer.mask))
                loss += self.weight_decay * wd

        grads = tape.gradient(loss, self.trainable_variables)
        grads = list(grads)

        # MASK GRADIENT FOR KERNELS
        for layer in self.layers:
            if isinstance(layer, MaskedDense):
                k = layer.kernel
                m = layer.mask
                for i, var in enumerate(self.trainable_variables):
                    if var is k and grads[i] is not None:
                        grads[i] = grads[i] * m
                        break

        self.optimizer.apply_gradients(zip(grads, self.trainable_variables))

        # HARD ENFORCE ZERO (FOR MOMENTUM)
        for layer in self.layers:
            if isinstance(layer, MaskedDense):
                layer.apply_mask()

        self.compiled_metrics.update_state(y, y_pred)
        out = {m.name: m.result() for m in self.metrics}
        out["loss"] = loss
        return out

In [5]:
class SETCallback(tf.keras.callbacks.Callback):
    def __init__(self, zeta=0.3, strategy="random", alpha=0.5):
        super().__init__()
        self.zeta = float(zeta)
        self.strategy = strategy
        self.alpha = float(alpha)

    def on_epoch_end(self, epoch, logs=None):
        for layer in self.model.layers:
            if isinstance(layer, MaskedDense):
                layer.prune_and_regrow(
                    zeta=self.zeta,
                    strategy=self.strategy,
                    alpha=self.alpha
                )

In [6]:
# IN PAPER: 784-1000-1000-1000-10
def build_set_mlp(eps=20, weight_decay=2e-4, dropout=0.3):
    
    inputs = tf.keras.Input(shape=(784,))

    x = MaskedDense(1000, eps=eps, name="md1")(inputs)
    x = SReLU(name="srelu1")(x)
    x = tf.keras.layers.Dropout(dropout)(x)

    x = MaskedDense(1000, eps=eps, name="md2")(x)
    x = SReLU(name="srelu2")(x)
    x = tf.keras.layers.Dropout(dropout)(x)

    x = MaskedDense(1000, eps=eps, name="md3")(x)
    x = SReLU(name="srelu3")(x)
    x = tf.keras.layers.Dropout(dropout)(x)

    outputs = MaskedDense(10, eps=eps, name="md4")(x)

    return SETModel(inputs, outputs, weight_decay=weight_decay)

In [7]:
def build_dense_mlp(weight_decay=2e-4, dropout=0.3):
    reg = tf.keras.regularizers.l2(weight_decay)

    inputs = tf.keras.Input(shape=(784,))

    x = tf.keras.layers.Dense(1000, kernel_regularizer=reg)(inputs)
    x = SReLU(name="srelu1_dense")(x)
    x = tf.keras.layers.Dropout(dropout)(x)

    x = tf.keras.layers.Dense(1000, kernel_regularizer=reg)(x)
    x = SReLU(name="srelu2_dense")(x)
    x = tf.keras.layers.Dropout(dropout)(x)

    x = tf.keras.layers.Dense(1000, kernel_regularizer=reg)(x)
    x = SReLU(name="srelu3_dense")(x)
    x = tf.keras.layers.Dropout(dropout)(x)

    outputs = tf.keras.layers.Dense(10, kernel_regularizer=reg)(x)

    return tf.keras.Model(inputs, outputs)


In [8]:
def count_active_weights(model):

    total = 0
    for layer in model.layers:
        if isinstance(layer, MaskedDense):
            total += int(tf.reduce_sum(layer.mask).numpy())
            if layer.bias is not None:
                total += int(layer.bias.shape[0])
        elif isinstance(layer, tf.keras.layers.Dense):
            total += int(np.prod(layer.kernel.shape))
            if layer.bias is not None:
                total += int(layer.bias.shape[0])
    return total


In [9]:
class TestEvalCallback(tf.keras.callbacks.Callback):

    def __init__(self, x_test, y_test, every=5, batch_size=256):
        super().__init__()
        self.every = int(every)
        self.test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size)

    def on_epoch_end(self, epoch, logs=None):
        if self.every <= 0:
            return
        if (epoch + 1) % self.every != 0:
            return
        test_loss, test_acc = self.model.evaluate(self.test_ds, verbose=0)
        print(f"  >> Test@{epoch+1:02d} - loss: {test_loss:.4f} - acc: {test_acc:.4f}")


In [16]:
def run_and_report(
    name,
    model,
    x_train, y_train, x_test, y_test,
    callbacks=None,
    epochs=20,
    batch_size=128,
    test_every=5,
    test_batch_size=256,
    target_test_acc=None,   
    min_epoch=5 ):
    
    callbacks = callbacks or []


    test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(test_batch_size)

    class TestEvalCallback(tf.keras.callbacks.Callback):
        def __init__(self, test_ds, every=1):
            super().__init__()
            self.test_ds = test_ds
            self.every = int(every)

        def on_epoch_end(self, epoch, logs=None):
            e = epoch + 1
            if self.every > 1 and (e % self.every != 0):
                return
            test_loss, test_acc = self.model.evaluate(self.test_ds, verbose=0)
            print(f"  >> Test@{e:03d} - loss: {test_loss:.4f} - acc: {test_acc:.4f}")

    class StopOnTestAccuracy(tf.keras.callbacks.Callback):
        def __init__(self, test_ds, target, min_epoch=1, every=1):
            super().__init__()
            self.test_ds = test_ds
            self.target = target
            self.min_epoch = int(min_epoch)
            self.every = int(every)

        def on_epoch_end(self, epoch, logs=None):
            if self.target is None:
                return
            e = epoch + 1
            if e < self.min_epoch:
                return
            if self.every > 1 and (e % self.every != 0):
                return

            _, test_acc = self.model.evaluate(self.test_ds, verbose=0)
            if test_acc >= self.target:
                print(f"  >> Reached target test acc {self.target:.4f} at epoch {e}. Stopping.")
                self.model.stop_training = True

    # Add callbacks
    callbacks = callbacks + [
        TestEvalCallback(test_ds, every=test_every),
        StopOnTestAccuracy(test_ds, target_test_acc, min_epoch=min_epoch, every=1),
    ]

    model.compile(
        optimizer=tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=["accuracy"],
    )

    print(f"\n\n Training: {name}")

    t0 = time.time()
    hist = model.fit(
        x_train, y_train,
        batch_size=batch_size,
        epochs=epochs,
        validation_split=0.1,
        callbacks=callbacks,
        verbose=2
    )
    train_time = time.time() - t0

    final_test_loss, final_test_acc = model.evaluate(test_ds, verbose=0)

    nW = count_active_weights(model)
    best_val = max(hist.history.get("val_accuracy", [float("nan")]))

    print(f"Done: {name} | Final TestAcc={final_test_acc:.4f} | nW={nW}")

    return {
        "Model": name,
        "nW": nW,
        "Test Acc": float(final_test_acc),
        "Best Val Acc": float(best_val),
        "Train Time (s)": float(train_time),
    }


In [19]:
def compare_all_models(
    eps=20,
    zeta=0.3,
    alpha=0.5,
    epochs=100,
    seed=42,
    weight_decay=2e-4,
    dropout=0.3
):
    tf.random.set_seed(seed)
    np.random.seed(seed)

    # MNIST preprocessing
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
    x_train = x_train.reshape(-1, 784).astype("float32") / 255.0
    x_test  = x_test.reshape(-1, 784).astype("float32") / 255.0
    y_train = y_train.astype("int32")
    y_test  = y_test.astype("int32")

    results = []
    # 1) Dense
    tf.keras.backend.clear_session()
    dense = build_dense_mlp(weight_decay=weight_decay, dropout=dropout)
    results.append(run_and_report(
        "Dense MLP",
        dense, x_train, y_train, x_test, y_test,
        callbacks=[],
        epochs=epochs,
        # optional early stop on test acc
        target_test_acc=0.985,  
        min_epoch=10
    ))
    del dense; gc.collect()

    # 2) FixProb, no SET!!!
    tf.keras.backend.clear_session()
    fixprob = build_set_mlp(eps=eps, weight_decay=weight_decay, dropout=dropout)
    results.append(run_and_report(
        "FixProb",
        fixprob, x_train, y_train, x_test, y_test,
        callbacks=[],
        epochs=epochs,
        target_test_acc=0.985,  
        min_epoch=10
    ))
    del fixprob; gc.collect()

    # 3) SET RANDOM
    tf.keras.backend.clear_session()
    set_random = build_set_mlp(eps=eps, weight_decay=weight_decay, dropout=dropout)
    cb_random = SETCallback(zeta=zeta, strategy="random", alpha=0.0)
    results.append(run_and_report(
        "SET (RAND)",
        set_random, x_train, y_train, x_test, y_test,
        callbacks=[cb_random],
        epochs=epochs,
        target_test_acc=0.985,   
        min_epoch=10
    ))
    del set_random; gc.collect()

    # 4) SET IMP MAGNITUDE
    tf.keras.backend.clear_session()
    set_imp = build_set_mlp(eps=eps, weight_decay=weight_decay, dropout=dropout)
    cb_imp = SETCallback(zeta=zeta, strategy="importance", alpha=alpha)
    results.append(run_and_report(
        f"SET (IMP, α={alpha})",
        set_imp, x_train, y_train, x_test, y_test,
        callbacks=[cb_imp],
        epochs=epochs,
        target_test_acc=0.985,   
        min_epoch=10
    ))
    del set_imp; gc.collect()


    print("\n Final Table")
    for r in results:
        print(
            f"{r['Model']:<32} | "
            f"nW={r['nW']:<10} | "
            f"TestAcc={r['Test Acc']:.4f} | "
            f"BestVal={r['Best Val Acc']:.4f} | "
            f"Time={r['Train Time (s)']:.1f}s"
        )

    return results


In [20]:
if __name__ == "__main__":
    
    compare_all_models(
        eps=20,
        zeta=0.3,
        alpha=0.75,
        epochs=200,
        seed=42,
        weight_decay=2e-4,
        dropout=0.3
    )




 Training: Dense MLP
Epoch 1/200
422/422 - 8s - loss: 1.0117 - accuracy: 0.8671 - val_loss: 0.7980 - val_accuracy: 0.9337 - 8s/epoch - 18ms/step
Epoch 2/200
422/422 - 6s - loss: 0.8443 - accuracy: 0.9136 - val_loss: 0.7124 - val_accuracy: 0.9525 - 6s/epoch - 13ms/step
Epoch 3/200
422/422 - 6s - loss: 0.7389 - accuracy: 0.9405 - val_loss: 0.6514 - val_accuracy: 0.9668 - 6s/epoch - 13ms/step
Epoch 4/200
422/422 - 5s - loss: 0.6757 - accuracy: 0.9544 - val_loss: 0.6138 - val_accuracy: 0.9733 - 5s/epoch - 13ms/step
Epoch 5/200
  >> Test@005 - loss: 0.5963 - acc: 0.9698
422/422 - 6s - loss: 0.6307 - accuracy: 0.9638 - val_loss: 0.5870 - val_accuracy: 0.9755 - 6s/epoch - 14ms/step
Epoch 6/200
422/422 - 6s - loss: 0.5980 - accuracy: 0.9692 - val_loss: 0.5687 - val_accuracy: 0.9760 - 6s/epoch - 13ms/step
Epoch 7/200
422/422 - 6s - loss: 0.5685 - accuracy: 0.9734 - val_loss: 0.5527 - val_accuracy: 0.9768 - 6s/epoch - 13ms/step
Epoch 8/200
422/422 - 6s - loss: 0.5470 - accuracy: 0.9761 - val_l