In [1]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Flatten, Input, Average
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import SparseCategoricalCrossentropy as scc
from tensorflow.keras.datasets import mnist
from tensorflow.keras.regularizers import l2
from tensorflow.config import experimental
physical_devices = experimental.list_physical_devices('GPU')
for dev in physical_devices:
    experimental.set_memory_growth(dev, True)
    
from spectraltools import Spectral


# Example of spectral training

In the following a branched functional model is created using several Spectral Layers. An L2 regularization is also applied as we would like to prune the model later on. The model is trained for 10 epoch and then evaluated on the test set. 

In [2]:
# Dataset and model creation
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train / 255.
x_test = x_test / 255.

spectral_configuration = {'activation': 'relu', 
                          'use_bias': True,
                          'base_regularizer': l2(1E-3),
                          'diag_regularizer': l2(5E-3)}

inputs = Input(shape=(28, 28,))
x = Flatten()(inputs)
y = Spectral(200,  **spectral_configuration, name='Spec1')(x)
y = Spectral(300,  **spectral_configuration, name='Spec2')(y)
outputs = Dense(10, activation="softmax", name='LastDense')(y)

model = Model(inputs=inputs, outputs=outputs, name="branched")

compile_dict=dict(optimizer=Adam(1E-3), 
                  loss=scc(from_logits=False), 
                  metrics=["accuracy"])

model.compile(**compile_dict)
model.fit(x_train, y_train, validation_split=0.2, batch_size=300, epochs=1, verbose=1)
model.evaluate(x_test, y_test, batch_size=300)



[2.307720184326172, 0.9397000074386597]

# Example of spectral pruning
Now that the model has been trained, we can prune it. In the following we will prune the 30% of the spectral layers nodes according to their relevance. The model is then evaluated on the test set.

In [3]:
from spectraltools import prune_percentile, metric_based_pruning
from spectraltools.spectralprune import original_model

In [4]:
# Now the 30% of the spectral layers node will be in place pruned according to their relevance. The eigenvalues whose magnitude is smaller than the corresponding percentile will be set to zero by masking the corresponding weights. This will also have an effect on the corresponding bias which will be also masked.
pruned_model = prune_percentile(model, 50,
                                compile_dictionary=compile_dict)
print(f'Pruned accuracy: {pruned_model.evaluate(x_test, y_test, batch_size=300)[1]:.3f}')

Number of nodes masked: 250 out of 500 (50.00%)
Pruned accuracy: 0.651


As we can see masking 50% of the eigenvalues has basically no impact in the accuracy of the model. This is due to the fact that the pruned eigenvalues are very small and therefore their contribution to the model is negligible; making the whole feature not relevant.


# Example of metric based spectral pruning 
In the following code we will prune the model according to the metric based approach. In this case we will prune until a given drop in the accuracy is reached. In this case we will prune until the accuracy drops by 5%.

In [5]:
import numpy as np

print(f'Baseline accuracy: {model.evaluate(x_test, y_test, batch_size=300)[1]:.3f}')
# Cycle through the spectral layers and count the number of active nodes

for lay in pruned_model.layers:
    if hasattr(lay, 'diag_end_mask'):
        print(f'Layer {lay.name} has {np.count_nonzero(lay.diag_end_mask)} active nodes')
    
pruned_model = metric_based_pruning(model, 
                     eval_dictionary=dict(x=x_train, y=y_train, batch_size=200),
                     compile_dictionary=compile_dict,
                     compare_metric='accuracy',
                     max_delta_percent=3)

Baseline accuracy: 0.940
Layer Spec1 has 140 active nodes
Layer Spec2 has 110 active nodes
Number of nodes masked: 0 out of 500 (0.00%)
Pruning with 0% of eigenvalues removed. Delta in accuracy: 0.0000%
Number of nodes masked: 25 out of 500 (5.00%)
Pruning with 5% of eigenvalues removed. Delta in accuracy: 0.1026%
Number of nodes masked: 50 out of 500 (10.00%)
Pruning with 10% of eigenvalues removed. Delta in accuracy: 0.2388%
Number of nodes masked: 75 out of 500 (15.00%)
Pruning with 15% of eigenvalues removed. Delta in accuracy: 0.8174%
Number of nodes masked: 100 out of 500 (20.00%)
Pruning with 20% of eigenvalues removed. Delta in accuracy: 1.6489%
Number of nodes masked: 125 out of 500 (25.00%)
Pruning with 25% of eigenvalues removed. Delta in accuracy: 4.5237%
Number of nodes masked: 100 out of 500 (20.00%)


In [6]:
import numpy as np
print(f'Pruned accuracy: {pruned_model.evaluate(x_test, y_test, batch_size=300)[1]:.3f}')
for lay in pruned_model.layers:
    if hasattr(lay, 'diag_end_mask'):
        print(f'Layer {lay.name} has {np.count_nonzero(lay.diag_end_mask)} active nodes')

Pruned accuracy: 0.925
Layer Spec1 has 186 active nodes
Layer Spec2 has 214 active nodes
