In [23]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('white')
sns.color_palette("muted")
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.svm import LinearSVC
from sklearn.svm import SVC
from sklearn.metrics import roc_curve, auc
from sklearn.pipeline import make_pipeline
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 
warnings.filterwarnings('ignore')
import pickle as pkl
from collections import defaultdict
#%matplotlib inline

In [24]:
# number of variables to be excluded for logistic regression
# these must be the first two variables in the feature matrix

# global inputs
logit_num = 4
model_names = ['l1','l2']

'''
Model Building
'''

# Pipeline dictionary
pipelines = {
    'l1' : make_pipeline(StandardScaler(), LogisticRegression( penalty = 'l1', random_state=125)),
    'l2' : make_pipeline(StandardScaler(), LogisticRegression( penalty = 'l2', random_state=125)),
    'rf' : make_pipeline(StandardScaler(), RandomForestClassifier(random_state=125)),
    'gb' : make_pipeline(StandardScaler(), GradientBoostingClassifier(random_state=125)),
    'linsvc' : make_pipeline(StandardScaler(), SVC(random_state=125,probability=True)),
    'rbfsvc' : make_pipeline(StandardScaler(), SVC(random_state=125,probability=True))
}

# Logistic Regression hyperparameters
l1_hyperparameters = {
    'logisticregression__C' : np.linspace(1e-2, 1e1, 500)
}

l2_hyperparameters = {
    'logisticregression__C' : np.linspace(1e-2, 1e1, 500)
}

# Random Forest hyperparameters
rf_hyperparameters = {
    'randomforestclassifier__n_estimators': [100, 300, 500],
    'randomforestclassifier__max_features': ['auto', 'sqrt', 0.33],
    'randomforestclassifier__max_depth': [1, 2, 3, 4, 5]
}

# Boosted Tree hyperparameters
gb_hyperparameters = {
    'gradientboostingclassifier__n_estimators': [100, 300, 500],
    'gradientboostingclassifier__learning_rate': [0.01, 0.1, 0.5, 1],
    'gradientboostingclassifier__max_depth': [1, 2, 3, 4, 5]
}

linsvc_hyperparameters = {
    'svc__C' : [1e-5, 1e-3, 1e-1, 1e1],
    'svc__kernel' : ['linear']
}

rbfsvc_hyperparameters = {
    'svc__C': [1e-5, 1e-3, 1e-1, 1e1],
    'svc__gamma' : [1e-5, 1e-3, 1e-1, 1e1],
    'svc__kernel' : ['rbf']
}
# Create hyperparameters dictionary
hyperparameters = {
    'l1' : l1_hyperparameters, 
    'l2' : l2_hyperparameters,
    'rf' : rf_hyperparameters,
    'gb' : gb_hyperparameters,
    'linsvc' : linsvc_hyperparameters,
    'rbfsvc' : rbfsvc_hyperparameters
}
# Create data pointing dictionary
datapointers = {
    'l1' : 'logistic',
    'l2' : 'logistic',
    'rf' : 'not_logistic',
    'gb' : 'not_logistic',
    'linsvc' : 'not logistic',
    'rbfsvc' : 'not logistic'
}

def model_scoring_auc(X_in, y_in, model, datapointer):
    if datapointer == 'logistic':
        pred = model.predict_proba(X_in[:,logit_num:])
    else:
        pred = model.predict_proba(X_in)
    # Get just the prediction for the positive class (1)
    pred = [p[1] for p in pred]
    # Calculate ROC curve
    fpr, tpr, thresholds = roc_curve(y_in, pred)
    # Calculate AUROC
    return auc(fpr, tpr)


def model_fitting(X, y, logit_num, model_names, pipelines, hyperparameters, datapointers, randstate):
    # Create empty dictionary called fitted_models
    fitted_models = {}
    fitted_scores = {}
    
    # split data for CV testing
    
    #this works:
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=randstate,stratify=X[:,9])




    # Loop through model pipelines, tuning each one and saving it to fitted_models
    for name in model_names:
        # Create cross-validation object from pipeline and hyperparameters
        model = GridSearchCV(pipelines[name], hyperparameters[name], scoring = 'neg_log_loss', cv=10, refit=True)

        # Fit model on X_train, y_train
        if datapointers[name] == 'logistic':
            model.fit(X_train[:,logit_num:], y_train)  
        else:
            model.fit(X_train, y_train)
        # Store model in fitted_models[name] 
        fitted_models[name] = model
        
        # store scores in fitted_scores[name]
        train_score = model_scoring_auc(X_train, y_train, model, datapointers[name])
        test_score = model_scoring_auc(X_test, y_test, model, datapointers[name])
        fitted_scores[name] = [train_score,test_score]
            
    return fitted_models, fitted_scores

        


In [25]:

df = pd.read_csv('../data/fitting_data.csv')
df.drop(['Unnamed: 0'],axis=1,inplace=True)
df = df.fillna(0)
y = df.pop('music').values
stratification_columns = df.pop('stratification_column').values
X = df.values


models_iterate = {}
scores_iterate = {}
for i in range(1000):
    models_iterate[i], scores_iterate[i] = model_fitting(X,y,logit_num,model_names,pipelines,hyperparameters,datapointers,i+10000)
    if i%10:
        print(scores_iterate[i])





{'l1': [0.809831570701136, 0.6897435897435898], 'l2': [0.8093093093093092, 0.6917378917378917]}
{'l1': [0.810349305285028, 0.7097750865051903], 'l2': [0.8085962861965977, 0.7083333333333333]}
{'l1': [0.8055555555555554, 0.7102102102102102], 'l2': [0.8046548285401789, 0.7141141141141142]}
{'l1': [0.8030483870967743, 0.7387706855791962], 'l2': [0.8011451612903226, 0.7375886524822695]}
{'l1': [0.7845653992150807, 0.7816816816816817], 'l2': [0.7829087048832273, 0.7777777777777778]}
{'l1': [0.775370654396728, 0.7996794871794872], 'l2': [0.7755304192229038, 0.8086538461538462]}
{'l1': [0.79622776262823, 0.7465397923875432], 'l2': [0.7946045968056096, 0.745674740484429]}
{'l1': [0.8097510914185183, 0.7166952544311036], 'l2': [0.807535674724702, 0.7184105202973128]}
{'l1': [0.8030641299424061, 0.7198250728862974], 'l2': [0.7994725943182555, 0.7355685131195335]}
{'l1': [0.8071354369562395, 0.7151095732410611], 'l2': [0.806307622386703, 0.7217416378316032]}
{'l1': [0.8077262044653349, 0.72279202

{'l1': [0.7968169674855021, 0.7384219554030875], 'l2': [0.7952531439369258, 0.7478559176672386]}
{'l1': [0.7801828668807761, 0.7990362811791383], 'l2': [0.7796093596381988, 0.8001700680272108]}
{'l1': [0.7991301231511044, 0.7402801600914809], 'l2': [0.7978595165178862, 0.7359919954259577]}
{'l1': [0.7989926739926739, 0.744034090909091], 'l2': [0.7975536368393511, 0.7474431818181818]}
{'l1': [0.779508833236265, 0.7970845481049563], 'l2': [0.7772115446838801, 0.796793002915452]}
{'l1': [0.8087857784623284, 0.704406364749082], 'l2': [0.8073418046463869, 0.7114443084455324]}
{'l1': [0.8076847290640394, 0.7235427277872101], 'l2': [0.8045977011494252, 0.7297679683078665]}
{'l1': [0.7938536585365854, 0.7677956371986224], 'l2': [0.7928455284552844, 0.7769804822043628]}
{'l1': [0.7905269639616281, 0.7779710144927536], 'l2': [0.7890037593984962, 0.769855072463768]}
{'l1': [0.8053064516129032, 0.7192671394799053], 'l2': [0.8051129032258065, 0.7213356973995272]}
{'l1': [0.8023530182866881, 0.69812

{'l1': [0.770546617561543, 0.8424295774647887], 'l2': [0.7690605414486013, 0.8365610328638496]}
{'l1': [0.7787267580261213, 0.8036036036036036], 'l2': [0.7785176606832658, 0.8063063063063063]}
{'l1': [0.7838500784929356, 0.8045454545454545], 'l2': [0.7841444270015698, 0.8042613636363636]}
{'l1': [0.7720864661654137, 0.8295652173913043], 'l2': [0.7708873476795437, 0.8359420289855073]}
{'l1': [0.798749836793315, 0.7290598290598291], 'l2': [0.7966934325629977, 0.7336182336182336]}
{'l1': [0.7876548469554958, 0.7876984126984126], 'l2': [0.7869666382644033, 0.7820294784580499]}
{'l1': [0.8055070212903643, 0.726530612244898], 'l2': [0.8040833495114216, 0.7276967930029156]}
{'l1': [0.7857535753575358, 0.7769679300291545], 'l2': [0.7856888630039475, 0.7839650145772595]}
{'l1': [0.7753440589261485, 0.8166079812206574], 'l2': [0.7747302448794986, 0.818075117370892]}
{'l1': [0.7679025216654721, 0.8250428816466553], 'l2': [0.7687007232683912, 0.8201829616923957]}
{'l1': [0.801778989741592, 0.71237

{'l1': [0.7999935073367095, 0.7280853517877739], 'l2': [0.799311777691209, 0.7309688581314878]}
{'l1': [0.7838728435743362, 0.783744131455399], 'l2': [0.782096013439297, 0.7943075117370892]}
{'l1': [0.7989020655502704, 0.7535734705546026], 'l2': [0.7984948198344954, 0.757004002287021]}
{'l1': [0.7933087426389698, 0.7536443148688047], 'l2': [0.792499838219116, 0.7574344023323615]}
{'l1': [0.779479674796748, 0.7990815154994259], 'l2': [0.7766341463414634, 0.8053960964408726]}
{'l1': [0.7928377592556698, 0.7608568075117371], 'l2': [0.7921270272016541, 0.7634976525821596]}
{'l1': [0.7864977674238012, 0.772594752186589], 'l2': [0.7856241506503592, 0.7673469387755103]}
{'l1': [0.7946307423919363, 0.7573356807511737], 'l2': [0.7931123602765393, 0.7632042253521126]}
{'l1': [0.7926459067763415, 0.764102564102564], 'l2': [0.7916993080036558, 0.7646723646723647]}
{'l1': [0.7916693976535361, 0.7712585034013605], 'l2': [0.7900144196106704, 0.7743764172335601]}
{'l1': [0.7783816425120772, 0.79344729

{'l1': [0.7910829478073891, 0.7632933104631218], 'l2': [0.7897634716882778, 0.7632933104631218]}
{'l1': [0.7987880367498534, 0.7598627787307033], 'l2': [0.7980712842900893, 0.7578616352201258]}
{'l1': [0.8001847290640395, 0.7437681159420291], 'l2': [0.7978837179154785, 0.7510144927536232]}
{'l1': [0.7866494743106527, 0.7816384180790961], 'l2': [0.7854922965020169, 0.7742937853107346]}
{'l1': [0.7891182963251526, 0.7546136101499423], 'l2': [0.786845864173484, 0.7560553633217992]}
{'l1': [0.7772580645161291, 0.819887706855792], 'l2': [0.7755000000000001, 0.8277186761229315]}
{'l1': [0.7691048742546022, 0.8162318840579711], 'l2': [0.7699474980554837, 0.8252173913043478]}
{'l1': [0.8009463313456056, 0.7285507246376811], 'l2': [0.8009139227378792, 0.7297101449275362]}
{'l1': [0.8001726721834886, 0.7367066895368782], 'l2': [0.795709259138594, 0.7452830188679245]}
{'l1': [0.7932258064516129, 0.7491134751773051], 'l2': [0.7918709677419354, 0.7624113475177305]}
{'l1': [0.7652158413274566, 0.844

{'l1': [0.787366318773354, 0.7899047051816557], 'l2': [0.7872535755701585, 0.7890113162596784]}
{'l1': [0.7869134042001555, 0.7727536231884057], 'l2': [0.7858115115374644, 0.7779710144927537]}
{'l1': [0.813479674796748, 0.7069460390355913], 'l2': [0.8121463414634146, 0.7069460390355913]}
{'l1': [0.7795238095238095, 0.7829654782116582], 'l2': [0.7768144499178983, 0.7976796830786643]}
{'l1': [0.7637736582836401, 0.8015942028985508], 'l2': [0.7642111744879441, 0.8104347826086956]}
{'l1': [0.7902220490845345, 0.7745098039215687], 'l2': [0.7883067134138424, 0.7739331026528259]}
{'l1': [0.8033003192806412, 0.7358490566037736], 'l2': [0.8000749332117026, 0.7375643224699828]}
{'l1': [0.7767658963311137, 0.7880341880341881], 'l2': [0.7746605300953127, 0.7988603988603988]}
{'l1': [0.8041138211382115, 0.7408151549942595], 'l2': [0.8029430894308943, 0.7431113662456946]}
{'l1': [0.8030946406742585, 0.7456456456456456], 'l2': [0.801421861931416, 0.7492492492492494]}
{'l1': [0.7963925954997384, 0.690

{'l1': [0.7843946159321815, 0.7623906705539358], 'l2': [0.784297547401799, 0.7839650145772595]}
{'l1': [0.7705088942464325, 0.8194682675814752], 'l2': [0.7705740535609567, 0.8197541452258433]}
{'l1': [0.7806882875967505, 0.7920886075949366], 'l2': [0.7780336467728524, 0.8006329113924051]}
{'l1': [0.8061453058044411, 0.6965109573241062], 'l2': [0.8061777691208935, 0.7112168396770473]}
{'l1': [0.8082419354838709, 0.7307919621749408], 'l2': [0.8062419354838709, 0.7304964539007092]}
{'l1': [0.7985910055973748, 0.7394894894894894], 'l2': [0.7984944991314418, 0.7475975975975977]}
{'l1': [0.7903912957986498, 0.7396541950113378], 'l2': [0.789195123549846, 0.750141723356009]}
{'l1': [0.7941436717572261, 0.7467403628117915], 'l2': [0.7928327980599068, 0.7464569160997732]}
{'l1': [0.7971545242416386, 0.7501449275362319], 'l2': [0.7962470832253048, 0.7489855072463769]}
{'l1': [0.770417393551722, 0.8129401408450704], 'l2': [0.7681236673773988, 0.8108861502347418]}
{'l1': [0.7928780487804878, 0.7680

{'l1': [0.7982813206693804, 0.7362089201877934], 'l2': [0.7969244685662596, 0.7420774647887324]}
{'l1': [0.7876672856502529, 0.8070669168230143], 'l2': [0.7850419414740347, 0.8055034396497811]}
{'l1': [0.780851614573222, 0.7967930029154519], 'l2': [0.7797191483854269, 0.7988338192419825]}
{'l1': [0.7812520325203252, 0.7987944890929966], 'l2': [0.7794634146341463, 0.8010907003444316]}
{'l1': [0.8014540307533938, 0.727027027027027], 'l2': [0.7992665508589076, 0.7303303303303302]}
{'l1': [0.8136937467294612, 0.7056818181818182], 'l2': [0.8135956305599162, 0.7090909090909091]}
{'l1': [0.7830194784184301, 0.7731778425655977], 'l2': [0.7824694234129295, 0.7749271137026239]}
{'l1': [0.7961884423736492, 0.751603498542274], 'l2': [0.7959619491360901, 0.7533527696793002]}
{'l1': [0.8137425207488902, 0.7040540540540542], 'l2': [0.8124557678697806, 0.6947447447447448]}
{'l1': [0.7955447154471544, 0.7692307692307693], 'l2': [0.7920325203252033, 0.7695177956371986]}
{'l1': [0.7844341457091003, 0.786

{'l1': [0.7747370471367355, 0.8056516724336794], 'l2': [0.7726593948837814, 0.8085351787773933]}
{'l1': [0.7844065040650405, 0.7850172215843857], 'l2': [0.7837723577235772, 0.7841561423650976]}
{'l1': [0.7761840149899852, 0.7929870892018779], 'l2': [0.774536408864767, 0.7947476525821596]}
{'l1': [0.7769770159719517, 0.8203575547866206], 'l2': [0.7763602129593559, 0.8243944636678201]}
{'l1': [0.7694379161770466, 0.8284900284900285], 'l2': [0.7661574618096358, 0.8518518518518519]}
{'l1': [0.7703743074346089, 0.8396366885050626], 'l2': [0.7695367864965856, 0.8351697438951756]}
{'l1': [0.7933333333333334, 0.7705149971703452], 'l2': [0.7908866995073892, 0.7826825127334466]}
{'l1': [0.7941014154005973, 0.755767012687428], 'l2': [0.7931762108817035, 0.7649942329873126]}
{'l1': [0.7830856334041048, 0.803003003003003], 'l2': [0.7822974972656502, 0.8004504504504505]}
{'l1': [0.7999667774086379, 0.7440577249575552], 'l2': [0.8002657807308969, 0.7457555178268251]}
{'l1': [0.7962364538451494, 0.748

{'l1': [0.7880727762803235, 0.7587209302325582], 'l2': [0.7856340649467335, 0.7584149326805385]}
{'l1': [0.7971921182266009, 0.7572156196943972], 'l2': [0.795254515599343, 0.752405206564799]}
{'l1': [0.7829521225043613, 0.7824237089201878], 'l2': [0.7814337403889642, 0.792106807511737]}
{'l1': [0.7939376369024609, 0.7572960095294818], 'l2': [0.7921337456513335, 0.7602739726027397]}
{'l1': [0.7932348111658456, 0.7724957555178268], 'l2': [0.7911494252873563, 0.7722127900396152]}
{'l1': [0.8032308980801444, 0.732876712328767], 'l2': [0.8020068290168793, 0.7260273972602739]}
{'l1': [0.8096910569105691, 0.7151262916188289], 'l2': [0.8083252032520325, 0.7165614236509759]}
{'l1': [0.7979480737018425, 0.7364502680166766], 'l2': [0.7971749774513593, 0.7370458606313282]}
{'l1': [0.7938548387096774, 0.7461583924349883], 'l2': [0.7940483870967743, 0.7458628841607564]}
{'l1': [0.8191671512567034, 0.664612676056338], 'l2': [0.8192317632616141, 0.664612676056338]}
{'l1': [0.7890804597701151, 0.788341

{'l1': [0.7901970443349754, 0.7604697226938314], 'l2': [0.7875041050903121, 0.7737691001697793]}
{'l1': [0.793274193548387, 0.750886524822695], 'l2': [0.7907903225806452, 0.7621158392434988]}
{'l1': [0.785405872193437, 0.7683544303797468], 'l2': [0.7835188383547623, 0.7772151898734178]}
{'l1': [0.7863036303630363, 0.7889212827988339], 'l2': [0.7835533553355335, 0.79067055393586]}
{'l1': [0.7806568873447854, 0.7966966966966966], 'l2': [0.7765392781316349, 0.7963963963963965]}
{'l1': [0.7874722548635592, 0.7726495726495727], 'l2': [0.7851873612743179, 0.7840455840455841]}
{'l1': [0.7844529311920617, 0.7715099715099716], 'l2': [0.7838327457892674, 0.7780626780626781]}
{'l1': [0.8031290322580644, 0.7182328605200946], 'l2': [0.8016774193548387, 0.717937352245863]}
{'l1': [0.8075473553123157, 0.719671201814059], 'l2': [0.806891918463656, 0.7185374149659863]}
{'l1': [0.781711058884972, 0.7945868945868946], 'l2': [0.7805686120903512, 0.7962962962962963]}
{'l1': [0.7950397980974568, 0.727551020

In [26]:
import pickle as pkl
with open('../data/1000_models.pkl', 'wb') as picklefile:
    pkl.dump(models_iterate, picklefile)
with open('../data/1000_models_scores.pkl', 'wb') as picklefile:
    pkl.dump(scores_iterate, picklefile)