## 3. Model Training

#### A360 MDK interface

In [1]:
a360ai

<A360 AI Interface for project: Cancer Treatment>

In [2]:
# Get default data repo
DATAREPO_LIST = a360ai.list_datarepos()
DATAREPO = DATAREPO_LIST['name'][0]
DATAREPO

'Cancer Treatment'

In [3]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from matplotlib import pyplot

In [4]:
import warnings
warnings.filterwarnings('ignore')

### 1. Set Default Data Repo and load data

In [5]:
a360ai.set_default_datarepo(DATAREPO)

In [6]:
a360ai.list_datasets()

Unnamed: 0,base_name,extension,size
1,X.csv,csv,33663549.0
2,X.parquet,parquet,4275639.0
3,X_drift.csv,csv,511237.0
4,X_feature.csv,csv,1736963.0
5,X_test.csv,csv,10048776.0
6,X_test.parquet,parquet,3424298.0
7,X_test_f.csv,csv,517542.0
8,rf_model.pkl,pkl,14530126.0
9,y.csv,csv,13290.0
10,y.parquet,parquet,3312.0


#### Load dataset

In [7]:
X = a360ai.load_dataset("X.parquet")
y = a360ai.load_dataset("y.parquet")

#### Train test split

In [8]:
X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.15,random_state=5)

#### Random Forest model

In [9]:
# Random Forest
rf = RandomForestClassifier()
# fit RF to training set
rf.fit(X_train, y_train)

In [10]:
y_pred = rf.predict(X_test)

In [11]:
accuracy = accuracy_score(y_test, y_pred)
accuracy 

0.6733466933867736

In [12]:
matrix = confusion_matrix(y_test, y_pred)
print(matrix)

[[ 60   0   0  18   6   2   3   0   0]
 [  1  33   0   1   0   0  33   0   0]
 [  2   0   2   2   1   0   4   0   0]
 [ 16   0   1  80   0   0  11   0   0]
 [  9   1   3   8  13   0   4   0   0]
 [  2   1   0   1   1  27   8   0   0]
 [  3   9   3   1   1   1 120   0   0]
 [  1   1   0   0   0   0   1   0   0]
 [  1   0   0   2   0   0   0   0   1]]


#### Feature Importance

In [13]:
# Top 25 features
features = pd.Series(rf.feature_importances_, index=X_train.columns)
features.sort_values(ascending=False, inplace=True)
print(features.head(25))

component №2     0.023291
component №3     0.016754
component №18    0.013818
component №4     0.012896
component №16    0.012890
component №19    0.012763
component №25    0.012756
component №7     0.012621
component №14    0.012589
component №12    0.012456
component №35    0.012402
component №11    0.012398
component №5     0.012339
component №9     0.012147
component №13    0.012017
component №22    0.011938
component №33    0.011684
component №27    0.011683
component №1     0.011673
component №6     0.011664
component №21    0.011654
component №34    0.011606
component №23    0.011525
component №36    0.011434
component №29    0.011270
dtype: float64


In [14]:
feature_list = []
for i in range(0, 25):
    feature_list.append(features.index[i])

In [15]:
X_feature = X[feature_list].copy()

In [16]:
X_feature

Unnamed: 0_level_0,component №2,component №3,component №18,component №4,component №16,component №19,component №25,component №7,component №14,component №12,...,component №22,component №33,component №27,component №1,component №6,component №21,component №34,component №23,component №36,component №29
ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,-0.052967,-0.009110,-0.019615,-0.048373,0.053959,-0.053679,-0.050326,-0.042512,0.028114,-0.039624,...,0.031583,0.015592,0.171112,0.798709,0.018489,-0.014169,-0.002331,-0.033686,0.015169,-0.157155
1,-0.107192,-0.067519,0.043535,0.096123,-0.013960,-0.006857,-0.003936,0.065608,0.017955,-0.019943,...,0.032290,0.023506,0.012060,0.922718,-0.048336,0.027495,-0.006764,-0.032788,0.018553,-0.041335
2,-0.107192,-0.067519,0.043535,0.096123,-0.013960,-0.006857,-0.003936,0.065608,0.017955,-0.019943,...,0.032290,0.023506,0.012060,0.922718,-0.048336,0.027495,-0.006764,-0.032788,0.018553,-0.041335
3,-0.078922,-0.050858,0.014758,0.058433,-0.035953,0.014738,-0.045953,-0.076216,0.016455,0.132207,...,0.050667,-0.009628,0.020542,0.919961,0.022259,0.014230,0.017680,-0.039442,-0.018075,0.030128
4,-0.027762,0.017160,0.015645,0.046482,-0.004018,0.044382,-0.033890,0.031766,-0.027954,-0.028391,...,-0.021795,0.031928,0.000205,0.934481,-0.080120,0.050684,-0.034162,-0.082065,0.009030,-0.040416
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3316,-0.056060,-0.011623,-0.000163,-0.006953,0.013618,-0.038551,-0.067528,-0.054964,0.003825,-0.020178,...,0.044246,-0.034576,0.018353,0.930693,0.001847,0.002245,-0.025954,0.011243,-0.047815,0.061524
3317,-0.050790,-0.020150,0.002886,-0.025461,0.048647,-0.059347,-0.060133,-0.061264,0.013372,-0.017241,...,0.039759,-0.053346,-0.008706,0.854699,0.016953,-0.019003,-0.049222,0.046656,-0.035728,0.077412
3318,0.027652,0.049176,0.039897,-0.077737,-0.013794,-0.068692,0.110985,0.006026,-0.053569,-0.049573,...,-0.015082,-0.052938,-0.001354,0.809202,-0.137128,-0.020978,0.103359,0.077186,-0.055269,-0.027130
3319,-0.038513,0.008694,0.004224,-0.079274,-0.104432,0.038788,0.064069,0.032354,-0.073494,-0.016747,...,-0.029581,0.028334,0.057448,0.903297,-0.053917,0.069744,0.037401,-0.082039,-0.010206,-0.019661


In [17]:
X_feature.columns=X_feature.columns.str.replace('component №','')

In [18]:
X_feature

Unnamed: 0_level_0,2,3,18,4,16,19,25,7,14,12,...,22,33,27,1,6,21,34,23,36,29
ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,-0.052967,-0.009110,-0.019615,-0.048373,0.053959,-0.053679,-0.050326,-0.042512,0.028114,-0.039624,...,0.031583,0.015592,0.171112,0.798709,0.018489,-0.014169,-0.002331,-0.033686,0.015169,-0.157155
1,-0.107192,-0.067519,0.043535,0.096123,-0.013960,-0.006857,-0.003936,0.065608,0.017955,-0.019943,...,0.032290,0.023506,0.012060,0.922718,-0.048336,0.027495,-0.006764,-0.032788,0.018553,-0.041335
2,-0.107192,-0.067519,0.043535,0.096123,-0.013960,-0.006857,-0.003936,0.065608,0.017955,-0.019943,...,0.032290,0.023506,0.012060,0.922718,-0.048336,0.027495,-0.006764,-0.032788,0.018553,-0.041335
3,-0.078922,-0.050858,0.014758,0.058433,-0.035953,0.014738,-0.045953,-0.076216,0.016455,0.132207,...,0.050667,-0.009628,0.020542,0.919961,0.022259,0.014230,0.017680,-0.039442,-0.018075,0.030128
4,-0.027762,0.017160,0.015645,0.046482,-0.004018,0.044382,-0.033890,0.031766,-0.027954,-0.028391,...,-0.021795,0.031928,0.000205,0.934481,-0.080120,0.050684,-0.034162,-0.082065,0.009030,-0.040416
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3316,-0.056060,-0.011623,-0.000163,-0.006953,0.013618,-0.038551,-0.067528,-0.054964,0.003825,-0.020178,...,0.044246,-0.034576,0.018353,0.930693,0.001847,0.002245,-0.025954,0.011243,-0.047815,0.061524
3317,-0.050790,-0.020150,0.002886,-0.025461,0.048647,-0.059347,-0.060133,-0.061264,0.013372,-0.017241,...,0.039759,-0.053346,-0.008706,0.854699,0.016953,-0.019003,-0.049222,0.046656,-0.035728,0.077412
3318,0.027652,0.049176,0.039897,-0.077737,-0.013794,-0.068692,0.110985,0.006026,-0.053569,-0.049573,...,-0.015082,-0.052938,-0.001354,0.809202,-0.137128,-0.020978,0.103359,0.077186,-0.055269,-0.027130
3319,-0.038513,0.008694,0.004224,-0.079274,-0.104432,0.038788,0.064069,0.032354,-0.073494,-0.016747,...,-0.029581,0.028334,0.057448,0.903297,-0.053917,0.069744,0.037401,-0.082039,-0.010206,-0.019661


In [19]:
a360ai.write_dataset(X_feature,"X_feature", overwrite=True)

True

### 2. Model training

#### Random Forest Model- top 25 features

In [20]:
X_train, X_test, y_train, y_test = train_test_split(X_feature,y,test_size=0.15,random_state=5)

#### Use a360ai MDK to track model experiments

In [21]:
model = a360ai.get_or_create_model(model_name=f"cancer_pred")

In [22]:
experiment = model.get_or_create_experiment(
    experiment_name = f"cancer_pred_RF",
    model_flavor = "sklearn",
    enable_drift_monitoring=True,
    train_features=X_train,
    train_target=y_train,
    feature_names=list(X.columns),
    data_exploration_file="/home/jovyan/01_exploratory-data-analysis.ipynb",
    data_preparation_file="/home/jovyan/02_data-preprocessing.ipynb",
    model_training_file="/home/jovyan/03_model-training.ipynb",
)


//-- Experiment Loaded --//
Model Name: cancer_pred
Experiment Name: cancer_pred_RF
Final Run Id: None
Model Flavor: sklearn
Input Signature: 2 float, 3 float, 18 float, 4 float, 16 float, 19 float, 25 float, 7 float, 14 float, 12 float, 35 float, 11 float, 5 float, 9 float, 13 float, 22 float, 33 float, 27 float, 1 float, 6 float, 21 float, 34 float, 23 float, 36 float, 29 float
Output Signature: Class float
Data Exploration File: /home/jovyan/01_exploratory-data-analysis.ipynb
Data Preparation File: /home/jovyan/02_data-preprocessing.ipynb
Model Training File: /home/jovyan/03_model-training.ipynb
Drift Monitoring Enabled: True



In [23]:
experiments = model.list_experiments()
experiments

Unnamed: 0,id,bestRun,model_id,experiment_name,best_run_id,model_flavor,input_signature,output_signature,data_exploration_file,data_preparation_file,model_training_file,baseline,train_shape,model_name,updated_at
0,2f9176c3-7691-49f5-8cdd-39739c805bad,,aba91d17-0a8e-4856-bdbd-d026539fdf09,cancer_pred_RF,,sklearn,"2 float, 3 float, 18 float, 4 float, 16 float,...",Class float,/home/jovyan/01_exploratory-data-analysis.ipynb,/home/jovyan/02_data-preprocessing.ipynb,/home/jovyan/03_model-training.ipynb,"{'2': {'mean': 0.0031845134698064415, 'std': 0...",[25],cancer_pred,2022-06-30 03:34


In [24]:
count = 1
for n in [25,50,75,100,125,150]:
    with experiment.run_experiment() as run:
    
        hyperparams = {
            "n_estimators": n,
            "max_depth": 6            
        }
    
        rf = RandomForestClassifier(**hyperparams)
        rf.fit(X_train, y_train)
        train_score = rf.score(X_train, y_train)
        test_score = rf.score(X_test, y_test)

        metrics = {
            "train_score": train_score,
            "test_score": test_score,
        }

        run.log_metrics(metrics)
        run.log_hyperparameters(hyperparams)
        run.log_model(rf)
        
        print("Run " + str(count) + " Complete!")
        count = count + 1 

Run 1 Complete!
Run 2 Complete!
Run 3 Complete!
Run 4 Complete!
Run 5 Complete!
Run 6 Complete!


In [25]:
runs = experiment.list_runs()
runs

Unnamed: 0,id,dataset,artifact_paths,metric_test_score,metric_train_score,hyperparameter_max_depth,hyperparameter_n_estimators,metadata_run_time
0,8f9e9085-c498-4bad-9a79-f3cdbf6b1b59,2d806210-8b38-48b0-b654-678bb7a13e85,[{'id': 'd86dae37-b829-4cd0-9827-a9cb226dee6e'...,0.521042,0.624734,6,25,0.169557
1,8159562c-0a5f-4a05-be48-789ab3ff6a39,2d806210-8b38-48b0-b654-678bb7a13e85,[{'id': '22e0d7ac-c8a4-4d1a-b16a-d50beeb60488'...,0.523046,0.61056,6,50,0.326773
2,ad79dd41-6967-4fd2-ae40-02d68df521df,2d806210-8b38-48b0-b654-678bb7a13e85,[{'id': 'd3bc8731-dbb8-4fd3-8649-176dc4688e3c'...,0.529058,0.631821,6,75,0.505113
3,54ea4459-9b2c-43dc-a065-7c510e962b60,2d806210-8b38-48b0-b654-678bb7a13e85,[{'id': 'b60c5f0d-411d-4f14-96e1-06ee9d6f02ac'...,0.547094,0.621899,6,100,0.665627
4,9e38b92e-774b-427e-bb40-d65973d8c1a2,2d806210-8b38-48b0-b654-678bb7a13e85,[{'id': 'd1498bf8-7809-4ce7-b0bc-3be080071e37'...,0.547094,0.629341,6,125,0.840809
5,376a3f9a-2595-429b-8036-18ffe4e7f3f6,2d806210-8b38-48b0-b654-678bb7a13e85,[{'id': '0a0eb92d-28a4-44c4-8f36-df68d04cf414'...,0.539078,0.621545,6,150,0.977326


In [26]:
runs.sort_values('metric_test_score',ascending=False)

Unnamed: 0,id,dataset,artifact_paths,metric_test_score,metric_train_score,hyperparameter_max_depth,hyperparameter_n_estimators,metadata_run_time
3,54ea4459-9b2c-43dc-a065-7c510e962b60,2d806210-8b38-48b0-b654-678bb7a13e85,[{'id': 'b60c5f0d-411d-4f14-96e1-06ee9d6f02ac'...,0.547094,0.621899,6,100,0.665627
4,9e38b92e-774b-427e-bb40-d65973d8c1a2,2d806210-8b38-48b0-b654-678bb7a13e85,[{'id': 'd1498bf8-7809-4ce7-b0bc-3be080071e37'...,0.547094,0.629341,6,125,0.840809
5,376a3f9a-2595-429b-8036-18ffe4e7f3f6,2d806210-8b38-48b0-b654-678bb7a13e85,[{'id': '0a0eb92d-28a4-44c4-8f36-df68d04cf414'...,0.539078,0.621545,6,150,0.977326
2,ad79dd41-6967-4fd2-ae40-02d68df521df,2d806210-8b38-48b0-b654-678bb7a13e85,[{'id': 'd3bc8731-dbb8-4fd3-8649-176dc4688e3c'...,0.529058,0.631821,6,75,0.505113
1,8159562c-0a5f-4a05-be48-789ab3ff6a39,2d806210-8b38-48b0-b654-678bb7a13e85,[{'id': '22e0d7ac-c8a4-4d1a-b16a-d50beeb60488'...,0.523046,0.61056,6,50,0.326773
0,8f9e9085-c498-4bad-9a79-f3cdbf6b1b59,2d806210-8b38-48b0-b654-678bb7a13e85,[{'id': 'd86dae37-b829-4cd0-9827-a9cb226dee6e'...,0.521042,0.624734,6,25,0.169557


#### Get the best run from model experiments

In [27]:
best_run_id = runs.sort_values("metric_test_score", ascending=False).id.values[0]
best_run_id

'54ea4459-9b2c-43dc-a065-7c510e962b60'

In [28]:
model.set_final_run(experiment, best_run_id)

#### Export data column/ feature names

In [29]:
feature_name = X_feature.columns.values.tolist()

In [30]:
import json
with open('feature_name.json', 'w') as f:
    json.dump(feature_name, f)

#### Prep. test set for top 25 features

In [32]:
X_test_f = X_test[feature_name].copy()

In [33]:
X_test_f

Unnamed: 0_level_0,2,3,18,4,16,19,25,7,14,12,...,22,33,27,1,6,21,34,23,36,29
ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1081,-0.038644,-0.007942,-0.003117,-0.021259,0.040945,-0.034344,-0.032219,-0.046097,0.001527,-0.036388,...,0.029083,0.013886,-0.008478,0.968383,0.008494,-0.010140,-0.023208,0.018020,-0.026078,0.055434
814,-0.047010,-0.032945,0.001002,0.042494,-0.049584,0.013782,-0.002617,-0.033661,0.011207,-0.044795,...,0.049164,0.041273,0.084019,0.925470,-0.004836,0.001740,-0.008649,-0.043752,0.039161,-0.027539
175,-0.146258,-0.196055,0.024924,0.338671,-0.001363,-0.059053,-0.005375,0.264207,-0.010807,0.002508,...,0.008221,0.033980,-0.013623,0.781727,-0.183682,-0.005067,0.011660,0.018533,0.014183,0.010950
1517,-0.117408,-0.124794,0.041319,-0.065290,0.006479,0.013326,-0.080263,0.407026,-0.003694,0.018367,...,0.023630,0.013892,-0.042821,0.669175,0.282862,0.021972,-0.024505,-0.025697,0.012304,-0.010314
1340,-0.018787,0.003293,0.034071,0.012444,0.009850,0.011467,0.011464,-0.047272,0.003134,-0.035320,...,0.035362,0.150737,0.066636,0.824034,0.041278,-0.024611,-0.069028,0.005105,0.053895,-0.102668
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2914,-0.066129,-0.042282,-0.004838,-0.191736,-0.031909,-0.002117,0.054479,0.075838,-0.032985,-0.023819,...,0.006690,0.039938,-0.040117,0.891856,0.024237,-0.006611,-0.049375,-0.013790,-0.018335,0.012907
1938,-0.108526,-0.056962,-0.016646,-0.125269,0.075754,-0.049787,-0.106561,0.013493,0.051718,0.001041,...,0.044867,-0.040556,0.004243,0.880509,0.046299,-0.029294,-0.043435,-0.024885,-0.039994,-0.000940
1155,-0.071574,-0.024697,0.036130,0.029547,-0.066270,0.024862,0.042050,-0.060431,-0.004445,-0.062243,...,0.010360,-0.019928,-0.033784,0.952807,0.054482,0.036576,-0.003301,-0.011613,0.007776,-0.031828
523,0.020749,0.125544,-0.010074,-0.268142,0.004517,0.013216,0.044019,0.083037,-0.020768,0.033174,...,-0.051542,0.022461,-0.004672,0.831405,0.120840,0.020137,0.017678,-0.018940,0.040894,0.003091


In [34]:
a360ai.write_dataset(X_test_f,"X_test_f", overwrite=True)

True