In [1]:
import tensorflow.keras
from tensorflow.keras.datasets import fashion_mnist
from matplotlib import pyplot as plt
import numpy as np

(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

# Veamos la forma tiene x_train
print("Shape:", x_train.shape)  # 60.000 imágenes de 28x28

# Veamos una imagen cualquiera, por ejemplo, con el índice 125
image = np.array(x_train[125], dtype='float')
plt.imshow(image, cmap='gray')
plt.show()

print("Label:", y_train[125])

In [2]:
print("Max value:", max(x_train[125].reshape(784)))
print("Min value:", min(x_train[125].reshape(784)))

In [3]:
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')

x_train /= 255  # Escalamos a un rango entre 0 y 1
x_test /= 255

x_train -= 0.5  # desplazamos el rango a -0.5 y 0.5
x_test -= 0.5

print("Max value:", max(x_train[125].reshape(784)))
print("Min value:", min(x_train[125].reshape(784)))

x_train = x_train.reshape(60000, 784)
x_test = x_test.reshape(10000, 784)

In [4]:
y_train = tensorflow.keras.utils.to_categorical(y_train, 10)  # 10 clases
y_test = tensorflow.keras.utils.to_categorical(y_test, 10)

print("Label:", y_train[125])  # Recordemos que esta muestra tenía valor 8

In [5]:
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model

inputs = Input(shape=(784,))  # Capa de entrada
output_h = Dense(units=32, activation='sigmoid')(inputs)  # Capa oculta
output_h2 = Dense(units=64, activation='sigmoid')(output_h)  # Capa oculta
output_h3 = Dense(units=25, activation='sigmoid')(output_h2)  # Capa oculta
predictions = Dense(10, activation='softmax')(output_h3)  # Capa de salida

model = Model(inputs=inputs, outputs=predictions)

In [6]:
model.compile(loss='mse',
              optimizer=tensorflow.keras.optimizers.SGD(lr=1),
              metrics=['accuracy'])

In [None]:
history = model.fit(x_train, y_train, epochs=50, batch_size=20, validation_data=(x_test, y_test))

In [None]:
from matplotlib import pyplot as plt 

plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label='validation accuracy')

plt.title('Entrenamiento MNIST')
plt.xlabel('Épocas')
plt.legend(loc="lower right")

plt.show()