In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import os
import sys

In [None]:
dir_current = os.path.abspath('')
dir_parent  = os.path.dirname(dir_current)
if not dir_parent in sys.path: sys.path.append(dir_parent)

In [None]:
tf.random.set_seed(1234)
np.random.seed(1234)

## Loading MNIST Data

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
# Normalizando
x_train = x_train/255.
x_test  = x_test/255.
# Expandiendo dimensiones desde (28x28) a (28x28x1)
x_train = tf.expand_dims(x_train, -1)
x_test  = tf.expand_dims(x_test, -1)
# Creando subconjunto de validacion
x_valid = x_train[50000:]  
y_valid = y_train[50000:]  

x_train = x_train[:50000]
y_train = y_train[:50000]


## Creating Dataset Objects

In [None]:
from functions import to_categorical

bs = 32

#Training
# ------------------------------
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
# Shuffle
train_dataset = train_dataset.shuffle(buffer_size=x_train.shape[0])
train_dataset = train_dataset.map(to_categorical)
# Divide in batches
train_dataset = train_dataset.batch(bs)
# Repeat
train_dataset = train_dataset.repeat()
#Validation   
# -----------------------
valid_dataset = tf.data.Dataset.from_tensor_slices((x_valid, y_valid))
# Enconding
valid_dataset = valid_dataset.map(to_categorical)
# Divide in batches
valid_dataset = valid_dataset.batch(bs)
# Repeat
valid_dataset = valid_dataset.repeat()
#Testing 
# -------------------
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_dataset = test_dataset.map(to_categorical)
test_dataset = test_dataset.batch(1)

## Building Lenet5 Model

In [None]:
from models import Lenet_body

input_layer  = tf.keras.Input((28, 28, 1))
output_layer = Lenet_body(input_layer, Quantization = False)

Lenet = tf.keras.Model(inputs=input_layer, outputs=output_layer)

## Summary of the network

In [None]:
Lenet.summary()

## Training Options

In [None]:
# Optimization params
# -------------------

# Loss
loss = tf.keras.losses.CategoricalCrossentropy()

# learning rate
lr = 1e-3
optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
# -------------------

# Validation metrics
# ------------------

metrics = ['accuracy']
# ------------------

# Compile Model
Lenet.compile(optimizer=optimizer, loss=loss, metrics=metrics)

## Callbacks

In [None]:
import os
from datetime import datetime

early_stop = True
cwd = os.getcwd()
callbacks = []

# Early Stopping
# --------------
if early_stop:
    es_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
    callbacks.append(es_callback)

## Training

In [None]:
Lenet.fit(x=train_dataset,
          epochs=100,  #### set repeat in training dataset
          steps_per_epoch=int(np.ceil(x_train.shape[0] / bs)),
          validation_data=valid_dataset,
          validation_steps=int(np.ceil(x_valid.shape[0] / bs)), 
          callbacks=callbacks)

## Saving Weights

In [None]:
Wgt_dir = os.path.join(cwd, 'TrainedWeights')
if not os.path.exists(Wgt_dir):
    os.makedirs(Wgt_dir)
Wgt_dir = os.path.join(Wgt_dir, 'Weights')
Lenet.save_weights(Wgt_dir)