In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.initializers import GlorotUniform, HeNormal, RandomNormal
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.losses import MeanSquaredError

# Carregar Fashion MNIST
(x_train, y_train), (_, _) = tf.keras.datasets.fashion_mnist.load_data()
x_train = x_train / 255.0  # Normalizar

# Usar apenas 10.000 exemplos para acelerar os testes
x_train = x_train[:10000]
y_train = y_train[:10000]

# Codificar rótulos como one-hot
y_train = tf.keras.utils.to_categorical(y_train, 10)

# Lista de funções de ativação para testar
activations = ['sigmoid', 'tanh', 'relu']

# Inicializações diferentes
initializers = [
    GlorotUniform(seed=0),
    HeNormal(seed=1),
    RandomNormal(mean=0.0, stddev=0.05, seed=2),
    GlorotUniform(seed=3),
    HeNormal(seed=4)
]

def build_model(activation, initializer):
    model = Sequential([
        Flatten(input_shape=(28, 28)),
        Dense(64, activation=activation, kernel_initializer=initializer),
        Dense(10, activation='softmax')
    ])
    return model

loss_fn = MeanSquaredError()

# Armazenar os resultados para plotar
results = {}

for activation in activations:
    print(f"\n🔍 Testando função de ativação: {activation}")
    histories = []

    for i, init in enumerate(initializers):
        print(f" Treinamento {i+1}/5 com inicialização diferente...")
        model = build_model(activation, init)
        model.compile(optimizer=SGD(learning_rate=0.01),
                      loss=loss_fn,
                      metrics=['mse'])

        history = model.fit(x_train, y_train, epochs=20, verbose=0)
        histories.append(history.history['loss'])

    # Plotar curvas de perda para cada ativação
    plt.figure(figsize=(10, 5))
    for i, loss_curve in enumerate(histories):
        plt.plot(loss_curve, label=f'Init {i+1}')
    plt.title(f'Curvas de perda (Loss) – Ativação: {activation}')
    plt.xlabel('Época')
    plt.ylabel('Loss (MSE)')
    plt.legend()
    plt.grid(True)
    plt.show()

    results[activation] = histories
