In [1]:
from sklearn.ensemble import AdaBoostClassifier, GradientBoostingClassifier, RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import GaussianNB
from sklearn.svm import SVC
from sklearn.neural_network import MLPClassifier
from xgboost import XGBClassifier
from eipy.ei import EnsembleIntegration
import eipy.utils as ut
from eipy.additional_ensembles import MeanAggregation, CES
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
from sklearn import datasets
pd.set_option('display.max_columns', None)

In [2]:
# If data is multi-class, run a check on the allowable base and meta models.

base_predictors = {
                    'ADAB': AdaBoostClassifier(),
                    'XGB': XGBClassifier(),
                    'DT': DecisionTreeClassifier(),
                    'RF': RandomForestClassifier(),
                    'GB': GradientBoostingClassifier(),
                    'KNN': KNeighborsClassifier(),
                    'LR': LogisticRegression(multi_class="multinomial"),
                    'NB': GaussianNB(),
                    'MLP': MLPClassifier(),
                    'SVM': SVC(probability=True)
}

In [3]:

"""
For filtering base predictors by whether or not they rely on heursitics for multiclass extension

natively_multi_class_predictors = ["XGBClassifier",
"BernoulliNB",
"DecisionTreeClassifier",
"ExtraTreeClassifier",
"GaussianNB",
"KNeighborsClassifier",
"LabelPropagation",
"LabelSpreading",
"LinearDiscriminantAnalysis",
"LinearSVC", #(setting multi_class=”crammer_singer”)
"LogisticRegression", #(setting multi_class=”multinomial”)
"LogisticRegressionCV", #(setting multi_class=”multinomial”)
"MLPClassifier",
"NearestCentroid",
"QuadraticDiscriminantAnalysis",
"RadiusNeighborsClassifier",
"RandomForestClassifier",
"RidgeClassifier",
"RidgeClassifierCV"]

base_predictors = {k : v for k,v in base_predictors.items() if str(v).split("(")[0] in natively_multi_class_predictors}
"""

'\nFor filtering base predictors by whether or not they rely on heursitics for multiclass extension\n\nnatively_multi_class_predictors = ["XGBClassifier",\n"BernoulliNB",\n"DecisionTreeClassifier",\n"ExtraTreeClassifier",\n"GaussianNB",\n"KNeighborsClassifier",\n"LabelPropagation",\n"LabelSpreading",\n"LinearDiscriminantAnalysis",\n"LinearSVC", #(setting multi_class=”crammer_singer”)\n"LogisticRegression", #(setting multi_class=”multinomial”)\n"LogisticRegressionCV", #(setting multi_class=”multinomial”)\n"MLPClassifier",\n"NearestCentroid",\n"QuadraticDiscriminantAnalysis",\n"RadiusNeighborsClassifier",\n"RandomForestClassifier",\n"RidgeClassifier",\n"RidgeClassifierCV"]\n\nbase_predictors = {k : v for k,v in base_predictors.items() if str(v).split("(")[0] in natively_multi_class_predictors}\n'

In [4]:
"""https://dev.pages.lis-lab.fr/scikit-multimodallearn/tutorial/auto_examples/combo/plot_combo_3_views_3_classes.html#
multi modal multi-class toy data generation"""

def generate_data(n_samples, lim):
    """Generate random data in a rectangle"""
    lim = np.array(lim)
    n_features = lim.shape[0]
    data = np.random.random((n_samples, n_features))
    data = (lim[:, 1]-lim[:, 0]) * data + lim[:, 0]
    return data
seed = 12
np.random.seed(seed)

n_samples = 300

modality_0 = np.concatenate((generate_data(n_samples, [[0., 1.], [0., 1.]]),
                         generate_data(n_samples, [[1., 2.], [0., 1.]]),
                         generate_data(n_samples, [[0., 2.], [0., 1.]])))

modality_1 = np.concatenate((generate_data(n_samples, [[1., 2.], [0., 1.]]),
                         generate_data(n_samples, [[0., 2.], [0., 1.]]),
                         generate_data(n_samples, [[0., 1.], [0., 1.]])))

modality_2 = np.concatenate((generate_data(n_samples, [[0., 2.], [0., 1.]]),
                         generate_data(n_samples, [[0., 1.], [0., 1.]]),
                         generate_data(n_samples, [[1., 2.], [0., 1.]])))

y = np.zeros(3*n_samples, dtype=np.int64)
y[n_samples:2*n_samples] = 1
y[2*n_samples:] = 2


In [5]:
X_0_train, X_0_test, y_train, y_test = train_test_split(modality_0, y, test_size=0.2, random_state=3, stratify=y)
X_1_train, X_1_test, _,_ = train_test_split(modality_1, y, test_size=0.2, random_state=3, stratify=y)
X_2_train, X_2_test, _,_ = train_test_split(modality_2, y, test_size=0.2, random_state=3, stratify=y)

In [6]:
data_train = {
                "Modality_0": X_0_train,
                "Modality_1": X_1_train,
                "Modality_2": X_2_train
                }

data_test = {
                "Modality_0": X_0_test,
                "Modality_1": X_1_test,
                "Modality_2": X_2_test
                }

In [7]:
EI = EnsembleIntegration(
                        base_predictors=base_predictors,
                        k_outer=5,
                        k_inner=5,
                        n_samples=1,
                        sampling_strategy=None,
                        n_jobs=-1,
                        random_state=42,
                        project_name="toy",
                        model_building=True,
                        )



In [8]:
EI.train_base(data_train, y_train)

Training base predictors on Modality_0...
        
... for ensemble performance analysis...


Generating meta training data: |██████████|100%
Generating meta test data: |██████████|100%



... for final ensemble...


Generating meta training data: |██████████|100%
Training final base predictors: |██████████|100%




Training base predictors on Modality_1...
        
... for ensemble performance analysis...


Generating meta training data: |██████████|100%
Generating meta test data: |██████████|100%



... for final ensemble...


Generating meta training data: |██████████|100%
Training final base predictors: |██████████|100%




Training base predictors on Modality_2...
        
... for ensemble performance analysis...


Generating meta training data: |██████████|100%
Generating meta test data: |██████████|100%



... for final ensemble...


Generating meta training data: |██████████|100%
Training final base predictors: |██████████|100%






In [9]:
EI.meta_training_data[0]

modality,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,labels
base predictor,ADAB,ADAB,ADAB,XGB,XGB,XGB,DT,DT,DT,RF,RF,RF,GB,GB,GB,KNN,KNN,KNN,LR,LR,LR,NB,NB,NB,MLP,MLP,MLP,SVM,SVM,SVM,ADAB,ADAB,ADAB,XGB,XGB,XGB,DT,DT,DT,RF,RF,RF,GB,GB,GB,KNN,KNN,KNN,LR,LR,LR,NB,NB,NB,MLP,MLP,MLP,SVM,SVM,SVM,ADAB,ADAB,ADAB,XGB,XGB,XGB,DT,DT,DT,RF,RF,RF,GB,GB,GB,KNN,KNN,KNN,LR,LR,LR,NB,NB,NB,MLP,MLP,MLP,SVM,SVM,SVM,Unnamed: 91_level_1
sample,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,Unnamed: 91_level_2
class,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,Unnamed: 91_level_3
0,5.033375e-01,2.222770e-16,0.496663,0.990476,0.000190,0.009334,1.0,0.0,0.0,0.91,0.00,0.09,0.873335,0.006062,0.120604,1.0,0.0,0.0,0.741568,0.022422,0.236010,0.787548,0.000082,0.212370,0.734722,0.004587,0.260691,0.694402,0.002815,0.302783,5.038713e-01,0.496129,2.223630e-16,0.461601,0.537134,0.001266,0.0,1.0,0.0,0.51,0.49,0.00,0.446751,0.542124,0.011125,0.8,0.2,0.0,0.812660,0.179730,0.007610,0.705237,0.294760,3.893088e-06,0.757580,0.238677,0.003743,0.704742,0.286515,0.008744,0.496191,2.223908e-16,5.038085e-01,0.034906,0.000094,0.965000,0.0,0.0,1.0,0.19,0.00,0.81,0.190731,0.005388,0.803881,0.4,0.0,0.6,0.402319,0.113573,0.484108,0.293706,0.014156,0.692139,0.344770,0.058037,0.597193,0.235021,0.003676,0.761303,0
1,5.033375e-01,2.222770e-16,0.496663,0.082868,0.000607,0.916525,0.0,0.0,1.0,0.54,0.00,0.46,0.305751,0.005220,0.689029,0.4,0.0,0.6,0.688637,0.025537,0.285826,0.743124,0.000103,0.256774,0.661490,0.010317,0.328193,0.685259,0.006523,0.308217,5.038713e-01,0.496129,2.223630e-16,0.120333,0.877135,0.002532,0.0,1.0,0.0,0.35,0.65,0.00,0.205477,0.780552,0.013971,0.6,0.4,0.0,0.643465,0.320709,0.035826,0.728743,0.270921,3.361942e-04,0.651904,0.333255,0.014840,0.636935,0.355380,0.007685,0.496820,5.031799e-01,2.222557e-16,0.130016,0.868783,0.001201,0.0,1.0,0.0,0.27,0.73,0.00,0.616100,0.379187,0.004714,0.2,0.8,0.0,0.381799,0.527239,0.090961,0.271800,0.725416,0.002784,0.356892,0.614702,0.028406,0.314243,0.683171,0.002586,0
2,2.223397e-16,5.036501e-01,0.496350,0.008628,0.492612,0.498759,0.0,0.0,1.0,0.04,0.63,0.33,0.012749,0.491289,0.495962,0.2,0.4,0.4,0.260780,0.271721,0.467498,0.197136,0.270495,0.532369,0.315407,0.241999,0.442594,0.252198,0.188472,0.559330,2.222529e-16,0.496881,5.031186e-01,0.000192,0.005501,0.994307,0.0,0.0,1.0,0.00,0.04,0.96,0.004441,0.080650,0.914909,0.0,0.0,1.0,0.016252,0.182152,0.801596,0.000027,0.219851,7.801217e-01,0.002285,0.205562,0.792153,0.005052,0.318134,0.676815,0.496820,5.031799e-01,2.222557e-16,0.678152,0.319825,0.002023,0.0,1.0,0.0,0.56,0.44,0.00,0.427438,0.565553,0.007009,0.4,0.6,0.0,0.303977,0.650942,0.045081,0.253442,0.746303,0.000256,0.341900,0.651430,0.006671,0.335281,0.662332,0.002388,1
3,2.223397e-16,5.036501e-01,0.496350,0.000314,0.914500,0.085186,0.0,1.0,0.0,0.01,0.74,0.25,0.007984,0.813579,0.178437,0.4,0.4,0.2,0.217677,0.251799,0.530523,0.158370,0.199676,0.641954,0.171173,0.328954,0.499873,0.145179,0.396572,0.458250,2.222529e-16,0.496881,5.031186e-01,0.000097,0.001815,0.998088,0.0,0.0,1.0,0.00,0.01,0.99,0.003233,0.107689,0.889078,0.0,0.2,0.8,0.017392,0.188231,0.794377,0.000033,0.219310,7.806570e-01,0.002537,0.215564,0.781898,0.004159,0.319830,0.676011,0.496191,2.223908e-16,5.038085e-01,0.006253,0.000095,0.993652,0.0,0.0,1.0,0.06,0.00,0.94,0.076960,0.002392,0.920648,0.2,0.0,0.8,0.328162,0.063236,0.608602,0.230185,0.003556,0.766259,0.322704,0.029860,0.647435,0.191561,0.002338,0.806101,2
4,5.033375e-01,2.222770e-16,0.496663,0.205049,0.002299,0.792652,1.0,0.0,0.0,0.52,0.01,0.47,0.495245,0.010051,0.494704,0.4,0.0,0.6,0.389640,0.149040,0.461319,0.522100,0.040081,0.437819,0.513466,0.096828,0.389706,0.514961,0.019730,0.465309,5.038713e-01,0.496129,2.223630e-16,0.729893,0.269725,0.000383,1.0,0.0,0.0,0.89,0.11,0.00,0.656200,0.337153,0.006647,0.8,0.2,0.0,0.689093,0.285340,0.025567,0.737586,0.262294,1.199098e-04,0.676987,0.312717,0.010295,0.632066,0.361283,0.006652,0.496191,2.223908e-16,5.038085e-01,0.559073,0.002290,0.438637,1.0,0.0,0.0,0.60,0.00,0.40,0.256467,0.006856,0.736677,0.2,0.0,0.8,0.334040,0.064163,0.601797,0.238629,0.002247,0.759124,0.334455,0.027040,0.638505,0.235444,0.001804,0.762752,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
571,5.034012e-01,2.222858e-16,0.496599,0.730658,0.000891,0.268450,1.0,0.0,0.0,0.85,0.00,0.15,0.326098,0.010482,0.663420,0.8,0.0,0.2,0.813545,0.010634,0.175820,0.763151,0.000017,0.236831,0.737883,0.003810,0.258307,0.674201,0.010199,0.315600,5.038713e-01,0.496129,2.223630e-16,0.771925,0.227552,0.000523,1.0,0.0,0.0,0.71,0.29,0.00,0.832982,0.158148,0.008870,0.6,0.4,0.0,0.583089,0.367700,0.049211,0.751509,0.247555,9.359377e-04,0.629503,0.345457,0.025041,0.759947,0.236493,0.003560,0.496942,5.030579e-01,2.222392e-16,0.588948,0.409782,0.001270,1.0,0.0,0.0,0.54,0.46,0.00,0.523892,0.470108,0.005999,0.6,0.4,0.0,0.410048,0.507348,0.082605,0.319283,0.676623,0.004094,0.366665,0.608913,0.024422,0.286633,0.711375,0.001992,0
572,5.034012e-01,2.222858e-16,0.496599,0.741070,0.003562,0.255368,0.0,0.0,1.0,0.73,0.00,0.27,0.758719,0.006605,0.234676,0.8,0.0,0.2,0.491131,0.072969,0.435900,0.655889,0.004847,0.339265,0.592587,0.046026,0.361386,0.672380,0.005945,0.321674,5.038713e-01,0.496129,2.223630e-16,0.709526,0.287658,0.002815,1.0,0.0,0.0,0.69,0.31,0.00,0.758846,0.231045,0.010109,0.8,0.2,0.0,0.744517,0.242041,0.013442,0.733706,0.266279,1.542590e-05,0.670398,0.321818,0.007784,0.710672,0.284397,0.004932,0.496942,5.030579e-01,2.222392e-16,0.105804,0.893397,0.000799,1.0,0.0,0.0,0.44,0.56,0.00,0.243999,0.750289,0.005711,0.4,0.6,0.0,0.267849,0.711016,0.021135,0.278899,0.721085,0.000017,0.341974,0.653626,0.004401,0.324069,0.671813,0.004119,0
573,2.223192e-16,5.035152e-01,0.496485,0.001255,0.920323,0.078422,0.0,1.0,0.0,0.00,0.74,0.26,0.003857,0.919627,0.076515,0.0,0.8,0.2,0.025359,0.779247,0.195393,0.000097,0.819810,0.180093,0.009429,0.766037,0.224534,0.005261,0.701820,0.292919,2.222572e-16,0.496849,5.031511e-01,0.000191,0.012208,0.987600,0.0,0.0,1.0,0.00,0.10,0.90,0.003511,0.452824,0.543666,0.0,0.2,0.8,0.025723,0.207135,0.767142,0.000044,0.205749,7.942068e-01,0.002974,0.234306,0.762720,0.001675,0.305228,0.693098,0.496066,2.223978e-16,5.039340e-01,0.023767,0.001044,0.975189,0.0,0.0,1.0,0.25,0.00,0.75,0.119909,0.010660,0.869431,0.2,0.0,0.8,0.270401,0.039299,0.690299,0.222018,0.000400,0.777582,0.285722,0.017523,0.696755,0.276874,0.003653,0.719473,2
574,5.034012e-01,2.222858e-16,0.496599,0.988050,0.000777,0.011173,1.0,0.0,0.0,0.87,0.00,0.13,0.924327,0.001435,0.074238,0.8,0.0,0.2,0.761829,0.009976,0.228195,0.653560,0.000010,0.346429,0.676873,0.005897,0.317230,0.672044,0.010727,0.317228,5.038713e-01,0.496129,2.223630e-16,0.599464,0.391220,0.009316,1.0,0.0,0.0,0.44,0.56,0.00,0.798176,0.192122,0.009702,0.4,0.6,0.0,0.822668,0.171797,0.005535,0.525422,0.474577,4.761566e-07,0.672010,0.325162,0.002828,0.457654,0.519184,0.023162,0.496942,5.030579e-01,2.222392e-16,0.097810,0.902037,0.000153,0.0,1.0,0.0,0.19,0.81,0.00,0.220555,0.774808,0.004636,0.2,0.8,0.0,0.428953,0.470013,0.101034,0.356938,0.632894,0.010168,0.380057,0.586231,0.033712,0.292738,0.702961,0.004301,0


In [10]:
EI.meta_test_data[0]

modality,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_0,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_1,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,Modality_2,labels
base predictor,ADAB,ADAB,ADAB,XGB,XGB,XGB,DT,DT,DT,RF,RF,RF,GB,GB,GB,KNN,KNN,KNN,LR,LR,LR,NB,NB,NB,MLP,MLP,MLP,SVM,SVM,SVM,ADAB,ADAB,ADAB,XGB,XGB,XGB,DT,DT,DT,RF,RF,RF,GB,GB,GB,KNN,KNN,KNN,LR,LR,LR,NB,NB,NB,MLP,MLP,MLP,SVM,SVM,SVM,ADAB,ADAB,ADAB,XGB,XGB,XGB,DT,DT,DT,RF,RF,RF,GB,GB,GB,KNN,KNN,KNN,LR,LR,LR,NB,NB,NB,MLP,MLP,MLP,SVM,SVM,SVM,Unnamed: 91_level_1
sample,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,Unnamed: 91_level_2
class,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,Unnamed: 91_level_3
0,2.223189e-16,5.035294e-01,0.496471,0.000618,0.720623,0.278758,0.0,1.0,0.0,0.00,0.83,0.17,0.007180,0.731360,0.261460,0.0,0.8,0.2,0.015874,0.766577,0.217548,0.000027,0.769516,0.230457,0.003011,0.728071,0.268918,0.002542,0.689202,0.308256,2.222654e-16,0.496788,5.032118e-01,0.000556,0.030934,0.968510,0.0,0.0,1.0,0.00,0.05,0.95,0.006157,0.442907,0.550936,0.0,0.2,0.8,0.062645,0.332167,0.605188,0.001618,0.246503,0.751880,0.011936,0.312141,0.675923,0.002091,0.321329,0.676580,0.496301,2.223652e-16,5.036992e-01,0.908271,0.002862,0.088867,1.0,0.0,0.0,0.62,0.00,0.38,0.562974,0.010962,0.426065,0.6,0.0,0.4,0.443625,0.161026,0.395349,0.412818,0.065808,0.521374,0.349900,0.098474,0.551626,0.398528,0.018970,0.582502,2
1,5.034139e-01,2.222883e-16,0.496586,0.219468,0.005430,0.775102,0.0,0.0,1.0,0.24,0.00,0.76,0.534920,0.008563,0.456517,0.6,0.0,0.4,0.445201,0.090013,0.464786,0.622085,0.008826,0.369089,0.555652,0.052631,0.391717,0.636889,0.012816,0.350295,2.222654e-16,0.496788,5.032118e-01,0.000828,0.844419,0.154753,0.0,1.0,0.0,0.00,0.79,0.21,0.006322,0.525537,0.468141,0.0,0.6,0.4,0.061354,0.332486,0.606160,0.001445,0.240839,0.757716,0.011552,0.320727,0.667721,0.001249,0.304024,0.694727,0.496301,2.223652e-16,5.036992e-01,0.784510,0.001360,0.214130,1.0,0.0,0.0,0.50,0.01,0.49,0.381531,0.007041,0.611428,0.4,0.0,0.6,0.160084,0.012417,0.827499,0.254391,0.000053,0.745556,0.272021,0.003683,0.724295,0.283650,0.006901,0.709450,2
2,5.034139e-01,2.222883e-16,0.496586,0.732608,0.000160,0.267231,1.0,0.0,0.0,0.70,0.00,0.30,0.707335,0.006983,0.285682,0.8,0.0,0.2,0.801036,0.014812,0.184152,0.780021,0.000052,0.219927,0.726409,0.002624,0.270968,0.686473,0.006124,0.307403,5.037331e-01,0.496267,2.223411e-16,0.465257,0.533994,0.000749,1.0,0.0,0.0,0.68,0.32,0.00,0.757383,0.233014,0.009603,0.6,0.4,0.0,0.633054,0.328613,0.038332,0.760700,0.238771,0.000530,0.656826,0.328632,0.014543,0.705697,0.291707,0.002596,0.496301,2.223652e-16,5.036992e-01,0.765530,0.000351,0.234118,1.0,0.0,0.0,0.66,0.00,0.34,0.614759,0.003882,0.381359,0.6,0.0,0.4,0.191402,0.016820,0.791778,0.267024,0.000076,0.732901,0.293355,0.004577,0.702068,0.307633,0.001984,0.690382,0
3,5.034139e-01,2.222883e-16,0.496586,0.876348,0.000254,0.123398,1.0,0.0,0.0,0.76,0.00,0.24,0.850283,0.002924,0.146793,0.6,0.0,0.4,0.260913,0.198842,0.540244,0.301386,0.114217,0.584397,0.258367,0.243032,0.498601,0.265139,0.186146,0.548715,2.222654e-16,0.496788,5.032118e-01,0.770131,0.106752,0.123117,0.0,0.0,1.0,0.05,0.35,0.60,0.003264,0.150026,0.846710,0.0,0.2,0.8,0.280249,0.488608,0.231143,0.270400,0.539250,0.190350,0.347345,0.428086,0.224569,0.270439,0.465189,0.264373,0.496301,2.223652e-16,5.036992e-01,0.016474,0.000094,0.983433,1.0,0.0,0.0,0.30,0.00,0.70,0.120875,0.003040,0.876084,0.6,0.0,0.4,0.455666,0.198010,0.346325,0.468849,0.142448,0.388703,0.393194,0.156214,0.450592,0.555346,0.069046,0.375608,2
4,5.034139e-01,2.222883e-16,0.496586,0.492079,0.000627,0.507294,0.0,0.0,1.0,0.63,0.00,0.37,0.272582,0.009598,0.717821,0.8,0.0,0.2,0.805281,0.012175,0.182545,0.743119,0.000028,0.256853,0.714202,0.002355,0.283443,0.684181,0.006495,0.309324,5.037331e-01,0.496267,2.223411e-16,0.795016,0.204328,0.000656,1.0,0.0,0.0,0.64,0.36,0.00,0.727153,0.267495,0.005352,0.6,0.4,0.0,0.557519,0.379345,0.063136,0.719914,0.277717,0.002369,0.631407,0.343918,0.024675,0.732465,0.264654,0.002881,0.496301,2.223652e-16,5.036992e-01,0.008265,0.000167,0.991569,0.0,0.0,1.0,0.15,0.00,0.85,0.088668,0.004193,0.907138,0.0,0.0,1.0,0.263961,0.037628,0.698411,0.204523,0.001023,0.794454,0.285422,0.012241,0.702337,0.232020,0.002750,0.765230,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
139,5.034139e-01,2.222883e-16,0.496586,0.598946,0.001246,0.399808,1.0,0.0,0.0,0.65,0.01,0.34,0.562459,0.007929,0.429612,0.4,0.0,0.6,0.660310,0.035683,0.304007,0.757708,0.000547,0.241745,0.687454,0.009499,0.303047,0.691182,0.003897,0.304922,5.037331e-01,0.496267,2.223411e-16,0.380407,0.619274,0.000319,0.0,1.0,0.0,0.75,0.25,0.00,0.596115,0.400438,0.003447,0.8,0.2,0.0,0.434019,0.453015,0.112966,0.602076,0.376522,0.021402,0.598784,0.334281,0.066936,0.673849,0.311067,0.015083,0.496301,2.223652e-16,5.036992e-01,0.013676,0.000110,0.986214,0.0,0.0,1.0,0.11,0.00,0.89,0.246981,0.006279,0.746740,0.2,0.0,0.8,0.425577,0.141522,0.432901,0.377372,0.050351,0.572277,0.318369,0.071767,0.609864,0.376655,0.014596,0.608749,0
140,5.034139e-01,2.222883e-16,0.496586,0.895722,0.000379,0.103899,1.0,0.0,0.0,0.72,0.00,0.28,0.676814,0.001660,0.321527,0.6,0.0,0.4,0.721730,0.014693,0.263578,0.681658,0.000031,0.318310,0.672675,0.004744,0.322580,0.678148,0.013602,0.308251,5.037331e-01,0.496267,2.223411e-16,0.979047,0.020876,0.000077,1.0,0.0,0.0,0.92,0.08,0.00,0.789191,0.207425,0.003384,1.0,0.0,0.0,0.353012,0.472125,0.174863,0.442759,0.487746,0.069495,0.502833,0.383264,0.113902,0.629780,0.349123,0.021098,0.496301,2.223652e-16,5.036992e-01,0.067536,0.000492,0.931972,0.0,0.0,1.0,0.08,0.00,0.92,0.120218,0.007723,0.872058,0.0,0.0,1.0,0.239997,0.028569,0.731435,0.222622,0.000360,0.777019,0.293802,0.008447,0.697751,0.282284,0.001291,0.716424,0
141,2.223189e-16,5.035294e-01,0.496471,0.000606,0.628569,0.370825,0.0,1.0,0.0,0.00,0.74,0.26,0.005543,0.713472,0.280984,0.0,0.8,0.2,0.190736,0.273824,0.535441,0.136800,0.266462,0.596738,0.133570,0.382941,0.483489,0.057380,0.427876,0.514744,2.222654e-16,0.496788,5.032118e-01,0.001441,0.148884,0.849675,0.0,0.0,1.0,0.00,0.16,0.84,0.004522,0.111622,0.883856,0.0,0.0,1.0,0.009791,0.145438,0.844771,0.000004,0.266566,0.733430,0.000672,0.241926,0.757403,0.003978,0.306077,0.689945,0.496301,2.223652e-16,5.036992e-01,0.109366,0.000377,0.890257,0.0,0.0,1.0,0.29,0.00,0.71,0.359522,0.009978,0.630501,0.4,0.0,0.6,0.437364,0.150108,0.412529,0.392514,0.051587,0.555899,0.337723,0.083950,0.578327,0.363116,0.013678,0.623206,2
142,2.223189e-16,5.035294e-01,0.496471,0.000275,0.901953,0.097772,0.0,1.0,0.0,0.01,0.75,0.24,0.005624,0.654730,0.339645,0.0,0.8,0.2,0.007671,0.832865,0.159464,0.000003,0.721223,0.278774,0.001253,0.754727,0.244020,0.007412,0.689814,0.302774,2.222654e-16,0.496788,5.032118e-01,0.001173,0.881385,0.117442,0.0,1.0,0.0,0.01,0.78,0.21,0.008211,0.616826,0.374963,0.0,0.4,0.6,0.183800,0.461152,0.355048,0.074223,0.485933,0.439845,0.134681,0.462223,0.403096,0.070806,0.536763,0.392430,0.496688,5.033118e-01,2.222723e-16,0.089760,0.910099,0.000141,0.0,1.0,0.0,0.18,0.82,0.00,0.137506,0.854570,0.007924,0.2,0.8,0.0,0.308509,0.650478,0.041013,0.250095,0.749646,0.000259,0.341602,0.652026,0.006372,0.323515,0.675633,0.000853,1


In [11]:
# create a summary of base predictor performance
EI.base_summary = ut.create_base_summary(EI.meta_test_data)
EI.base_summary

     (Modality_0, ADAB, 0)  (Modality_0, DT, 0)  (Modality_0, GB, 0)  \
0                        1                    1                    1   
1                        0                    2                    0   
2                        0                    0                    0   
3                        0                    0                    0   
4                        0                    2                    2   
..                     ...                  ...                  ...   
139                      0                    0                    0   
140                      0                    0                    0   
141                      1                    1                    1   
142                      1                    1                    1   
143                      1                    1                    1   

     (Modality_0, KNN, 0)  (Modality_0, LR, 0)  (Modality_0, MLP, 0)  \
0                       1                    1                 

AssertionError: axis must be a MultiIndex

In [None]:
EI.train_meta(meta_predictors=base_predictors)

Analyzing ensembles: |          |  0%

Analyzing ensembles: |          |  0%


ValueError: Classification metrics can't handle a mix of multiclass and continuous-multioutput targets

In [None]:
EI.meta_summary["metrics"]

In [None]:
y_pred = EI.predict(X_dict=data_test, meta_model_key="XGB")
y_pred = np.round(y_pred)
y_pred

In [None]:
y_test

In [None]:
accuracy = sum([1*(y==y_hat)+0*(y!=y_hat) for y,y_hat in list(zip(y_test, y_pred))])/len(y_test)
accuracy # =179/180

In [None]:
EI.meta_summary['thresholds']["XGB"]

In [None]:
iris = datasets.load_iris()
X = iris.data
y = iris.target

Modality_a = X[:, 0:2]
Modality_b = X[:, 2:4]

X_a_train, X_a_test, y_train, y_test = train_test_split(Modality_a, y, test_size=0.2, random_state=3, stratify=y)
X_b_train, X_b_test, _,_ = train_test_split(Modality_b, y, test_size=0.2, random_state=3, stratify=y)

In [None]:
iris_data_train = {
                "Modality_a": X_a_train,
                "Modality_b": X_b_train
                }

iris_data_test = {
                "Modality_a": X_a_test,
                "Modality_b": X_b_test
                }

In [None]:
base_predictors = {
                    'ADAB': AdaBoostClassifier(),
                    'XGB': XGBClassifier(),
                    'DT': DecisionTreeClassifier(),
                    'RF': RandomForestClassifier(),
                    'GB': GradientBoostingClassifier(),
                    'KNN': KNeighborsClassifier(),
                    'LR': LogisticRegression(),
                    'NB': GaussianNB(),
                    'MLP': MLPClassifier(),
                    'SVM': SVC(probability=True)
}

In [None]:
EI_iris = EnsembleIntegration(
                        base_predictors=base_predictors,
                        k_outer=5,
                        k_inner=5,
                        n_samples=1,
                        sampling_strategy=None,
                        n_jobs=-1,
                        random_state=0,
                        project_name="iris",
                        model_building=True,
                        )


In [None]:
for name, modality in iris_data_train.items():
    EI_iris.train_base(modality, y_train, modality_name=name)

In [None]:
EI_iris.meta_training_data

In [None]:
EI_iris.train_meta(meta_predictors=base_predictors)

In [None]:
EI_iris.meta_summary["metrics"]

In [None]:
y_pred_iris = EI_iris.predict(X_dict=iris_data_test, meta_model_key="SVM")
y_pred_iris = np.round(y_pred_iris)
y_pred_iris

In [None]:
accuracy = sum([1*(y==y_hat)+0*(y!=y_hat) for y,y_hat in list(zip(y_test, y_pred_iris))])/len(y_test)
accuracy

In [None]:
import pandas as pd
import numpy as np

# Create a dictionary with 2D arrays as values
data = {
    'Column1': [np.array([1, 2]), np.array([3, 4]), np.array([5, 6])],
    'Column2': [np.array([10, 20]), np.array([30, 40]), np.array([50, 60])]
}

# Create a DataFrame with MultiIndex
df = pd.DataFrame()

# Iterate over the dictionary items and reshape 2D arrays into sub-columns
for key, value in data.items():
    num_cols = value[0].shape[0]  # Get the shape of the 2D array
    print(num_cols)
    for i in range(num_cols):
        # Create sub-columns for each element of the 2D array
        sub_column_name = f'SubColumn{i + 1}'
        df[(key, sub_column_name)] = [row[i] for row in value]

# Print the DataFrame
print(df)
