In [3]:
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
import os
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

# ---------------------------
# Output Directory Setup
# ---------------------------
output_dir = "/content/drive/MyDrive/Edge-Pop_0.5_2/"
os.makedirs(output_dir, exist_ok=True)
log_file = os.path.join(output_dir, "active_learning_log.csv")
weights_file = os.path.join(output_dir, "global_lenet5.weights.h5")
final_subnet_weights = os.path.join(output_dir, "final_subnet.weights.h5")
final_global_weights = os.path.join(output_dir, "final_global_lenet5_full_training.weights.h5")

# ---------------------------
# Edge-Popup Components
# ---------------------------

class GetSubnet(tf.keras.layers.Layer):
    def __init__(self, k):
        super(GetSubnet, self).__init__()
        self.k = k

    def call(self, scores):
        scores_flat = tf.reshape(scores, [-1])
        k_val = tf.cast(tf.size(scores_flat), tf.float32) * self.k
        k_val = tf.cast(k_val, tf.int32)
        k_val = tf.maximum(k_val, 1)

        topk_values, topk_indices = tf.math.top_k(scores_flat, k=k_val, sorted=False)
        mask_flat = tf.zeros_like(scores_flat)
        mask_flat = tf.tensor_scatter_nd_update(
            mask_flat, tf.expand_dims(topk_indices, 1), tf.ones_like(topk_values)
        )
        return tf.reshape(mask_flat, tf.shape(scores))

class SubnetConv2D(tf.keras.layers.Layer):
    def __init__(self, base_weights, base_bias, filters, kernel_size, strides=1, padding="same", k=0.5, use_bias=True):
        super(SubnetConv2D, self).__init__()
        self.k = k
        self.use_bias = use_bias
        self.strides = strides
        self.padding = padding.upper()

        self.base_weights = base_weights
        self.base_bias = base_bias
        self.get_subnet = GetSubnet(k)

        self.popup_scores = self.add_weight(
            name="popup_scores",
            shape=base_weights.shape,
            initializer=tf.keras.initializers.RandomNormal(stddev=0.1),
            trainable=True,
        )

    def call(self, inputs):
        mask = self.get_subnet(tf.abs(self.popup_scores))
        masked_weights = self.base_weights * mask
        x = tf.nn.conv2d(inputs, masked_weights, strides=[1, self.strides, self.strides, 1], padding=self.padding)
        if self.use_bias and self.base_bias is not None:
            x = tf.nn.bias_add(x, self.base_bias)
        return x

def build_global_lenet5(input_shape=(32, 32, 3), num_classes=10):
    inputs = tf.keras.Input(shape=input_shape)
    x = layers.Conv2D(6, kernel_size=5, padding="same", activation='tanh')(inputs)
    x = layers.AveragePooling2D(pool_size=2)(x)
    x = layers.Conv2D(16, kernel_size=5, activation='tanh')(x)
    x = layers.AveragePooling2D(pool_size=2)(x)
    x = layers.Flatten()(x)
    x = layers.Dense(120, activation='tanh')(x)
    x = layers.Dense(84, activation='tanh')(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    model = models.Model(inputs, outputs)
    return model

def build_edgepopup_subnet(global_model, k=0.5):
    inputs = tf.keras.Input(shape=(32, 32, 3))

    w1, b1 = global_model.layers[1].get_weights()
    w2, b2 = global_model.layers[3].get_weights()

    x = SubnetConv2D(w1, b1, 6, 5, k=k, padding='same')(inputs)
    x = layers.Activation('tanh')(x)
    x = layers.AveragePooling2D(pool_size=2)(x)

    x = SubnetConv2D(w2, b2, 16, 5, k=k, padding='valid')(x)
    x = layers.Activation('tanh')(x)
    x = layers.AveragePooling2D(pool_size=2)(x)

    x = layers.Flatten()(x)
    x = layers.Dense(120, activation='tanh')(x)
    x = layers.Dense(84, activation='tanh')(x)
    outputs = layers.Dense(10, activation='softmax')(x)
    return tf.keras.Model(inputs, outputs)

# ---------------------------
# Sampling Strategy
# ---------------------------

def least_confidence_sampling(model, unlabeled_pool, n_samples):
    probs = model.predict(unlabeled_pool, verbose=0)
    confidence = np.max(probs, axis=1)
    return np.argsort(confidence)[:n_samples]

# ---------------------------
# Plotting
# ---------------------------

def plot_progress(accs, sizes):
    plt.figure(figsize=(8, 5))
    plt.plot(sizes, accs, 'o-')
    plt.xlabel("Labeled Samples")
    plt.ylabel("Test Accuracy")
    plt.title("Active Learning Progress")
    plt.grid(True)
    plt.savefig(os.path.join(output_dir, "progress.png"))
    plt.close()

def plot_comparison(global_acc, subnet_acc):
    plt.figure(figsize=(8, 5))
    models = ['Global LeNet-5', 'Final Subnet']
    accuracies = [global_acc, subnet_acc]
    plt.bar(models, accuracies)
    plt.ylim([0, 1.0])
    plt.ylabel("Test Accuracy")
    plt.title("Model Comparison after Full Training")
    plt.grid(True, axis='y')
    for i, acc in enumerate(accuracies):
        plt.text(i, acc + 0.01, f"{acc:.4f}", ha='center')
    plt.savefig(os.path.join(output_dir, "model_comparison.png"))
    plt.close()

# ---------------------------
# Active Learning Loop
# ---------------------------

def active_learning(global_model, x_train, y_train, x_test, y_test,
                    k=0.5, init_size=1000, query_size=1000, iterations=10, epochs=10):

    indices = np.arange(len(x_train))
    labeled = np.random.choice(indices, size=init_size, replace=False)
    unlabeled = np.setdiff1d(indices, labeled)

    acc_hist = []
    size_hist = []

    popup_log = []
    final_subnet_model = None  # To store the final subnet model

    for i in range(iterations):
        print(f"\n=== Iteration {i+1} ===")
        x_labeled, y_labeled = x_train[labeled], y_train[labeled]

        model = build_edgepopup_subnet(global_model, k)
        model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])

        model.fit(x_labeled, y_labeled, validation_split=0.1,
                  epochs=epochs, batch_size=128, verbose=1)

        loss, acc = model.evaluate(x_test, y_test, verbose=0)
        print(f"Test Accuracy: {acc:.4f}")

        # Collect popup score statistics
        score_stats = {}
        for layer in model.layers:
            if isinstance(layer, SubnetConv2D):
                popup_scores = layer.popup_scores.numpy()
                stats = {
                    'min': np.min(popup_scores),
                    'max': np.max(popup_scores),
                    'mean': np.mean(popup_scores),
                    'std': np.std(popup_scores)
                }
                score_stats[layer.name] = stats
                print(f"[{layer.name}] Popup Scores - min: {stats['min']:.4f}, max: {stats['max']:.4f}, mean: {stats['mean']:.4f}, std: {stats['std']:.4f}")

                # Optional: save popup scores for debugging
                np.save(os.path.join(output_dir, f"popup_scores_iter{i+1}_{layer.name}.npy"), popup_scores)

        acc_hist.append(acc)
        size_hist.append(len(labeled))
        popup_log.append(score_stats)

        # Save the model from the final iteration
        if i == iterations - 1:
            final_subnet_model = model
            model.save_weights(final_subnet_weights)
            print(f"Final subnet model saved to {final_subnet_weights}")

        if i < iterations - 1:
            x_pool = x_train[unlabeled]
            indices_new = least_confidence_sampling(model, x_pool, query_size)
            new_samples = unlabeled[indices_new]
            labeled = np.concatenate([labeled, new_samples])
            unlabeled = np.setdiff1d(unlabeled, new_samples)

    # Save accuracy log
    pd.DataFrame({
        "samples": size_hist,
        "accuracy": acc_hist
    }).to_csv(log_file, index=False)

    # Save popup score stats
    pd.DataFrame([
        {"iteration": i+1, "layer": layer, **popup_log[i][layer]}
        for i in range(len(popup_log)) for layer in popup_log[i]
    ]).to_csv(os.path.join(output_dir, "popup_score_stats.csv"), index=False)

    plot_progress(acc_hist, size_hist)
    return acc_hist, size_hist, final_subnet_model


# ---------------------------
# Full Dataset Training
# ---------------------------

def train_models_on_full_dataset(global_model, final_subnet_model, x_train, y_train, x_test, y_test, epochs=20):
    """Train both the global model and final subnet model on the full dataset for comparison"""
    print("\n=== Training Global LeNet-5 on Full Dataset ===")

    # Create a fresh copy of the global model for full training
    global_model_full = build_global_lenet5()
    global_model_full.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    global_model_full.fit(x_train, y_train, epochs=epochs, batch_size=128, validation_split=0.1, verbose=1)
    global_model_full.save_weights(final_global_weights)

    global_loss, global_acc = global_model_full.evaluate(x_test, y_test, verbose=0)
    print(f"Global LeNet-5 Test Accuracy after full training: {global_acc:.4f}")

    print("\n=== Training Final Subnet Model on Full Dataset ===")
    # Reset the final subnet model and train on full dataset
    final_subnet_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    final_subnet_model.fit(x_train, y_train, epochs=epochs, batch_size=128, validation_split=0.1, verbose=1)

    subnet_loss, subnet_acc = final_subnet_model.evaluate(x_test, y_test, verbose=0)
    print(f"Final Subnet Test Accuracy after full training: {subnet_acc:.4f}")

    # Compare the models
    print("\n=== Model Comparison ===")
    print(f"Global LeNet-5: {global_acc:.4f}")
    print(f"Final Subnet: {subnet_acc:.4f}")
    print(f"Difference: {subnet_acc - global_acc:.4f}")

    # Save comparison results
    comparison_df = pd.DataFrame({
        "model": ["Global LeNet-5", "Final Subnet"],
        "accuracy": [global_acc, subnet_acc]
    })
    comparison_df.to_csv(os.path.join(output_dir, "model_comparison.csv"), index=False)

    # Plot comparison
    plot_comparison(global_acc, subnet_acc)

    return global_acc, subnet_acc


# ---------------------------
# Run the Setup
# ---------------------------

# Load CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)

# Subset for speed
num_train_samples = 30000
x_train = x_train[:num_train_samples]
y_train = y_train[:num_train_samples]

# Check if pretrained weights exist
if not os.path.exists(weights_file):
    print("Training and saving global LeNet-5 model...")
    global_lenet5 = build_global_lenet5()
    global_lenet5.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    global_lenet5.fit(x_train, y_train, epochs=10, batch_size=128, validation_split=0.1)
    global_lenet5.save_weights(weights_file)
else:
    print("Loading pretrained LeNet-5 weights...")
    global_lenet5 = build_global_lenet5()
    global_lenet5.load_weights(weights_file)

# Run Active Learning with pretrained global model
acc_history, size_history, final_subnet_model = active_learning(
    global_lenet5, x_train, y_train, x_test, y_test,
    k=0.5, init_size=2000, query_size=2000, iterations=15, epochs=10
)

for i, (n, acc) in enumerate(zip(size_history, acc_history)):
    print(f"Iteration {i+1}: {n} samples - Accuracy: {acc:.4f}")

# Train both models on the full dataset for final comparison
global_acc, subnet_acc = train_models_on_full_dataset(
    global_lenet5, final_subnet_model, x_train, y_train, x_test, y_test, epochs=20
)

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
[1m170498071/170498071[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 0us/step
Training and saving global LeNet-5 model...
Epoch 1/10
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 22ms/step - accuracy: 0.2846 - loss: 1.9842 - val_accuracy: 0.3563 - val_loss: 1.7957
Epoch 2/10
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.3901 - loss: 1.7349 - val_accuracy: 0.3867 - val_loss: 1.6915
Epoch 3/10
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.4236 - loss: 1.6512 - val_accuracy: 0.4313 - val_loss: 1.6048
Epoch 4/10
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.4483 - loss: 1.5586 - val_accuracy: 0.4423 - val_loss: 1.5394
Epoch 5/10
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.4773 - loss: 1.4786 - val_accuracy: 0.476



[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 316ms/step - accuracy: 0.1441 - loss: 2.3294 - val_accuracy: 0.2100 - val_loss: 2.1190
Epoch 2/10
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step - accuracy: 0.2486 - loss: 2.0487 - val_accuracy: 0.2700 - val_loss: 2.0331
Epoch 3/10
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 6ms/step - accuracy: 0.3121 - loss: 1.9363 - val_accuracy: 0.2650 - val_loss: 1.9786
Epoch 4/10
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 6ms/step - accuracy: 0.3512 - loss: 1.8527 - val_accuracy: 0.3000 - val_loss: 1.9361
Epoch 5/10
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 6ms/step - accuracy: 0.3612 - loss: 1.7947 - val_accuracy: 0.3050 - val_loss: 1.9608
Epoch 6/10
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step - accuracy: 0.3855 - loss: 1.7800 - val_accuracy: 0.3200 - val_loss: 1.9879
Epoch 7/10
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━



[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 156ms/step - accuracy: 0.1428 - loss: 2.2771 - val_accuracy: 0.1550 - val_loss: 2.2041
Epoch 2/10
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 5ms/step - accuracy: 0.2455 - loss: 2.0546 - val_accuracy: 0.2050 - val_loss: 2.1285
Epoch 3/10
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - accuracy: 0.2605 - loss: 2.0187 - val_accuracy: 0.1950 - val_loss: 2.0969
Epoch 4/10
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - accuracy: 0.3011 - loss: 1.9367 - val_accuracy: 0.2350 - val_loss: 2.0680
Epoch 5/10
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - accuracy: 0.3041 - loss: 1.9191 - val_accuracy: 0.2425 - val_loss: 2.0664
Epoch 6/10
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - accuracy: 0.2988 - loss: 1.9112 - val_accuracy: 0.1900 - val_loss: 2.0703
Epoch 7/10
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━



[1m43/43[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 109ms/step - accuracy: 0.1748 - loss: 2.2149 - val_accuracy: 0.2033 - val_loss: 2.1273
Epoch 2/10
[1m43/43[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - accuracy: 0.2508 - loss: 2.0382 - val_accuracy: 0.2233 - val_loss: 2.0902
Epoch 3/10
[1m43/43[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - accuracy: 0.2960 - loss: 1.9649 - val_accuracy: 0.2183 - val_loss: 2.0672
Epoch 4/10
[1m43/43[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - accuracy: 0.2967 - loss: 1.9543 - val_accuracy: 0.2333 - val_loss: 2.0548
Epoch 5/10
[1m43/43[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - accuracy: 0.3135 - loss: 1.9124 - val_accuracy: 0.2383 - val_loss: 2.0443
Epoch 6/10
[1m43/43[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - accuracy: 0.3244 - loss: 1.8964 - val_accuracy: 0.2300 - val_loss: 2.0267
Epoch 7/10
[1m43/43[0m [32m━━━━━━━━━━━━━━━━━━━



[1m57/57[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 83ms/step - accuracy: 0.1602 - loss: 2.2535 - val_accuracy: 0.2325 - val_loss: 2.0663
Epoch 2/10
[1m57/57[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 4ms/step - accuracy: 0.2323 - loss: 2.0780 - val_accuracy: 0.2387 - val_loss: 2.0928
Epoch 3/10
[1m57/57[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - accuracy: 0.2548 - loss: 2.0429 - val_accuracy: 0.2000 - val_loss: 2.0578
Epoch 4/10
[1m57/57[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - accuracy: 0.2590 - loss: 2.0169 - val_accuracy: 0.2812 - val_loss: 2.0266
Epoch 5/10
[1m57/57[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - accuracy: 0.2810 - loss: 1.9726 - val_accuracy: 0.2625 - val_loss: 2.0197
Epoch 6/10
[1m57/57[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - accuracy: 0.2788 - loss: 1.9698 - val_accuracy: 0.2325 - val_loss: 2.0372
Epoch 7/10
[1m57/57[0m [32m━━━━━━━━━━━━━━━━━━━━



[1m71/71[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 80ms/step - accuracy: 0.1542 - loss: 2.2503 - val_accuracy: 0.1930 - val_loss: 2.1257
Epoch 2/10
[1m71/71[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - accuracy: 0.2178 - loss: 2.0977 - val_accuracy: 0.2310 - val_loss: 2.1019
Epoch 3/10
[1m71/71[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - accuracy: 0.2435 - loss: 2.0499 - val_accuracy: 0.2140 - val_loss: 2.0749
Epoch 4/10
[1m71/71[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - accuracy: 0.2433 - loss: 2.0347 - val_accuracy: 0.2400 - val_loss: 2.0388
Epoch 5/10
[1m71/71[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - accuracy: 0.2790 - loss: 1.9842 - val_accuracy: 0.2840 - val_loss: 1.9863
Epoch 6/10
[1m71/71[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - accuracy: 0.2853 - loss: 1.9451 - val_accuracy: 0.2430 - val_loss: 1.9855
Epoch 7/10
[1m71/71[0m [32m━━━━━━━━━━━━━━━━━━━━



[1m85/85[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 46ms/step - accuracy: 0.1792 - loss: 2.1961 - val_accuracy: 0.2617 - val_loss: 2.0195
Epoch 2/10
[1m85/85[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.2538 - loss: 2.0389 - val_accuracy: 0.2675 - val_loss: 1.9775
Epoch 3/10
[1m85/85[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - accuracy: 0.2751 - loss: 1.9814 - val_accuracy: 0.3000 - val_loss: 1.8946
Epoch 4/10
[1m85/85[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.2904 - loss: 1.9618 - val_accuracy: 0.3258 - val_loss: 1.8780
Epoch 5/10
[1m85/85[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - accuracy: 0.3185 - loss: 1.9105 - val_accuracy: 0.3183 - val_loss: 1.8388
Epoch 6/10
[1m85/85[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.3246 - loss: 1.8723 - val_accuracy: 0.3417 - val_loss: 1.8101
Epoch 7/10
[1m85/85[0m [32m━━━━━━━━━━━━━━━━━━━━



[1m99/99[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 54ms/step - accuracy: 0.1668 - loss: 2.2106 - val_accuracy: 0.2057 - val_loss: 2.0127
Epoch 2/10
[1m99/99[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - accuracy: 0.2343 - loss: 2.0643 - val_accuracy: 0.2314 - val_loss: 1.9903
Epoch 3/10
[1m99/99[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - accuracy: 0.2573 - loss: 2.0149 - val_accuracy: 0.2157 - val_loss: 1.9777
Epoch 4/10
[1m99/99[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - accuracy: 0.2620 - loss: 1.9799 - val_accuracy: 0.2536 - val_loss: 1.9179
Epoch 5/10
[1m99/99[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.2915 - loss: 1.9439 - val_accuracy: 0.2986 - val_loss: 1.8860
Epoch 6/10
[1m99/99[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - accuracy: 0.2978 - loss: 1.9156 - val_accuracy: 0.3079 - val_loss: 1.8713
Epoch 7/10
[1m99/99[0m [32m━━━━━━━━━━━━━━━━━━━━



[1m113/113[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 46ms/step - accuracy: 0.1671 - loss: 2.2148 - val_accuracy: 0.2537 - val_loss: 1.9811
Epoch 2/10
[1m113/113[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - accuracy: 0.2493 - loss: 2.0404 - val_accuracy: 0.3031 - val_loss: 1.9209
Epoch 3/10
[1m113/113[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.2669 - loss: 1.9946 - val_accuracy: 0.3050 - val_loss: 1.8912
Epoch 4/10
[1m113/113[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.2901 - loss: 1.9369 - val_accuracy: 0.3175 - val_loss: 1.8586
Epoch 5/10
[1m113/113[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.3119 - loss: 1.8978 - val_accuracy: 0.3294 - val_loss: 1.8155
Epoch 6/10
[1m113/113[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.3290 - loss: 1.8449 - val_accuracy: 0.3481 - val_loss: 1.7899
Epoch 7/10
[1m113/113[0m [32m━━━━━━



[1m127/127[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 44ms/step - accuracy: 0.1725 - loss: 2.1852 - val_accuracy: 0.3006 - val_loss: 1.8452
Epoch 2/10
[1m127/127[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 4ms/step - accuracy: 0.2590 - loss: 2.0077 - val_accuracy: 0.3439 - val_loss: 1.7704
Epoch 3/10
[1m127/127[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.2961 - loss: 1.9417 - val_accuracy: 0.3439 - val_loss: 1.7264
Epoch 4/10
[1m127/127[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.3078 - loss: 1.8961 - val_accuracy: 0.3578 - val_loss: 1.6753
Epoch 5/10
[1m127/127[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - accuracy: 0.3411 - loss: 1.8212 - val_accuracy: 0.3544 - val_loss: 1.6558
Epoch 6/10
[1m127/127[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.3572 - loss: 1.7819 - val_accuracy: 0.3650 - val_loss: 1.6336
Epoch 7/10
[1m127/127[0m [32m━━━━━━



[1m141/141[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 35ms/step - accuracy: 0.2028 - loss: 2.1445 - val_accuracy: 0.3205 - val_loss: 1.8290
Epoch 2/10
[1m141/141[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - accuracy: 0.2813 - loss: 1.9650 - val_accuracy: 0.3225 - val_loss: 1.8066
Epoch 3/10
[1m141/141[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.2970 - loss: 1.9168 - val_accuracy: 0.3530 - val_loss: 1.7807
Epoch 4/10
[1m141/141[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.3112 - loss: 1.8754 - val_accuracy: 0.3575 - val_loss: 1.7254
Epoch 5/10
[1m141/141[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.3309 - loss: 1.8296 - val_accuracy: 0.3825 - val_loss: 1.6792
Epoch 6/10
[1m141/141[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.3552 - loss: 1.7860 - val_accuracy: 0.4000 - val_loss: 1.6464
Epoch 7/10
[1m141/141[0m [32m━━━━━━



[1m155/155[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 31ms/step - accuracy: 0.1997 - loss: 2.1388 - val_accuracy: 0.3809 - val_loss: 1.7183
Epoch 2/10
[1m155/155[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - accuracy: 0.2821 - loss: 1.9616 - val_accuracy: 0.3977 - val_loss: 1.6773
Epoch 3/10
[1m155/155[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.2933 - loss: 1.9194 - val_accuracy: 0.4100 - val_loss: 1.6413
Epoch 4/10
[1m155/155[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.3195 - loss: 1.8625 - val_accuracy: 0.4118 - val_loss: 1.5987
Epoch 5/10
[1m155/155[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.3400 - loss: 1.8096 - val_accuracy: 0.4100 - val_loss: 1.5909
Epoch 6/10
[1m155/155[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.3495 - loss: 1.7747 - val_accuracy: 0.4427 - val_loss: 1.5210
Epoch 7/10
[1m155/155[0m [32m━━━━━━



[1m169/169[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 29ms/step - accuracy: 0.2213 - loss: 2.0805 - val_accuracy: 0.4304 - val_loss: 1.6074
Epoch 2/10
[1m169/169[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - accuracy: 0.3057 - loss: 1.9000 - val_accuracy: 0.4779 - val_loss: 1.4967
Epoch 3/10
[1m169/169[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.3277 - loss: 1.8279 - val_accuracy: 0.4921 - val_loss: 1.4397
Epoch 4/10
[1m169/169[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.3529 - loss: 1.7594 - val_accuracy: 0.4971 - val_loss: 1.3853
Epoch 5/10
[1m169/169[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.3688 - loss: 1.7174 - val_accuracy: 0.5108 - val_loss: 1.3517
Epoch 6/10
[1m169/169[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.3875 - loss: 1.6754 - val_accuracy: 0.5029 - val_loss: 1.3323
Epoch 7/10
[1m169/169[0m [32m━━━━━━



[1m183/183[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 26ms/step - accuracy: 0.2323 - loss: 2.0590 - val_accuracy: 0.5027 - val_loss: 1.4224
Epoch 2/10
[1m183/183[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 5ms/step - accuracy: 0.3179 - loss: 1.8709 - val_accuracy: 0.5235 - val_loss: 1.3523
Epoch 3/10
[1m183/183[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.3464 - loss: 1.8016 - val_accuracy: 0.5538 - val_loss: 1.2532
Epoch 4/10
[1m183/183[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.3663 - loss: 1.7396 - val_accuracy: 0.5735 - val_loss: 1.2299
Epoch 5/10
[1m183/183[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.3812 - loss: 1.6940 - val_accuracy: 0.5808 - val_loss: 1.1747
Epoch 6/10
[1m183/183[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.4073 - loss: 1.6335 - val_accuracy: 0.5900 - val_loss: 1.1530
Epoch 7/10
[1m183/183[0m [32m━━━━━━



[1m197/197[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 23ms/step - accuracy: 0.2467 - loss: 2.0379 - val_accuracy: 0.5989 - val_loss: 1.2656
Epoch 2/10
[1m197/197[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.3273 - loss: 1.8527 - val_accuracy: 0.6754 - val_loss: 1.1634
Epoch 3/10
[1m197/197[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.3583 - loss: 1.7644 - val_accuracy: 0.7118 - val_loss: 1.0774
Epoch 4/10
[1m197/197[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 5ms/step - accuracy: 0.3869 - loss: 1.7084 - val_accuracy: 0.7143 - val_loss: 1.0210
Epoch 5/10
[1m197/197[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 5ms/step - accuracy: 0.3942 - loss: 1.6584 - val_accuracy: 0.7236 - val_loss: 0.9711
Epoch 6/10
[1m197/197[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 5ms/step - accuracy: 0.4160 - loss: 1.6158 - val_accuracy: 0.7100 - val_loss: 0.9834
Epoch 7/10
[1m197/197[0m [32m━━━━━━



[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 23ms/step - accuracy: 0.2633 - loss: 2.0103 - val_accuracy: 0.7787 - val_loss: 1.0322
Epoch 2/10
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.3476 - loss: 1.8142 - val_accuracy: 0.8210 - val_loss: 0.8891
Epoch 3/10
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.3673 - loss: 1.7452 - val_accuracy: 0.8297 - val_loss: 0.8615
Epoch 4/10
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.3961 - loss: 1.6757 - val_accuracy: 0.8243 - val_loss: 0.7857
Epoch 5/10
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.4161 - loss: 1.6082 - val_accuracy: 0.8507 - val_loss: 0.7123
Epoch 6/10
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.4332 - loss: 1.5728 - val_accuracy: 0.8477 - val_loss: 0.7088
Epoch 7/10
[1m211/211[0m [32m━━━━━━

In [4]:
print("Global LeNet-5 Trainable Parameters:", global_lenet5.count_params())
print("Final Subnet Trainable Parameters:", final_subnet_model.count_params())

Global LeNet-5 Trainable Parameters: 83126
Final Subnet Trainable Parameters: 83104


In [6]:
sample = np.expand_dims(x_test[0], axis=0)  # One sample from the test set

global_pred = global_lenet5.predict(sample, verbose=0)
subnet_pred = final_subnet_model.predict(sample, verbose=0)

print("Global LeNet-5 Prediction:", np.argmax(global_pred))
print("Final Subnet Prediction:", np.argmax(subnet_pred))

print("Prediction Difference (L2 Norm):", np.linalg.norm(global_pred - subnet_pred))


Global LeNet-5 Prediction: 3
Final Subnet Prediction: 8
Prediction Difference (L2 Norm): 0.6391015


In [7]:
def analyze_conv_layer_sparsity(model, model_name="Model"):
    total_params = 0
    zero_params = 0

    print(f"\n=== {model_name} Convolutional Layer Sparsity ===")
    for layer in model.layers:
        if isinstance(layer, SubnetConv2D):  # For subnet model
            weights = layer.base_weights
            mask = layer.get_subnet(tf.abs(layer.popup_scores)).numpy()
            num_total = weights.size
            num_masked = np.sum(mask == 0)

            print(f"[{layer.name}] Total: {num_total}, Masked: {num_masked}, Kept: {num_total - num_masked}")
            total_params += num_total
            zero_params += num_masked

        elif isinstance(layer, tf.keras.layers.Conv2D):  # For global model
            weights = layer.get_weights()
            if weights:
                w = weights[0]
                num_total = w.size
                num_zeros = np.sum(w == 0)

                print(f"[{layer.name}] Total: {num_total}, Zeros: {num_zeros}, Non-Zero: {num_total - num_zeros}")
                total_params += num_total
                zero_params += num_zeros

    print(f"\n=== {model_name} Summary ===")
    print(f"Total Conv Params : {total_params}")
    print(f"Zero/Masked Params: {zero_params}")
    print(f"Non-zero/Kept     : {total_params - zero_params}")
    print(f"Sparsity          : {zero_params / total_params:.2%}")


In [8]:
analyze_conv_layer_sparsity(final_subnet_model, model_name="Final Subnet Model")
analyze_conv_layer_sparsity(global_lenet5, model_name="Global LeNet-5")


=== Final Subnet Model Convolutional Layer Sparsity ===
[subnet_conv2d_28] Total: 450, Masked: 225, Kept: 225
[subnet_conv2d_29] Total: 2400, Masked: 1200, Kept: 1200

=== Final Subnet Model Summary ===
Total Conv Params : 2850
Zero/Masked Params: 1425
Non-zero/Kept     : 1425
Sparsity          : 50.00%

=== Global LeNet-5 Convolutional Layer Sparsity ===
[conv2d] Total: 450, Zeros: 0, Non-Zero: 450
[conv2d_1] Total: 2400, Zeros: 0, Non-Zero: 2400

=== Global LeNet-5 Summary ===
Total Conv Params : 2850
Zero/Masked Params: 0
Non-zero/Kept     : 2850
Sparsity          : 0.00%


In [42]:
for layer in final_subnet_model.layers:
    print(f"{layer.name}: {layer}")


input_layer_15: <InputLayer name=input_layer_15, built=True>
subnet_conv2d_28: <SubnetConv2D name=subnet_conv2d_28, built=True>
activation_28: <Activation name=activation_28, built=True>
average_pooling2d_30: <AveragePooling2D name=average_pooling2d_30, built=True>
subnet_conv2d_29: <SubnetConv2D name=subnet_conv2d_29, built=True>
activation_29: <Activation name=activation_29, built=True>
average_pooling2d_31: <AveragePooling2D name=average_pooling2d_31, built=True>
flatten_15: <Flatten name=flatten_15, built=True>
dense_45: <Dense name=dense_45, built=True>
dense_46: <Dense name=dense_46, built=True>
dense_47: <Dense name=dense_47, built=True>


In [73]:
final_subnet_model.summary()