In [1]:
import os
os.chdir('..')

In [2]:
from itertools import chain

In [3]:
import numpy as np
import pandas as pd
from tqdm.autonotebook import tqdm
import networkx as nx
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors
from sklearn.metrics import classification_report
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.ensemble import RandomForestClassifier

  from tqdm.autonotebook import tqdm


In [4]:
from espora import fragmenter

In [5]:
def get_mw(smi):
    mol = Chem.MolFromSmiles(smi)
    return Descriptors.MolWt(mol)

In [6]:
def run_fragrank_frag(smis, threhsold_mw=250):
    fragments = [fragmenter.frag_rec(smi) for smi in tqdm(smis, desc="Fragmenting")]
    fragments = [list(filter(lambda x: get_mw(x) < threhsold_mw, fragment)) for fragment in fragments]
    
    unique_subs = set()
    for subs in chain.from_iterable(fragments):
        unique_subs.add(subs)

    G = nx.Graph()
    for sub in unique_subs:
        G.add_node(sub)

    for i in range(len(smis)):
        for j in range(len(fragments[i])):
            for k in range(j+1, len(fragments[i])):
                sub1 = fragments[i][j]
                sub2 = fragments[i][k]
                if G.has_edge(sub1, sub2):
                    G[sub1][sub2]['weight'] += 1
                else:
                    G.add_edge(sub1, sub2, weight=1)

    pr = nx.pagerank(G)
    return pr

In [7]:
def featurize(smis, top_fragments_mols):
    X_vec = []
    for smi in tqdm(smis, desc="Featurizing"):
        mol = Chem.MolFromSmiles(smi)
        X_vec.append(np.array([mol.HasSubstructMatch(i) for i in top_fragments_mols]).astype(int))
    return np.array(X_vec) 

In [8]:
def model(clf, X_train, X_test, y_train, y_test, top_fragments_mols):
    X_train = featurize(X_train, top_fragments_mols)
    X_test = featurize(X_test, top_fragments_mols)
    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)
    print(classification_report(y_test, y_pred))

In [17]:
df = pd.read_csv("./data/Q16602.csv")

In [18]:
X = df.canonical_smiles.values
y = df["100 uM"].values

In [19]:
sss = StratifiedShuffleSplit(n_splits=3, test_size=0.33, random_state=42)

In [20]:
sss.get_n_splits(X, y)

3

In [21]:
for i, (train_index, test_index) in enumerate(sss.split(X, y)):
    print(f"Fold {i+1}/{sss.get_n_splits(X, y)}:")
    clf = RandomForestClassifier()
    pr = run_fragrank_frag(X[train_index])
    top_fragments = sorted(pr, key=pr.get, reverse=True)[:500]
    top_fragments_mols = [Chem.MolFromSmarts(i) for i in top_fragments]
    model(clf, X[train_index], X[test_index], y[train_index], y[test_index], top_fragments_mols)

Fold 1/3:


Fragmenting::   0%|          | 0/67 [00:00<?, ?it/s]

featurizing:   0%|          | 0/67 [00:00<?, ?it/s]

featurizing:   0%|          | 0/33 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         0.0       0.89      1.00      0.94        24
         1.0       1.00      0.67      0.80         9

    accuracy                           0.91        33
   macro avg       0.94      0.83      0.87        33
weighted avg       0.92      0.91      0.90        33

Fold 2/3:


Fragmenting::   0%|          | 0/67 [00:00<?, ?it/s]

featurizing:   0%|          | 0/67 [00:00<?, ?it/s]

featurizing:   0%|          | 0/33 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         0.0       0.96      1.00      0.98        24
         1.0       1.00      0.89      0.94         9

    accuracy                           0.97        33
   macro avg       0.98      0.94      0.96        33
weighted avg       0.97      0.97      0.97        33

Fold 3/3:


Fragmenting::   0%|          | 0/67 [00:00<?, ?it/s]

featurizing:   0%|          | 0/67 [00:00<?, ?it/s]

featurizing:   0%|          | 0/33 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         0.0       0.89      1.00      0.94        24
         1.0       1.00      0.67      0.80         9

    accuracy                           0.91        33
   macro avg       0.94      0.83      0.87        33
weighted avg       0.92      0.91      0.90        33

