In [1]:
import json
import re
import numpy as np
import torch_geometric as tg
from pymongo import MongoClient, InsertOne
import pymongo
from chespex.molecules import MoleculeGenerator

In [None]:
n_beads = 4
client = MongoClient("mongodb://localhost:27017")
database = client.get_database(f"molecules-{n_beads}")

In [5]:
level = 2
with open('../bead_types/bead-types.json', 'r') as bead_types_file:
    all_bead_types = json.load(bead_types_file)
bead_types = all_bead_types[level]

In [6]:
bead_class_names = sorted(list(set([re.sub('[ST+-]', '', b['name']) for b in bead_types])), key=lambda x: ['Q','P','N','C','X'].index(x[0]) * 20 - (int(x[1:]) if len(x) > 1 else 0))
print(bead_class_names)

['Q5', 'Q4', 'Q3', 'Q2', 'Q1', 'P6', 'P5', 'P4', 'P3', 'P2', 'P1', 'N6', 'N5', 'N4', 'N3', 'N2', 'N1', 'C6', 'C5', 'C4', 'C3', 'C2', 'C1', 'X4', 'X3', 'X2', 'X1']


In [9]:
n_bead_sizes=3
charged_beads=5
n_bead_types=len(bead_types)
print(f'Number of bead types: {n_bead_types}')
n_bead_classes = n_bead_types // n_bead_sizes - (charged_beads)
assert n_bead_classes == len(bead_class_names), f'Inconsistent number of bead classes, bead sizes, and max charged index: {n_bead_classes} != {len(bead_class_names)}'
latent_dim=5
feature_dim = n_bead_classes + n_bead_sizes + 2 + 1 # bead class, bead size, charge, octanol-water transfer free energy
print(f'Feature dimension: {feature_dim}')

Number of bead types: 96
Feature dimension: 33


In [7]:
collection = database.get_collection(f"level-{level}")

In [None]:
molecule_generator = MoleculeGenerator(n_beads, bead_types)
molecule_inserts = []
for i, molecule in enumerate(molecule_generator.generate()):
    tg_molecule = molecule.as_torch_graph()
    bead_names = [b['name'] for b in molecule.beads]
    mol = {
        'name': str(molecule),
        'bead_names': bead_names,
        'node_features': tg_molecule.x.tolist(),
        'edge_index': tg_molecule.edge_index.tolist()
    }
    molecule_inserts.append(InsertOne(mol))
    if len(molecule_inserts) == 10_000:
        print(i, end="\r")
        collection.bulk_write(molecule_inserts)
        molecule_inserts = []

19040000

In [8]:
collection_size = collection.count_documents({})
print(collection_size)

136870880


In [None]:
names = []
cursor = collection.aggregate([{"$sample": {"size": collection_size}}])
for i, mol in enumerate(cursor):
    if i < 10:
        print(mol['name'])
    if i < 10_000:
        names.append(mol['name'])
    graph = tg.data.Data(x=mol['node_features'], edge_index=mol['edge_index'])
    if i == 20_000:
        break

P3 Q4+ SP3 SQ2-,0-1 0-2 2-3
N2 SN3 SQ5- X4,0-1 0-2 0-3 1-3 2-3
SC3 SP3 TC2 TX2,0-1 0-2 0-3 1-2 1-3 2-3
SC2 SQ1+ TQ2- TQ4+,0-2 1-2 1-3
Q3+ SQ4- SX1 TX4,0-2 0-3 1-2 2-3
SQ2+ TC5 TQ5+ TX1,0-2 1-2 2-3
N5 SN3 TC2 TN1,0-2 1-2 2-3
C5 N5 TQ1+ TQ1-,0-2 0-3 1-2
SN1 SQ4- SQ5+ SX3,0-1 0-2 0-3 1-3 2-3
SQ5- TC5 TX3 X3,0-2 1-2 1-3 2-3


In [None]:
client.close()