In [1]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import layers

In [2]:
# Load the MNIST dataset using TensorFlow
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
# Display the shapes of the training and test datasets
print("Training data shape:", x_train.shape, y_train.shape)
print("Test data shape:", x_test.shape, y_test.shape)

# reshape data as 2D numpy arrays
# convert to float32 and normalize grayscale for better num. representation
x_train = x_train.reshape(60000, 784).astype("float32") / 255
x_test = x_test.reshape(10000, 784).astype("float32") / 255

y_train = y_train.astype("float32")
y_test = y_test.astype("float32")

# The tutorial reserved 10.000 training samples for validation, we change to 5.000 
# as that is what Frankle and Carbin did in their paper
x_val = x_train[-5000:]
y_val = y_train[-5000:]
x_train = x_train[:-5000]
y_train = y_train[:-5000]

Training data shape: (60000, 28, 28) (60000,)
Test data shape: (10000, 28, 28) (10000,)


In [3]:
# Hyperparams
batch_size = 60 # batchsize, 60 images per weight update
epochs = 10 # nr. of epochs we train our models
validation_split = 1/11 # 5000 val 55000 train data
input_dim = 784 # input_distribution size for MINE
d1_dim = 100 # first hidden layer distribution size for MINE
d2_dim = 30 # second hidden layer distribution size for MINE
output_dim = 10 # output_distribution dim for MINE
pruning_rate = 0.4 # pruning rate for LTH iterative Pruning -> removes pruning_rate% of lowest magnitude weights in an iteration
pruning_iterations = 7 # number of iterations for applying the pruning rate iteratively -> 1 time : 20% sparse, 13 times : ~95% sparse
averaging_iterations = 5 # number of total experimental runs to average for graph representations

In [4]:
tf.keras.backend.clear_session() # clearing backend right at start, just in case

inputs = keras.Input(shape=(input_dim,), name="digits") # Functional build of a 2-hidden layer fully connected MLP
x = layers.Dense(d1_dim, activation="relu", name="dense_1")(inputs) # methods made no mention of the activaton function specifically
x = layers.Dense(d2_dim, activation="relu", name="dense_2")(x) # ReLU is standard, as all available implementations seem to use it too
outputs = layers.Dense(output_dim, activation="softmax", name="predictions")(x)  # softmax activation for multi-class classification

base_model = keras.Model(inputs=inputs, outputs=outputs)
base_model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 digits (InputLayer)         [(None, 784)]             0         
                                                                 
 dense_1 (Dense)             (None, 100)               78500     
                                                                 
 dense_2 (Dense)             (None, 30)                3030      
                                                                 
 predictions (Dense)         (None, 10)                310       
                                                                 
Total params: 81,840
Trainable params: 81,840
Non-trainable params: 0
_________________________________________________________________


In [5]:
# we save the model with the initial weights for later use here
init_model = keras.models.clone_model(base_model)
init_weights = base_model.get_weights()
base_model.save_weights('init_weights.h5') # saving initial weights for later use

In [6]:
# fully-connected trained
# train model, save for later
print("Fit model on training data")
model = keras.models.clone_model(base_model)
model.compile(optimizer=keras.optimizers.Adam(learning_rate=1.2e-3), # Adam optimizer, lr=0.0012
                  # Loss function to minimize
                  loss=keras.losses.SparseCategoricalCrossentropy(), # multi-class classification loss function
                  # List of metrics to monitor
                  metrics=[keras.metrics.SparseCategoricalAccuracy()],
                 )


history = model.fit(x_train,
                    y_train,
                    batch_size=batch_size,
                    epochs=epochs,
                    verbose=0,
                    validation_data=(x_val, y_val),
                    )
print("")
print("")
trained_loss, trained_accuracy = model.evaluate(x_test, y_test)
print("fully connected model, trained: " + "loss: " + str(trained_loss) + " acc: " + str(trained_accuracy))

Fit model on training data


fully connected model, trained: loss: 0.0904313251376152 acc: 0.9757000207901001


In [7]:
model.save_weights("trained_weights.h5") # saving trained weights for later use