Ensemble Learning

In [None]:
import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical

# Load MNIST dataset
(X_train, y_train), (X_test, y_test) = mnist.load_data()

# Preprocess data for binary classification (detecting digit '5' as diseased)
positive_digit = 5
y_train_binary = np.where(y_train == positive_digit, 1, 0)
y_test_binary = np.where(y_test == positive_digit, 1, 0)

# Reshape and normalize input data
X_train = X_train.reshape(-1, 28 * 28) / 255.0
X_test = X_test.reshape(-1, 28 * 28) / 255.0

# Split data into training and validation sets
X_train, X_val, y_train_binary, y_val_binary = train_test_split(X_train, y_train_binary, test_size=0.2, random_state=42)

# Define base classifiers for the ensemble
svm_clf = SVC(kernel='rbf', probability=True)
rf_clf = RandomForestClassifier(n_estimators=100)

# Train base classifiers
svm_clf.fit(X_train, y_train_binary)
rf_clf.fit(X_train, y_train_binary)

# Make predictions on validation set
svm_val_pred = svm_clf.predict(X_val)
rf_val_pred = rf_clf.predict(X_val)

# Evaluate individual classifiers
print(f"SVM Validation Accuracy: {accuracy_score(y_val_binary, svm_val_pred):.4f}")
print(f"Random Forest Validation Accuracy: {accuracy_score(y_val_binary, rf_val_pred):.4f}")

# Ensemble (Voting Classifier)
voting_clf = VotingClassifier(estimators=[('svm', svm_clf), ('random_forest', rf_clf)], voting='soft')
voting_clf.fit(X_train, y_train_binary)

# Make predictions on test set
ensemble_test_pred = voting_clf.predict(X_test)

# Evaluate ensemble performance
ensemble_accuracy = accuracy_score(y_test_binary, ensemble_test_pred)
print(f"Ensemble Test Accuracy: {ensemble_accuracy:.4f}")


SVM Validation Accuracy: 0.9937
Random Forest Validation Accuracy: 0.9887
Ensemble Test Accuracy: 0.9953
