# Import packages and functions

In [1]:
import sys
# force the notebook to look for files in the upper level directory
sys.path.insert(1, '../')

In [2]:
import time
import pprint
import pandas as pd
import xgboost as xgb
from model.model_building import load_data, tune_hyperparam, eval_xgb_model

# Set up constants

In [3]:
PROCESSED_PATH = "../data/processed/IMT_Classification_Dataset_Processed_v3.xlsx"
RANDOM_SEED = 31415926

# Read in the processed data

In [4]:
df = pd.read_excel(PROCESSED_PATH)
df

Unnamed: 0,Compound,Label,struct_file_path,range MendeleevNumber,avg_dev MendeleevNumber,range AtomicWeight,mean AtomicWeight,avg_dev AtomicWeight,range MeltingT,mean MeltingT,...,avg_mx_dists,max_xx_dists,min_xx_dists,avg_xx_dists,v_x,iv,iv_p1,est_hubbard_u,est_charge_trans,volumn_per_sites
0,SrRuO3,0,../data/Structures/Metals/SrRuO3_75561.cif,79,26.400000,85.07060,47.337640,37.605888,2552.20,764.280000,...,1.983938,3.579973,2.760023,2.947568,23.738172,45.000000,59.00000,10.330721,8.527722,12.089967
1,OsO2,0,../data/Structures/Metals/OsO2_15070.cif,30,13.333333,174.23060,74.076267,77.435822,3251.20,1138.533333,...,1.983671,2.805520,2.442651,2.684563,25.269881,41.000000,55.00000,9.953087,13.687053,10.747095
2,SrLaCuO4,0,../data/Structures/Metals/LaSrCuO4_10252.cif,79,29.346939,122.90607,50.581296,39.522167,1302.97,545.710000,...,2.062565,3.421841,2.662257,2.966018,23.905465,36.841000,57.38000,18.524833,-7.676852,13.436088
3,SrCrO3,0,../data/Structures/Metals/SrCrO3_245834.cif,79,28.080000,71.62060,37.522860,25.828152,2125.20,678.880000,...,1.909900,2.701006,2.701006,2.701006,24.337085,49.160000,69.46000,16.530261,6.586603,11.146843
4,CrO2,0,../data/Structures/Metals/CrO2_202836.cif,38,16.888889,35.99670,27.998300,15.998533,2125.20,763.200000,...,1.901255,2.688819,2.471404,2.616347,26.561430,49.160000,69.46000,16.126339,8.219417,9.504907
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
223,YbFe4(CuO4)3,2,../data/Structures/MIT_materials/HighT/YbCu3Fe...,48,14.700000,157.05460,38.953240,27.544608,1756.20,653.345500,...,2.365615,2.924013,2.552849,2.744642,24.360992,36.841000,57.38000,16.597335,7.361181,9.750947
224,NiSeS,2,../data/Structures/MIT_materials/HighT/NiS(2-x...,28,12.222222,46.89500,56.572800,16.338533,1339.64,870.120000,...,2.424039,3.287898,2.376963,3.060164,9.778249,18.168838,35.18700,13.516153,8.891048,16.385810
225,Ti2O3,2,../data/Structures/MIT_materials/HighT/Ti2O3_H...,44,21.120000,31.86760,28.746440,15.296448,1886.20,809.280000,...,2.048209,2.900002,2.771288,2.844355,24.648770,27.491710,43.26717,11.068473,16.169806,10.490597
226,Ca1.2La2.8Mn4O12,2,../data/Structures/MIT_materials/HighT/La0.7Ca...,80,26.592000,122.90607,42.438695,32.010437,1464.20,570.600000,...,1.962424,3.516261,2.747250,2.906344,22.934073,38.930600,57.57000,14.915598,-7.594869,11.576092


# Tune the hyperparameters

In [5]:
best_params = {choice: tune_hyperparam(df, choice, RANDOM_SEED)
               for choice in ["Metal", "Insulator", "MIT"]}
# pause the execution for 1 second to ensure proper printout format
time.sleep(1)
pprint.pprint(best_params)


Tuning for Metal vs. non-Metal binary classifier
Fitting 5 folds for each of 432 candidates, totalling 2160 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done  34 tasks      | elapsed:    1.7s
[Parallel(n_jobs=-1)]: Done 304 tasks      | elapsed:    6.2s
[Parallel(n_jobs=-1)]: Done 1068 tasks      | elapsed:   15.7s
[Parallel(n_jobs=-1)]: Done 2145 out of 2160 | elapsed:   27.2s remaining:    0.2s
[Parallel(n_jobs=-1)]: Done 2160 out of 2160 | elapsed:   27.3s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.



Tuning for Insulator vs. non-Insulator binary classifier
Fitting 5 folds for each of 432 candidates, totalling 2160 fits


[Parallel(n_jobs=-1)]: Done  56 tasks      | elapsed:    0.8s
[Parallel(n_jobs=-1)]: Done 656 tasks      | elapsed:    9.1s
[Parallel(n_jobs=-1)]: Done 1656 tasks      | elapsed:   23.1s
[Parallel(n_jobs=-1)]: Done 2160 out of 2160 | elapsed:   28.9s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.



Tuning for MIT vs. non-MIT binary classifier
Fitting 5 folds for each of 432 candidates, totalling 2160 fits


[Parallel(n_jobs=-1)]: Done  56 tasks      | elapsed:    0.8s
[Parallel(n_jobs=-1)]: Done 656 tasks      | elapsed:    9.2s
[Parallel(n_jobs=-1)]: Done 1656 tasks      | elapsed:   24.1s
[Parallel(n_jobs=-1)]: Done 2160 out of 2160 | elapsed:   30.2s finished


{'Insulator': {'base_score': 0.3,
               'learning_rate': 1.0,
               'max_depth': 2,
               'n_estimators': 80,
               'scale_pos_weight': 1.6823529411764706},
 'MIT': {'base_score': 0.7,
         'learning_rate': 0.1,
         'max_depth': 2,
         'n_estimators': 150,
         'scale_pos_weight': 1.0727272727272728},
 'Metal': {'base_score': 0.7,
           'learning_rate': 1.0,
           'max_depth': 3,
           'n_estimators': 20,
           'scale_pos_weight': 5.909090909090909}}


# Evaluate the tuned model with 10-fold stratified cv

In [6]:
for choice in ["Metal", "Insulator", "MIT"]:
    eval_xgb_model(df, choice, best_params, RANDOM_SEED, eval_method="robust")


Evaluating the Metal vs. non-Metal binary classifier
For 10 folds
Median precision_macro: 0.77 w/ IQR: 0.13
Median recall_macro: 0.70 w/ IQR: 0.17
Median roc_auc: 0.92 w/ IQR: 0.15
Median f1_macro: 0.73 w/ IQR: 0.14

Evaluating the Insulator vs. non-Insulator binary classifier
For 10 folds
Median precision_macro: 0.88 w/ IQR: 0.06
Median recall_macro: 0.87 w/ IQR: 0.08
Median roc_auc: 0.94 w/ IQR: 0.07
Median f1_macro: 0.86 w/ IQR: 0.06

Evaluating the MIT vs. non-MIT binary classifier
For 10 folds
Median precision_macro: 0.92 w/ IQR: 0.06
Median recall_macro: 0.91 w/ IQR: 0.07
Median roc_auc: 0.96 w/ IQR: 0.06
Median f1_macro: 0.91 w/ IQR: 0.08


# Train on the entire dataset and save the models

In [7]:
for choice in ["Metal", "Insulator", "MIT"]:
    X, y = load_data(df, choice)
    xgb_tuned_model = xgb.XGBClassifier(**best_params[choice])
    xgb_tuned_model.fit(X, y)
    xgb_tuned_model.save_model("../model/saved_models/new_models/{}.model".format(choice.lower()))