# Deep-CBN Demo
This notebook demonstrates using the helper functions from `train_deep_cbn_fnx.py`.

In [15]:
from train_deep_cbn_fnx import train_deep_cbn, calculate_roc_auc, plot_confusion_matrix, process_multiple_targets,  predict_with_models, label_smiles

# Single Target

In [16]:
dataset_path = '../Data/tox21.csv'
target_col = 'NR-PPAR-gamma'
smiles_col = 'smiles'

# Train a model with minimal epochs for demonstration
model, train_eval, test_eval, X_test, y_test_cat = train_deep_cbn(dataset_path, target_col, smiles_col, n_epochs=1)

# Evaluate
roc_auc = calculate_roc_auc(model, X_test, y_test_cat)
print(f'ROC-AUC: {roc_auc:.4f}')
plot_confusion_matrix(model, X_test, y_test_cat, class_names=['Negative','Positive'])





[1m 7/21[0m [32m━━━━━━[0m[37m━━━━━━━━━━━━━━[0m [1m5s[0m 398ms/step - accuracy: 0.2318 - auc: 0.1604 - f1_score: 0.1912 - loss: 0.7234 - precision: 0.2318 - recall: 0.2318

KeyboardInterrupt: 

# Multiple Targets

In [17]:
# Target columns to process
target_cols = ['NR-PPAR-gamma', 'NR-AhR', 'SR-p53']
dataset_path = '../Data/tox21.csv'
smiles_col = 'smiles'
n_epochs = 1

# Call the function
results_df, models_dict = process_multiple_targets(dataset_path, target_cols, smiles_col, n_epochs)

# Display the results dataframe
print(results_df)

Processing target column: NR-PPAR-gamma






[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 391ms/step - accuracy: 0.5071 - auc: 0.5221 - f1_score: 0.3562 - loss: 0.6965 - precision: 0.5071 - recall: 0.5071
Restoring model weights from the end of the best epoch: 1.
[1m41/41[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 16ms/step - accuracy: 0.9673 - auc: 0.9448 - f1_score: 0.4917 - loss: 0.6876 - precision: 0.9673 - recall: 0.9673
[1m162/162[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step
[1m41/41[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 580ms/step - accuracy: 0.4291 - auc: 0.4968 - loss: 0.9017
Restoring model weights from the end of the best epoch: 1.
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 284ms/step - accuracy: 0.8203 - auc: 0.8145 - f1_score: 0.4746 - loss: 1.4888 - precision: 0.8203 - recall: 0.8203
Restoring model weights from the end of the best epoch: 1.
[1m162/162[0





[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 395ms/step - accuracy: 0.4328 - auc: 0.3969 - f1_score: 0.3740 - loss: 0.7021 - precision: 0.4328 - recall: 0.4328
Restoring model weights from the end of the best epoch: 1.
[1m41/41[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 16ms/step - accuracy: 0.7977 - auc: 0.5074 - f1_score: 0.6046 - loss: 0.6920 - precision: 0.7977 - recall: 0.7977
[1m164/164[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step
[1m41/41[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 576ms/step - accuracy: 0.4930 - auc: 0.5418 - loss: 0.9264
Restoring model weights from the end of the best epoch: 1.
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 283ms/step - accuracy: 0.6056 - auc: 0.4690 - f1_score: 0.5083 - loss: 0.7990 - precision: 0.6056 - recall: 0.6056
Restoring model weights from the end of the best epoch: 1.
[1m164/164[0





[1m22/22[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 370ms/step - accuracy: 0.3892 - auc: 0.3494 - f1_score: 0.3115 - loss: 0.7377 - precision: 0.3892 - recall: 0.3892
Restoring model weights from the end of the best epoch: 1.
[1m43/43[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 18ms/step - accuracy: 0.9245 - auc: 0.9298 - f1_score: 0.4803 - loss: 0.6834 - precision: 0.9245 - recall: 0.9245
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 5ms/step
[1m43/43[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step
[1m22/22[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 554ms/step - accuracy: 0.5067 - auc: 0.5483 - loss: 0.9715
Restoring model weights from the end of the best epoch: 1.
[1m22/22[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 280ms/step - accuracy: 0.7448 - auc: 0.7088 - f1_score: 0.4875 - loss: 0.7276 - precision: 0.7448 - recall: 0.7448
Restoring model weights from the end of the best epoch: 1.
[1m170/170[0

In [18]:
import pandas as pd
import numpy as np
from tensorflow.keras.utils import to_categorical

# Load subset of tox21 dataset for predictions
new_data = pd.read_csv('../Data/tox21.csv').dropna(subset=['NR-PPAR-gamma','NR-AhR','SR-p53','smiles']).head(100)
smiles_list = new_data['smiles']

    # Dictionary for converting SMILES characters to numbers
smiles_dict = {
    "#": 29, "%": 30, ")": 31, "(": 1, "+": 32, "-": 33, "/": 34, ".": 2,
    "1": 35, "0": 3, "3": 36, "2": 4, "5": 37, "4": 5, "7": 38, "6": 6,
    "9": 39, "8": 7, "=": 40, "A": 41, "@": 8, "C": 42, "B": 9, "E": 43,
    "D": 10, "G": 44, "F": 11, "I": 45, "H": 12, "K": 46, "M": 47, "L": 13,
    "O": 48, "N": 14, "P": 15, "S": 49, "R": 16, "U": 50, "T": 17, "W": 51,
    "V": 18, "Y": 52, "[": 53, "Z": 19, "]": 54, "\\": 20, "a": 55, "c": 56,
    "b": 21, "e": 57, "d": 22, "g": 58, "f": 23, "i": 59, "h": 24, "m": 60,
    "l": 25, "o": 61, "n": 26, "s": 62, "r": 27, "u": 63, "t": 28, "y": 64,
    " ": 65, ":": 66, ",": 67, "p": 68, "j": 69, "*": 70
    }


def label_smiles(line, MAX_SMI_LEN, smi_ch_ind):
    X = np.zeros(MAX_SMI_LEN, dtype=int)
    for i, ch in enumerate(line[:MAX_SMI_LEN]):
        if ch in smi_ch_ind:
            X[i] = smi_ch_ind[ch]
    return X

X_new = np.array([label_smiles(str(s), 100, smiles_dict) for s in smiles_list])
X_new = to_categorical(X_new, num_classes=71)

target_cols = ['NR-PPAR-gamma','NR-AhR','SR-p53']
X_test_dict = {t: X_new for t in target_cols}
y_test_cat_dict = {t: to_categorical(new_data[t].astype(int), num_classes=2) for t in target_cols}

predictions_df = predict_with_models(models_dict, X_test_dict, y_test_cat_dict)
print(predictions_df)

Predicting for target column: NR-PPAR-gamma
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step
Accuracy for NR-PPAR-gamma: 0.9900
AUC for NR-PPAR-gamma: 0.8283
Confusion Matrix for NR-PPAR-gamma:
[[99  0]
 [ 1  0]]

Predicting for target column: NR-AhR
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step
Accuracy for NR-AhR: 0.1200
AUC for NR-AhR: 0.4299
Confusion Matrix for NR-AhR:
[[ 0 88]
 [ 0 12]]

Predicting for target column: SR-p53
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step
Accuracy for SR-p53: 0.9700
AUC for SR-p53: 0.7320
Confusion Matrix for SR-p53:
[[97  0]
 [ 3  0]]

      target_col  accuracy       auc    confusion_matrix
0  NR-PPAR-gamma      0.99  0.828283   [[99, 0], [1, 0]]
1         NR-AhR      0.12  0.429924  [[0, 88], [0, 12]]
2         SR-p53      0.97  0.731959   [[97, 0], [3, 0]]
