In [11]:
# set up working catalog
import sys
from pathlib import Path
project_path = str(Path().cwd().parent.parent.resolve())
if project_path not in sys.path:
    sys.path.append(project_path)

# imports
from common.utils import get_datasets, X_TRAIN, Y_TRAIN, X_TEST, Y_TEST
from common.mixture_of_experts import MixtureOfExperts

from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

In [12]:
datasets = get_datasets(n_splits=5, should_label_encode=True)

In [13]:
for dataset in datasets:
    mixtureOfExperts = MixtureOfExperts(
        experts=[LogisticRegression(max_iter=300), SVC(probability=True), RandomForestClassifier()]
    )
    
    mixtureOfExperts.fit(dataset[X_TRAIN], dataset[Y_TRAIN])
    y_pred = mixtureOfExperts.predict(dataset[X_TEST])
    
    print(f"Accuracy: {accuracy_score(dataset[Y_TEST], y_pred)}")

Accuracy: 0.7751412429378531
Accuracy: 0.768361581920904
Accuracy: 0.768361581920904
Accuracy: 0.7536723163841808
Accuracy: 0.7805429864253394


In [14]:
for dataset in datasets:
    mixtureOfExperts = MixtureOfExperts(
        experts=[LogisticRegression(max_iter=250) for _ in range(5)]
    )
    
    mixtureOfExperts.fit(dataset[X_TRAIN], dataset[Y_TRAIN])
    y_pred = mixtureOfExperts.predict(dataset[X_TEST])
    
    print(f"Accuracy: {accuracy_score(dataset[Y_TEST], y_pred)}")

Accuracy: 0.7661016949152543
Accuracy: 0.7796610169491526
Accuracy: 0.7672316384180791
Accuracy: 0.7480225988700565
Accuracy: 0.7613122171945701


In [15]:
for dataset in datasets:
    mixtureOfExperts = MixtureOfExperts(
        experts=[RandomForestClassifier() for _ in range(5)]
    )
    
    mixtureOfExperts.fit(dataset[X_TRAIN], dataset[Y_TRAIN])
    y_pred = mixtureOfExperts.predict(dataset[X_TEST])
    
    print(f"Accuracy: {accuracy_score(dataset[Y_TEST], y_pred)}")

Accuracy: 0.7638418079096045
Accuracy: 0.7774011299435029
Accuracy: 0.7774011299435029
Accuracy: 0.7333333333333333
Accuracy: 0.7669683257918553
