In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# Imports
import os
import json
import csv
import pandas as pd
from utils.download import download_3d_similar_molecules
from utils.chem import compute_3d_similarity, extract_features

In [6]:
# Paths
notebook = os.path.join(".")
temp = os.path.join(notebook, ".temp") # use to download temporary files (temporary downloads).
if not os.path.exists(temp):
    os.makedirs(temp)

molecule_name = "remdesivir"
temp = os.path.join(temp, molecule_name)

if not os.path.exists(temp):
    os.makedirs(temp)

reinvent_prior_path = os.path.join(notebook, '..', 'models', 'reinvent.prior')

In [7]:
# Download Smiles
input_smiles = "CC1(OC2C(OC(C2O1)(C#N)C3=CC=C4N3N=CN=C4N)CO)C"
similar_str_smiles_path = os.path.join(temp, "similar.json")

download_3d_similar_molecules(input_smiles, similar_str_smiles_path)

True

In [8]:
# Reading the downloaded 3d similar structures.

similar_str_smiles = {}
with open(similar_str_smiles_path) as reader:
    similar_str_smiles = json.load(reader)

df = pd.DataFrame(data=similar_str_smiles)
df

Unnamed: 0,smiles,identifier,similarity
0,CC1(C)O[C@H]2[C@H](n3c(Br)nc4c3ncnc4N)O[C@H](C...,ZINC15 : ZINC000017381098,0.369863
1,CC1(C)O[C@H]2[C@@H](CO)O[C@@H](n3c(Br)nc4c3ncn...,ZINC15 : ZINC000095949869,0.369863
2,CC1(C)O[C@@H]2[C@H](CO)O[C@@H](n3cnc4c3ncnc4N)...,ZINC15 : ZINC000100807906,0.366197
3,CC1(C)O[C@H]2[C@H](n3cnc4c3ncnc4N)O[C@@H](CO)[...,ZINC15 : ZINC000012958516,0.366197
4,CC1(C)O[C@H]2[C@H](n3cnc4c3ncnc4N)O[C@H](CO)[C...,ZINC15 : ZINC000004347645,0.366197
...,...,...,...
395,CC1(C)O[C@@H]2[C@@H](O1)[C@@H](CO)O[C@@H]2n1cn...,ZINC15 : ZINC000008955192,0.285714
396,CC1(C)O[C@@H]2[C@@H](CO)O[C@@H](n3cnc4c3nc[nH]...,ZINC15 : ZINC000101133086,0.285714
397,CC1(C)O[C@@H]2[C@@H](CO)O[C@H](n3cnc4c3nc[nH]c...,ZINC15 : ZINC000004538849,0.285714
398,CC1(C)O[C@@H]2[C@@H](O1)[C@H](CO)O[C@@H]2n1cnc...,ZINC15 : ZINC000004538848,0.285714


In [9]:
# Training (80%), Validation(10%), Test(10%)
df = df.sample(frac=1)

# Define your split sizes
train_size = int(0.8 * len(df))

# Split your DataFrame
train_df = df[:train_size]
valid_df = df[train_size:]

train_set_file = os.path.join(temp, 'training.smi')
valid_set_file = os.path.join(temp, 'validation.smi')


train_df.to_csv(train_set_file, sep="\t", index=False, header=False)
valid_df.to_csv(valid_set_file, sep="\t", index=False, header=False)

In [10]:
# Transfer learning config. (Ref: https://github.com/MolecularAI/REINVENT4/blob/main/notebooks/Reinvent_TLRL.py)

config_filename = os.path.join(temp, 'config.json')
temp_models = os.path.join(temp, 'checkpoints')

if not os.path.exists(temp_models):
    os.mkdir(temp_models)

new_model_path = os.path.join(temp_models, 'temp.model')

reinvet_transfer_learning_parameter = {
    "run_type": "transfer_learning",
    "device": "cpu",
    "tb_logdir": os.path.join(temp, 'tb_TL'),
    "parameters": {
        "num_epochs": 100,
        "save_every_n_epochs": 2,
        "batch_size": 50,
        "sample_batch_size": 100,
        "input_model_file": reinvent_prior_path,
        "output_model_file": new_model_path,
        "smiles_file": train_set_file,
        "validation_smiles_file": valid_set_file,
        "standardize_smiles": True,
        "randomize_smiles": False,
        "randomize_all_smiles": False,
        "internal_diversity": True,
    },
}

with open(config_filename, "w") as writer:
    json.dump(reinvet_transfer_learning_parameter, writer, indent=2)

In [11]:
# Transfer Learning.

!reinvent $config_filename -f json

19:38:37 <INFO> Started REINVENT 4.4.22 (C) AstraZeneca 2017, 2023 on 2024-07-22
19:38:37 <INFO> Command line: /root/miniconda3/envs/reinvent-transfer-learning/bin/reinvent ./.temp/remdesivir/config.json -f json
19:38:37 <INFO> User root on host Ank
19:38:37 <INFO> Python version 3.11.9
19:38:37 <INFO> PyTorch version 2.3.1+cu121, git d44533f9d073df13895333e70b66f81c513c1889
19:38:37 <INFO> PyTorch compiled with CUDA version 12.1
19:38:37 <INFO> RDKit version 2023.09.5
19:38:37 <INFO> Platform Linux-5.15.153.1-microsoft-standard-WSL2-x86_64-with-glibc2.35
19:38:37 <INFO> Number of PyTorch CUDA devices 1
19:38:37 <INFO> Using CPU x86_64
19:38:37 <INFO> Writing TensorBoard summary to /mnt/d/projects/github/reinvent-transfer-learning/notebooks/.temp/remdesivir/tb_TL
19:38:37 <INFO> Starting Transfer Learning
19:38:37 <INFO> /mnt/d/projects/github/reinvent-transfer-learning/models/reinvent.prior has valid hash:
{ 'comments': [],
  'creation_date': 0,
  'date_format': 'UNIX epoch',
  'hash_

In [12]:
# Running new model
new_model_config_path = os.path.join(temp, '_config.json')
output_smiles = os.path.join(temp, 'output.csv')
config = {
    "run_type": "sampling",
    "device": "cpu",
    "parameters": {
        "model_file": new_model_path,
        "output_file": output_smiles,
        "num_smiles": 500,
        "unique_molecules": True,
        "randomize_smiles": True,
    }
}

with open(new_model_config_path, "w") as writer:
    json.dump(config, writer, indent=2)

!reinvent $new_model_config_path -f json

19:41:31 <INFO> Started REINVENT 4.4.22 (C) AstraZeneca 2017, 2023 on 2024-07-22
19:41:31 <INFO> Command line: /root/miniconda3/envs/reinvent-transfer-learning/bin/reinvent ./.temp/remdesivir/_config.json -f json
19:41:31 <INFO> User root on host Ank
19:41:31 <INFO> Python version 3.11.9
19:41:31 <INFO> PyTorch version 2.3.1+cu121, git d44533f9d073df13895333e70b66f81c513c1889
19:41:31 <INFO> PyTorch compiled with CUDA version 12.1
19:41:31 <INFO> RDKit version 2023.09.5
19:41:31 <INFO> Platform Linux-5.15.153.1-microsoft-standard-WSL2-x86_64-with-glibc2.35
19:41:31 <INFO> Number of PyTorch CUDA devices 1
19:41:31 <INFO> Using CPU x86_64
19:41:31 <INFO> Starting Sampling
19:41:31 <INFO> /mnt/d/projects/github/reinvent-transfer-learning/notebooks/.temp/remdesivir/checkpoints/temp.model has valid hash:
{ 'comments': ['TL'],
  'creation_date': 0,
  'date_format': 'UNIX epoch',
  'hash_id': 'c5b16ad84d26ec1cad7daa01a42e793b',
  'hash_id_format': 'xxhash.xxh3_128_hex 3.4.1',
  'model_id': '5

In [26]:
%%time
# Attempting to filter out non-similar molecule without using ML
# In this apprach we will try to calculate RMSD of all the molecules generated
# by our new model. If RMSD is heigher than 2 then we will ignore the molecules.

entries = []

with open(output_smiles) as reader:
    rows = csv.reader(reader)
    next(rows, None) # Skipping header
    for row in rows:
        entries.append(row)

scores = []

for entry in entries:
    score = compute_3d_similarity(input_smiles, entry[0])
    if score[0] is False or score[1] > 2:
        continue
    scores.append({"smile": entry[0], "score": score[1]})

without_ml = os.path.join(temp, 'without_ml.csv')
df = pd.DataFrame(data=scores)

df.to_csv(without_ml, index=False)
df

CPU times: user 7.03 s, sys: 0 ns, total: 7.03 s
Wall time: 7.04 s


Unnamed: 0,smile,score
0,CC1(C)CN(c2cc(N)nc3ccnn23)CC(CO)O1,1.506041
1,CC1(C)OC2C(CO)OC(n3cnc(C(N)=O)c3N)C2O1,1.623593
2,CC1(C)CC(CN)C2CC(C1)C2(C)C,1.797067
3,CC1(C)CN(c2cc(N)n3ncnc3n2)CC(CO)O1,1.406245
4,CC(=O)OCC1OC(n2c(Br)nc3c(N)ncnc32)C2OC(C)(C)OC12,1.626120
...,...,...
129,COCC(O)Cn1cc(CN2CC(C)OC(C)C2)nn1,1.853355
130,CCOC1OC(COc2ncnc3ccnn23)C2OC(C)(C)OC12,1.461097
131,CC1(C)C(C(=O)Nc2ccccc2)N2C(=O)C(=Cc3ccccn3)C2S...,1.836153
132,CC1(C)CC(CC(N)C(=O)O)OC(C)(C)O1,1.370143


# Traning a new classifier model

In [159]:
# Generating Random molecules
# We will use reinent to generate random molecules.

# reinvent_path = os.path.join(notebook, '..', 'models', 'reinvent.prior')
# random_config_path = os.path.join(temp, 'random_config.json')
# out = os.path.join(temp, 'random.csv')
# config = {
#     "run_type": "sampling",
#     "device": "cpu",
#     "parameters": {
#         "model_file": reinvent_path,
#         "output_file": out,
#         "num_smiles": 5000,
#         "unique_molecules": True,
#         "randomize_smiles": True,
#     }
# }

# with open(random_config_path, "w") as writer:
#     json.dump(config, writer, indent=2)


# !reinvent $random_config_path -f json

20:07:22 <INFO> Started REINVENT 4.4.22 (C) AstraZeneca 2017, 2023 on 2024-07-15
20:07:22 <INFO> Command line: /root/miniconda3/envs/reinvent-transfer-learning/bin/reinvent ./.temp/cephalotaxin/random_config.json -f json
20:07:22 <INFO> User root on host Ank
20:07:22 <INFO> Python version 3.11.9
20:07:22 <INFO> PyTorch version 2.3.1+cu121, git d44533f9d073df13895333e70b66f81c513c1889
20:07:22 <INFO> PyTorch compiled with CUDA version 12.1
20:07:22 <INFO> RDKit version 2023.09.5
20:07:22 <INFO> Platform Linux-5.15.153.1-microsoft-standard-WSL2-x86_64-with-glibc2.35
20:07:22 <INFO> Number of PyTorch CUDA devices 1
20:07:22 <INFO> Using CPU x86_64
20:07:22 <INFO> Starting Sampling
20:07:22 <INFO> /mnt/d/projects/github/reinvent-transfer-learning/models/reinvent.prior has valid hash:
{ 'comments': [],
  'creation_date': 0,
  'date_format': 'UNIX epoch',
  'hash_id': '173568c36e1fc3d95cab289c7d31ce0b',
  'hash_id_format': 'xxhash.xxh3_128_hex 3.4.1',
  'model_id': '55d68f8a81c04f5a86304ebe1

In [27]:
%%time
negative_entries = []
# Make sure to download and keep the reference library in .temp folder.
# Download it from here: https://github.com/ersilia-os/groverfeat/blob/main/data/reference_library.csv
reference_smiles_path = os.path.join(temp, 'reference_library.csv')

with open(reference_smiles_path) as reader:
    rows = csv.reader(reader)
    for row in rows:
        negative_entries.append(row)

filtered = []


for idx, entry in enumerate(negative_entries):
    print(f"Current Idx: {idx}. Total len: {len(filtered)}")
    if len(filtered) > 400:
        break
    score = compute_3d_similarity(input_smiles, entry[0])
    if score[0] is False or score[1] < 2.5:
        continue
    filtered.append({"smile": entry[0], "score": score[1]})

negative_output = os.path.join(temp, 'negative.csv')
df = pd.DataFrame(data=filtered)
df.to_csv(negative_output, index=False)
df

Current Idx: 0. Total len: 0
Current Idx: 1. Total len: 0
Current Idx: 2. Total len: 0
Current Idx: 3. Total len: 1
Current Idx: 4. Total len: 2
Current Idx: 5. Total len: 2
Current Idx: 6. Total len: 3
Current Idx: 7. Total len: 3
Current Idx: 8. Total len: 4
Current Idx: 9. Total len: 4
Current Idx: 10. Total len: 5
Current Idx: 11. Total len: 5
Current Idx: 12. Total len: 5
Current Idx: 13. Total len: 5
Current Idx: 14. Total len: 5
Current Idx: 15. Total len: 5
Current Idx: 16. Total len: 6
Current Idx: 17. Total len: 6
Current Idx: 18. Total len: 7
Current Idx: 19. Total len: 7
Current Idx: 20. Total len: 7
Current Idx: 21. Total len: 7
Current Idx: 22. Total len: 7
Current Idx: 23. Total len: 7
Current Idx: 24. Total len: 8
Current Idx: 25. Total len: 8
Current Idx: 26. Total len: 8
Current Idx: 27. Total len: 8
Current Idx: 28. Total len: 8
Current Idx: 29. Total len: 8
Current Idx: 30. Total len: 8
Current Idx: 31. Total len: 9
Current Idx: 32. Total len: 9
Current Idx: 33. Tot

[19:57:50] UFFTYPER: Unrecognized charge state for atom: 25


Current Idx: 273. Total len: 84
Current Idx: 274. Total len: 84
Current Idx: 275. Total len: 84
Current Idx: 276. Total len: 84
Current Idx: 277. Total len: 84
Current Idx: 278. Total len: 84
Current Idx: 279. Total len: 84
Current Idx: 280. Total len: 84
Current Idx: 281. Total len: 85
Current Idx: 282. Total len: 85
Current Idx: 283. Total len: 86
Current Idx: 284. Total len: 87
Current Idx: 285. Total len: 87
Current Idx: 286. Total len: 87
Current Idx: 287. Total len: 87
Current Idx: 288. Total len: 88
Current Idx: 289. Total len: 88
Current Idx: 290. Total len: 88
Current Idx: 291. Total len: 88
Current Idx: 292. Total len: 88
Current Idx: 293. Total len: 89
Current Idx: 294. Total len: 90
Current Idx: 295. Total len: 90
Current Idx: 296. Total len: 90
Current Idx: 297. Total len: 90
Current Idx: 298. Total len: 91
Current Idx: 299. Total len: 91
Current Idx: 300. Total len: 91
Current Idx: 301. Total len: 91
Current Idx: 302. Total len: 91
Current Idx: 303. Total len: 91
Current 

[19:58:10] UFFTYPER: Unrecognized charge state for atom: 8


Current Idx: 587. Total len: 164
Current Idx: 588. Total len: 164
Current Idx: 589. Total len: 164
Current Idx: 590. Total len: 164
Current Idx: 591. Total len: 164
Current Idx: 592. Total len: 164
Current Idx: 593. Total len: 165
Current Idx: 594. Total len: 165
Current Idx: 595. Total len: 165
Current Idx: 596. Total len: 165
Current Idx: 597. Total len: 165
Current Idx: 598. Total len: 166
Current Idx: 599. Total len: 167
Current Idx: 600. Total len: 167
Current Idx: 601. Total len: 167
Current Idx: 602. Total len: 167
Current Idx: 603. Total len: 168
Current Idx: 604. Total len: 168
Current Idx: 605. Total len: 168
Current Idx: 606. Total len: 168
Current Idx: 607. Total len: 168
Current Idx: 608. Total len: 168
Current Idx: 609. Total len: 168
Current Idx: 610. Total len: 168


[19:58:12] UFFTYPER: Unrecognized charge state for atom: 5


Current Idx: 611. Total len: 169
Current Idx: 612. Total len: 170


[19:58:12] UFFTYPER: Unrecognized charge state for atom: 20


Current Idx: 613. Total len: 171
Current Idx: 614. Total len: 172
Current Idx: 615. Total len: 172
Current Idx: 616. Total len: 172
Current Idx: 617. Total len: 173
Current Idx: 618. Total len: 173
Current Idx: 619. Total len: 173
Current Idx: 620. Total len: 174
Current Idx: 621. Total len: 174
Current Idx: 622. Total len: 174
Current Idx: 623. Total len: 174
Current Idx: 624. Total len: 174
Current Idx: 625. Total len: 174
Current Idx: 626. Total len: 174
Current Idx: 627. Total len: 174
Current Idx: 628. Total len: 175
Current Idx: 629. Total len: 175
Current Idx: 630. Total len: 175
Current Idx: 631. Total len: 176
Current Idx: 632. Total len: 177
Current Idx: 633. Total len: 178
Current Idx: 634. Total len: 178
Current Idx: 635. Total len: 178
Current Idx: 636. Total len: 179
Current Idx: 637. Total len: 179
Current Idx: 638. Total len: 179
Current Idx: 639. Total len: 179
Current Idx: 640. Total len: 179
Current Idx: 641. Total len: 180
Current Idx: 642. Total len: 180
Current Id

[19:58:18] UFFTYPER: Unrecognized charge state for atom: 1


Current Idx: 701. Total len: 192
Current Idx: 702. Total len: 192
Current Idx: 703. Total len: 192
Current Idx: 704. Total len: 193
Current Idx: 705. Total len: 194
Current Idx: 706. Total len: 194
Current Idx: 707. Total len: 194
Current Idx: 708. Total len: 195
Current Idx: 709. Total len: 195
Current Idx: 710. Total len: 195
Current Idx: 711. Total len: 195
Current Idx: 712. Total len: 196
Current Idx: 713. Total len: 196
Current Idx: 714. Total len: 196
Current Idx: 715. Total len: 196
Current Idx: 716. Total len: 197
Current Idx: 717. Total len: 198
Current Idx: 718. Total len: 198
Current Idx: 719. Total len: 198
Current Idx: 720. Total len: 198
Current Idx: 721. Total len: 199
Current Idx: 722. Total len: 199
Current Idx: 723. Total len: 199
Current Idx: 724. Total len: 199
Current Idx: 725. Total len: 199
Current Idx: 726. Total len: 199
Current Idx: 727. Total len: 200
Current Idx: 728. Total len: 200
Current Idx: 729. Total len: 200
Current Idx: 730. Total len: 200
Current Id

[19:58:37] UFFTYPER: Unrecognized atom type: Se2+2 (5)
[19:58:37] UFFTYPER: Unrecognized atom type: Se2+2 (31)


Current Idx: 975. Total len: 265
Current Idx: 976. Total len: 265
Current Idx: 977. Total len: 266
Current Idx: 978. Total len: 267
Current Idx: 979. Total len: 267
Current Idx: 980. Total len: 267
Current Idx: 981. Total len: 267
Current Idx: 982. Total len: 267
Current Idx: 983. Total len: 267
Current Idx: 984. Total len: 267
Current Idx: 985. Total len: 268
Current Idx: 986. Total len: 269
Current Idx: 987. Total len: 269
Current Idx: 988. Total len: 269
Current Idx: 989. Total len: 270
Current Idx: 990. Total len: 270
Current Idx: 991. Total len: 270
Current Idx: 992. Total len: 270
Current Idx: 993. Total len: 270
Current Idx: 994. Total len: 270
Current Idx: 995. Total len: 270
Current Idx: 996. Total len: 270
Current Idx: 997. Total len: 270
Current Idx: 998. Total len: 270
Current Idx: 999. Total len: 270
Current Idx: 1000. Total len: 270
Current Idx: 1001. Total len: 271
Current Idx: 1002. Total len: 271
Current Idx: 1003. Total len: 272
Current Idx: 1004. Total len: 272
Curre

[19:59:06] UFFTYPER: Unrecognized charge state for atom: 12


Current Idx: 1200. Total len: 344
Current Idx: 1201. Total len: 344
Current Idx: 1202. Total len: 345
Current Idx: 1203. Total len: 345
Current Idx: 1204. Total len: 345
Current Idx: 1205. Total len: 345
Current Idx: 1206. Total len: 346
Current Idx: 1207. Total len: 346
Current Idx: 1208. Total len: 346
Current Idx: 1209. Total len: 347
Current Idx: 1210. Total len: 347
Current Idx: 1211. Total len: 348
Current Idx: 1212. Total len: 348
Current Idx: 1213. Total len: 348
Current Idx: 1214. Total len: 348
Current Idx: 1215. Total len: 349
Current Idx: 1216. Total len: 349
Current Idx: 1217. Total len: 349
Current Idx: 1218. Total len: 349
Current Idx: 1219. Total len: 350
Current Idx: 1220. Total len: 350
Current Idx: 1221. Total len: 350
Current Idx: 1222. Total len: 350
Current Idx: 1223. Total len: 350
Current Idx: 1224. Total len: 350
Current Idx: 1225. Total len: 350
Current Idx: 1226. Total len: 350
Current Idx: 1227. Total len: 351
Current Idx: 1228. Total len: 351
Current Idx: 1

[19:59:11] UFFTYPER: Unrecognized atom type: Se2+2 (26)


Current Idx: 1271. Total len: 360
Current Idx: 1272. Total len: 361
Current Idx: 1273. Total len: 362
Current Idx: 1274. Total len: 362
Current Idx: 1275. Total len: 362
Current Idx: 1276. Total len: 363
Current Idx: 1277. Total len: 364
Current Idx: 1278. Total len: 364
Current Idx: 1279. Total len: 364
Current Idx: 1280. Total len: 364
Current Idx: 1281. Total len: 364
Current Idx: 1282. Total len: 364
Current Idx: 1283. Total len: 364
Current Idx: 1284. Total len: 364
Current Idx: 1285. Total len: 365
Current Idx: 1286. Total len: 365
Current Idx: 1287. Total len: 365
Current Idx: 1288. Total len: 365
Current Idx: 1289. Total len: 366
Current Idx: 1290. Total len: 366
Current Idx: 1291. Total len: 367
Current Idx: 1292. Total len: 367
Current Idx: 1293. Total len: 367
Current Idx: 1294. Total len: 367
Current Idx: 1295. Total len: 368
Current Idx: 1296. Total len: 368
Current Idx: 1297. Total len: 368
Current Idx: 1298. Total len: 368
Current Idx: 1299. Total len: 368
Current Idx: 1

Unnamed: 0,smile,score
0,CC1=C(S(=O)(=O)N2CCCCC2)C2=C(S1)N=CN(CC(=O)N1C...,2.709123
1,CN(C)CCOC1=CC=C(C(=O)/C=C/C2=CC=C(OC3=CC=CC=C3...,3.172030
2,CC1=CC=C(COC2=NN(CN3CCOCC3)C(=S)N2/N=C/C2=CNN=...,2.610649
3,CC(C)(C)C(=O)NC1=NC=NC2=C1C=NN2CCCCN1CCCCCC1,2.562486
4,CC(=O)[C@H]1CC[C@H]2[C@@H]3CCC4=C[C@@H](OC(=O)...,2.509299
...,...,...
396,COC1=CC=CC(N2C=C(NC(=O)CN3C=CC(C4=CC=CC=N4)=N3...,2.899857
397,O=C(O)CC1CC(CNC(=O)CCCCNC2=CC=CC=N2)=CCC2=CC=C...,2.798341
398,NC1=CC(C(=O)N[C@@H](CC2=CNC3=CC=CC=C23)C(=O)O)...,3.522521
399,COC1=CC=C(NC(=O)C2=CC3=CC=C4C5=CC=CC=C5NC4=C3C...,2.623829


In [31]:
# Prepare data for classifier model
negatives = []
positives = []

# Positives are all the smiles downloaded using cheese api
with open(similar_str_smiles_path) as reader:
    data = json.load(reader)
    
    for entry in data:
        positives.append(entry['smiles'])

# Negatives are all the smiles generated using reinvent
with open(negative_output) as reader:
    data = csv.reader(reader)
    next(data, None) # skipping header
    for row in data:
        negatives.append(row[0])

labels = [1] * len(positives)
labels += [0] * len(negatives)

total = positives + negatives
features = extract_features(total)

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 1., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 1., ..., 0., 0., 0.]])

In [29]:
# Training
from flaml import AutoML
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

X = pd.DataFrame(features)
y = pd.Series(labels)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

automl = AutoML()
automl_settings = {
    "time_budget": 90,  # time budget in seconds
    "metric": 'roc_auc',  # metric to optimize
    "task": 'classification',
    "estimator_list": ['rf'] # random forest
}

automl.fit(X_train, y_train, **automl_settings)

# Checking accuracy
y_pred = automl.predict(X_test)
print(f"Accuracy: {accuracy_score(y_test, y_pred)}")

[flaml.automl.logger: 07-22 20:00:58] {1680} INFO - task = classification
[flaml.automl.logger: 07-22 20:00:58] {1691} INFO - Evaluation method: holdout
[flaml.automl.logger: 07-22 20:00:59] {1789} INFO - Minimizing error metric: 1-roc_auc
[flaml.automl.logger: 07-22 20:00:59] {1901} INFO - List of ML learners in AutoML Run: ['rf']
[flaml.automl.logger: 07-22 20:00:59] {2219} INFO - iteration 0, current learner rf
[flaml.automl.logger: 07-22 20:00:59] {2345} INFO - Estimated sufficient time budget=1490s. Estimated necessary time budget=1s.
[flaml.automl.logger: 07-22 20:00:59] {2392} INFO -  at 2.6s,	estimator rf's best error=0.0046,	best estimator rf's best error=0.0046
[flaml.automl.logger: 07-22 20:00:59] {2219} INFO - iteration 1, current learner rf
[flaml.automl.logger: 07-22 20:00:59] {2392} INFO -  at 2.7s,	estimator rf's best error=0.0000,	best estimator rf's best error=0.0000
[flaml.automl.logger: 07-22 20:00:59] {2219} INFO - iteration 2, current learner rf
[flaml.automl.logg

In [33]:
# Testing generated output (the output generated by the new model.)

gen_out = []

with open(output_smiles) as reader:
    data = csv.reader(reader)
    next(data, None) # skipping header
    for row in data:
        gen_out.append(row[0])

gen_out_feature = extract_features(gen_out)
gen_out_feature = pd.DataFrame(gen_out_feature)

prediction = automl.predict_proba(gen_out_feature) # predict_proba
prediction

array([[0.03617992, 0.96382008],
       [0.4962926 , 0.5037074 ],
       [0.0465105 , 0.9534895 ],
       [0.46301664, 0.53698336],
       [0.45386519, 0.54613481],
       [0.59349425, 0.40650575],
       [0.03617992, 0.96382008],
       [0.18613957, 0.81386043],
       [0.45386519, 0.54613481],
       [0.68142389, 0.31857611],
       [0.45386519, 0.54613481],
       [0.25627592, 0.74372408],
       [0.56547009, 0.43452991],
       [0.87376158, 0.12623842],
       [0.24774555, 0.75225445],
       [0.0465105 , 0.9534895 ],
       [0.04317293, 0.95682707],
       [0.49928104, 0.50071896],
       [0.4962926 , 0.5037074 ],
       [0.24774555, 0.75225445],
       [0.03259569, 0.96740431],
       [0.182802  , 0.817198  ],
       [0.66328077, 0.33671923],
       [0.04317293, 0.95682707],
       [0.66328077, 0.33671923],
       [0.38376322, 0.61623678],
       [0.47200831, 0.52799169],
       [0.59322043, 0.40677957],
       [0.04317293, 0.95682707],
       [0.24904591, 0.75095409],
       [0.

In [34]:
ml_prediction_output = os.path.join(temp, 'ml_output.csv')

output = []
for idx, mol in enumerate(gen_out):
    if prediction[idx][1] > 0.8:
        continue
    output.append({"smile": gen_out[idx], "score": 0 })


with_ml_df = pd.DataFrame(data=output)
with_ml_df.to_csv(ml_prediction_output, index=False, header=False)
with_ml_df

Unnamed: 0,smile,score
0,C=CC1OC(Oc2cnc3ccccc3c2C#N)C2OC(C)(C)OC12,0
1,C=CC1OC(Oc2nccc3ncnn23)C2OC(C)(C)OC12,0
2,CC1(C)CC(CN)C2CC(C1)C2(C)C,0
3,COCC12CCN(Cc3ccoc3)CC1CN(Cc1ccncc1)C2,0
4,CC1(C)CC(CC(N)=O)N(c2cncc3ncnn23)C1,0
...,...,...
115,C=CC1OC(Oc2ncnc3c2ncn3C2OC(CO)C(O)C2O)C(O)C1O,0
116,CCOC1OC(COc2ncnc3ccnn23)C2OC(C)(C)OC12,0
117,CC1(C)C(C(=O)Nc2ccccc2)N2C(=O)C(=Cc3ccccn3)C2S...,0
118,CC1(C)CC(CC(N)C(=O)O)OC(C)(C)O1,0


In [35]:
# Overlapped
without_ml_df = pd.read_csv(without_ml)

with_ml_df['type'] = 'with_ml'
without_ml_df['type'] = 'without_ml'

merged = pd.concat([with_ml_df, without_ml_df])
overlapped = merged.duplicated(subset=['smile'])
                               
merged[overlapped]

Unnamed: 0,smile,score,type
2,CC1(C)CC(CN)C2CC(C1)C2(C)C,1.797067,without_ml
5,CC1(C)CC(CC(N)=O)N(c2cncc3ncnn23)C1,1.745158,without_ml
6,CC1(C)OC2C(=O)OC(COc3ccnc4c(F)cccc34)C2O1,1.767933,without_ml
7,CC1(C)OC(=O)C(CO)O1,1.832699,without_ml
8,Nc1ncnc2c1c1cncnc1n2C1OC(CO)C(O)C1O,1.577912,without_ml
...,...,...,...
129,COCC(O)Cn1cc(CN2CC(C)OC(C)C2)nn1,1.853355,without_ml
130,CCOC1OC(COc2ncnc3ccnn23)C2OC(C)(C)OC12,1.461097,without_ml
131,CC1(C)C(C(=O)Nc2ccccc2)N2C(=O)C(=Cc3ccccn3)C2S...,1.836153,without_ml
132,CC1(C)CC(CC(N)C(=O)O)OC(C)(C)O1,1.370143,without_ml
