In [15]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Flatten, Input, Average
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import SparseCategoricalCrossentropy as scc
from tensorflow.keras.datasets import mnist
from spectraltools import Spectral, spectral_pruning

In [2]:
## Creating the branched model

In [16]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train / 255.
x_test = x_test / 255.

inputs = Input(shape=(28, 28,))
x = Flatten()(inputs)
y = Spectral(200, activation='relu', name='Spec1', use_bias=False)(x)
y = Spectral(300, activation='relu', is_diag_start_trainable=True, use_bias=False, name='Spec2')(y)
y = Spectral(300, activation='relu', name='Dense1')(y)

x = Spectral(200, activation='relu', name='Spec3', use_bias=False)(x)
x = Spectral(300, activation='relu', is_diag_start_trainable=True, use_bias=False, name='Spec4')(x)
x = Spectral(300, activation='relu', name='Spec5')(x)

z = Average()([x, y])
outputs = Dense(10, activation="softmax")(z)

model = Model(inputs=inputs, outputs=outputs, name="branched")

model.compile(optimizer=Adam(1E-3), loss=scc(from_logits=False), metrics=["accuracy"])

In [17]:
model.summary()
model.fit(x_train, y_train, validation_split=0.2, batch_size=300, epochs=1, verbose=0)
model.evaluate(x_test, y_test, batch_size=300)

Model: "branched"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_5 (InputLayer)            [(None, 28, 28)]     0                                            
__________________________________________________________________________________________________
flatten_4 (Flatten)             (None, 784)          0           input_5[0][0]                    
__________________________________________________________________________________________________
Spec3 (Spectral)                (None, 200)          157784      flatten_4[0][0]                  
__________________________________________________________________________________________________
Spec1 (Spectral)                (None, 200)          157784      flatten_4[0][0]                  
___________________________________________________________________________________________

[0.45438194274902344, 0.8636999726295471]

In [19]:
new = spectral_pruning(model, 80)
new.evaluate(x_test, y_test, batch_size=300)
new.summary()


Model: "branched"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_5 (InputLayer)            [(None, 28, 28)]     0                                            
__________________________________________________________________________________________________
flatten_4 (Flatten)             (None, 784)          0           input_5[0][0]                    
__________________________________________________________________________________________________
Spec3 (Spectral)                (None, 72)           57304       flatten_4[0][0]                  
__________________________________________________________________________________________________
Spec1 (Spectral)                (None, 79)           62799       flatten_4[0][0]                  
___________________________________________________________________________________________