In [1]:
from datetime import datetime
import numpy as np
import torch
from torch import nn
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from pymongo import MongoClient, UpdateOne
from itertools import permutations
from chespex.molecules import Molecule

In [2]:
n_beads = 4
n_bead_types = 96
n_bead_classes = 27
level = 2
client = MongoClient("mongodb://localhost:27017")
database = client.get_database(f"molecules-{n_beads}")
collection = database.get_collection(f"level-{level}")

In [3]:
parent_map = [
    {
        'Q1': 'Q', 'Q2': 'Q',
        'P3': 'P', 'P2': 'P', 'P1': 'P',
        'N3': 'N', 'N2': 'N', 'N1': 'N',
        'C3': 'C', 'C2': 'C', 'C1': 'C', 'X2': 'C', 'X1': 'C'
    },
    {
        'Q5': 'Q2', 'Q4': 'Q2', 'Q3': 'Q2', 'Q2': 'Q1', 'Q1': 'Q1',
        'P6': 'P3', 'P5': 'P3', 'P4': 'P2', 'P3': 'P2', 'P2': 'P1', 'P1': 'P1',
        'N6': 'N3', 'N5': 'N3', 'N4': 'N2', 'N3': 'N2', 'N2': 'N1', 'N1': 'N1',
        'C6': 'C3', 'C5': 'C3', 'C4': 'C2', 'C3': 'C2', 'C2': 'C1', 'C1': 'C1',
        'X4': 'X2', 'X3': 'X2', 'X2': 'X1', 'X1': 'X1'
    }
]

In [None]:
# Set parents
if level > 0:
    parent_collection = database.get_collection(f"level-{level - 1}")
    update_list = []
    for mol_idx, mol in enumerate(collection.find()):
        # Reconstruct parent molecule
        parent_beads = []
        for bead in mol['bead_names']:
            for k, v in reversed(parent_map[level - 1].items()):
                bead = bead.replace(k, v)
            parent_beads.append(bead)
        parent_molecule = str(Molecule.reconstruct(
            parent_beads, [[] for _ in range(len(parent_beads))], mol['edge_index'], []
        ))
        # Update parent molecule
        if parent_collection.find_one({"name": parent_molecule}):
            update_list.append(UpdateOne({"_id": mol["_id"]}, {"$set": {"parent": parent_molecule}}))
        else:
            beads_string, bonds_string = parent_molecule.split(",")
            beads = beads_string.split()
            valid_permutations = []
            for permutation in permutations(range(len(beads))):
                new_beads = [beads[i] for i in permutation]
                if beads == new_beads:
                    valid_permutations.append(permutation)
            result = []
            for permutation in valid_permutations:
                new_bonds_string = bonds_string
                for i, j in enumerate(permutation):
                    if i != j:
                        new_bonds_string = new_bonds_string.replace(str(i), chr(j + 65))
                for j in permutation:
                    new_bonds_string = new_bonds_string.replace(chr(j + 65), str(j))
                new_bonds_string = ' '.join(sorted(['-'.join(sorted(b.split('-'))) for b in new_bonds_string.split()]))
                result.append(beads_string + "," + new_bonds_string)
            found_parent = False
            for new_parent in result:
                if parent_collection.find_one({"name": new_parent}):
                    update_list.append(UpdateOne({"_id": mol["_id"]}, {"$set": {"parent": new_parent}}))
                    found_parent = True
                    break
            if not found_parent:
                print(f"Parent not found for {mol['name']}: {result}")
        if len(update_list) > 1000:
            print(mol_idx, end="\r", flush=True)
            collection.bulk_write(update_list)
            update_list = []
        if mol_idx % 100000 == 0:
            with open('parent-update.log', 'a') as f:
                f.write(f"{datetime.now()}: {mol_idx}\n")
    if len(update_list) > 0:
        collection.bulk_write(update_list)