In [1]:
import pandas as pd
import numpy as np
import json

from sklearn.ensemble import RandomForestClassifier
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.preprocessing import StandardScaler
import torch

import utils

from rdkit import rdBase
rdBase.DisableLog('rdApp.error') 

from dotenv import load_dotenv
load_dotenv()

import sys
import os
parent_dir = os.path.abspath(os.path.join(os.getcwd(), "../.."))
sys.path.append(parent_dir)

from model.model import GNNFingerprint3D

In [2]:
data_path = os.getenv("DATA_PATH")
models_path = os.getenv("MODELS_PATH")

In [3]:
def read_data(path, ki_threshold):
    df = pd.read_csv(path, sep=";")
    df['Activity'] = df['Standard Value'].apply(lambda x: 1 if x < ki_threshold else 0)

    df = df[['Smiles', 'Activity']].dropna()

    X_train, X_test, y_train, y_test = train_test_split(
        df['Smiles'], df['Activity'], test_size=0.2, random_state=42
    )

    return X_train.reset_index(drop=True), X_test.reset_index(drop=True), y_train.reset_index(drop=True), y_test.reset_index(drop=True)

In [4]:
def get_score(X_train, y_train, X_test, y_test):
    s_scaler = StandardScaler()
    X_train = s_scaler.fit_transform(X_train)
    X_test = s_scaler.transform(X_test)

    # PCA to 167 dim
    pca = PCA(n_components=167)
    X_train = pca.fit_transform(X_train)
    X_test = pca.transform(X_test)

    model = RandomForestClassifier()
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)

    print("Classification Report:")
    print(classification_report(y_test, y_pred))

In [9]:
fingerprint_model = GNNFingerprint3D(13, 5)
fingerprint_model.load_state_dict(torch.load(os.path.join(os.getenv("MODELS_PATH"), "GNN_MUCH_MORE_WEIGHT_3D.pth")))
fingerprint_model = fingerprint_model.to("cuda")
fingerprint_model.eval()

with open(os.path.join(data_path, "means_and_stds.json")) as f:
    scaler = json.load(f)

  fingerprint_model.load_state_dict(torch.load(os.path.join(os.getenv("MODELS_PATH"), "GNN_MUCH_MORE_WEIGHT_3D.pth")))


In [10]:
folder = os.path.join(data_path, "CHEMBL")

for db in os.listdir(folder):
    print("============================================")
    print(db)
    print("============================================")
    
    data = os.path.join(folder, db)
    X_train, X_test, y_train, y_test = read_data(data, 100)

    train_filtered = [(s, y) for s, y in zip(X_train, y_train) if getattr(utils, "is_valid_smiles")(s)]
    test_filtered = [(s, y) for s, y in zip(X_test, y_test) if getattr(utils, "is_valid_smiles")(s)]
    X_train, y_train = zip(*train_filtered) if train_filtered else ([], [])
    X_test, y_test = zip(*test_filtered) if test_filtered else ([], [])

    X_train, y_train = list(X_train), list(y_train)
    X_test, y_test = list(X_test), list(y_test)

    for fingerprint in ("ecfp", "maccs", "rdkit", "rdf", "random", "3D"):
        func_name = "smiles_to_" + fingerprint

        if fingerprint == "3D":
            X_train_prep = np.array([getattr(utils, func_name)(smiles, fingerprint_model, scaler, False).detach().cpu() for smiles in X_train])
            X_test_prep = np.array([getattr(utils, func_name)(smiles, fingerprint_model, scaler, False).detach().cpu() for smiles in X_test])

            rows_all_nan_train = np.isnan(X_train_prep).all(axis=1)
            nan_indices_train = np.where(rows_all_nan_train)[0]

            rows_all_nan_test = np.isnan(X_test_prep).all(axis=1)
            nan_indices_test = np.where(rows_all_nan_test)[0]

            X_train_prep = np.delete(X_train_prep, nan_indices_train, axis=0)
            y_train = np.delete(y_train, nan_indices_train, axis=0)
            X_test_prep = np.delete(X_test_prep, nan_indices_test, axis=0)
            y_test = np.delete(y_test, nan_indices_test, axis=0)
        else:
            X_train_prep = np.array([getattr(utils, func_name)(smiles) for smiles in X_train])
            X_test_prep = np.array([getattr(utils, func_name)(smiles) for smiles in X_test])

        print(fingerprint.upper())
        get_score(X_train_prep, y_train, X_test_prep, y_test)

CHEMBL1833_5HT2B.csv
ECFP
Classification Report:
              precision    recall  f1-score   support

           0       0.87      0.98      0.92       408
           1       0.89      0.54      0.68       134

    accuracy                           0.87       542
   macro avg       0.88      0.76      0.80       542
weighted avg       0.87      0.87      0.86       542

MACCS
Classification Report:
              precision    recall  f1-score   support

           0       0.86      0.96      0.91       408
           1       0.83      0.54      0.65       134

    accuracy                           0.86       542
   macro avg       0.85      0.75      0.78       542
weighted avg       0.85      0.86      0.85       542

RDKIT
Classification Report:
              precision    recall  f1-score   support

           0       0.88      0.96      0.92       408
           1       0.84      0.59      0.69       134

    accuracy                           0.87       542
   macro avg       0.

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


3D
Classification Report:
              precision    recall  f1-score   support

           0       0.81      0.99      0.89       408
           1       0.90      0.28      0.43       134

    accuracy                           0.82       542
   macro avg       0.86      0.64      0.66       542
weighted avg       0.83      0.82      0.78       542

CHEMBL214_5HT1A.csv




ECFP
Classification Report:
              precision    recall  f1-score   support

           0       0.80      0.76      0.78       549
           1       0.81      0.84      0.83       674

    accuracy                           0.81      1223
   macro avg       0.81      0.80      0.80      1223
weighted avg       0.81      0.81      0.81      1223

MACCS
Classification Report:
              precision    recall  f1-score   support

           0       0.78      0.75      0.77       549
           1       0.80      0.83      0.82       674

    accuracy                           0.79      1223
   macro avg       0.79      0.79      0.79      1223
weighted avg       0.79      0.79      0.79      1223

RDKIT
Classification Report:
              precision    recall  f1-score   support

           0       0.81      0.74      0.78       549
           1       0.81      0.86      0.83       674

    accuracy                           0.81      1223
   macro avg       0.81      0.80      0.8



RDF
Classification Report:
              precision    recall  f1-score   support

           0       0.73      0.53      0.61       549
           1       0.69      0.84      0.76       674

    accuracy                           0.70      1223
   macro avg       0.71      0.69      0.69      1223
weighted avg       0.71      0.70      0.69      1223

RANDOM
Classification Report:
              precision    recall  f1-score   support

           0       0.44      0.03      0.06       549
           1       0.55      0.97      0.70       674

    accuracy                           0.55      1223
   macro avg       0.49      0.50      0.38      1223
weighted avg       0.50      0.55      0.41      1223





3D
Classification Report:
              precision    recall  f1-score   support

           0       0.69      0.63      0.65       549
           1       0.72      0.77      0.74       673

    accuracy                           0.70      1222
   macro avg       0.70      0.70      0.70      1222
weighted avg       0.70      0.70      0.70      1222

CHEMBL224_5HT2A.csv


KeyboardInterrupt: 