In [1]:
import tensorflow as tf
from tensorflow.keras.datasets import fashion_mnist

import numpy as np

In [3]:
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz


In [4]:
x_train.shape

(60000, 28, 28)

In [5]:
x_train = x_train.reshape(-1, 784) / 255.0  
x_test = x_test.reshape(-1, 784) / 255.0

In [6]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(256, activation='relu', input_shape=(784,)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(32, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

In [7]:
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

In [8]:
model.fit(x_train, y_train, epochs=5, batch_size=64, validation_data=(x_test, y_test))

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x21421699720>

In [9]:
max(model.get_weights()[0][0])

0.19471411

In [10]:
def prune_weights(model, threshold):
    pruned_model = tf.keras.models.clone_model(model)
    pruned_model.set_weights(model.get_weights())  # Copy the original weights

    # Prune the weights below the threshold
    for layer in pruned_model.layers:
        if isinstance(layer, tf.keras.layers.Dense):
            weights = layer.get_weights()
            weights[0][weights[0] < threshold] = 0.0  # Prune weights below threshold
            layer.set_weights(weights)
    
    return pruned_model


In [11]:

thresholds = [0.0, 0.1, 0.2, 0.3, 1, 2,5,10,20,50]
performance = []

In [12]:
for threshold in thresholds:
    pruned_model = prune_weights(model, threshold)
    pruned_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    loss, accuracy = pruned_model.evaluate(x_test, y_test)
    performance.append(accuracy)



In [13]:
acc_on_t = []

for threshold, acc in zip(thresholds, performance):
    ind = []
    ind.append(f'Threshold: {threshold}, Accuracy: {acc}')
    acc_on_t.append(ind)
print(acc_on_t)

[['Threshold: 0.0, Accuracy: 0.10000000149011612'], ['Threshold: 0.1, Accuracy: 0.10000000149011612'], ['Threshold: 0.2, Accuracy: 0.10000000149011612'], ['Threshold: 0.3, Accuracy: 0.10000000149011612'], ['Threshold: 1, Accuracy: 0.10000000149011612'], ['Threshold: 2, Accuracy: 0.10000000149011612'], ['Threshold: 5, Accuracy: 0.10000000149011612'], ['Threshold: 10, Accuracy: 0.10000000149011612'], ['Threshold: 20, Accuracy: 0.10000000149011612'], ['Threshold: 50, Accuracy: 0.10000000149011612']]
