In [1]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Flatten, Input, Average
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import SparseCategoricalCrossentropy as scc
from tensorflow.keras.datasets import mnist
from spectraltools import Spectral, spectral_pretrain

# Example of spectral pretrain

In the following a branched functional model is created. Then the function `spectral_pretrain` will return a smaller model by training only the eigenvalues. The subnetwork that had the "luckiest initialization" is therefore spotted.  

In [2]:
# Dataset and model creation
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train / 255.
x_test = x_test / 255.

inputs = Input(shape=(28, 28,))
x = Flatten()(inputs)
y = Spectral(200, activation='relu', name='Spec1', use_bias=False)(x)
y = Spectral(300, activation='relu', is_diag_start_trainable=True, use_bias=False, name='Spec2')(y)
y = Spectral(300, activation='relu', name='Dense1')(y)

x = Spectral(200, activation='relu', name='Spec3', use_bias=False)(x)
x = Spectral(300, activation='relu', is_diag_start_trainable=True, use_bias=False, name='Spec4')(x)
x = Spectral(300, activation='relu', name='Spec5')(x)

z = Average()([x, y])
outputs = Dense(10, activation="softmax")(z)

model = Model(inputs=inputs, outputs=outputs, name="branched")
model.compile(optimizer=Adam(1E-3), loss=scc(from_logits=False), metrics=["accuracy"])


In [5]:
# Now the 30% of the spectral layers node will be pruned according to their relevance
fit_dict = dict(x=x_train, y=y_train, batch_size=300, epochs=10, verbose=0)
eval_dict = dict(x=x_test, y=y_test, batch_size=300)
pruned = spectral_pretrain(model, 
                           fit_dictionary=fit_dict,
                           eval_dictionary=eval_dict,
                           max_delta=10,
                           compare_with='acc')
pruned.summary()

0.00033209714996125864
0.00022139809997417242
0.0005534292791070717
0.0008854604582399708
0.006308790316010163
0.009407901919850062
0.01582739128584731
0.027227480280494768
0.040730521370755064
0.08677367354929323
0.12473718046616966
Model: "branched"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 28, 28)]     0                                            
__________________________________________________________________________________________________
flatten (Flatten)               (None, 784)          0           input_1[0][0]                    
__________________________________________________________________________________________________
Spec3 (Spectral)                (None, 80)           63584       flatten[0][0]                    
_______________________________________________________