-
Notifications
You must be signed in to change notification settings - Fork 2
/
generate_molecule_dataset_from_csv.py
executable file
·789 lines (693 loc) · 36.6 KB
/
generate_molecule_dataset_from_csv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
"""
Generates a molecule visual graph dataset from a CSV source dataset.
Given a molecular prediction dataset in the format of a CSV file containing the SMILES representations of the
molecules and the target value annotations, this experiment will convert it into a visual graph dataset folder.
**LOGGING**
The experiment will output log messages about the current progress in periodic intervals. The lenght of
these intervals can be set using an experiment parameter. In these log messages, some basic profiling information
will be shown such as the completed and remaining number of elements. The messages also contain estimates of the
average creation time and derived from that estimates for the remaining time and estimated time of completion.
**CHANGELOG**
0.1.0 - 10.12.2022 - Initial version
0.2.0 - 22.01.2023 - (1) Added the optional feature to extract the 3D atomic coordinates for each atom of
each molecule as well and save that as additional property of the resulting graph representation.
(2) Added the possibility to extract known train test-splits from the CSV file as well.
(3) It is not possible to use a CSV file from the local system as well instead of downloading from the
remote file share provider
0.3.0 - 24.02.2023 - (1) Added additional convenience where the printed logs are more detailed and an
additional artifact is created which contains all the omitted elements and the reason why they were omitted
(2) Added support for classification datasets via an additional flag parameter
0.4.0 - 21.03.2023 - big preprocessing update: The processing of the smiles strings into valid VGD elements
is no longer hard coded in this experiment, but has been moved as generic functionality of the
MoleculeProcessing class which is now used here.
Most importantly: This class also enables the dynamic generation of standalone python module which will be
shipped together with the dataset and which allows the streamlined conversion of any SMILES string into
the exact format of the VGD.
0.5.0 - 05.05.2023 - (1) Added tracking for the average write speed for each chunk of data to keep track
of an issue.
(2) Switched to the pycomex.functional.api
(3) There are some backwards incompatible changes in the processing api in the new version which have
been implemented here.
(4) optimized the process of saving the metadata which got rid of 2 unnecessary file operations.
0.6.0 - 05.06.2023 - (1) Added more default atom features from derived properties that are calculated by
RDKit such as the partial Gasteiger charges of the atoms or the crippen logP contributions
(2) Added some more runtime profiling to better keep track of the performance problem which causes this
script to become very slow after a certain number of elements have been processed.
(3) Some more performance optimizations in this script and under the hood for the VGD element creation
process which will hopefully solve the performance problem.
(4) Using a VisualGraphDatasetWriter instance now to actually write the data to the disk which enables
the usage of dataset chunking.
0.7.0 - 19.10.23 - (1) Fixed the memory accumulation problem that was caused by storing the Mol objects for
each of the smiles.
(2) Added some more documentation of the experiment parameters
"""
import os
import gc
import time
import csv
import yaml
import datetime
import multiprocessing
import typing as t
from collections import defaultdict
import rdkit.Chem.Descriptors
from rdkit.rdBase import BlockLogs
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from pycomex.functional.experiment import Experiment
from pycomex.utils import folder_path, file_namespace
from rdkit import Chem
import visual_graph_datasets.typing as tc
from visual_graph_datasets.config import Config
from visual_graph_datasets.web import AbstractFileShare, get_file_share
from visual_graph_datasets.processing.base import identity, list_identity
from visual_graph_datasets.processing.base import create_processing_module
from visual_graph_datasets.processing.base import ProcessingError
from visual_graph_datasets.processing.molecules import chem_prop, chem_descriptor
from visual_graph_datasets.processing.molecules import apply_atom_callbacks, apply_bond_callbacks
from visual_graph_datasets.processing.molecules import mol_from_smiles
from visual_graph_datasets.processing.molecules import OneHotEncoder
from visual_graph_datasets.processing.molecules import MoleculeProcessing
from visual_graph_datasets.processing.molecules import crippen_contrib, lasa_contrib, tpsa_contrib
from visual_graph_datasets.processing.molecules import gasteiger_charges, estate_indices
from visual_graph_datasets.data import load_visual_graph_dataset
from visual_graph_datasets.data import generate_visual_graph_dataset_metadata
from visual_graph_datasets.data import VisualGraphDatasetWriter
mpl.use('Agg')
plt.ioff()
# This will block the annoying warning generated by RDKit
block = BlockLogs()
# == SOURCE PARAMETERS ==
# These parameters determine how to handle the source CSV file of the dataset. There exists the possibility
# to define a file from the local system or to download a file from the VGD remote file share location.
# In this section one also has to determine, for example, the type of the source dataset (regression,
# classification) and provide the names of the relevant columns in the CSV file.
# :param FILE_SHARE_PROVIDER:
# The vgd file share provider from which to download the CSV file to be used as the source for the VGD
# conversion.
FILE_SHARE_PROVIDER: str = 'main'
# :param CSV_FILE_NAME:
# The name of the CSV file to be used as the source for the dataset conversion.
# This may be one of the following two things:
# 1. A valid absolute file path on the local system pointing to a CSV file to be used as the source for
# the VGD conversion
# 2. A valid relative path to a CSV file stashed on the given vgd file share provider which will be
# downloaded first and then processed.
CSV_FILE_NAME: str = 'source/benzene_solubility.csv'
# :param INDEX_COLUMN_NAME:
# (Optional) this may define the string name of the CSV column which contains the integer index
# associated with each dataset element. If this is not given, then integer indices will be randomly
# generated for each element in the final VGD
INDEX_COLUMN_NAME: t.Optional[str] = None
# :param INDICES_BLACKLIST_PATH:
# Optionally it is possible to define the path to a file which defines the blacklisted indices for the
# dataset. This file should contain a list of integers, where each integer represents the index of an
# element which should be excluded from the final dataset. The file should be a normal TXT file where each
# integer is on a new line.
# The indices listed in that file will be immediately skipped during processing without even loading the
# the molecule.
INDICES_BLACKLIST_PATH: t.Optional[str] = None
# :param SMILES_COLUMN_NAME:
# This has to be the string name of the CSV column which contains the SMILES string representation of
# the molecule.
SMILES_COLUMN_NAME: str = 'SMILES'
# :param TARGET_TYPE:
# This has to be the string name of the type of dataset that the source file represents. The valid
# options here are "regression" and "classification"
TARGET_TYPE: str = 'regression' # 'classification'
# :param TARGET_COLUMN_NAMES:
# This has to be a list of string column names within the source CSV file, where each name defines
# one column that contains a target value for each row. In the regression case, this may be multiple
# different regression targets for each element and in the classification case there has to be one
# column per class.
TARGET_COLUMN_NAMES: t.List[str] = ['LogS']
# :param SPLIT_COLUMN_NAMES:
# The keys of this dictionary are integers which represent the indices of various train test splits. The
# values are the string names of the columns which define those corresponding splits. It is expected that
# these CSV columns contain a "1" if that corresponding element is considered as part of the training set
# of that split and "0" if it is part of the test set.
# This dictionary may be empty and then no information about splits will be added to the dataset at all.
SPLIT_COLUMN_NAMES: t.Dict[int, str] = {
}
# :param SUBSET:
# Optional. This can be used to set a number of elements after which to terminate the processing procedure.
# If this is None, the whole dataset will be processed. This feature can be useful if only a certain
# part of the datase should be processed or for testing reasons for example.
SUBSET: t.Optional[int] = None
# == PROCESSING PARAMETERS ==
# These parameters control the processing of the raw SMILES into the molecule representations with RDKit
# and then finally the conversion into the graph dict representation.
class VgdMoleculeProcessing(MoleculeProcessing):
node_attribute_map = {
'symbol': {
'callback': chem_prop('GetSymbol', OneHotEncoder(
['H', 'C', 'N', 'O', 'B', 'F', 'Si', 'P', 'S', 'Cl', 'As', 'Se', 'Br', 'Te', 'I', 'At'],
add_unknown=True,
dtype=str
)),
'description': 'one-hot encoding of atom type',
'is_type': True,
'encodes_symbol': True,
},
'hybridization': {
'callback': chem_prop('GetHybridization', OneHotEncoder(
[2, 3, 4, 5, 6],
add_unknown=True,
dtype=int,
)),
'description': 'one-hot encoding of atom hybridization',
},
'total_degree': {
'callback': chem_prop('GetTotalDegree', OneHotEncoder(
[0, 1, 2, 3, 4, 5],
add_unknown=False,
dtype=int
)),
'description': 'one-hot encoding of the degree of the atom'
},
'num_hydrogen_atoms': {
'callback': chem_prop('GetTotalNumHs', OneHotEncoder(
[0, 1, 2, 3, 4],
add_unknown=False,
dtype=int
)),
'description': 'one-hot encoding of the total number of attached hydrogen atoms'
},
'mass': {
'callback': chem_prop('GetMass', list_identity),
'description': 'The mass of the atom'
},
'charge': {
'callback': chem_prop('GetFormalCharge', list_identity),
'description': 'The charge of the atom',
},
'is_aromatic': {
'callback': chem_prop('GetIsAromatic', list_identity),
'description': 'Boolean flag of whether the atom is aromatic',
},
'is_in_ring': {
'callback': chem_prop('IsInRing', list_identity),
'description': 'Boolean flag of whether atom is part of a ring'
},
'crippen_contributions': {
'callback': crippen_contrib(),
'description': 'The crippen logP contributions of the atom as computed by RDKit'
},
'tpsa_contribution': {
'callback': tpsa_contrib(),
'description': 'Contribution to TPSA as computed by RDKit',
},
'lasa_contribution': {
'callback': lasa_contrib(),
'description': 'Contribution to ASA as computed by RDKit'
},
'gasteiger_charge': {
'callback': gasteiger_charges(),
'description': 'The partial gasteiger charge attributed to atom as computed by RDKit'
},
'estate_indices': {
'callback': estate_indices(),
'description': 'EState index as computed by RDKit'
}
}
edge_attribute_map = {
'bond_type': {
'callback': chem_prop('GetBondType', OneHotEncoder(
[1, 2, 3, 12],
add_unknown=False,
dtype=int,
)),
'description': 'one-hot encoding of the bond type',
'is_type': True,
'encodes_bond': True,
},
'stereo': {
'callback': chem_prop('GetStereo', OneHotEncoder(
[0, 1, 2, 3],
add_unknown=False,
dtype=int,
)),
'description': 'one-hot encoding of the stereo property'
},
'is_aromatic': {
'callback': chem_prop('GetIsAromatic', list_identity),
'description': 'boolean flag of whether bond is aromatic',
},
'is_in_ring': {
'callback': chem_prop('IsInRing', list_identity),
'description': 'boolean flag of whether bond is part of ring',
},
'is_conjugated': {
'callback': chem_prop('GetIsConjugated', list_identity),
'description': 'boolean flag of whether bond is conjugated'
}
}
graph_attribute_map = {
'molecular_weight': {
'callback': chem_descriptor(Chem.Descriptors.ExactMolWt, list_identity),
'description': 'the exact molecular weight of the molecule',
},
'num_radical_electrons': {
'callback': chem_descriptor(Chem.Descriptors.NumRadicalElectrons, list_identity),
'description': 'the total number of radical electrons in the molecule',
},
'num_valence_electrons': {
'callback': chem_descriptor(Chem.Descriptors.NumValenceElectrons, list_identity),
'description': 'the total number of valence electrons in the molecule'
}
}
# :param PROCESSING:
# A MoleculeProcessing instance which will be used to convert the molecule smiles representations
# into strings.
PROCESSING = VgdMoleculeProcessing()
# :param UNDIRECTED_EDGES_AS_TWO:
# If this flag is True, the undirected edges which make up molecular graph will be converted into two
# opposing directed edges. Depends on the downstream ML framework to be used.
UNDIRECTED_EDGES_AS_TWO: bool = True
# :param USE_NODE_COORDINATES:
# If this flag is True, the coordinates of each atom will be calculated for each molecule and the resulting
# 3D coordinate vector will be added as a separate property of the resulting graph dict.
USE_NODE_COORDINATES: bool = True
# :param GRAPH_METADATA_CALLBACKS:
# This is a dictionary that can be use to define additional information that should be extracted from the
# the csv file and to be transferred to the metadata dictionary of the visual graph dataset elements.
# The keys of this dict should be the string names that the properties will then have in the final metadata
# dictionary. The values should be callback functions with two parameters: "mol" is the rdkit molecule object
# representation of each dataset element and "data" is the corresponding dictionary containing all the
# values from the csv file indexed by the names of the columns. The function itself should return the actual
# data to be used for the corresponding custom property.
GRAPH_METADATA_CALLBACKS = {
'name': lambda mol, data: data['smiles'],
'smiles': lambda mol, data: data['smiles'],
}
# Filters: This list is supposed to contain callback functions which take the molecule object and the data
# dictionary belonging to that function and if even one of those callbacks returns true, the corresponding
# element will be skipped and NOT included in the final VGD. In that sense these callbacks define the
# rules by which to filter the original dataset.
def no_bonds(mol, data):
# All molecules which don't have any bonds do not qualify as graphs and will be skipped
return len(mol.GetBonds()) == 0 or len(mol.GetAtoms()) == 1
# :param FILTER_CALLBACKS:
# This should be a list consisting of callable callback methods. Each of these callables should accept
# two parameters: "mol" which is the rdkit molecule representation of a dataset element and "data" which
# is a dict containing all the columns of the corresponding element from the csv source file.
# One such callable should return True if a condition is met that leads to the exclusion of the given
# element and return False otherwise.
FILTER_CALLBACKS: t.List[t.Callable] = [
no_bonds
]
# == DATASET PARAMETERS ==
# These parameters control aspects of the visual graph dataset creation process. This for example includes
# the dimensions of the graph visualization images to be created or the name of the visual graph dataset
# that should be given to the dataset folder.
# :param DATASET_CHUNK_SIZE:
# This number will determine the chunking of the dataset. Dataset chunking will split the dataset
# elements into multiple sub folders within the main VGD folder. Especially for larger datasets
# this should increase the efficiency of subsequent IO operations.
# If this is None then no chunking will be applied at all and everything will be placed into the
# top level folder.
DATASET_CHUNK_SIZE: t.Optional[int] = 10_000
# :param DATASET_NAME:
# The name given to the visual graph dataset folder which will be created.
DATASET_NAME: str = 'benzene'
# :param IMAGE_WIDTH:
# The width molecule visualization PNG image
IMAGE_WIDTH: int = 1000
# :param IMAGE_HEIGHT:
# The height of the molecule visualization PNG image
IMAGE_HEIGHT: int = 1000
# :parm DATASET_META:
# This dict will be converted into the .meta.yml file which will be added to the final visual graph dataset
# folder. This is an optional file, which can add additional meta information about the entire dataset
# itself. Such as documentation in the form of a description of the dataset etc.
DATASET_META: t.Optional[dict] = {
'version': '0.1.0',
# A list of strings where each element is a description about the changes introduced in a newer
# version of the dataset.
'changelog': [
'0.1.0 - 29.01.2023 - initial version'
],
# A general description about the dataset, which gives a general overview about where the data was
# sampled from, what the input features look like, what the prediction target is etc...
'description': (
'Small dataset consisting of molecular graphs, where the target is the measured logS value of '
'the molecules solubility in Benzene.'
),
# A list of informative strings (best case containing URLS) which are used as references for the
# dataset. This could for example be a reference to a paper where the dataset was first introduced
# or a link to site where the raw data can be downloaded etc.
'references': [
'Library used for the processing and visualization of molecules. https://www.rdkit.org/',
],
# A small description about how to interpret the visualizations which were created by this dataset.
'visualization_description': (
'Molecular graphs generated by RDKit based on the SMILES representation of the molecule.'
),
# A dictionary, where the keys should be the integer indices of the target value vector for the dataset
# and the values should be string descriptions of what the corresponding target value is about.
'target_descriptions': {
0: 'measured logS values of the molecules solubility in Benzene. (unprocessed)'
}
}
# == EVALUATION PARAMETERS ==
# These parameters control the evaluation process which included the plotting of the dataset statistics
# after the dataset has been completed for example.
# :param EVAL_LOG_STEP:
# The number of iterations after which to print a log message
EVAL_LOG_STEP = 100
# :param NUM_BINS:
# The number of bins to use for the histogram
NUM_BINS = 10
# :param PLOT_COLOR:
# The color to be used for the plots
PLOT_COLOR = 'gray'
# == EXPERIMENT PARAMETERS ==
__DEBUG__ = True
experiment = Experiment(
base_path=folder_path(__file__),
namespace=file_namespace(__file__),
glob=globals()
)
@experiment.hook('load_blacklist')
def load_blacklist(e: Experiment) -> set[int]:
"""
This hook loads the blacklist of indices from the given file path and returns it as a set of integers.
This set of integers defines all the indices of the target dataset that should be skipped during processing.
"""
if e.INDICES_BLACKLIST_PATH is not None:
with open(e.INDICES_BLACKLIST_PATH, 'r') as file:
return {int(line) for line in file}
else:
return set()
@experiment
def experiment(e: Experiment):
# Here we provide the possibility for sub experiments to add more specific filters for their purposes
# by using a hook.
FILTER_CALLBACKS = e.apply_hook(
'modify_filter_callbacks',
filter_callbacks=e.FILTER_CALLBACKS,
default=e.FILTER_CALLBACKS
)
e.log('generating a molecule visual graph dataset from CSV source file...')
config = Config()
config.load()
@e.hook('load_data')
def load_data(e, config):
# -- get source dataset --
# If the given path is not a valid file on the local system, we will interpret it as a relative
# path to download from the file share provider.
if os.path.exists(e.CSV_FILE_NAME):
file_path = e.CSV_FILE_NAME
e.log(f'CSV found on local system: {file_path}')
else:
e.log('downloading from remote file share...')
file_share: AbstractFileShare = get_file_share(config, e.FILE_SHARE_PROVIDER)
file_path = file_share.download_file(e.CSV_FILE_NAME, e.path)
e.log('CSV downloaded from file share')
# -- Load that data into the required format --
raw_data_list = []
with open(file_path) as file:
dict_reader = csv.DictReader(file)
for c, row in enumerate(dict_reader):
smiles = row[e.SMILES_COLUMN_NAME]
# 23.02.2023
# To extend the experiment to also work with classification datasets, I had to change
# the parameter with the column names to be a list instead of just a single value.
targets = []
for column_name in e.TARGET_COLUMN_NAMES:
value = row[column_name]
if e.TARGET_TYPE == 'classification':
# For a classification dataset, we expect each of the given columns to represent one
# class that is present in the dataset. The value of that column for every element
# of the dataset should be some sort of boolean indication of whether that element
# belongs to that class.
value = int(value)
targets.append(value)
data = {
'smiles': smiles,
'target': np.array(targets),
'data': row,
}
raw_data_list.append(data)
return raw_data_list
# The end result of the "data loading" process should be a list of dictionaries, where each dictionary
# represents one item in the dataset. The content of these dicts may be arbitrary to some point,
# depending on how sub-experiments may extend the functionality, but these should at least contain they
# should at least contain the following elements:
# - smiles: the SMILES string representation
# - target: the target value for that element
# - data: a dictionary containing additional data loaded from the CSV.
e.log(f'loading data...')
raw_data_list: t.List[dict] = e.apply_hook(
'load_data',
config=config
)
dataset_length = len(raw_data_list)
e.log(f'loaded data with {dataset_length} elements')
if e.SUBSET is not None and e.SUBSET < dataset_length:
dataset_length = e.SUBSET
# -- Load the blacklist of indices --
# :hook load_blacklist:
# This hook is supposed to load the blacklist of indices from the given file path and return it as
# a set of integers. This set of integers defines all the indices of the target dataset that should
# be skipped during processing.
blacklist_indices: set[int] = e.apply_hook('load_blacklist')
e.log(f'loaded blacklist consisting of {len(blacklist_indices)} indices')
# -- Processing the dataset into visual graph dataset --
e.log('creating the dataset folder...')
dataset_path = os.path.join(e.path, DATASET_NAME)
os.mkdir(dataset_path)
e['dataset_path'] = dataset_path
writer = VisualGraphDatasetWriter(
path=dataset_path,
chunk_size=e.DATASET_CHUNK_SIZE,
)
# 21.03.2023 - This was essentially the whole point of the preprocessing update: This function will
# automatically generate the code for a python standalone python module which contains all the
# functionality to convert any given SMILES (domain-specific representation of a molecule) into a
# valid visual graph dataset element representation (graph representation + visualization).
e.log(f'generating pre-processing python module for {PROCESSING.__class__}...')
module_code = create_processing_module(PROCESSING)
module_path = os.path.join(dataset_path, 'process.py')
with open(module_path, mode='w') as file:
file.write(module_code)
e.log('creating visual graph dataset...')
# 24.02.2023
# In this dictionary we want to use the smiles identifiers of the elements as the keys and the values
# should be short reasons why these elements were omitted from the final dataset. In the end we then
# want to save this information as a json file into the archive folder as an artifact.
omitted_elements: t.Dict[str, str] = {}
profiling = defaultdict(float)
start_time: float = time.time()
time_previous: float = time.time() # The time when the previous chunk ended
bytes_written: int = 0 # How many bytes were written since the last chunk was processed
index: int = 0
for d in raw_data_list:
# 05.02.24
# If the index is part of the index blacklist, then we will skip the current element and continue
# with the next one.
if index in blacklist_indices:
e.log(f' * skipping {index} due to blacklist')
index += 1
continue
smiles = d['smiles']
# ~ Convert the smiles string into a molecule
try:
# Internally, this function will use the RDKIT SmilesToMol function for the conversion.
# This will raise an exception if the string could not be converted into a proper molecule
mol = mol_from_smiles(smiles)
d['mol'] = mol
except Exception as exc:
e.log(f' ! Error converting smiles "{smiles}" to mol: {exc}, skipping...')
continue
# ~ Convert the Mol object into a GraphDict
data: dict = {**d['data'], **d}
# As the first thing we are going to apply the filters to check if the current element is even
# a valid element according to the rules defined by those callbacks
skip = False
for cb in FILTER_CALLBACKS:
if cb(mol, data):
message = f' ! skipping "{data[SMILES_COLUMN_NAME]}" due to filter "{cb.__name__}"'
omitted_elements[smiles] = message
e.log(message)
skip = True
break
if skip:
continue
target = [float(v) for v in d['target']]
# 13.04.2023 - I needed a method to attach additional data from the CSV to the graph itself and not
# the metadata. This is why this new optional hook exists which generates that dictionary that will
# be added to the graph dict.
# :hook additional_graph_data:
# This hook receives the molecule representation of the current element and the data dict from
# the original csv file and is supposed to output a dictionary which will be used to extend
# the GraphDict representation of that molecule.
additional_graph_data = {'graph_labels': target}
additional_graph_data = e.apply_hook(
'additional_graph_data',
default=additional_graph_data,
additional_graph_data=additional_graph_data,
mol=mol,
data=data,
)
# 04.12.23
# This might seem redundant since we could just use the target value from above and in most cases
# this will evaluate to the same value anyways, but doing it like this enables the possibility to
# modify the target value within the filter hook addtional_graph_data and have those changes be
# reflected here as well
target = additional_graph_data['graph_labels']
additional_metadata = {
'target': target
}
# But there can also be custom entries which are defined as callbacks in this dictionary. These
# values will be associated with the same string keys, which are also used in the callbacks dict
for name, callback in GRAPH_METADATA_CALLBACKS.items():
additional_metadata[name] = callback(mol, data)
# Optionally, if defined, we also add the information about the train test splits to the metadata
# of the element.
if len(SPLIT_COLUMN_NAMES) != 0:
# If the element defined in the column of the given name is a 1 then this indicates that it is
# supposed to be a training element. If it is 0 then it is supposed to be a test element.
additional_metadata['train_indices'] = [index
for index, name in SPLIT_COLUMN_NAMES.items()
if name in data and int(data[name]) == 1]
additional_metadata['test_indices'] = [index
for index, name in SPLIT_COLUMN_NAMES.items()
if name in data and int(data[name]) == 0]
# 23.03.2023 - This is an instance of a ProcessingBase subclass. This class is specifically designed
# to wrap all the functionality which is needed to create a valid VGD element representation given
# only the SMILES string representation of a molecule.
# This method will already create the two required files: The visualization PNG and the metadata
# JSON file.
try:
time_start_create = time.time()
PROCESSING.create(
value=smiles,
index=str(index),
name=smiles,
double_edges_undirected=UNDIRECTED_EDGES_AS_TWO,
use_node_coordinates=USE_NODE_COORDINATES,
additional_graph_data=additional_graph_data,
additional_metadata=additional_metadata,
width=IMAGE_WIDTH,
height=IMAGE_HEIGHT,
create_svg=False,
output_path=dataset_path,
# By specifying a Writer instance, the creation process will use that writer to actually
# safe the data to the disk, using various IO optimizations such as folder chunking.
writer=writer,
)
time_end_create = time.time()
profiling['create_time'] += time_end_create - time_start_create
profiling['graph_size'] += len(mol.GetAtoms())
except (ProcessingError, ValueError) as exc:
e.log(f' * error: {smiles} ({exc})')
continue
# In regular intervals we will print how it's currently going aka how many elements have already
# been processed in what time and how much more time is going to be approximately needed.
if index % EVAL_LOG_STEP == 0:
time_elapsed = time.time() - start_time
time_per_element = time_elapsed / (index+1)
time_remaining = time_per_element * (dataset_length - index)
eta = datetime.datetime.now() + datetime.timedelta(seconds=time_remaining)
# 05.05.23 - In addition to an overall prediction of remaining time, we also want to keep track
# of the average write speed for each chunk of data here as well. We do that because there seems
# to be a problem were the write speed cont. goes down as time goes on and I would like to
# confirm / keep track of this issue.
average_write = bytes_written / (time.time() - time_previous)
e.log(f' * {index}/{dataset_length} elements created'
f' - elapsed time: {time_elapsed:.2f}s'
f' - remaining time: {time_remaining:.2f}s'
f' - eta: {eta:%A %d.%m %H:%M}'
f' - step time: {time.time() - time_previous:.2f}s'
f' - avg graph size: {profiling["graph_size"] / e.EVAL_LOG_STEP:.0f}'
f' - avg create time: {profiling["create_time"] / e.EVAL_LOG_STEP:.3f}s'
f' - ctime/gsize: {profiling["create_time"]/profiling["graph_size"]:.4f}s')
# We need to reset the tracking variables for the chunk
profiling['create_time'] = 0
profiling['graph_size'] = 0
time_previous = time.time()
gc.collect()
# 19.10.23 - This is another potential source for memory issues. Further above we are creating
# a mol object for every one of the smiles strings and they were never cleared. Even though a
# mol object isn't big this will still cause some memory build up that could be problematic
# for systems with low memory.
del additional_graph_data, additional_metadata, data, mol, target
del d['mol'], d['smiles'], d['target'], d['data']
index += 1
# 24.10.23 - Added the option to terminate the loop after a certain number of elements have been
# processed already.
if e.SUBSET is not None and index > e.SUBSET:
break
e.commit_json('omitted_elements.json', omitted_elements)
e.log(f'created {index} out of {dataset_length} original elements. '
f'The rest was skipped either due to errors or filter exclusions.')
e.log('generating dataset metadata...')
metadata_map = {}
# metadata_map.update(generate_visual_graph_dataset_metadata(index_data_map))
metadata_map.update(DATASET_META)
# 21.03.2023 - Another cool feature of the Processing class: This method will automatically create a
# dictionary that contains natural language descriptions for all the elements of the node, edge and
# graph attribute vectors.
#metadata_map.update(PROCESSING.get_description_map())
yaml_path = os.path.join(dataset_path, '.meta.yml')
with open(yaml_path, mode='w') as file:
yaml.dump(metadata_map, file)
e.log(f'metadata written to: {yaml_path}')
e.apply_hook('after_experiment')
e.log(f'loading dataset...')
metadata_map, index_data_map = load_visual_graph_dataset(
e['dataset_path'],
logger=e.logger,
log_step=EVAL_LOG_STEP,
metadata_contains_index=True
)
@experiment.analysis
def analysis(e: Experiment):
e.log('attempting to load visual graph dataset...')
metadata_map, index_data_map = load_visual_graph_dataset(
e['dataset_path'],
logger=e.logger,
log_step=EVAL_LOG_STEP,
metadata_contains_index=True
)
e.log(f'loaded visual graph dataset with {len(index_data_map)} elements')
# -- Plotting information about the dataset --
@e.hook('dataset_info')
def dataset_info(e, index_data_map):
pdf_path = os.path.join(e.path, 'dataset_info.pdf')
with PdfPages(pdf_path) as pdf:
e.log(f'target value distribution...')
fig, ax = plt.subplots(ncols=1, nrows=1, figsize=(12, 12))
ax.set_title('Target Value Distribution')
targets = [np.argmax(d['metadata']['target']) for i, d in index_data_map.items()]
e.log(f'number of targets: {len(targets)}')
e.log(f' * min: {np.min(targets):.2f} - mean: {np.mean(targets)} - max: {np.max(targets):.2f}')
n, bins, edges = ax.hist(
targets,
bins=e.NUM_BINS,
color=e.PLOT_COLOR,
)
ax.set_xticks(bins)
ax.set_xticklabels([round(v, 2) for v in bins])
pdf.savefig(fig)
e.log('graph size distribution...')
fig, ax = plt.subplots(ncols=1, nrows=1, figsize=(12, 12))
ax.set_title('Graph Size Distribution')
sizes = [len(d['metadata']['graph']['node_indices']) for d in index_data_map.values()]
n, bins, edges = ax.hist(
sizes,
bins=e.NUM_BINS,
color=e.PLOT_COLOR,
)
ax.set_xticks(bins)
ax.set_xticklabels([int(v) for v in bins])
pdf.savefig(fig)
e.log(f'plotting dataset analyses...')
e.apply_hook(
'dataset_info',
index_data_map=index_data_map,
)
experiment.run_if_main()