In [6]:
import pandas as pd
import numpy as np
import json

from sklearn.ensemble import RandomForestClassifier
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.preprocessing import StandardScaler
import torch

import utils

from rdkit import rdBase
rdBase.DisableLog('rdApp.error') 

from dotenv import load_dotenv
load_dotenv()

import sys
import os
parent_dir = os.path.abspath(os.path.join(os.getcwd(), "../.."))
sys.path.append(parent_dir)

from model.model import GNNFingerprint3D

In [7]:
data_path = os.getenv("DATA_PATH")
models_path = os.getenv("MODELS_PATH")

In [8]:
def read_data(path, ki_threshold):
    df = pd.read_csv(path, sep=";")
    df['Activity'] = df['Standard Value'].apply(lambda x: 1 if x < ki_threshold else 0)

    df = df[['Smiles', 'Activity']].dropna()

    X_train, X_test, y_train, y_test = train_test_split(
        df['Smiles'], df['Activity'], test_size=0.2, random_state=42
    )

    return X_train.reset_index(drop=True), X_test.reset_index(drop=True), y_train.reset_index(drop=True), y_test.reset_index(drop=True)

In [9]:
def get_score(X_train, y_train, X_test, y_test):
    s_scaler = StandardScaler()
    X_train = s_scaler.fit_transform(X_train)
    X_test = s_scaler.transform(X_test)

    # PCA to 167 dim
    pca = PCA(n_components=167)
    X_train = pca.fit_transform(X_train)
    X_test = pca.transform(X_test)

    model = RandomForestClassifier()
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)

    print("Classification Report:")
    print(classification_report(y_test, y_pred))

In [10]:
fingerprint_model_2D = GNNFingerprint3D(13, 5)
fingerprint_model_2D.load_state_dict(torch.load(os.path.join(os.getenv("MODELS_PATH"), "GNN_MORE_WEIGHT_3D.pth")))
fingerprint_model_2D = fingerprint_model_2D.to("cuda")
fingerprint_model_2D.eval()

fingerprint_model_3D = GNNFingerprint3D(13, 5)
fingerprint_model_3D.load_state_dict(torch.load(os.path.join(os.getenv("MODELS_PATH"), "FINAL_GNN.pth")))
fingerprint_model_3D = fingerprint_model_3D.to("cuda")
fingerprint_model_3D.eval()

with open(os.path.join(data_path, "means_and_stds.json")) as f:
    scaler = json.load(f)

  fingerprint_model_2D.load_state_dict(torch.load(os.path.join(os.getenv("MODELS_PATH"), "GNN_MORE_WEIGHT_3D.pth")))
  fingerprint_model_3D.load_state_dict(torch.load(os.path.join(os.getenv("MODELS_PATH"), "FINAL_GNN.pth")))


In [13]:
folder = os.path.join(data_path, "CHEMBL")

for db in os.listdir(folder):
    print("============================================")
    print(db)
    print("============================================")
    
    data = os.path.join(folder, db)
    X_train, X_test, y_train, y_test = read_data(data, 100)

    train_filtered = [(s, y) for s, y in zip(X_train, y_train) if getattr(utils, "is_valid_smiles")(s)]
    test_filtered = [(s, y) for s, y in zip(X_test, y_test) if getattr(utils, "is_valid_smiles")(s)]
    X_train, y_train = zip(*train_filtered) if train_filtered else ([], [])
    X_test, y_test = zip(*test_filtered) if test_filtered else ([], [])

    X_train, y_train = list(X_train), list(y_train)
    X_test, y_test = list(X_test), list(y_test)

    for fingerprint in ("ecfp", "maccs", "rdkit", "rdf", "random", "gnn_fp_2d", "gnn_fp_3d"):
        func_name = "smiles_to_" + fingerprint

        if "gnn_fp" in fingerprint:
            func_name = "smiles_to_3D"
            if fingerprint == "gnn_fp_2d":
                X_train_prep = np.array([getattr(utils, func_name)(smiles, fingerprint_model_2D, scaler, False).detach().cpu() for smiles in X_train])
                X_test_prep = np.array([getattr(utils, func_name)(smiles, fingerprint_model_2D, scaler, False).detach().cpu() for smiles in X_test])
            else:
                X_train_prep = np.array([getattr(utils, func_name)(smiles, fingerprint_model_3D, scaler, False).detach().cpu() for smiles in X_train])
                X_test_prep = np.array([getattr(utils, func_name)(smiles, fingerprint_model_3D, scaler, False).detach().cpu() for smiles in X_test])

            rows_all_nan_train = np.isnan(X_train_prep).all(axis=1)
            nan_indices_train = np.where(rows_all_nan_train)[0]

            rows_all_nan_test = np.isnan(X_test_prep).all(axis=1)
            nan_indices_test = np.where(rows_all_nan_test)[0]

            X_train_prep = np.delete(X_train_prep, nan_indices_train, axis=0)
            y_train_prep = np.delete(y_train, nan_indices_train, axis=0)
            X_test_prep = np.delete(X_test_prep, nan_indices_test, axis=0)
            y_test_prep = np.delete(y_test, nan_indices_test, axis=0)
        else:
            X_train_prep = np.array([getattr(utils, func_name)(smiles) for smiles in X_train])
            y_train_prep = y_train
            X_test_prep = np.array([getattr(utils, func_name)(smiles) for smiles in X_test])
            y_test_prep = y_test

        print(fingerprint.upper())
        get_score(X_train_prep, y_train_prep, X_test_prep, y_test_prep)

CHEMBL1833_5HT2B.csv
ECFP
Classification Report:
              precision    recall  f1-score   support

           0       0.86      0.98      0.91       408
           1       0.87      0.51      0.64       134

    accuracy                           0.86       542
   macro avg       0.86      0.74      0.78       542
weighted avg       0.86      0.86      0.85       542

MACCS
Classification Report:
              precision    recall  f1-score   support

           0       0.87      0.97      0.92       408
           1       0.84      0.56      0.67       134

    accuracy                           0.87       542
   macro avg       0.86      0.76      0.79       542
weighted avg       0.86      0.87      0.86       542

RDKIT
Classification Report:
              precision    recall  f1-score   support

           0       0.88      0.97      0.92       408
           1       0.85      0.58      0.69       134

    accuracy                           0.87       542
   macro avg       0.

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


GNN_FP_2D
Classification Report:
              precision    recall  f1-score   support

           0       0.84      0.98      0.90       408
           1       0.86      0.45      0.59       134

    accuracy                           0.85       542
   macro avg       0.85      0.71      0.75       542
weighted avg       0.85      0.85      0.83       542

GNN_FP_3D
Classification Report:
              precision    recall  f1-score   support

           0       0.81      0.99      0.89       408
           1       0.90      0.28      0.43       134

    accuracy                           0.82       542
   macro avg       0.86      0.64      0.66       542
weighted avg       0.83      0.82      0.78       542

CHEMBL214_5HT1A.csv




ECFP
Classification Report:
              precision    recall  f1-score   support

           0       0.80      0.76      0.78       549
           1       0.81      0.85      0.83       674

    accuracy                           0.81      1223
   macro avg       0.81      0.80      0.80      1223
weighted avg       0.81      0.81      0.81      1223

MACCS
Classification Report:
              precision    recall  f1-score   support

           0       0.79      0.77      0.78       549
           1       0.82      0.83      0.82       674

    accuracy                           0.80      1223
   macro avg       0.80      0.80      0.80      1223
weighted avg       0.80      0.80      0.80      1223

RDKIT
Classification Report:
              precision    recall  f1-score   support

           0       0.80      0.76      0.78       549
           1       0.81      0.85      0.83       674

    accuracy                           0.81      1223
   macro avg       0.81      0.80      0.8



RDF
Classification Report:
              precision    recall  f1-score   support

           0       0.75      0.53      0.62       549
           1       0.69      0.85      0.76       674

    accuracy                           0.71      1223
   macro avg       0.72      0.69      0.69      1223
weighted avg       0.72      0.71      0.70      1223

RANDOM
Classification Report:
              precision    recall  f1-score   support

           0       0.54      0.07      0.12       549
           1       0.56      0.95      0.70       674

    accuracy                           0.56      1223
   macro avg       0.55      0.51      0.41      1223
weighted avg       0.55      0.56      0.44      1223





GNN_FP_2D
Classification Report:
              precision    recall  f1-score   support

           0       0.74      0.72      0.73       549
           1       0.77      0.79      0.78       673

    accuracy                           0.76      1222
   macro avg       0.76      0.75      0.76      1222
weighted avg       0.76      0.76      0.76      1222





GNN_FP_3D
Classification Report:
              precision    recall  f1-score   support

           0       0.71      0.62      0.66       549
           1       0.72      0.79      0.75       673

    accuracy                           0.71      1222
   macro avg       0.71      0.71      0.71      1222
weighted avg       0.71      0.71      0.71      1222

CHEMBL224_5HT2A.csv
ECFP
Classification Report:
              precision    recall  f1-score   support

           0       0.86      0.95      0.90       704
           1       0.92      0.80      0.86       532

    accuracy                           0.89      1236
   macro avg       0.89      0.87      0.88      1236
weighted avg       0.89      0.89      0.88      1236

MACCS
Classification Report:
              precision    recall  f1-score   support

           0       0.84      0.94      0.89       704
           1       0.90      0.77      0.83       532

    accuracy                           0.86      1236
   macro avg      