In [1]:
import tensorflow as tf
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, Input, Layer
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import to_categorical, plot_model
from tensorflow.keras.initializers import Constant
from matplotlib import pyplot as plt

#### Importo il dataset

In [2]:
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data(path="mnist.npz")

#### Pre-processing

In [3]:
train_images = (train_images / 255)
test_images = (test_images / 255)

train_images = train_images.reshape((-1, 784))
test_images = test_images.reshape((-1, 784))

In [4]:
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)

#### Creo una funzione per generare un blocco con architettura highway che mi consente di implementare una grande quantita di layers e evitare esplosione o svanimento del gradiente 

In [5]:
def get_higway_block(input):
    H = Dense(8, activation = 'swish')(input)
    T = Dense(8, activation = 'sigmoid', bias_initializer = Constant(-10))(input)
    return H * T + input * (1 - T)

input = Input(shape=(train_images.shape[1:]))

dense = Dense(8, activation = 'swish')(input)
highway = get_higway_block(dense)
for i in range(200):
    highway = get_higway_block(highway)

output = Dense(10, activation = 'softmax')(highway)

model = Model(input, output)
model.compile(loss = 'categorical_crossentropy', optimizer= Adam(), metrics = 'accuracy')
model.summary()

print(len(model.layers), 'layers')

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, 784)]                0         []                            
                                                                                                  
 dense (Dense)               (None, 8)                    6280      ['input_1[0][0]']             
                                                                                                  
 dense_2 (Dense)             (None, 8)                    72        ['dense[0][0]']               
                                                                                                  
 dense_1 (Dense)             (None, 8)                    72        ['dense[0][0]']               
                                                                                              

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



 dense_400 (Dense)           (None, 8)                    72        ['tf.__operators__.add_198[0][
                                                                    0]']                          
                                                                                                  
 dense_399 (Dense)           (None, 8)                    72        ['tf.__operators__.add_198[0][
                                                                    0]']                          
                                                                                                  
 tf.math.subtract_199 (TFOp  (None, 8)                    0         ['dense_400[0][0]']           
 Lambda)                                                                                          
                                                                                                  
 tf.math.multiply_398 (TFOp  (None, 8)                    0         ['dense_399[0][0]',           
 Lambda)  

#### Addestramento del modello 

In [6]:
history = model.fit(
    x=train_images,
    y=train_labels,
    epochs=10,
    batch_size=1024,
    validation_data=(
        test_images,
        test_labels
    )
)

Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
