# Importação de pacotes

In [None]:
from tensorflow import keras
import matplotlib.pyplot as plt
import numpy as np

# Carregamento e preparação da base de dados

In [None]:
cifar = keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar.load_data()

In [None]:
x_train = x_train / 255.0
x_test = x_test / 255.0

# Criação do modelo

In [None]:
model = keras.models.Sequential([
  keras.layers.Conv2D(64, (3, 3), activation='relu', input_shape=(32, 32, 3)),
  keras.layers.MaxPooling2D(2, 2),
  keras.layers.Conv2D(64, (3, 3), activation='relu'),
  keras.layers.MaxPooling2D(2,2),
  keras.layers.Flatten(),
  keras.layers.Dense(128, activation='relu'),
  keras.layers.Dense(10, activation='softmax')
])

In [None]:
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.summary()

# Treinamento e validação

In [None]:
hist = model.fit(x_train, y_train, validation_split=0.2, epochs=5)

In [None]:
# Accuracy
plt.plot(hist.history['accuracy'])
plt.plot(hist.history['val_accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()

# Loss
plt.plot(hist.history['loss'])
plt.plot(hist.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()

In [None]:
test_loss, test_accuracy = model.evaluate(x_test, y_test)
print ('Test loss: {}, Test accuracy: {}'.format(test_loss, test_accuracy*100))

# Análise de feature map

In [None]:
image = x_test[48]
plt.imshow(image)
plt.show()

successive_outputs = [layer.output for layer in model.layers]
visualization_model = keras.models.Model(inputs = model.input, outputs = successive_outputs)

x = image.reshape((1,) + image.shape)
successive_feature_maps = visualization_model.predict(x)

feature_map = successive_feature_maps[0]

n_features = 10 # todas: feature_map.shape[-1]

for i in range(n_features):
  x = feature_map[0, :, :, i]
  x -= x.mean()
  x /= x.std()
  x *= 64
  x += 128
  x = np.clip(x, 0, 255).astype('uint8')
  plt.imshow(x)
  plt.show()