In [4]:
import sys
sys.path.append("../")
from sklearn.datasets import load_digits, load_wine
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.neural_network import MLPClassifier

# Import the custom MLP class
from neural_networks.MLP import MLP  # Adjust the import according to your file structure

def test_mlp_on_dataset(X, y):
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    # Scale the features for better performance of the neural network
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_test = scaler.transform(X_test)

    # Initialize and train the custom MLP
    mlp_custom = MLP(layer_sizes=[X_train.shape[1], 64, len(set(y))], epochs=200)
    mlp_custom.fit(X_train, y_train)

    # Initialize and train the sklearn MLP
    mlp_sklearn = MLPClassifier(hidden_layer_sizes=(64,), max_iter=200, random_state=42)
    mlp_sklearn.fit(X_train, y_train)

    # Predict and calculate accuracy
    predictions_custom = mlp_custom.predict(X_test)
    accuracy_custom = accuracy_score(y_test, predictions_custom)

    predictions_sklearn = mlp_sklearn.predict(X_test)
    accuracy_sklearn = accuracy_score(y_test, predictions_sklearn)

    return accuracy_custom, accuracy_sklearn

def run_tests():
    # Load the digits dataset
    digits = load_digits()
    digits_accuracy_custom, digits_accuracy_sklearn = test_mlp_on_dataset(digits.data, digits.target)
    print(f"Digits dataset - Custom MLP Accuracy: {digits_accuracy_custom}")
    print(f"Digits dataset - Sklearn MLP Accuracy: {digits_accuracy_sklearn}")

    # Load the wine dataset
    wine = load_wine()
    wine_accuracy_custom, wine_accuracy_sklearn = test_mlp_on_dataset(wine.data, wine.target)
    print(f"Wine dataset - Custom MLP Accuracy: {wine_accuracy_custom}")
    print(f"Wine dataset - Sklearn MLP Accuracy: {wine_accuracy_sklearn}")

run_tests()

Digits dataset - Custom MLP Accuracy: 0.975
Digits dataset - Sklearn MLP Accuracy: 0.9833333333333333
Wine dataset - Custom MLP Accuracy: 1.0
Wine dataset - Sklearn MLP Accuracy: 1.0


