# Tensorflow Decision Forests

Link: https://www.tensorflow.org/decision_forests?hl=pt-br

In [26]:
import numpy as np
import pandas as pd

import keras
import tensorflow_decision_forests as tfdf

import tensorflow_decision_forests as tfdf


from IPython.core.magic import register_line_magic
from IPython.display import Javascript

from imblearn.combine import SMOTEENN

import warnings
warnings.filterwarnings("ignore")

In [16]:
# Check the version of TensorFlow Decision Forests
print("Found TensorFlow Decision Forests v" + tfdf.__version__)

Found TensorFlow Decision Forests v1.2.0


In [36]:
METRICS = [ 
      keras.metrics.BinaryAccuracy(name='accuracy'),
      keras.metrics.Precision(name='precision'),
      keras.metrics.Recall(name='recall'),
      keras.metrics.AUC(name='auc')
]

tfdf.keras.get_all_models()

[tensorflow_decision_forests.keras.RandomForestModel,
 tensorflow_decision_forests.keras.GradientBoostedTreesModel,
 tensorflow_decision_forests.keras.CartModel,
 tensorflow_decision_forests.keras.DistributedGradientBoostedTreesModel]

In [8]:
df_ppi_1 = pd.read_csv('data/base_cerevisiae.csv')
df_ppi_2 = pd.read_csv('data/base_elegans.csv')
df_ppi_3 = pd.read_csv('data/base_drosophila.csv')

In [9]:
df = pd.concat([df_ppi_1, df_ppi_2, df_ppi_3], ignore_index=True)
df

Unnamed: 0,Locus,Sequence,Sequence_Length,Aromaticity,Sec_Struct_Helix,Sec_Struct_Turn,Sec_Struct_Sheet,Percent_A,Percent_C,Percent_D,...,Percent_V,Percent_W,Percent_Y,Protein_key,DegreeCentrality,EigenvectorCentrality,BetweennessCentrality,ClosenessCentrality,Clustering,is_essential
0,YPL071C,MSSRFARSNGNPNHIRKRNHSPDPIGIDNYKRKRLIIDLENLSLND...,156,0.096154,0.262821,0.435897,0.301282,0.044872,0.006410,0.128205,...,0.038462,0.032051,0.038462,4932.YPL071C,0.000986,0.000512,3.315435e-06,0.426787,0.266667,0
1,YLL050C,MSRSGVAVADESLTAFNDLKLGKKYKFILFGLNDAKTEIVVKETST...,143,0.111888,0.293706,0.405594,0.300699,0.076923,0.006993,0.083916,...,0.083916,0.006993,0.048951,4932.YLL050C,0.053392,0.017135,2.003725e-04,0.497262,0.302355,1
2,YMR172W,MSGMGIAILCIVRTKIYRITISFDYSTLMSPFFLFLMMPTTLKDGY...,719,0.043115,0.314325,0.442281,0.243394,0.055633,0.002782,0.058414,...,0.030598,0.004172,0.018081,4932.YMR172W,0.007557,0.002314,3.144740e-06,0.438952,0.289855,0
3,YOR185C,MSAPAQNNAEVPTFKLVLVGDGGTGKTTFVKRHLTGEFEKKYIATI...,220,0.109091,0.322727,0.340909,0.336364,0.077273,0.013636,0.063636,...,0.081818,0.013636,0.040909,4932.YOR185C,0.046164,0.017683,1.040158e-04,0.491754,0.307905,0
4,YLL032C,MDNFKIYSTVITTAFLQVPHLYTTNRLWKPIEAPFLVEFLQKRISS...,825,0.100606,0.306667,0.358788,0.334545,0.042424,0.010909,0.042424,...,0.043636,0.002424,0.043636,4932.YLL032C,0.021028,0.006141,1.265410e-04,0.473128,0.236713,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
30506,7227.FBpp0306211,MARLISGVRNLFHRYPFVTNSAIYGSLYVGAEYSQQFASKRWLATA...,204,0.147059,0.348039,0.269608,0.382353,0.102941,0.014706,0.029412,...,0.073529,0.029412,0.058824,7227.FBpp0306211,0.000177,0.000006,1.872049e-08,0.290879,0.000000,0
30507,7227.FBpp0306213,MVKILQAYNFARQQTYALNGDILAASLIGNNRIAISSAEQFIEIYD...,1536,0.083984,0.372396,0.333333,0.294271,0.074870,0.019531,0.052734,...,0.065104,0.005208,0.035807,7227.FBpp0306213,0.001237,0.000074,6.935191e-06,0.324301,0.648352,0
30508,7227.FBpp0306214,MSGGDYDSGDYFMRSRKQRDKPSLWDSFQDPPSKKTSGSDADWKKL...,1393,0.117014,0.384063,0.284996,0.330940,0.083274,0.022254,0.055994,...,0.071070,0.023690,0.035176,7227.FBpp0306214,0.014759,0.005067,1.719578e-04,0.421110,0.268595,0
30509,7227.FBpp0306223,MEREIAHSLAGGEERSSDVAPGQVKTFEELRLYRNLLNGLKRNNFV...,1028,0.071012,0.351167,0.381323,0.267510,0.057393,0.007782,0.052529,...,0.057393,0.003891,0.034047,7227.FBpp0306223,0.055590,0.016066,4.047679e-04,0.451477,0.227401,1


In [45]:
def split_dataset(dataset, test_ratio=0.30):
    """Splits a panda dataframe in two."""
    test_indices = np.random.rand(len(dataset)) < test_ratio
    return dataset[~test_indices], dataset[test_indices]

# Name of the label column.
label = "is_essential"

train_ds_pd, test_ds_pd = split_dataset(df)


print("{} examples in training, {} examples for testing.".format(
    len(train_ds_pd), len(test_ds_pd)))


21305 examples in training, 9206 examples for testing.


Unnamed: 0,Locus,Sequence,Sequence_Length,Aromaticity,Sec_Struct_Helix,Sec_Struct_Turn,Sec_Struct_Sheet,Percent_A,Percent_C,Percent_D,...,Percent_V,Percent_W,Percent_Y,Protein_key,DegreeCentrality,EigenvectorCentrality,BetweennessCentrality,ClosenessCentrality,Clustering,is_essential
1,YLL050C,MSRSGVAVADESLTAFNDLKLGKKYKFILFGLNDAKTEIVVKETST...,143,0.111888,0.293706,0.405594,0.300699,0.076923,0.006993,0.083916,...,0.083916,0.006993,0.048951,4932.YLL050C,0.053392,0.017135,2.003725e-04,0.497262,0.302355,1
3,YOR185C,MSAPAQNNAEVPTFKLVLVGDGGTGKTTFVKRHLTGEFEKKYIATI...,220,0.109091,0.322727,0.340909,0.336364,0.077273,0.013636,0.063636,...,0.081818,0.013636,0.040909,4932.YOR185C,0.046164,0.017683,1.040158e-04,0.491754,0.307905,0
5,YBR225W,MGSNKEAKNIDSKNDRGLTSITSNKISNLKAHDNHTSSMITEHKNA...,900,0.084444,0.283333,0.444444,0.272222,0.054444,0.008889,0.064444,...,0.032222,0.011111,0.028889,4932.YBR225W,0.003450,0.001264,2.723482e-06,0.442561,0.376190,0
6,YEL041W,MKTDRLLINASPETCTKGDAEMDTMDTIDRMTSVKVLAEGKVLSNF...,495,0.082828,0.296970,0.365657,0.337374,0.046465,0.020202,0.070707,...,0.074747,0.010101,0.026263,4932.YEL041W,0.005421,0.001134,6.316723e-06,0.430685,0.153409,0
7,YOR237W,MSQHASSSSWTSFLKSISSFNGDLSSLSAPPFILSPTSLTEFSQYW...,434,0.099078,0.311060,0.382488,0.306452,0.041475,0.006912,0.041475,...,0.059908,0.023041,0.025346,4932.YOR237W,0.009693,0.002169,1.555220e-05,0.429439,0.255406,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
30495,7227.FBpp0306150,MSHFTCLNCDARFASADVQRNHYKTDWHRYNLKRRVAQLPPVTAEE...,409,0.085575,0.415648,0.298289,0.286064,0.097800,0.029340,0.075795,...,0.063570,0.009780,0.044010,7227.FBpp0306150,0.020857,0.006465,3.776616e-05,0.413002,0.418897,0
30496,7227.FBpp0306158,MYPFGSGMPSHPPTSTNHHEPPRAPFGAGWVPPMQQNSPYPPPSQP...,511,0.086106,0.338552,0.383562,0.277886,0.080235,0.011742,0.041096,...,0.031311,0.003914,0.056751,7227.FBpp0306158,0.010517,0.004368,7.096667e-05,0.423143,0.274890,0
30499,7227.FBpp0306188,MLSLLTRPFLPIFCFLYGPQSEGSTRIQCLRRFVTFLLGLVLGFLL...,715,0.120280,0.356643,0.268531,0.374825,0.046154,0.030769,0.039161,...,0.079720,0.006993,0.043357,7227.FBpp0306188,0.000354,0.000057,3.462324e-08,0.324264,0.333333,0
30506,7227.FBpp0306211,MARLISGVRNLFHRYPFVTNSAIYGSLYVGAEYSQQFASKRWLATA...,204,0.147059,0.348039,0.269608,0.382353,0.102941,0.014706,0.029412,...,0.073529,0.029412,0.058824,7227.FBpp0306211,0.000177,0.000006,1.872049e-08,0.290879,0.000000,0


In [21]:
train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_ds_pd, label=label, task=tfdf.keras.Task.CLASSIFICATION)
test_ds = tfdf.keras.pd_dataframe_to_tf_dataset(test_ds_pd, label=label, task=tfdf.keras.Task.CLASSIFICATION)

### Treinamento

In [37]:
# Specify the model.
model_1 = tfdf.keras.RandomForestModel(task=tfdf.keras.Task.CLASSIFICATION, max_depth=6)

# Optionally, add evaluation metrics.
model_1.compile(
    metrics=[METRICS])

# Train the model.
# "sys_pipes" is optional. It enables the display of the training logs.

model_1.fit(x=train_ds)

Use /tmp/tmpt5ha4vgo as temporary training directory
Reading training dataset...
Training dataset read in 0:00:00.621373. Found 21449 examples.
Training model...
Model trained in 0:00:01.980760
Compiling model...


[INFO 2023-02-20T16:54:09.486107131-03:00 kernel.cc:1214] Loading model from path /tmp/tmpt5ha4vgo/model/ with prefix 5a872bd9e23c4cbf
[INFO 2023-02-20T16:54:09.517853696-03:00 decision_forest.cc:661] Model loaded with 300 root(s), 16714 node(s), and 31 input feature(s).
[INFO 2023-02-20T16:54:09.517896546-03:00 abstract_model.cc:1311] Engine "RandomForestOptPred" built
[INFO 2023-02-20T16:54:09.517913114-03:00 kernel.cc:1046] Use fast generic engine


Model compiled.


<keras.callbacks.History at 0x7f36291d3e50>

In [38]:
model_1.summary()

Model: "random_forest_model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
Total params: 1
Trainable params: 0
Non-trainable params: 1
_________________________________________________________________
Type: "RANDOM_FOREST"
Task: CLASSIFICATION
Label: "__LABEL"

Input Features (33):
	Aromaticity
	BetweennessCentrality
	ClosenessCentrality
	Clustering
	DegreeCentrality
	EigenvectorCentrality
	Locus
	Percent_A
	Percent_C
	Percent_D
	Percent_E
	Percent_F
	Percent_G
	Percent_H
	Percent_I
	Percent_K
	Percent_L
	Percent_M
	Percent_N
	Percent_P
	Percent_Q
	Percent_R
	Percent_S
	Percent_T
	Percent_V
	Percent_W
	Percent_Y
	Protein_key
	Sec_Struct_Helix
	Sec_Struct_Sheet
	Sec_Struct_Turn
	Sequence
	Sequence_Length

No weights

Variable Importance: INV_MEAN_MIN_DEPTH:
    1.   "ClosenessCentrality"  0.296903 ################
    2.      "DegreeCentrality"  0.238419 ########
    3. "EigenvectorCentrality"  0.223

In [39]:
results = model_1.evaluate(test_ds, return_dict=True, verbose=0)
print("model_1 Evaluation: \n")
for name, value in results.items():
    print(f"{name}: {value:.4f}")

model_1 Evaluation: 

loss: 0.0000
accuracy: 0.9493
precision: 0.0000
recall: 0.0000
auc: 0.5997


In [40]:
#tfdf.model_plotter.plot_model(model_1, tree_idx=0, max_depth=3)


with open("plot.html", "w") as f: 
    f.write(tfdf.model_plotter.plot_model(model_1, tree_idx=0, max_depth=5))


In [41]:
model_1.make_inspector().evaluation()

Evaluation(num_examples=21449, accuracy=0.9436337358385006, loss=1.6238716572216656, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)

In [42]:
# Export the meta-data to tensorboard.
model_1.make_inspector().export_to_tensorboard("/tmp/tensorboard_logs")

In [43]:
# docs_infra: no_execute
# Start a tensorboard instance.

%load_ext tensorboard

%tensorboard --logdir "/tmp/tensorboard_logs"

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 6321), started 0:15:48 ago. (Use '!kill 6321' to kill it.)