In [1]:
import numpy as np
import torch
from torch import nn
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from pymongo import MongoClient, UpdateOne, DESCENDING
from itertools import permutations
from multiprocessing import Pool
from chespex.molecules import Molecule

In [2]:
n_beads = 4
level = 1
n_bead_types = [15, 45, 96][level]
n_bead_classes = [4, 13, 27][level]
client = MongoClient("mongodb://localhost:27017")
database = client.get_database(f"molecules-{n_beads}")
collection = database.get_collection(f"level-{level}")

In [3]:
def transform(batch):
    node_type = nn.functional.one_hot(batch.x[:, 1].to(torch.int64), n_bead_classes)
    node_size = nn.functional.one_hot(batch.x[:, 0].to(torch.int64), 3)
    node_charge = nn.functional.one_hot(batch.x[:, 2].to(torch.int64), 3)[:,::2]
    node_oco_w_tfe = batch.x[:, 3:] / 10 # Octanol-water transfer free energy
    batch.x = torch.cat([node_type, node_size, node_charge, node_oco_w_tfe], dim=1)
    return batch.cuda()

In [None]:
model_name = f"../models/model-{n_beads}-{n_bead_types}-5d.pt"
print(f"Loading model from {model_name}")
model = torch.load(model_name, weights_only=False)

In [6]:
parent_map = [
    {
        'Q1': 'Q', 'Q2': 'Q',
        'P3': 'P', 'P2': 'P', 'P1': 'P',
        'N3': 'N', 'N2': 'N', 'N1': 'N',
        'C3': 'C', 'C2': 'C', 'C1': 'C', 'X2': 'C', 'X1': 'C'
    },
    {
        'Q5': 'Q2', 'Q4': 'Q2', 'Q3': 'Q2', 'Q2': 'Q1', 'Q1': 'Q1',
        'P6': 'P3', 'P5': 'P3', 'P4': 'P2', 'P3': 'P2', 'P2': 'P1', 'P1': 'P1',
        'N6': 'N3', 'N5': 'N3', 'N4': 'N2', 'N3': 'N2', 'N2': 'N1', 'N1': 'N1',
        'C6': 'C3', 'C5': 'C3', 'C4': 'C2', 'C3': 'C2', 'C2': 'C1', 'C1': 'C1',
        'X4': 'X2', 'X3': 'X2', 'X2': 'X1', 'X1': 'X1'
    }
]

In [7]:
filter_query = {}  # {key_name: {"$exists": False}}
size = collection.count_documents(filter_query) if len(filter_query) > 0 else {2: 136870880, 1: 6742680, 0: 89960}[level]
key_name = "latent_space_test"

molecules = []
indices = []
for i, mol in enumerate(collection.find(filter_query, batch_size=20_000).sort("_id", DESCENDING)):
    node_features = torch.Tensor(mol["node_features"]).to('cuda')
    edge_index = torch.Tensor(mol["edge_index"]).to(torch.int64).to("cuda")
    torch_graph = Data(x=node_features, edge_index=edge_index)
    molecules.append(torch_graph)
    indices.append(mol["_id"])
    if len(molecules) == 20_000 or i == size - 1:
        print(f'Processing {i+1:,} / {size:,}', end="\r", flush=True)
        dataloader = DataLoader(molecules, batch_size=len(molecules), shuffle=False)
        batch = next(iter(dataloader))
        batch = transform(batch)
        with torch.no_grad():
            latent_space = torch.fmod(model.encoder(batch), 5)
        update_list = []
        for idx, space in zip(indices, latent_space):
            update = {key_name: space.tolist()}
            update_list.append(UpdateOne({"_id": idx}, {"$set": update}))
        res = collection.bulk_write(update_list)
        molecules = []
        indices = []

Processing 6,742,680 / 6,742,680

In [None]:
# Extract example molecule
for k, v in collection.find_one().items():
    print(f'{k}: {v}')

_id: 666f2642ab68cb0585610d42
name: Q2+,
bead_names: ['Q2+']
node_features: [[0.0, 0.0, 2.0, -19.73332977294922]]
edge_index: [[], []]
latent_space: [1.2288522720336914, 2.4856317043304443, 2.344226360321045, -2.741889476776123, 5.2077131271362305]
parent: Q+,
group: 11 6 12 5 11
latent_space_test: [0.524034321308136, -3.7589316368103027, 1.5780901908874512, -0.3245792090892792, -2.533147096633911, -0.5984180569648743, -2.137758493423462, 4.545452117919922, -2.7681477069854736, 0.33618515729904175]


In [9]:
client.close()