In [1]:
import tensorflow as tf
from tensorflow.keras.datasets import mnist, fashion_mnist

import numpy as np

In [2]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()

In [3]:
x_train.shape

(60000, 28, 28)

In [4]:
x_train = x_train.reshape(-1, 784) / 255.0  
x_test = x_test.reshape(-1, 784) / 255.0

In [5]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(256, activation='relu', input_shape=(784,)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(32, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

In [6]:
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

In [7]:
layer = model.layers[0]
len(layer.get_weights()[0][0])

256

In [8]:
model.fit(x_train, y_train, epochs=5, batch_size=64, validation_data=(x_test, y_test))

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x2002d5a36d0>

In [9]:
len(model.get_weights()[0][0])                                                 

256

In [10]:
def prune_weights(model, threshold):
    pruned_model = tf.keras.models.clone_model(model)
    pruned_model.set_weights(model.get_weights())  # Copy the original weights

    # Prune the weights below the threshold
    for layer in pruned_model.layers:
        if isinstance(layer, tf.keras.layers.Dense):
            weights = layer.get_weights()
            weights[0][abs(weights[0]) < threshold] = 0.0  # Prune weights below threshold
            layer.set_weights(weights)
    
    return pruned_model


In [11]:

thresholds = [0.1, 0.2, 0.3, 0.5, 0.7,1]
performance = []

In [12]:
for threshold in thresholds:
    pruned_model = prune_weights(model, threshold)
    pruned_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    # pruned_model.fit(x_train, y_train, epochs=5, batch_size=64, validation_data=(x_test, y_test))
    loss, accuracy = pruned_model.evaluate(x_test, y_test)
    performance.append(accuracy)



In [13]:
acc_on_t = []

for threshold, acc in zip(thresholds, performance):
    ind = []
    ind.append(f'Threshold: {threshold}, Accuracy: {acc}')
    acc_on_t.append(ind)
print(acc_on_t)

[['Threshold: 0.1, Accuracy: 0.9413999915122986'], ['Threshold: 0.2, Accuracy: 0.11739999800920486'], ['Threshold: 0.3, Accuracy: 0.10320000350475311'], ['Threshold: 0.5, Accuracy: 0.0982000008225441'], ['Threshold: 0.7, Accuracy: 0.0982000008225441'], ['Threshold: 1, Accuracy: 0.0982000008225441']]
