#### TEST SPSA IMPLEMENTATION IN TENSORFLOW

Objective : Test the SPSA alorithm method for a gradient descent of the loss function, within a complex model like a Convolution Neural Network.
We use the Model keras class object in Tensorflow.

In [144]:
import tensorflow as tf
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt
import copy

from IPython.display import clear_output

In [3]:
import os, sys

class HiddenPrints:
    def __enter__(self):
        self._original_stdout = sys.stdout
        sys.stdout = open(os.devnull, 'w')

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout.close()
        sys.stdout = self._original_stdout

In [4]:
with HiddenPrints():
    print("This will not be printed")

In [5]:
LIMIT = 100

In [151]:
class Model_for_SPSA(tf.keras.Model):

    def __init__(self) -> None:
        super(Model_for_SPSA, self).__init__()

        # Define the CONV model
        self.inputs_shape = (28, 28, 1, )
        self.conv_1 = tf.keras.layers.Conv2D(64, 2, input_shape=self.inputs_shape)
        self.conv_2 = tf.keras.layers.Conv2D(16, 2)
        self.conv_3 = tf.keras.layers.Conv2D(8, 2)
        self.flatten = tf.keras.layers.Flatten()
        self.dense_final = tf.keras.layers.Dense(10, "softmax")

        # And to automaticly active the model :
        data_activate = np.random.random(self.inputs_shape)
        self(np.expand_dims(data_activate, axis=0))

        self.metric_list = []
        self.nb_layers = len(self.weights)


    def call(self, inputs):
        x = self.conv_1(inputs)
        x = self.conv_2(x)
        x = self.conv_3(x)
        x = self.flatten(x)
        return self.dense_final(x)
    
    def compile_metrics(self, metric_list): # a special method to compile any metric
        for metric in metric_list:
            self.metric_list.append(metric)

    def compile_SPSA_parameters(self, c, alpha, gamma): # a special method to compile SPSA optimizer (hyperparameters)
        self.spsa_alpha = alpha
        self.spsa_c = c
        self.spsa_gamma = gamma


    def grad_loss_spsa(self, c_k, batch_data, batch_label): #to compute the gradient of loss function with the SPSA algorithm

        with HiddenPrints():
                    model_plus = copy.copy(self)
                    model_minus = copy.copy(self)

        Delta_list = [] # we have to stock the Delta vectors for each layer

        for i in range(self.nb_layers):

            dim_base_weight = np.shape(np.array(self.trainable_weights[i])) # we take the dimension of the keras layer

            D = c_k*(2*np.round(np.random.random(dim_base_weight)) - 1) # we compute the random vector perturbation
            
            # Two model randomly perturbed with a gap = 2 * Delta*c_k:
            model_plus.trainable_weights[i].assign_add(model_plus.trainable_weights[i] + D) 
            model_minus.trainable_weights[i].assign_add(model_minus.trainable_weights[i] - D)
            
            # we stock Delta vector for compute the final update for each layer later
            Delta_list.append(D)

        # return the list of Delta vector, and the global approximation SPSA derivate loss scalar
        return Delta_list , (self.loss(batch_label, model_plus(batch_data)) - self.loss(batch_label, model_minus(batch_data))) / (2*c_k)
        
        
    
    def fit_SPSA(self, data_train, validation_data, epochs): # special method of fit with SPSA method
        # data_train : a dataset object
        # validation_data : a dataset object

        # --- Define Variables ---
        
        #Computing the targets of the train and test dataset
        labels_train = np.concatenate([y for x, y in data_train], axis=0)
        labels_test = np.concatenate([y for x, y in validation_data], axis=0)

        # we also have to find our batch_size

        for data in data_train.take(1):
            batch_size = np.shape(data[0])[0]

        steps_per_epochs = len(data_train)
        steps_per_epochs_test = len(validation_data)

        #--- Define SPSA additionnal parameters ---

        # to define the magnitude_g0, we take all the dataset divised by 10... (not just a little batch)... to have a better approximation
        batch_data_big_sample = next(iter(data_train.rebatch(int(steps_per_epochs * batch_size / 10))))[0]
        batch_label_big_sample = next(iter(data_train.rebatch(int(steps_per_epochs * batch_size / 10))))[1]

        _ , derivate_loss_init = self.grad_loss_spsa(self.spsa_c, batch_data_big_sample, batch_label_big_sample)

        magnitude_g0 = np.abs(derivate_loss_init)

        A = 0.1*epochs

        a = 0.1*((A+1)**self.spsa_alpha)/magnitude_g0

        print('Hyperparameters initialized for the beginning of the fit :')
        print(f'- derivate_loss_init : {derivate_loss_init}')
        print(f'- magnitude_g0 : {magnitude_g0}')
        print(f'- a : {a}')
        print(f'- c : {self.spsa_c}')

        # We also need a big batch to compute the loss of the test set (for final results of each epoch)

        batch_data_big_sample_test = next(iter(validation_data.rebatch(int(steps_per_epochs_test * batch_size / 10))))[0]
        batch_label_big_sample_test = next(iter(validation_data.rebatch(int(steps_per_epochs_test * batch_size / 10))))[1]

        # --- fit ----

        for epoch in range(epochs):

            print(f'--- epoch : {epoch} / {epochs} ---')

            for step in range(steps_per_epochs):
                clear_output(wait=True)
                print(f'--- epoch : {epoch} / {epochs} ---')
                print(f'step : {step} / {steps_per_epochs}')


                # update a_k and c_k ---
                a_k = a / (epoch + 1 + A)**self.spsa_alpha
                c_k = self.spsa_c / (epoch + 1)**self.spsa_gamma

                print(f'a_k : {a_k} - c_k : {c_k}')

                # Compute the SPSA gradient of the loss function for a sample batch ---

                batch_data = next(iter(data_train))[0]
                batch_label = next(iter(data_train))[1]

                Delta_list, derivate_loss = self.grad_loss_spsa(c_k, batch_data, batch_label)

                # Update condition ---

                update=True # variable of updating condition (transformed in "False" if one of all the layers become too big !...)

                for i in range(self.nb_layers): # for each layer we compare the norm of the tensor of all the weights (divised by the number of weights) with our LIMIt parameter

                    # Uncomment for a look on the vectors... and its norm...
                    #print(f'for variation {i} : ')
                    #print(f'norme variation vecteur poids à venir_{i} : {np.linalg.norm(- a_k * c_k * (derivate_loss / Delta_list[i]))}')
                    #print(f'norme variation vecteur poids à venir (normé)_{i} : {np.linalg.norm(- a_k * c_k * (derivate_loss / Delta_list[i])) / np.product(np.shape(Delta_list[i]))}')
                    #print(a_k * c_k * derivate_loss / Delta_list[i])
                    if np.linalg.norm(- a_k * c_k * (derivate_loss / Delta_list[i])) / np.product(np.shape(Delta_list[i])) > LIMIT:
                        #print(f'limits for {i}...no update')
                        update=False # all the layers not will be update
                        break

                if update:
                    for i in range(self.nb_layers):
                        self.trainable_weights[i].assign_add( self.trainable_weights[i] - a_k * c_k * (derivate_loss / Delta_list[i]) )


                        #print(f'result vecteur poids_{i}')
                        #print(self.trainable_weights[i])
                        #print(f'result variation_{i} on norm : {np.linalg.norm(- a_k * c_k * (derivate_loss / Delta_list[i]))}')

            # Display the results of the loss and metric at the end of the epoch :

            print(f'results epochs {epoch + 1}/{epochs} : ')

            print('--Train_set--')

            print(f'Loss_function : {self.loss(batch_label_big_sample, self(batch_data_big_sample))}')
            y_pred = self.predict(data_train)
            self.metric_list[0].update_state(labels_train, y_pred)
            print(f'metric : {self.metric_list[0].result().numpy()}')

            print('--Test_set--')

            print(f'Loss_function : {self.loss(batch_label_big_sample_test, self(batch_data_big_sample_test))}')

            y_pred = self.predict(validation_data)
            self.metric_list[0].update_state(labels_test, y_pred)
            print(f'metric : {self.metric_list[0].result().numpy()}')


---------

IMPORT DATA (MNIST) :

In [152]:
BATCH_SIZE = 64

In [153]:
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)



In [154]:
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(BATCH_SIZE)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

In [155]:
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(BATCH_SIZE)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

----------

TEST WITH A COMMON FIT :

In [164]:
model_common = Model_for_SPSA()

In [165]:
model_common.compile(optimizer=tf.keras.optimizers.Adam(),
                   loss=tf.keras.losses.SparseCategoricalCrossentropy(),
                   metrics=tf.keras.metrics.SparseCategoricalAccuracy())

In [166]:
model_common.fit(ds_train, validation_data=ds_test, epochs=2)

Epoch 1/2
Epoch 2/2


<keras.callbacks.History at 0x7fc118b61ea0>

----------

TEST WITH THE SPSA FIT :

In [167]:
SPSA_model = Model_for_SPSA()

In [168]:
SPSA_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy())

In [169]:
SPSA_model.compile_metrics([tf.keras.metrics.SparseCategoricalAccuracy()])

In [170]:
SPSA_model.compile_SPSA_parameters(0.001, 0.01, 0.01)

In [171]:
SPSA_model.fit_SPSA(ds_train, ds_test, epochs=2)

--- epoch : 1 / 2 ---
step : 937 / 938
a_k : 0.014062226559774724 - c_k : 0.000993092495437036
results epochs 2/2 : 
--Train_set--
Loss_function : nan
metric : 0.09866154193878174
--Test_set--
Loss_function : nan
metric : 0.09861428290605545
