In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import wandb
wandb.init(
    project="fashion-mnist-ffnn",
    config={
        "hidden_layers": [128, 64],
        "optimizer": "adam",
        "learning_rate": 0.001,
        "epochs": 20,
        "batch_size": 64
    }
)
config = wandb.config
(X_train, Y_train), (X_test, Y_test) = get_fa_mnist(flatten=True)
model = NeuralNet(
    input_dim=784,
    hidden_layers=config.hidden_layers,
    output_dim=10,
    optimizer=config.optimizer,
    lr=config.learning_rate
)
model.train(
    X_train,
    Y_train,
    X_test,
    Y_test,
    epochs=config.epochs,
    batch_size=config.batch_size
)
Y_pred_test, _ = model.forward(X_test)
y_true = np.argmax(Y_test, axis=1)
y_pred = np.argmax(Y_pred_test, axis=1)
test_accuracy = np.mean(y_true == y_pred)
wandb.log({"test_accuracy": test_accuracy})
num_classes = 10
cm = np.zeros((num_classes, num_classes), dtype=int)
for t, p in zip(y_true, y_pred):
    cm[t, p] += 1
cm_norm = cm / cm.sum(axis=1, keepdims=True)
wandb.log({
    "confusion_matrix": wandb.plot.confusion_matrix(
        probs=None,
        y_true=y_true,
        preds=y_pred,
        class_names=[str(i) for i in range(10)]
    )
})
class_accuracy = np.diag(cm) / cm.sum(axis=1)
for i, acc in enumerate(class_accuracy):
    wandb.log({f"class_{i}_accuracy": acc})
plt.figure(figsize=(8, 6))
sns.heatmap(
    cm_norm,
    annot=True,
    fmt=".2f",
    cmap="YlGnBu",
    xticklabels=range(10),
    yticklabels=range(10)
)
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Normalized Confusion Matrix â€“ Best Model (Test Set)")
plt.tight_layout()
plt.show()
wandb.finish()