In [None]:
pip install scikeras

In [None]:

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from sklearn.model_selection import train_test_split, GridSearchCV
from scikeras.wrappers import KerasClassifier
from sklearn.datasets import load_iris


# Load the Iris dataset from scikit-learn
data = load_iris()
X = data.data  # Features
y = data.target  # Target labels


# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)


# Define a function to create your neural network model
def create_model(learning_rate=0.01, num_units=64):
    model = keras.Sequential([
        keras.layers.Dense(units=num_units, activation='relu', input_shape=(X_train.shape[1],)),
        keras.layers.Dense(units=num_units, activation='relu'),
        keras.layers.Dense(units=3, activation='softmax')  # Multi-class classification
    ])
    optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
    model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model


# Create a KerasClassifier with your model function
model = KerasClassifier(learning_rate=0.01,num_units=64,build_fn=create_model, epochs=5, batch_size=10)

# Define the hyperparameters you want to tune
param_grid = {
    'learning_rate': [0.001, 0.01, 0.1],
    'num_units': [32, 64, 128]
}

# Perform hyperparameter tuning using GridSearchCV
grid_search = GridSearchCV(estimator=model, param_grid=param_grid, cv=3, verbose=1)
grid_result = grid_search.fit(X_train, y_train)

# Print the best hyperparameters and corresponding performance
print(f"Best Parameters: {grid_result.best_params_}")
print(f"Best Accuracy: {grid_result.best_score_}")

# Train your final model with the best hyperparameters
best_model = grid_result.best_estimator_
best_model.fit(X_train, y_train, epochs=30, batch_size=32)


# Evaluate the final model on the test set
test_accuracy = best_model.score(X_test, y_test)
print(f"Test Accuracy: {test_accuracy}")
