In [1]:
import sys
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import backend as K
import lottery_ticket_pruner
from lottery_ticket_pruner import LotteryTicketPruner, PrunerCallback
from mine import MINE
import pickle




# In[2]: loading MNIST data for training


# Load the MNIST dataset using TensorFlow
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# reshape data as 2D numpy arrays
# convert to float32 and normalize grayscale for better num. representation
x_train = x_train.reshape(60000, 784).astype("float32") / 255
x_test = x_test.reshape(10000, 784).astype("float32") / 255

y_train = y_train.astype("float32")
y_test = y_test.astype("float32")

# The tutorial reserved 10.000 training samples for validation, we change to 5.000 
# as that is what Frankle and Carbin did in their paper
x_val = x_train[-5000:]
y_val = y_train[-5000:]
x_train = x_train[:-5000]
y_train = y_train[:-5000]
y_train_1hot = keras.utils.to_categorical(y_train, num_classes=10) # need y_train in a 1-hot encoded array for mine


# In[3]: Hyperparameters for the experiment

epochs_LT = 6 # epochs for the tickets, 5.45 epochs for about 5000 iterations, which is early-stop iteration in Frankle et al. paper
batch_size_LT = 60 # mini-batch size for the tickets
batch_size_mine = 100 # batch size for MINE algorithm
epochs_mine = 100 # epochs for MINE algorithm
validation_split = 1/11 # 5000 val 55000 train data
input_dim = 784 # input_distribution dim. for MINE, also dim. size of MNIST input
d1_dim = 300 # first hidden layer size for lottery ticket model, also first hidden layer activation distribution dim. for MINE
d2_dim = 100  # second hidden layer size for lottery ticket model, also second hidden layer activation distribution dim. for MINE
o_dim = 10 # output layer size for lottery ticket model, also output layer distribution dim. for MINE
pruning_rate = 0.2 # pruning rate for LTH iterative Pruning -> removes pruning_rate% of lowest magnitude weights in an iteration
pruning_iterations = 15  # number of iterations for applying the pruning rate iteratively -> 1 time : 20% sparse, 24 times : ~99.5% sparse
averaging_iterations = 1 # Frankle et al. usually use average of 5 trials
# we train this script as singular runs

In [2]:
tf.keras.backend.clear_session() # clearing backend right at start, just in case

inputs = keras.Input(shape=(input_dim,), name="digits") # Functional build of a 2-hidden layer fully connected MLP
x = layers.Dense(d1_dim, activation="ReLU", name="dense_1")(inputs) # methods made no mention of the activaton function specifically
x = layers.Dense(d2_dim, activation="ReLU", name="dense_2")(x) # ReLU is standard, as all available implementations seem to use it too
outputs = layers.Dense(o_dim, activation="softmax", name="predictions")(x)  # softmax activation for multi-class classification

base_model = keras.Model(inputs=inputs, outputs=outputs)
base_model.summary()


# loading the saved initialization
base_model.load_weights("init_weights_fs.h5")
init_model = keras.models.clone_model(base_model)
init_weights = init_model.get_weights() # init weights for Lotter Ticket reset to initial weights

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 digits (InputLayer)         [(None, 784)]             0         
                                                                 
 dense_1 (Dense)             (None, 300)               235500    
                                                                 
 dense_2 (Dense)             (None, 100)               30100     
                                                                 
 predictions (Dense)         (None, 10)                1010      
                                                                 
Total params: 266,610
Trainable params: 266,610
Non-trainable params: 0
_________________________________________________________________


In [3]:
base_model.layers[2].get_weights()

[array([[ 0.11849876,  0.08326044,  0.03046269, ...,  0.00832217,
         -0.03234859,  0.02240524],
        [-0.05200388, -0.11659102,  0.11483195, ...,  0.04681451,
         -0.09298775,  0.04070682],
        [ 0.07204082,  0.11276603,  0.06366249, ..., -0.03666466,
         -0.03787875,  0.02010769],
        ...,
        [ 0.06125017, -0.11605567, -0.05846643, ...,  0.10483941,
         -0.04802044, -0.09744417],
        [ 0.0922941 ,  0.08039788, -0.10164817, ..., -0.0538088 ,
          0.04122659, -0.08141087],
        [ 0.06768062, -0.10923558, -0.01338455, ..., -0.03052533,
          0.0835809 , -0.06167452]], dtype=float32),
 array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.

In [4]:
init_model.layers[2].get_weights()

[array([[-0.09970823, -0.02305449,  0.11002179, ..., -0.11474014,
         -0.05269753,  0.03375185],
        [ 0.06214058, -0.02567535,  0.12043392, ...,  0.09188355,
         -0.10343234,  0.00736142],
        [-0.04803406, -0.02286296,  0.0575523 , ...,  0.10050128,
          0.0285947 ,  0.12061263],
        ...,
        [ 0.10425439, -0.05459931, -0.03847858, ..., -0.04969621,
         -0.09415815, -0.05277631],
        [-0.10485386,  0.03297601, -0.05529372, ..., -0.11328057,
         -0.08493343,  0.11632317],
        [ 0.02431818,  0.01961897, -0.05147316, ..., -0.03757758,
          0.10760636, -0.04880322]], dtype=float32),
 array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.