# Tensorflow Decision Forests

Link: https://www.tensorflow.org/decision_forests?hl=pt-br

In [6]:
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd

import keras
import tensorflow_decision_forests as tfdf

import tensorflow_decision_forests as tfdf


from IPython.core.magic import register_line_magic
from IPython.display import Javascript

from imblearn.combine import SMOTEENN

In [7]:
# Check the version of TensorFlow Decision Forests
print("Found TensorFlow Decision Forests v" + tfdf.__version__)

Found TensorFlow Decision Forests v1.2.0


In [8]:
METRICS = [ 
      keras.metrics.TruePositives(name='tp'),
      keras.metrics.FalsePositives(name='fp'),
      keras.metrics.TrueNegatives(name='tn'),
      keras.metrics.FalseNegatives(name='fn'),
      keras.metrics.BinaryAccuracy(name='accuracy'),
      keras.metrics.Precision(name='precision'),
      keras.metrics.Recall(name='recall'),
      keras.metrics.AUC(name='auc')
]

tfdf.keras.get_all_models()

[tensorflow_decision_forests.keras.RandomForestModel,
 tensorflow_decision_forests.keras.GradientBoostedTreesModel,
 tensorflow_decision_forests.keras.CartModel,
 tensorflow_decision_forests.keras.DistributedGradientBoostedTreesModel]

In [9]:
df_ppi_1 = pd.read_csv('data/base_cerevisiae.csv')
df_ppi_2 = pd.read_csv('data/base_elegans.csv')
df_ppi_3 = pd.read_csv('data/base_drosophila.csv')

In [29]:
df = pd.concat([df_ppi_1, df_ppi_2, df_ppi_3], ignore_index=True)
df

Unnamed: 0,Locus,Sequence,Sequence_Length,Aromaticity,Sec_Struct_Helix,Sec_Struct_Turn,Sec_Struct_Sheet,Percent_A,Percent_C,Percent_D,...,Percent_V,Percent_W,Percent_Y,Protein_key,DegreeCentrality,EigenvectorCentrality,BetweennessCentrality,ClosenessCentrality,Clustering,is_essential
0,YPL071C,MSSRFARSNGNPNHIRKRNHSPDPIGIDNYKRKRLIIDLENLSLND...,156,0.096154,0.262821,0.435897,0.301282,0.044872,0.006410,0.128205,...,0.038462,0.032051,0.038462,4932.YPL071C,0.000986,0.000512,3.315435e-06,0.426787,0.266667,0
1,YLL050C,MSRSGVAVADESLTAFNDLKLGKKYKFILFGLNDAKTEIVVKETST...,143,0.111888,0.293706,0.405594,0.300699,0.076923,0.006993,0.083916,...,0.083916,0.006993,0.048951,4932.YLL050C,0.053392,0.017135,2.003725e-04,0.497262,0.302355,1
2,YMR172W,MSGMGIAILCIVRTKIYRITISFDYSTLMSPFFLFLMMPTTLKDGY...,719,0.043115,0.314325,0.442281,0.243394,0.055633,0.002782,0.058414,...,0.030598,0.004172,0.018081,4932.YMR172W,0.007557,0.002314,3.144740e-06,0.438952,0.289855,0
3,YOR185C,MSAPAQNNAEVPTFKLVLVGDGGTGKTTFVKRHLTGEFEKKYIATI...,220,0.109091,0.322727,0.340909,0.336364,0.077273,0.013636,0.063636,...,0.081818,0.013636,0.040909,4932.YOR185C,0.046164,0.017683,1.040158e-04,0.491754,0.307905,0
4,YLL032C,MDNFKIYSTVITTAFLQVPHLYTTNRLWKPIEAPFLVEFLQKRISS...,825,0.100606,0.306667,0.358788,0.334545,0.042424,0.010909,0.042424,...,0.043636,0.002424,0.043636,4932.YLL032C,0.021028,0.006141,1.265410e-04,0.473128,0.236713,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
30506,7227.FBpp0306211,MARLISGVRNLFHRYPFVTNSAIYGSLYVGAEYSQQFASKRWLATA...,204,0.147059,0.348039,0.269608,0.382353,0.102941,0.014706,0.029412,...,0.073529,0.029412,0.058824,7227.FBpp0306211,0.000177,0.000006,1.872049e-08,0.290879,0.000000,0
30507,7227.FBpp0306213,MVKILQAYNFARQQTYALNGDILAASLIGNNRIAISSAEQFIEIYD...,1536,0.083984,0.372396,0.333333,0.294271,0.074870,0.019531,0.052734,...,0.065104,0.005208,0.035807,7227.FBpp0306213,0.001237,0.000074,6.935191e-06,0.324301,0.648352,0
30508,7227.FBpp0306214,MSGGDYDSGDYFMRSRKQRDKPSLWDSFQDPPSKKTSGSDADWKKL...,1393,0.117014,0.384063,0.284996,0.330940,0.083274,0.022254,0.055994,...,0.071070,0.023690,0.035176,7227.FBpp0306214,0.014759,0.005067,1.719578e-04,0.421110,0.268595,0
30509,7227.FBpp0306223,MEREIAHSLAGGEERSSDVAPGQVKTFEELRLYRNLLNGLKRNNFV...,1028,0.071012,0.351167,0.381323,0.267510,0.057393,0.007782,0.052529,...,0.057393,0.003891,0.034047,7227.FBpp0306223,0.055590,0.016066,4.047679e-04,0.451477,0.227401,1


In [33]:
# Retirando identificadores

df_tmp = df.drop(['Protein_key', 'Locus', 'Sequence'], axis=1)
df_tmp

Unnamed: 0,Sequence_Length,Aromaticity,Sec_Struct_Helix,Sec_Struct_Turn,Sec_Struct_Sheet,Percent_A,Percent_C,Percent_D,Percent_E,Percent_F,...,Percent_T,Percent_V,Percent_W,Percent_Y,DegreeCentrality,EigenvectorCentrality,BetweennessCentrality,ClosenessCentrality,Clustering,is_essential
0,156,0.096154,0.262821,0.435897,0.301282,0.044872,0.006410,0.128205,0.044872,0.025641,...,0.038462,0.038462,0.032051,0.038462,0.000986,0.000512,3.315435e-06,0.426787,0.266667,0
1,143,0.111888,0.293706,0.405594,0.300699,0.076923,0.006993,0.083916,0.069930,0.055944,...,0.055944,0.083916,0.006993,0.048951,0.053392,0.017135,2.003725e-04,0.497262,0.302355,1
2,719,0.043115,0.314325,0.442281,0.243394,0.055633,0.002782,0.058414,0.047288,0.020862,...,0.083449,0.030598,0.004172,0.018081,0.007557,0.002314,3.144740e-06,0.438952,0.289855,0
3,220,0.109091,0.322727,0.340909,0.336364,0.077273,0.013636,0.063636,0.063636,0.054545,...,0.059091,0.081818,0.013636,0.040909,0.046164,0.017683,1.040158e-04,0.491754,0.307905,0
4,825,0.100606,0.306667,0.358788,0.334545,0.042424,0.010909,0.042424,0.066667,0.054545,...,0.059394,0.043636,0.002424,0.043636,0.021028,0.006141,1.265410e-04,0.473128,0.236713,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
30506,204,0.147059,0.348039,0.269608,0.382353,0.102941,0.014706,0.029412,0.034314,0.058824,...,0.068627,0.073529,0.029412,0.058824,0.000177,0.000006,1.872049e-08,0.290879,0.000000,0
30507,1536,0.083984,0.372396,0.333333,0.294271,0.074870,0.019531,0.052734,0.065104,0.042969,...,0.046875,0.065104,0.005208,0.035807,0.001237,0.000074,6.935191e-06,0.324301,0.648352,0
30508,1393,0.117014,0.384063,0.284996,0.330940,0.083274,0.022254,0.055994,0.048816,0.058148,...,0.045944,0.071070,0.023690,0.035176,0.014759,0.005067,1.719578e-04,0.421110,0.268595,0
30509,1028,0.071012,0.351167,0.381323,0.267510,0.057393,0.007782,0.052529,0.081712,0.033074,...,0.053502,0.057393,0.003891,0.034047,0.055590,0.016066,4.047679e-04,0.451477,0.227401,1


In [44]:
def split_dataset(dataset, test_ratio=0.30):
    """Splits a panda dataframe in two."""
    test_indices = np.random.rand(len(dataset)) < test_ratio
    return dataset[~test_indices], dataset[test_indices]

# Name of the label column.
label = "is_essential"

train_ds_pd, test_ds_pd = split_dataset(df_tmp)


print("{} examples in training, {} examples for testing.".format(
    len(train_ds_pd), len(test_ds_pd)))


21312 examples in training, 9199 examples for testing.


In [58]:
## Combine

X_train = train_ds_pd.drop(['is_essential'], axis=1) 

y_train = train_ds_pd.drop(X_train.columns, axis=1)

sample = SMOTEENN(random_state=7)

X_train_sample, y_train_sample = sample.fit_resample(X_train, y_train)

y_train_sample.value_counts()

is_essential
1               17332
0               14046
dtype: int64

In [59]:
train_ds_pd = pd.concat([X_train_sample, y_train_sample], axis=1)
train_ds_pd 

Unnamed: 0,Sequence_Length,Aromaticity,Sec_Struct_Helix,Sec_Struct_Turn,Sec_Struct_Sheet,Percent_A,Percent_C,Percent_D,Percent_E,Percent_F,...,Percent_T,Percent_V,Percent_W,Percent_Y,DegreeCentrality,EigenvectorCentrality,BetweennessCentrality,ClosenessCentrality,Clustering,is_essential
0,156,0.096154,0.262821,0.435897,0.301282,0.044872,0.006410,0.128205,0.044872,0.025641,...,0.038462,0.038462,0.032051,0.038462,0.000986,0.000512,0.000003,0.426787,0.266667,0
1,719,0.043115,0.314325,0.442281,0.243394,0.055633,0.002782,0.058414,0.047288,0.020862,...,0.083449,0.030598,0.004172,0.018081,0.007557,0.002314,0.000003,0.438952,0.289855,0
2,900,0.084444,0.283333,0.444444,0.272222,0.054444,0.008889,0.064444,0.054444,0.044444,...,0.054444,0.032222,0.011111,0.028889,0.003450,0.001264,0.000003,0.442561,0.376190,0
3,434,0.099078,0.311060,0.382488,0.306452,0.041475,0.006912,0.041475,0.085253,0.050691,...,0.052995,0.059908,0.023041,0.025346,0.009693,0.002169,0.000016,0.429439,0.255406,0
4,470,0.110638,0.351064,0.338298,0.310638,0.063830,0.014894,0.061702,0.091489,0.057447,...,0.055319,0.051064,0.029787,0.023404,0.005750,0.001404,0.000006,0.448268,0.242017,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
31373,971,0.064505,0.372617,0.338145,0.289238,0.079948,0.017969,0.061448,0.067479,0.035679,...,0.055947,0.072450,0.004295,0.024531,0.039976,0.007221,0.000322,0.456425,0.217309,1
31374,553,0.088573,0.362571,0.336926,0.300503,0.073442,0.015903,0.067311,0.080063,0.040003,...,0.064473,0.063127,0.007221,0.041349,0.045710,0.014645,0.000118,0.475481,0.253883,1
31375,466,0.094288,0.272108,0.411346,0.316546,0.056670,0.004574,0.101457,0.049489,0.051281,...,0.080973,0.066389,0.014758,0.028248,0.005173,0.001323,0.000014,0.442248,0.187598,1
31376,667,0.080608,0.343425,0.377563,0.279012,0.064521,0.007943,0.068693,0.082494,0.042644,...,0.050145,0.064388,0.010484,0.027480,0.040579,0.011765,0.000554,0.481757,0.256227,1


In [60]:
train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_ds_pd, label=label, task=tfdf.keras.Task.CLASSIFICATION)
test_ds = tfdf.keras.pd_dataframe_to_tf_dataset(test_ds_pd, label=label, task=tfdf.keras.Task.CLASSIFICATION)

### Treinamento

In [62]:
# Specify the model.
model_1 = tfdf.keras.RandomForestModel(task=tfdf.keras.Task.CLASSIFICATION, max_depth=6)

# Optionally, add evaluation metrics.
model_1.compile(
    metrics=[METRICS])

# Train the model.
# "sys_pipes" is optional. It enables the display of the training logs.

model_1.fit(x=train_ds)


model_1.save("project/model")

Use /tmp/tmp_9cbvnst as temporary training directory
Reading training dataset...
Training dataset read in 0:00:00.692238. Found 31378 examples.
Training model...
Model trained in 0:00:02.500460
Compiling model...


[INFO 2023-02-21T20:17:06.800025987-03:00 kernel.cc:1214] Loading model from path /tmp/tmp_9cbvnst/model/ with prefix fa24c16d61744ac0
[INFO 2023-02-21T20:17:06.833877741-03:00 decision_forest.cc:661] Model loaded with 300 root(s), 17452 node(s), and 30 input feature(s).
[INFO 2023-02-21T20:17:06.833902821-03:00 abstract_model.cc:1311] Engine "RandomForestOptPred" built
[INFO 2023-02-21T20:17:06.833917626-03:00 kernel.cc:1046] Use fast generic engine


Model compiled.


<keras.callbacks.History at 0x7f7ee501ee00>

In [63]:
model_1.summary()

Model: "random_forest_model_3"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
Total params: 1
Trainable params: 0
Non-trainable params: 1
_________________________________________________________________
Type: "RANDOM_FOREST"
Task: CLASSIFICATION
Label: "__LABEL"

Input Features (30):
	Aromaticity
	BetweennessCentrality
	ClosenessCentrality
	Clustering
	DegreeCentrality
	EigenvectorCentrality
	Percent_A
	Percent_C
	Percent_D
	Percent_E
	Percent_F
	Percent_G
	Percent_H
	Percent_I
	Percent_K
	Percent_L
	Percent_M
	Percent_N
	Percent_P
	Percent_Q
	Percent_R
	Percent_S
	Percent_T
	Percent_V
	Percent_W
	Percent_Y
	Sec_Struct_Helix
	Sec_Struct_Sheet
	Sec_Struct_Turn
	Sequence_Length

No weights

Variable Importance: INV_MEAN_MIN_DEPTH:
    1.   "ClosenessCentrality"  0.261428 ################
    2. "EigenvectorCentrality"  0.244163 #############
    3.      "DegreeCentrality"  0.237647 ###########
    4. "B

### Evaluation

In [64]:
results = model_1.evaluate(test_ds, return_dict=True, verbose=0)
print("model_1 Evaluation: \n")
for name, value in results.items():
    print(f"{name}: {value:.4f}")

model_1 Evaluation: 

loss: 0.0000
accuracy: 0.7788
precision: 0.1069
recall: 0.4381
auc: 0.6409


In [65]:
#tfdf.model_plotter.plot_model(model_1, tree_idx=0, max_depth=3)


with open("plot.html", "w") as f: 
    f.write(tfdf.model_plotter.plot_model(model_1, tree_idx=0, max_depth=5))


In [66]:
model_1.make_inspector().evaluation()

Evaluation(num_examples=31378, accuracy=0.8033972847217796, loss=1.0746463852586325, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)

In [67]:
# Export the meta-data to tensorboard.
model_1.make_inspector().export_to_tensorboard("/tmp/tensorboard_logs")

In [68]:
# docs_infra: no_execute
# Start a tensorboard instance.

%load_ext tensorboard

%tensorboard --logdir "/tmp/tensorboard_logs"

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 8753), started 1:29:25 ago. (Use '!kill 8753' to kill it.)