In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten, Activation
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.layers import BatchNormalization
import os

class MLP_CIFAR10:
    def __init__(self):
        self.batch_size = 100
        self.maxepoches = 100  # Set to your required value
        self.learning_rate = 0.01
        self.num_classes = 10
        self.momentum = 0.9

        self.create_model()
        self.train()

    def create_model(self):
        self.model = Sequential()
        self.model.add(Flatten(input_shape=(32, 32, 3)))
        self.model.add(Dense(4000))
        self.model.add(BatchNormalization())                # Same as original
        self.model.add(Activation('relu'))              # Replacing SReLU with ReLU
        self.model.add(Dropout(0.3))
        self.model.add(Dense(1000))
        self.model.add(BatchNormalization())
        self.model.add(Activation('relu'))
        self.model.add(Dropout(0.3))
        self.model.add(Dense(4000))
        self.model.add(BatchNormalization())
        self.model.add(Activation('relu'))
        self.model.add(Dropout(0.3))
        self.model.add(Dense(self.num_classes, activation='softmax'))

    def train(self):
        x_train, y_train, x_test, y_test = self.read_data()

        datagen = ImageDataGenerator(
            rotation_range=10,
            width_shift_range=0.1,
            height_shift_range=0.1,
            horizontal_flip=True
        )
        datagen.fit(x_train)

        self.model.summary()

        sgd = SGD(learning_rate=self.learning_rate, momentum=self.momentum)
        self.model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])

        history = self.model.fit(
            datagen.flow(x_train, y_train, batch_size=self.batch_size),
            epochs=self.maxepoches,
            validation_data=(x_test, y_test),
            steps_per_epoch=len(x_train) // self.batch_size,
            verbose=2
        )

        self.accuracies_per_epoch = history.history['val_accuracy']

    def read_data(self):
        (x_train, y_train), (x_test, y_test) = cifar10.load_data()
        y_train = to_categorical(y_train, self.num_classes)
        y_test = to_categorical(y_test, self.num_classes)
        x_train = x_train.astype('float32') / 255.0
        x_test = x_test.astype('float32') / 255.0
        return x_train, y_train, x_test, y_test


if __name__ == '__main__':
    model = MLP_CIFAR10()

    os.makedirs("results", exist_ok=True)
    np.savetxt("results/dense_mlp_relu_sgd_cifar10_acc.txt", np.asarray(model.accuracies_per_epoch))


  super().__init__(**kwargs)


Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
[1m170498071/170498071[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 0us/step


Epoch 1/100


  self._warn_if_super_not_called()


500/500 - 32s - 64ms/step - accuracy: 0.2861 - loss: 1.9593 - val_accuracy: 0.3734 - val_loss: 1.7593
Epoch 2/100
500/500 - 26s - 52ms/step - accuracy: 0.3451 - loss: 1.8165 - val_accuracy: 0.3878 - val_loss: 1.7002
Epoch 3/100
500/500 - 42s - 83ms/step - accuracy: 0.3658 - loss: 1.7675 - val_accuracy: 0.4180 - val_loss: 1.6308
Epoch 4/100
500/500 - 26s - 51ms/step - accuracy: 0.3779 - loss: 1.7388 - val_accuracy: 0.4320 - val_loss: 1.5822
Epoch 5/100
500/500 - 41s - 83ms/step - accuracy: 0.3867 - loss: 1.7090 - val_accuracy: 0.4166 - val_loss: 1.6128
Epoch 6/100
500/500 - 41s - 82ms/step - accuracy: 0.3913 - loss: 1.6910 - val_accuracy: 0.4520 - val_loss: 1.5389
Epoch 7/100
500/500 - 41s - 81ms/step - accuracy: 0.3989 - loss: 1.6662 - val_accuracy: 0.4575 - val_loss: 1.5303
Epoch 8/100
500/500 - 26s - 52ms/step - accuracy: 0.4070 - loss: 1.6431 - val_accuracy: 0.4430 - val_loss: 1.5440
Epoch 9/100
500/500 - 27s - 55ms/step - accuracy: 0.4151 - loss: 1.6315 - val_accuracy: 0.4738 - val

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten, Activation
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical
import os


class MLP_CIFAR10:
    def __init__(self):
        self.batch_size = 100
        self.maxepoches = 100
        self.initial_epochs = 5
        self.learning_rate = 0.01
        self.num_classes = 10
        self.momentum = 0.9
        self.prune_percent = 0.2  # 20% pruning

        self.create_model()
        self.train()

    def create_model(self):
        self.model = Sequential()
        self.model.add(Flatten(input_shape=(32, 32, 3)))
        self.model.add(Dense(4000))                     # Layer 1
        self.model.add(Activation('relu'))
        self.model.add(Dropout(0.3))
        self.model.add(Dense(1000))                     # Layer 2
        self.model.add(Activation('relu'))
        self.model.add(Dropout(0.3))
        self.model.add(Dense(4000))                     # Layer 3
        self.model.add(Activation('relu'))
        self.model.add(Dropout(0.3))
        self.model.add(Dense(self.num_classes, activation='softmax'))

    def train(self):
        x_train, y_train, x_test, y_test = self.read_data()

        datagen = ImageDataGenerator(
            rotation_range=10,
            width_shift_range=0.1,
            height_shift_range=0.1,
            horizontal_flip=True
        )
        datagen.fit(x_train)

        self.model.summary()
        sgd = SGD(learning_rate=self.learning_rate, momentum=self.momentum)
        self.model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])

        # Train initially
        self.model.fit(
            datagen.flow(x_train, y_train, batch_size=self.batch_size),
            epochs=self.initial_epochs,
            validation_data=(x_test, y_test),
            steps_per_epoch=len(x_train) // self.batch_size,
            verbose=2
        )

        # Prune + Rewire
        total_before = self.count_total_nonzero_weights()
        self.prune_weights(self.prune_percent)
        total_after_prune = self.count_total_nonzero_weights()
        rewiring_needed = total_before - total_after_prune
        self.rewire_model_balanced(rewiring_needed)
        total_after_rewire = self.count_total_nonzero_weights()

        print(f"[INFO] Non-zero weights: before={total_before}, after prune={total_after_prune}, after rewire={total_after_rewire}")

        # Continue training
        history = self.model.fit(
            datagen.flow(x_train, y_train, batch_size=self.batch_size),
            epochs=self.maxepoches,
            initial_epoch=self.initial_epochs,
            validation_data=(x_test, y_test),
            steps_per_epoch=len(x_train) // self.batch_size,
            verbose=2
        )

        self.accuracies_per_epoch = history.history['val_accuracy']

    def read_data(self):
        (x_train, y_train), (x_test, y_test) = cifar10.load_data()
        y_train = to_categorical(y_train, self.num_classes)
        y_test = to_categorical(y_test, self.num_classes)
        x_train = x_train.astype('float32') / 255.0
        x_test = x_test.astype('float32') / 255.0
        return x_train, y_train, x_test, y_test

    def count_total_nonzero_weights(self):
        total = 0
        for layer in self.model.layers:
            if isinstance(layer, Dense):
                W, _ = layer.get_weights()
                total += np.count_nonzero(W)
        return total

    def prune_weights(self, prune_percent=0.2):
        for layer in self.model.layers:
            if isinstance(layer, Dense):
                weights, biases = layer.get_weights()
                flat = np.abs(weights.flatten())
                threshold = np.percentile(flat, prune_percent * 100)
                weights[np.abs(weights) < threshold] = 0.0
                layer.set_weights([weights, biases])

    def rewire_model_balanced(self, rewiring_needed):
        print(f"\n[INFO] Rewiring {rewiring_needed} connections...")

        dense_layers = [layer for layer in self.model.layers if isinstance(layer, Dense)]
        layer_weights = [layer.get_weights() for layer in dense_layers]

        valid_layer_pairs = [(i, j) for i in range(len(dense_layers))
                             for j in range(len(dense_layers)) if abs(i - j) > 1]

        connections_added = 0
        attempt = 0
        max_attempts = rewiring_needed * 5

        while connections_added < rewiring_needed and attempt < max_attempts:
            src_idx, dst_idx = valid_layer_pairs[np.random.randint(0, len(valid_layer_pairs))]
            W_src, _ = layer_weights[src_idx]
            W_dst, b_dst = layer_weights[dst_idx]

            src_neurons = np.where(np.any(W_src != 0, axis=0))[0]
            dst_inputs = np.where(np.all(W_dst == 0, axis=1))[0]

            if len(src_neurons) == 0 or len(dst_inputs) == 0:
                attempt += 1
                continue

            src = np.random.choice(src_neurons)
            dst = np.random.choice(dst_inputs)
            if W_dst[dst, src % W_dst.shape[1]] == 0:
                W_dst[dst, src % W_dst.shape[1]] = np.random.normal(0, 0.05)
                layer_weights[dst_idx][0] = W_dst
                connections_added += 1
            attempt += 1

        for i, layer in enumerate(dense_layers):
            layer.set_weights(layer_weights[i])

        print(f"[INFO] Rewiring complete. Connections added: {connections_added}\n")


if __name__ == '__main__':
    model = MLP_CIFAR10()
    os.makedirs("results", exist_ok=True)
    np.savetxt("results/dense_mlp_balanced_pruned_rewired_acc.txt", np.asarray(model.accuracies_per_epoch))


  super().__init__(**kwargs)


Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
[1m170498071/170498071[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 0us/step


Epoch 1/5


  self._warn_if_super_not_called()


500/500 - 39s - 78ms/step - accuracy: 0.2796 - loss: 1.9689 - val_accuracy: 0.3596 - val_loss: 1.7801
Epoch 2/5
500/500 - 32s - 63ms/step - accuracy: 0.3427 - loss: 1.8204 - val_accuracy: 0.4113 - val_loss: 1.6585
Epoch 3/5
500/500 - 30s - 61ms/step - accuracy: 0.3624 - loss: 1.7738 - val_accuracy: 0.4234 - val_loss: 1.6137
Epoch 4/5
500/500 - 41s - 82ms/step - accuracy: 0.3727 - loss: 1.7399 - val_accuracy: 0.4233 - val_loss: 1.5999
Epoch 5/5
500/500 - 31s - 62ms/step - accuracy: 0.3862 - loss: 1.7115 - val_accuracy: 0.4277 - val_loss: 1.5883

[INFO] Rewiring 4065600 connections...
