# Ensemble of predictive models

This notebook introduces the ensemble of T2DM predictive models and demonstrate the code implementation.

There are two cohort based on 16s rRNA microbiome sequencing involved in this research. One is Guandong Gut Microbiome Project (GGMP) and the other is Shandong Gut Microbiome Project (SGMP). To address bias and inconstancy across different cohorts or even regions within a cohort, we developed T2DM prediction models using strategy of ensemble learning. We trained an ensemble of models from random subsampling (without replacement) of an equal number of both healthy and T2DM samples. Each T2DM prediction model was trained using XGBoost with generated SHAP values, which quantify how microbiome features contributed to different models. To filter models with poor performance and verify the significance of consistently contributed features, 100 null models were developed for each within-model cross-validation. By aggregating the SHAP values across remained T2DM prediction models, microbiome signatures are extracted for each samples.

![ensemble model schema](images/ensemble_of_predictive_models.png)

In [1]:
import numpy as np
import os
import pandas as pd
import pickle
import sys
from functools import reduce
from scipy.stats import wilcoxon
from tqdm import tqdm

sys.path.insert(0, '../../t2d_gut_microbiome_signatures')
import dataset.GGMP.load_data as ggmp_data
import dataset.SGMP.load_data as sgmp_data
from model import learner, brewer

### Load datasets

load microbiome datasets of GGMP and SGMP for modeling.

In [2]:
ggmp_genus_table_file = 'data/GGMP/table-filtered-feature-rarefied5k-L6.tsv'
ggmp_metadata_file = 'data/GGMP/sample-metadata.tsv'
sgmp_genus_table_file = 'data/SGMP/table-filtered-feature-rarefied5k-L6.tsv'
sgmp_metadata_file = 'data/SGMP/sample-metadata.tsv'

ggmp_genus_table = ggmp_data.get_genus_table_for_xgb(ggmp_genus_table_file)
ggmp_metadata = ggmp_data.get_disease_and_healthy(ggmp_metadata_file)
sgmp_genus_table = sgmp_data.get_genus_table_for_xgb(sgmp_genus_table_file)
sgmp_metadata = sgmp_data.get_disease_and_healthy(sgmp_metadata_file)

In [3]:
print(f'GGMP genus table: {ggmp_genus_table.shape}')
ggmp_genus_table.head()

GGMP genus table: (6998, 341)


Unnamed: 0_level_0,k__Archaea;p__Euryarchaeota;c__Methanobacteria;o__Methanobacteriales;f__Methanobacteriaceae;g__Methanobrevibacter,k__Archaea;p__Euryarchaeota;c__Methanobacteria;o__Methanobacteriales;f__Methanobacteriaceae;g__Methanosphaera,k__Archaea;p__Euryarchaeota;c__Thermoplasmata;o__E2;f__ Methanomassiliicoccaceae ;g__vadinCA11,k__Bacteria;p__Acidobacteria;c__ Chloracidobacteria ;o__RB41;f__Ellin6075;g__,k__Bacteria;p__Actinobacteria;c__Actinobacteria;o__Actinomycetales;f__Actinomycetaceae;g__,k__Bacteria;p__Actinobacteria;c__Actinobacteria;o__Actinomycetales;f__Actinomycetaceae;g__Actinomyces,k__Bacteria;p__Actinobacteria;c__Actinobacteria;o__Actinomycetales;f__Actinomycetaceae;g__Mobiluncus,k__Bacteria;p__Actinobacteria;c__Actinobacteria;o__Actinomycetales;f__Actinomycetaceae;g__Varibaculum,k__Bacteria;p__Actinobacteria;c__Actinobacteria;o__Actinomycetales;f__Brevibacteriaceae;g__Brevibacterium,k__Bacteria;p__Actinobacteria;c__Actinobacteria;o__Actinomycetales;f__Corynebacteriaceae;g__Corynebacterium,...,k__Bacteria;p__Synergistetes;c__Synergistia;o__Synergistales;f__Synergistaceae;g__Synergistes,k__Bacteria;p__TM7;c__TM7-3;o__;f__;g__,k__Bacteria;p__Tenericutes;c__Mollicutes;o__Mycoplasmatales;f__Mycoplasmataceae;g__Mycoplasma,k__Bacteria;p__Tenericutes;c__Mollicutes;o__RF39;f__;g__,k__Bacteria;p__Tenericutes;c__RF3;o__ML615J-28;f__;g__,k__Bacteria;p__Verrucomicrobia;c__Opitutae;o__ Cerasicoccales ;f__ Cerasicoccaceae ;g__,k__Bacteria;p__Verrucomicrobia;c__Verrucomicrobiae;o__Verrucomicrobiales;f__Verrucomicrobiaceae;g__Akkermansia,k__Bacteria;p__WPS-2;c__;o__;f__;g__,k__Bacteria;p__ Thermi ;c__Deinococci;o__Deinococcales;f__Deinococcaceae;g__Deinococcus,k__Bacteria;p__ Thermi ;c__Deinococci;o__Thermales;f__Thermaceae;g__Thermus
SampleID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
G440606620,0.001,0.0,0.0,0.0,0.0,0.001,0.0,0.0,0.0014,0.0,...,0.0002,0.0004,0.0,0.0018,0.0,0.0,0.0056,0.0,0.0,0.0
G445224053,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0002,...,0.0008,0.0,0.0,0.0026,0.0,0.0,0.0,0.0,0.0,0.0
G440305296,0.0022,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0014,0.0,0.0,0.0004,0.0,0.0,0.0
G445302608,0.0002,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0008,0.0002,0.0,0.0
G441502283,0.0002,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0466,0.0,0.0,0.0002,0.0,0.0,0.0


In [4]:
print(f'SGMP genus table: {sgmp_genus_table.shape}')
sgmp_genus_table.head()

SGMP genus table: (2012, 174)


Unnamed: 0_level_0,k__Bacteria;p__Actinobacteria;c__Actinobacteria;o__Actinomycetales;f__Actinomycetaceae;g__Actinomyces,k__Bacteria;p__Actinobacteria;c__Actinobacteria;o__Actinomycetales;f__Corynebacteriaceae;g__Corynebacterium,k__Bacteria;p__Actinobacteria;c__Actinobacteria;o__Actinomycetales;f__Microbacteriaceae;g__Microbacterium,k__Bacteria;p__Actinobacteria;c__Actinobacteria;o__Actinomycetales;f__Micrococcaceae;g__Kocuria,k__Bacteria;p__Actinobacteria;c__Actinobacteria;o__Actinomycetales;f__Micrococcaceae;g__Rothia,k__Bacteria;p__Actinobacteria;c__Actinobacteria;o__Actinomycetales;f__Mycobacteriaceae;g__Mycobacterium,k__Bacteria;p__Actinobacteria;c__Actinobacteria;o__Actinomycetales;f__Propionibacteriaceae;g__Propionibacterium,k__Bacteria;p__Actinobacteria;c__Actinobacteria;o__Actinomycetales;f__Williamsiaceae;g__Williamsia,k__Bacteria;p__Actinobacteria;c__Actinobacteria;o__Bifidobacteriales;f__Bifidobacteriaceae;g__Bifidobacterium,k__Bacteria;p__Actinobacteria;c__Coriobacteriia;o__Coriobacteriales;f__Coriobacteriaceae;g__,...,k__Bacteria;p__Proteobacteria;c__Gammaproteobacteria;o__Xanthomonadales;f__Xanthomonadaceae;g__Stenotrophomonas,k__Bacteria;p__Synergistetes;c__Synergistia;o__Synergistales;f__Dethiosulfovibrionaceae;g__Pyramidobacter,k__Bacteria;p__Synergistetes;c__Synergistia;o__Synergistales;f__Synergistaceae;g__Cloacibacillus,k__Bacteria;p__Synergistetes;c__Synergistia;o__Synergistales;f__Synergistaceae;g__Synergistes,k__Bacteria;p__TM7;c__TM7-3;o__;f__;g__,k__Bacteria;p__Tenericutes;c__Mollicutes;o__;f__;g__,k__Bacteria;p__Tenericutes;c__Mollicutes;o__Anaeroplasmatales;f__Anaeroplasmataceae;g__,k__Bacteria;p__Tenericutes;c__Mollicutes;o__RF39;f__;g__,k__Bacteria;p__Tenericutes;c__RF3;o__ML615J-28;f__;g__,k__Bacteria;p__Verrucomicrobia;c__Verrucomicrobiae;o__Verrucomicrobiales;f__Verrucomicrobiaceae;g__Akkermansia
SampleID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
sam-570,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0026,0.0,0.0002
sam-477,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0014,0.0,0.0016
JN-MF2-15520,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0002,0.001,...,0.0,0.0,0.0004,0.0,0.0,0.0,0.0,0.001,0.0006,0.0
JN-MF2-15218,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
JN-TMF2-SD0168,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0008,0.0,...,0.0,0.0002,0.0,0.001,0.0,0.0,0.0,0.0004,0.0,0.0


Merge GGMP  & SGMP datasets

In [5]:
genus_table = pd.concat([ggmp_genus_table, sgmp_genus_table], axis=0, sort=False).fillna(0)
print(f'Genus table shape: {genus_table.shape}')
genus_table.head()

Genus table shape: (9010, 361)


Unnamed: 0_level_0,k__Archaea;p__Euryarchaeota;c__Methanobacteria;o__Methanobacteriales;f__Methanobacteriaceae;g__Methanobrevibacter,k__Archaea;p__Euryarchaeota;c__Methanobacteria;o__Methanobacteriales;f__Methanobacteriaceae;g__Methanosphaera,k__Archaea;p__Euryarchaeota;c__Thermoplasmata;o__E2;f__ Methanomassiliicoccaceae ;g__vadinCA11,k__Bacteria;p__Acidobacteria;c__ Chloracidobacteria ;o__RB41;f__Ellin6075;g__,k__Bacteria;p__Actinobacteria;c__Actinobacteria;o__Actinomycetales;f__Actinomycetaceae;g__,k__Bacteria;p__Actinobacteria;c__Actinobacteria;o__Actinomycetales;f__Actinomycetaceae;g__Actinomyces,k__Bacteria;p__Actinobacteria;c__Actinobacteria;o__Actinomycetales;f__Actinomycetaceae;g__Mobiluncus,k__Bacteria;p__Actinobacteria;c__Actinobacteria;o__Actinomycetales;f__Actinomycetaceae;g__Varibaculum,k__Bacteria;p__Actinobacteria;c__Actinobacteria;o__Actinomycetales;f__Brevibacteriaceae;g__Brevibacterium,k__Bacteria;p__Actinobacteria;c__Actinobacteria;o__Actinomycetales;f__Corynebacteriaceae;g__Corynebacterium,...,k__Bacteria;p__Proteobacteria;c__Betaproteobacteria;o__Rhodocyclales;f__Rhodocyclaceae;g__Methyloversatilis,k__Bacteria;p__Proteobacteria;c__Gammaproteobacteria;o__Aeromonadales;f__Aeromonadaceae;g__Aeromonas,k__Bacteria;p__Proteobacteria;c__Gammaproteobacteria;o__Cardiobacteriales;f__Cardiobacteriaceae;g__Cardiobacterium,k__Bacteria;p__Proteobacteria;c__Gammaproteobacteria;o__Enterobacteriales;f__Enterobacteriaceae;g__Enterobacter,k__Bacteria;p__Proteobacteria;c__Gammaproteobacteria;o__Enterobacteriales;f__Enterobacteriaceae;g__Erwinia,k__Bacteria;p__Proteobacteria;c__Gammaproteobacteria;o__Enterobacteriales;f__Enterobacteriaceae;g__Klebsiella,k__Bacteria;p__Proteobacteria;c__Gammaproteobacteria;o__Enterobacteriales;f__Enterobacteriaceae;g__Kluyvera,k__Bacteria;p__Proteobacteria;c__Gammaproteobacteria;o__Xanthomonadales;f__Xanthomonadaceae;g__Achromobacter,k__Bacteria;p__Tenericutes;c__Mollicutes;o__;f__;g__,k__Bacteria;p__Tenericutes;c__Mollicutes;o__Anaeroplasmatales;f__Anaeroplasmataceae;g__
SampleID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
G440606620,0.001,0.0,0.0,0.0,0.0,0.001,0.0,0.0,0.0014,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
G445224053,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0002,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
G440305296,0.0022,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
G445302608,0.0002,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
G441502283,0.0002,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


GGMP is a community-based dataset, which is consist of healthy people and individuals with T2DM or other disease. Only healthy people and individuals with T2DM were included in our analysis. SGMP is a cohort of healthy and T2DM samples, so samples selection is not required.
Selecting samples of health and T2DM from GGMP, combined with 2 libraries (consist of both samples of health and T2DM) from SGMP. 
Summary of sample metadata is as follows.

In [6]:
columns = ['T2DM', 'Health', 'Districts']
samples = pd.concat([ggmp_metadata[columns], sgmp_metadata[columns]], axis=0, sort=False)
print(f'Metadata table: {samples.shape}')
pd.pivot_table(samples.assign(sample_id=ggmp_metadata.index.tolist() + sgmp_metadata.sample_ID.tolist()),
               index='Districts',
               columns='T2DM',
               values='sample_id',
               aggfunc='count')

Metadata table: (3571, 3)


T2DM,False,True
Districts,Unnamed: 1_level_1,Unnamed: 2_level_1
G440104,138,43
G440205,108,50
G440282,163,42
G440305,127,52
G440606,94,21
G440883,220,31
G440981,129,55
G441284,161,29
G441303,138,45
G441424,111,69


### Random subsampling and developing sub-models 

Random subsampling was applied for each sub-models. The sampling pool was merged datasets of GGMP (T2DM and healthy samples) and SGMP. training set were subsampled from both Healthy (100 samples) and T2DM (100 samples) individuals in the sampling pool without replacement. The same for validation set, except for sub-sample size was 50. Test set was subsampled without restraining proportion of health and T2DM samples. 

![method of random subsets](images/random_subsets_of_samples.png)

A function was defined to fit multiple sub-models, save each model, and save contributions of features (SHAP values) for subsequent analyses. XGBoost with generated SHAP values were employed in each sub-model. Null models was developed from same subset with shuffled T2DM or health label.

![xgboost model](images/predictive_model.png)

In [7]:
def fit_model(tax_profile, samples, outpath, 
              disease='T2DM', train_sub_size=100, valid_sub_size=50,
              test_sub_size=400, meta_n=1000, permutation=100, thread=20):
    """
    Fitting base learners of an ensemble.
    :param tax_profile: pd.DataFrame, table of microbiome taxa profile for classifier training;
    :param samples: pd.DataFrame, table of metadata, from which generate training label;
    :param outpath: str, file path of output directory;
    :param disease: str, column name from metadata, used for training label, default "T2DM";
    :param train_sub_size: int, size of random sampling for training data, default 100;
    :param valid_sub_size: int, size of random sampling for validation data, default 50;
    :param test_sub_size: int, size of random sampling for testing data, default 400;
    :param meta_n: int, number of base learners, default 1000;
    :param permutation: int, number of null models corresponded to each base learner, default 100;
    :param thread: int, number of cores for xgboost;
    :return:
    """
    outpath_model = f'{outpath}/model'
    outpath_null = f'{outpath}/null'
    os.makedirs(outpath, exist_ok=True)
    os.makedirs(outpath_model, exist_ok=True)
    os.makedirs(outpath_null, exist_ok=True)
    
    ####################################################################################################
    # setup learner for brewing...
    ####################################################################################################
    learner._default_xgb_params.update({'nthread': thread})
    xgb_learner = learner.XGB(params=learner._default_xgb_params,
                              num_rounds=500,
                              early_stopping_rounds=50)
    
    ####################################################################################################
    # iterate through data buckets, learn and extract signatures
    ####################################################################################################
    # results of fitting models
    bst = dict()
    results = []
    results_null = []

    samples_health = samples[samples['Health']]
    samples_disease = samples[samples[disease]]
    sub_size = train_sub_size + valid_sub_size
    
    for _iter in range(1, meta_n + 1):
        print(f'subsample-{_iter}')
        # prepare random sampling data for base learner
        samples_health_sub = samples_health.sample(sub_size, replace=False)
        train_health_sub = samples_health_sub.sample(train_sub_size, replace=False)
        valid_health_sub = samples_health_sub.loc[samples_health_sub.index.difference(train_health_sub.index)]

        samples_disease_sub = samples_disease.sample(sub_size, replace=False)
        train_disease_sub = samples_disease_sub.sample(train_sub_size, replace=False)
        valid_disease_sub = samples_disease_sub.loc[samples_disease_sub.index.difference(train_disease_sub.index)]

        train_sub = pd.concat([train_health_sub, train_disease_sub])
        valid_sub = pd.concat([valid_health_sub, valid_disease_sub])
        test_sub = samples.loc[samples.index.difference(list(train_sub.index) + list(valid_sub.index))]
        test_sub = test_sub.sample(test_sub_size, replace=False)
        # shuffle the orders of case/control samples !!!
        # can be done by the learner with a default fit (shuffle=True)

        train_X = tax_profile.loc[train_sub.index]
        train_y = train_sub[disease].astype(int)
        valid_X = tax_profile.loc[valid_sub.index]
        valid_y = valid_sub[disease].astype(int)
        test_X = tax_profile.loc[test_sub.index]
        test_y = test_sub[disease].astype(int)
        
        bucket = brewer.Bucket(train_data=(train_X, train_y),
                               valid_data=(valid_X, valid_y),
                               test_data=(test_X, test_y),
                               meta_data={'District': f'subsample-{_iter}'})
        
        _model = bucket.fit(learner=xgb_learner, shuffle=True)
        
        bst.setdefault(_iter, _model['model'])
        result = pd.DataFrame([bucket.get_best_result(_model, with_meta=True)])
        results.append(result)
        signature_train = bucket.transform(_model, data='train', with_meta=True)
        signature_valid = bucket.transform(_model, data='valid', with_meta=True)
        signature_test = bucket.transform(_model, data='test', with_meta=True)
        
        # write model out
        list(map(lambda x: pickle.dump(bst[x], file=open(f'{outpath_model}/subsample-{_iter}.pkl', 'wb')),
                 bst.keys()))
        # write signature out
        ofile = f'{outpath}/subsample-{_iter}_shap_train.tsv'
        signature_train.to_csv(ofile, sep="\t", index_label='SampleID')
        signature_valid.to_csv(ofile.replace('train', 'valid'), sep="\t", index_label='SampleID')
        signature_test.to_csv(ofile.replace('train', 'test'), sep="\t", index_label='SampleID')
        
        ####################################################################################################
        # shuffle for null models
        ####################################################################################################
        for _shuffle in tqdm(range(1, permutation + 1)):
            # shuffle training data
            _train_y = train_y.copy()
            np.random.shuffle(_train_y)
            bucket_null = brewer.Bucket(train_data=(train_X, _train_y),
                                        valid_data=(valid_X, valid_y),
                                        test_data=(test_X, test_y),
                                        meta_data={'District': f'subsample-{_iter}',
                                                   'shuffle': _shuffle})
            _model_null = bucket_null.fit(learner=xgb_learner)
            result_null = pd.DataFrame([bucket_null.get_best_result(_model_null, with_meta=True)])
            results_null.append(result_null)
            signature_train_null = bucket_null.transform(_model_null, data='train', with_meta=True)
            signature_valid_null = bucket_null.transform(_model_null, data='valid', with_meta=True)
            signature_test_null = bucket_null.transform(_model_null, data='test', with_meta=True)

            # write null signature out
            ofile_null = f'{outpath_null}/subsample-{_iter}_p_{_shuffle}_shap_train.tsv'
            # signature_train_null.to_csv(ofile_null, sep="\t", index_label='SampleID')
            signature_valid_null.to_csv(ofile_null.replace('train', 'valid'), sep="\t", index_label="SampleID")
            signature_test_null.to_csv(ofile_null.replace('train', 'test'), sep="\t", index_label="SampleID")

    ####################################################################################################
    # calculate signatures and prediction results from the models
    ####################################################################################################
    results = pd.concat(results, axis=0)
    results.to_csv(os.path.join(outpath, 'xgb_auc_results.tsv'),
                   sep="\t",
                   index=False)
    results_null = pd.concat(results_null, axis=0)
    results_null.to_csv(os.path.join(outpath, 'xgb_auc_results_null.tsv'),
                        sep="\t",
                        index=False)

Call the function on merged dataset with 1000 random subsampling datasets (to save time and space, this demonstration perform 100 random subsampling), you can also call it on GGMP dataset or SGMP dataset seperately.

In [8]:
fit_model(tax_profile=genus_table, samples=samples, outpath='output/ensemble_of_predictive_models', 
          disease='T2DM', train_sub_size=100, valid_sub_size=50,
          test_sub_size=400, meta_n=10, permutation=100, thread=20)
genus_table.loc[samples.index].to_csv('output/ensemble_of_predictive_models/data_X.tsv',
                                      sep="\t",
                                      index_label='SampleID')
samples['T2DM'].to_csv('output/ensemble_of_predictive_models/data_y.tsv', 
                       sep="\t",
                       index_label="SampleID")

filter out sub-models with poor performance of AUC comparing to null models.

In [11]:
# load AUC results of base learner and null learners
model_results = pd.read_table('output/ensemble_of_predictive_models/xgb_auc_results.tsv')
null_results = pd.read_table('output/ensemble_of_predictive_models/xgb_auc_results_null.tsv')

model_results.set_index('District', inplace=True)
null_results['District'] = null_results.District.astype('category')
null_results['District'].cat.set_categories(model_results.index.values, inplace=True)

In [13]:
# one sample wilcoxon test to evaluate whether base learner has good performance or poor performance
cutoff_p = 0.05

_wilcox_valid = {_idx: wilcoxon(_da - model_results.loc[_idx, 'valid_auc'], alternative='less')
                 for _idx, _da in null_results.groupby('District').valid_auc}
_wilcox_test = {_idx: wilcoxon(_da - model_results.loc[_idx, 'test_auc'], alternative='less')
                for _idx, _da in null_results.groupby('District').test_auc}
_model_valid = dict(list(filter(lambda x: x[1][1] < cutoff_p, _wilcox_valid.items())))
_model_test = dict(list(filter(lambda x: x[1][1] < cutoff_p, _wilcox_test.items())))

# retain base learners with both good performance on valid data and test data
_model = list(set(_model_valid.keys()).intersection(set(_model_test.keys())))
print(f'retained models: {len(_model)}')

retained models: 1000


in this demonstration, all sub-models have satisfied predictive probability.

In [14]:
# save model name for follow analysis
pickle.dump(_model, file=open('output/ensemble_of_predictive_models/model_retained_idx.pkl', 'wb'))