In [9]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Flatten, Input, Average
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import SparseCategoricalCrossentropy as scc
from tensorflow.keras.datasets import mnist
from spectraltools import Spectral, spectral_pruning

# Example of spectral pruning

In the following a branched functional model is created. Then the function `spectral_pruning` will return a pruned model.

In [3]:
# Dataset and model creation
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train / 255.
x_test = x_test / 255.

inputs = Input(shape=(28, 28,))
x = Flatten()(inputs)
y = Spectral(200, activation='relu', name='Spec1', use_bias=False)(x)
y = Spectral(300, activation='relu', is_diag_start_trainable=True, use_bias=False, name='Spec2')(y)
y = Spectral(300, activation='relu', name='Dense1')(y)

x = Spectral(200, activation='relu', name='Spec3', use_bias=False)(x)
x = Spectral(300, activation='relu', is_diag_start_trainable=True, use_bias=False, name='Spec4')(x)
x = Spectral(300, activation='relu', name='Spec5')(x)

z = Average()([x, y])
outputs = Dense(10, activation="softmax")(z)

model = Model(inputs=inputs, outputs=outputs, name="branched")

model.compile(optimizer=Adam(1E-3), loss=scc(from_logits=False), metrics=["accuracy"])
model.fit(x_train, y_train, validation_split=0.2, batch_size=300, epochs=1, verbose=0)
model.evaluate(x_test, y_test, batch_size=300)



[0.41358765959739685, 0.8770999908447266]

In [6]:
# Now the 30% of the spectral layers node will be pruned according to their relevance
pruned = spectral_pruning(model, 30)
pruned.evaluate(x_test, y_test, batch_size=300)
pruned.summary()


Model: "branched"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 28, 28)]     0                                            
__________________________________________________________________________________________________
flatten (Flatten)               (None, 784)          0           input_1[0][0]                    
__________________________________________________________________________________________________
Spec3 (Spectral)                (None, 143)          113039      flatten[0][0]                    
__________________________________________________________________________________________________
Spec1 (Spectral)                (None, 146)          115394      flatten[0][0]                    
___________________________________________________________________________________________