## Importação de pacotes

In [None]:
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

## Carregamento da base de dados

In [None]:
cifar = keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar.load_data()
labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

In [None]:
sampleID = 7

classe = y_train[sampleID][0]
print('Classe:', labels[classe])

plt.imshow(x_train[sampleID])
plt.show()

print(x_train[sampleID])

### Normalização

In [None]:
x_train  = x_train / 255.0
x_test = x_test / 255.0

In [None]:
for i in range(10):

  classe = y_train[i][0]
  print('Classe:', labels[classe])

  plt.imshow(x_train[i])
  plt.show()

## Construção e treinamento do modelo

In [None]:
model = keras.Sequential()

model.add( keras.layers.Flatten() )
model.add( keras.layers.Dense( 128, activation=keras.activations.relu ) )
model.add( keras.layers.Dense( 10, activation=keras.activations.softmax ) )

model.compile(optimizer = keras.optimizers.Adam(),
              loss = keras.losses.sparse_categorical_crossentropy,
              metrics = ['accuracy'])

In [None]:
hist = model.fit(x_train, y_train, validation_split=0.2, epochs=15)

## Avaliando os resultados

In [None]:
# Accuracy
plt.plot(hist.history['accuracy'])
plt.plot(hist.history['val_accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()

# Loss
plt.plot(hist.history['loss'])
plt.plot(hist.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()

In [None]:
model.evaluate(x_test, y_test)

In [None]:
classifications = model.predict(x_test)

In [None]:
for i in range(20):
  classe = np.argmax(classifications[i])
  if classe != y_test[i]:
    print('Erro em', i)
    print('Vetor de resultados:', classifications[i])
    print('Era para ser', labels[int(y_test[i])], 'mas foi classificado como', labels[classe])
    plt.imshow(x_test[i])
    plt.show()