# New molecule/ff input generation pipeline

### Outline

```tex
input:      [molecule names]



ps <- Pseudoatoms()


for name in molecule names:
    trappe_name <- search_trappe_names()
    if trappe_name:
        name <- trappe_name
    
    mol <- get_mol(name)
    
    if trappe_name:
        data <- load_trappe()
        map <- map_atoms(mol, data)

        for atom in map:
            assign atom_type
            atom.SetCharge(data[atom])

        interactions <- get_interactions(mol)
    
    else:
        ps.load_references()
    

    make_molecule_file(name, mol, interactions)
    out_names.append(name)


make_ff_ps_files(ps)



output:     {molecule}.def, force_field_mixing_rules.def, pseudo_atoms.def


###

In [1]:
import os
__file__ = os.path.join(os.getcwd(), "input_gen.ipynb")

In [2]:
from student.agent.tools.tools_raspa import RaspaTool
import os
import json
import requests
from student.agent.utils import *
from student.agent.tools.input_gen.pseudoatoms import PseudoAtoms, Atom
from student.agent.tools.input_gen.utils_molecules import *
from student.agent.tools.input_gen.generate_mol_definition import *

Using device: cpu


In [56]:
class TrappeLoader(RaspaTool):

    def __init__(self, path=None):
        super().__init__("molecule definition generator", "", path)
        self.molecules = self.load_molecule_names()
        self.ps_overall = PseudoAtoms()
        self.blacklist = ["methyl acetate", "ethyl acetate", "methyl propionate", "vinyl acetate"]
        self.make_files = True
        self.verbose = False

    def reset(self):
        self.ps_overall = PseudoAtoms()

    def run(self, molecule_names : List[str]):
        if type(molecule_names) == str:
            molecule_names = [molecule_names]

        molecule_names = [name.replace(" ", "_") for name in molecule_names]

        for name in molecule_names:

            res = self._search_name(name)
            if res is  None:
                continue

            id = self.get_molecule_id(res)
            
            mol_def = self.build_molecule_definition(id, name)
            self.make_file(mol_def, f"{name}.def")
        
        self.make_ff_ps_files()
        return
    
    def make_ff_ps_files(self):
        self.make_file(self.ps_overall.build_pseudoatoms(), "pseudo_atoms.def")
        self.make_file(self.ps_overall.build_ff_mixing(), "force_field_mixing_rules.def")
        self.make_file(self.ps_overall.build_ff(), "force_field.def")
        return
    
    ################################################################################
    ################################      Utils          ###########################
    ################################################################################

    def make_file(self, content, file_name):
        if self.make_files is False:
            pass
        if type(content) == list:
            content = "\n".join(content)

        output_dir = self.get_path(full=True)
        with open(os.path.join(output_dir, file_name), "w") as f:
            f.write(content)

    def _load_trappe_names(self):
        # URL to scrape
        url = "http://trappe.oit.umn.edu/scripts/search_select.php"
        # check if the data is already downloaded
        path = self.get_path(full=False)
        file_path = os.path.join(path, "trappe_molecule_list.json")
        try:
            with open(file_path) as f:
                return json.load(f)
        except FileNotFoundError:
            pass
        os.makedirs(path, exist_ok=True)
        res_dict = json.loads(request_by_post(url))['search']
        
        with open(file_path, "w") as f:
            json.dump(res_dict, f)

        return res_dict
    
    def load_molecule_names(self, families=["UA", "small"]):
        mols = self._load_trappe_names()
        molecules = {}
        for m in mols:
            if m['family'] in families:
                name = m["name"].replace("<em>", "").replace("</em>", "")
                if name.startswith("n-"):
                    name = name[2:]
                molecules[name] =  m["molecule_ID"] 
        return molecules

    def get_molecule_id(self, mol):
        return self.molecules.get(mol, None)

    def molecule_names(self):
        return self.molecules.keys()
    
    def _search_name(self, query, score_cutoff=80):
        candidates = self.molecule_names()
        matches = quick_search(query, candidates, limit=5, score_cutoff=score_cutoff)
        matches = [i for i in matches if i not in self.blacklist]

        if len(matches) == 0:
            return None
        best_match = matches[0]

        return best_match[0]
    
    def parse_section(self, param_str: str, section_key: str, min_parts: int) -> list:
        lines = param_str.splitlines()
        section_lines = []
        in_section = False
        for line in lines:
            if not line.strip():
                continue
            if line.startswith(section_key):
                in_section = True
                continue
            if in_section:
                if line.startswith("#,"):
                    break  # end of this section
                parts = [p.strip() for p in line.split(",")]
                if len(parts) >= min_parts:
                    section_lines.append(parts[:min_parts])
        return section_lines
    
    def build_molecule_definition(self, id, name) -> list:
        
        Tc, pc, acentric_factor = get_trappe_properties(id)
        params = self.load_trappe_parameters(id)
        
        ps = params["pseudoatoms"]
        self.ps_overall += ps

        mol = get_mol(name.replace("_", " "), verbose=self.verbose)
        mol = Chem.RemoveHs(mol)
        
        if mol is None:
            raise RuntimeError("No molecule could be generated for ", mol)
        
        atoms = ps.get_atoms_main()
        bonds = params['bonds']
        mol = align_mol_indeces(mol, atoms, bonds, verbose=self.verbose)

        interactions = get_intramol_interactions(mol)
        n_vdw = len(interactions['vdw'])
        n_coulomb = len(interactions['coulomb'])


        lines = []
        
        # Critical constants section
        lines.append("# critical constants: Temperature [T], Pressure [Pa], and Acentric factor [-]")
        lines.append(f"{Tc}")
        lines.append(f"{pc}")
        lines.append(f"{acentric_factor}")
        
        # Molecular composition section
        lines.append("# Number Of Atoms")
        lines.append(f"{params['num_atoms']}")
        lines.append("# Number Of Groups")
        lines.append(f"{params['num_groups']}")
        
        # Group information
        lines.append(f"# {params['group_name']}")
        lines.append(f"{params['group_flexibility']}")
        lines.append("# number of atoms")
        lines.append(f"{params['group_atom_count']}")
        
        # Atomic positions
        lines.append("# atomic positions")
        for i, (index, atom_type) in enumerate(ps.get_atoms()):
            lines.append(f"{i} {atom_type}")
        
        # Intramolecular interaction flags
        lines.append("# Chiral centers Bond  BondDipoles Bend  UrayBradley InvBend  Torsion Imp. Torsion Bond/Bond Stretch/Bend Bend/Bend Stretch/Torsion Bend/Torsion IntraVDW IntraCoulomb")
        intramol = params['intramolecular_flags'] + [n_vdw, n_coulomb]
        flag_format = (
            "{:16d}"  # Chiral centers
            "{:5d}"   # Bond
            "{:13d}"  # BondDipoles
            "{:5d}"   # Bend
            "{:13d}"  # UrayBradley
            "{:8d}"   # InvBend
            "{:9d}"   # Torsion
            "{:13d}"  # Imp. Torsion
            "{:10d}"  # Bond/Bond 
            "{:13d}"  # Stretch/Bend
            "{:10d}"  # Bend/Bend 
            "{:16d}"  # Stretch/Torsion 
            "{:13d}"  # Bend/Torsion 
            "{:9d}"   # IntraVDW 
            "{:12d}"  # IntraCoulomb
        )
        intramolecular_flags = flag_format.format(*intramol)
        lines.append(intramolecular_flags)
        
        # Bond stretching parameters
        lines.append("# Bond stretch: atom n1-n2, type, parameters")
        for bond in params['bond_stretches']:
            atom1, atom2, bond_type, force_constant, eq_length = bond
            lines.append(f"{atom1} {atom2} {bond_type} {force_constant} {eq_length}")
        
        # Bond bending parameters (if available)
        if "bond_bends" in params and params["bond_bends"]:
            lines.append("# Bond bending: atom n1-n2-n3, type, parameters")
            for bend in params["bond_bends"]:
                atom1, atom2, atom3, bend_type, force_constant, theta = bend
                lines.append(f"{atom1} {atom2} {atom3} {bend_type} {force_constant} {theta}")
        
        # Torsion parameters (if available)
        if "bond_torsions" in params and params["bond_torsions"]:
            lines.append("# Torsion: atom n1-n2-n3-n4, type, parameters")
            for torsion in params["bond_torsions"]:
                atom1, atom2, atom3, atom4, torsion_type, c0, c1, c2, c3 = torsion
                lines.append(f"{atom1} {atom2} {atom3} {atom4} {torsion_type} {c0} {c1} {c2} {c3}")
        
        # Intra-molecular interactions
        lines.append(get_intramol_string(interactions))
        
        # Partial reinsertion moves
        lines.append(get_nr_fixed_section(mol))
        
        lines.append("")
        return lines
    

    def load_trappe_parameters(self, molecule_id: int) -> dict:
        """
        Retrieves TraPPE parameters for a given molecule_id and returns a dictionary.
        The returned dictionary includes pseudoatom information, bond stretching, bending,
        and torsion parameters, as well as a formatted intramolecular flag string.
        """

        PARAM_STRING = download_parameters(molecule_id)

        # --- Pseudoatom Section ---
        pseudoatoms = self.parse_section(PARAM_STRING, "#,(pseudo)atom", 6)
        ps = PseudoAtoms()
        ps.parse_trappe(pseudoatoms)

        num_atoms = len(pseudoatoms)

        num_groups = 1  # (could be set to 2 if num_atoms > 8, etc.)
        group_name = "Group"
        group_flexibility = "flexible"
        group_atom_count = num_atoms // num_groups
        # atomic_positions = [(int(p[0]) - 1, p[1]) for p in pseudoatoms]

        # --- Bond Stretching Parameters ---
        stretches = self.parse_section(PARAM_STRING, "#,stretch", 4)
        default_force_constant = 96500
        bond_stretches = []
        bonds = []
        for bond in stretches:
            # bond: [index, bond_range, bond_type, length_str]
            _, bond_range, bond_type, length_str = bond
            bond_range = bond_range.strip(' "\'')
            parts = [p.strip(' "\'') for p in bond_range.split('-')]
            if len(parts) == 2:
                try:
                    atom1 = int(parts[0]) - 1  # Convert from 1-indexed to 0-indexed.
                    atom2 = int(parts[1]) - 1
                    eq_length = float(length_str)
                    #if family == "small":
                    #    bond_stretches.append((atom1, atom2, "RIGID_BOND", "", ""))
                    # bond_stretches.append((atom1, atom2, "HARMONIC_BOND", default_force_constant, eq_length))
                    bond_stretches.append((atom1, atom2, "RIGID_BOND", "", ""))
                    bonds.append((atom1, atom2))
                except Exception:
                    continue
        # --- Bond Bending Parameters ---
        bends = self.parse_section(PARAM_STRING, "#,bend", 5)
        bond_bends = []
        for bend in bends:
            # bend: [index, bend_range, bend_type, theta_str, k_theta_str]
            _, bend_range, bend_type, theta_str, k_theta_str = bend
            bend_range = bend_range.strip(' "\'')
            parts = [p.strip(' "\'') for p in bend_range.split('-')]
            if len(parts) == 3:
                try:
                    atom1 = int(parts[0]) - 1
                    atom2 = int(parts[1]) - 1
                    atom3 = int(parts[2]) - 1
                    theta = float(theta_str)
                    force_constant = float(k_theta_str)
                    bond_bends.append((atom1, atom2, atom3, "HARMONIC_BEND", force_constant, theta))
                except Exception:
                    continue
                
        
        # --- Torsion Parameters ---
        torsions = self.parse_section(PARAM_STRING, "#,torsion", 7)
        bond_torsions = []
        for torsion in torsions:
            # torsion: [index, torsion_range, torsion_type, c0_str, c1_str, c2_str, c3_str]
            _, torsion_range, torsion_type, c0_str, c1_str, c2_str, c3_str = torsion
            torsion_range = torsion_range.strip(' "\'')
            parts = [p.strip(' "\'') for p in torsion_range.split('-')]
            if len(parts) == 4:
                try:
                    atom1 = int(parts[0]) - 1
                    atom2 = int(parts[1]) - 1
                    atom3 = int(parts[2]) - 1
                    atom4 = int(parts[3]) - 1
                
                    c0 = float(c0_str)
                    c1 = float(c1_str)
                    c2 = float(c2_str)
                    c3 = float(c3_str)
                    bond_torsions.append((atom1, atom2, atom3, atom4, "TRAPPE_DIHEDRAL", c0, c1, c2, c3))
                except Exception:
                    continue
        
        num_bond_stretches = len(bond_stretches)
        num_bond_bends     = len(bond_bends)
        num_bond_torsions  = len(bond_torsions)
        
        # Create intramolecular flags using the prescribed format.
        fields = [
            0,                      # Chiral centers
            num_bond_stretches,     # Bond
            0,                      # BondDipoles
            num_bond_bends,         # Bend
            0,                      # UrayBradley
            0,                      # InvBend
            num_bond_torsions,      # Torsion
            0,                      # Imp. Torsion
            0,                      # Bond/Bond 
            0,                      # Stretch/Bend 
            0,                      # Bend/Bend 
            0,                      # Stretch/Torsion 
            0,                      # Bend/Torsion 
        ]
        parameters = {
            "num_atoms": num_atoms,
            "num_groups": num_groups,
            "group_name": group_name,
            "group_flexibility": group_flexibility,
            "group_atom_count": group_atom_count,
            "intramolecular_flags": fields,
            "bond_stretches": bond_stretches,
            "bond_bends": bond_bends,
            "bond_torsions": bond_torsions,
            "pseudoatoms" : ps,
            "bonds" : bonds
        }
        return parameters


def get_intramol_interactions(mol: Chem.Mol, k: int = 4) -> tuple:
    """
    Computes a dictionary of atom-atom interactions that are separated by at least k bonds.
    """
    n_atoms = mol.GetNumAtoms()
    interactions = {"vdw" : [], "coulomb" : []}
    print("Warning: Coulomb interactions only for heteroatoms currently!")

    def is_coulomb(a1: Chem.Atom, a2: Chem.Atom):
        #polar_atoms = {"N", "O", "F", "Cl", "Br", "I", "P", "S"}
        #return (a1.GetSymbol() in polar_atoms and a2.GetSymbol() in polar_atoms) 
        if a1.HasProp("charge") and a1.HasProp("charge"):
            return a1.GetDoubleProp("charge") != 0 and a2.GetDoubleProp("charge") != 0
        else:
            #raise RuntimeError("Atoms dont have charge assigned!")
            return False

    for i in range(n_atoms):
        for j in range(i+1, n_atoms):
            path = Chem.rdmolops.GetShortestPath(mol, i, j)
            dist = len(path) - 1
            if dist >= k:
                interactions["vdw"].append((i,j))

                atom_i = mol.GetAtomWithIdx(i)
                atom_j = mol.GetAtomWithIdx(j)
                
                if is_coulomb(atom_i, atom_j):
                    interactions["coulomb"].append((i,j))

    
    return interactions

In [62]:
from rdkit import Chem
from collections import defaultdict

def align_mol_indeces(mol, atoms, bonds, verbose=False):
    '''
    Align indices from external atom/bond list to RDKit mol atom indices.

    Parameters
    ----------
    mol : rdkit.Chem.Mol
        The RDKit molecule.
    atoms : dict
        {index: (main_atom, charge)} - external indices and atom properties.
    bonds : list of tuples
        [(index1, index2), ...] - external bond pairs using atom indices.

    Returns
    -------
    atom_map : dict
        Mapping from mol index -> atoms.keys() (external indices)
    '''

    # 1. Create a reference structure of atom symbols and bond topology from input
    def extract_main_atom(label):
        # Normalize pseudo-groups like CHx, CFx, etc. to the main atom
        if label.startswith("CH") or label.startswith("CF"):
            return "C"
        elif label.startswith("NH"):
            return "N"
        elif label.startswith("OH"):
            return "O"
        elif label.startswith("SH"):
            return "S"
        elif label.startswith("Hx") or label == "H":
            return "H"
        elif label.startswith("PH"):
            return "P"
        else:
            return label

    # Now build the symbol map with normalization
    atom_idx_to_symbol = {
        i: extract_main_atom(data[0]) for i, data in atoms.items()
    }
    
    # 2. Build an adjacency graph from external atoms and bonds
    adjacency = defaultdict(set)
    for i, j in bonds:
        adjacency[i].add(j)
        adjacency[j].add(i)

    # 3. Build a similar graph from mol
    mol_atom_symbols = [atom.GetSymbol() for atom in mol.GetAtoms()]
    mol_adjacency = defaultdict(set)
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        mol_adjacency[i].add(j)
        mol_adjacency[j].add(i)

    if verbose:
        print(atom_idx_to_symbol)
        print(adjacency)
        print(mol_adjacency)

    # 4. Try to match external atom graph to mol atom graph
    # Naive approach: assume same number of atoms and that a unique match exists based on symbols and connectivity
    atom_map = {}
    used_atoms = set()

    for mol_idx, symbol in enumerate(mol_atom_symbols):
        candidates = [ext_idx for ext_idx, sym in atom_idx_to_symbol.items()
                      if sym == symbol and ext_idx not in used_atoms]

        for ext_idx in candidates:
            ext_neighbors = {atom_idx_to_symbol[n] for n in adjacency[ext_idx]}
            mol_neighbors = {mol_atom_symbols[n] for n in mol_adjacency[mol_idx]}
            if ext_neighbors == mol_neighbors:
                atom_map[mol_idx] = ext_idx
                used_atoms.add(ext_idx)
                break
        else:
            raise ValueError(f"Could not find a match for mol atom {mol_idx} ({symbol})")

    # Create inverse map: from external index -> mol index
    ext_to_mol_map = {v: k for k, v in atom_map.items()}

    # Build new ordering: mol indices sorted by external index order
    new_order = [ext_to_mol_map[i] for i in sorted(atoms.keys())]

    # Renumber atoms in the molecule
    mol = Chem.RenumberAtoms(mol, new_order)

    # Reassign charges to the reordered atoms
    for i, atom in enumerate(mol.GetAtoms()):
        charge = atoms[i][1]
        atom.SetDoubleProp("charge", charge)

    return mol


### Test

In [63]:
t = TrappeLoader()
t.verbose = True



In [64]:
t._search_name("n-pentane", score_cutoff=88)

In [65]:
t.ps_overall

<student.agent.tools.input_gen.pseudoatoms.PseudoAtoms at 0x134a0ada0>

In [66]:
t.run("n-pentane")

SMILES found for  n-pentane  : CCCCC
{0: 'C', 1: 'C', 2: 'C', 3: 'C', 4: 'C'}
defaultdict(<class 'set'>, {0: {1}, 1: {0, 2}, 2: {1, 3}, 3: {2, 4}, 4: {3}})
defaultdict(<class 'set'>, {0: {1}, 1: {0, 2}, 2: {1, 3}, 3: {2, 4}, 4: {3}})


## Pseusdoatoms typing

In [67]:
t = TrappeLoader()



In [70]:
atom_types = {}         # main_atom : {ps_type}
type_to_params = {}     # ps_type : {(params)}

for name, id in t.molecules.items():
    trappe_parameters = download_parameters(id)
    pseudoatoms =t.parse_section(trappe_parameters, "#,(pseudo)atom", 6)
    for atom in pseudoatoms:
        try:
            ps_id = f"{int(atom[0])-1}"
            main_atom = atom[1]
            ps_type = atom[2]
            epsilon = float(atom[3])
            sigma = float(atom[4])
            charge = float(atom[5])
            

            if main_atom not in atom_types.keys():
                atom_types[main_atom] = set()
            atom_types[main_atom].add(ps_type)
            #if main_atom == "M":
            #    print(name, ps_type)
        
            if ps_type not in type_to_params.keys():
                type_to_params[ps_type] = set()
            type_to_params[ps_type].add((epsilon, sigma, charge))
        
        except ValueError as e:
            print("Missing parameters for", name)

Missing parameters for methyl acetate
Missing parameters for ethyl acetate
Missing parameters for methyl propionate
Missing parameters for vinyl acetate


### Type to label

In [71]:
# How many different types per main pseudoatom?
for i, k in atom_types.items():
    print(i, len(k))

CH4 1
CH3 13
CH2 16
CH 13
C 14
O 16
H 4
S 5
N 5
M 6
CF3 1
CF2 1
P 2
F 1


In [74]:
type_to_label = {} # main_atom : {ps_type : label}

for i,k in atom_types.items():
    type_to_label[i] = dict()
    for type in k:
        type_to_label[i][type] = None if len(k) != 1 else type

type_to_label["CH3"] = {
    '[CH3]-C#N': "CH3_cn",
    '[CH3]-S-CHx': "CH3_sc",
    '[CH3]-P': "CH3_p",
    '[CH3]-O-H': "CH3_oh",
    '[CH3]-CH=O': "CH3_co",
    '[CH3]-C-CH-O-P': "CH3_ccop",
    '[CH3]-S-S-CHx': "CH3_ssc",
    '[CH3]-SH': "CH3_sh",
    '[CH3]-CHx': "CH3_chx",
    '[CH3]-CH-O-P': "CH3_cop",
    '[CH3]-P(=O)-(OCH3)2': "CH3_po",
    '[CH3]-O-CHx': "CH3_ocg",
    '[CH3]-O-P': "CH3_op"
}

type_to_label["CH2"] = {
    '[CH2]-O-CH2': "CH2_oc",
    'O-[CH2]-CH2': "o_CH2_c",
    '[CH2]=CHx': "h2C=c",
    'CHx-[CH2]-C#N': "c_CH2cn",
    'O-[CH2]-CH2-CH2': "o_CH2_cc",
    'CHx-[CH2]-O-CHy': "c_CH2_oc",
    'CHx-[CH2]-S-S-CHx': "c_CH2_ssr",
    'CHx-[CH2]-CH=O': "c_CH2_cho",
    'CHx-[CH2]-SH': "c_CH2_sh",
    'O-[CH2]-O': "o_CH2_o",
    'CH2-[CH2]-CH2': "ch2_CH2_ch2",
    'O-CH2-[CH2]-CH2': "oc_CH2_c",
    'O-[CH2]-CH2-O': "oCH2_c_o",
    'CHx-[CH2]-CHx': "c_CH2_c",
    'CHx-[CH2]-O-H': "c_CH2_oh",
    'CHx-[CH2]-S-CHx': "c_CH2_sc"
}

type_to_label["CH"] = {

    'CHx=[CH](sp2)-CHy(sp3)': "c=CH_csp3",
    'CHx=[CH](sp2)-CHy(sp2)': "c=CH_csp2",
    
    '(CHx)2-[CH]-O-H': "r2_CH_oh",
    '(CHx)2-[CH]-CHx': "r2_CH_c",
    '(CH3)2-[CH]-O-P': "r2_CH_op",
    '(CHx)2-[CH]-SH': "r2_CH_sh",
    '(CHx)2-[CH]-O-CHy': "r2_CH_oc",
        
    'CHx-[CH]=O': "c_CH=o",
    'CH3-[CH](-C)-O-P': "me_CH(c)_op",
    
    'CH+[CH]+N': "c+CH+n",
    'CH+[CH]+S': "c+CH+s",
    'CH+[CH]+CH': "c+CH+ch",
    'N+[CH]+N': "n+CH+n"
}

type_to_label["C"] = {
    'O=[C]=O': "C_co2",
    'CHx-[C]#N': "nCr",

    'CHx=[C](sp2)-CHy(sp2)': "c=C_csp2",
    'CHx=[C](sp2)-CHy(sp3)': "c=C_csp3",

    '(CHx)3-[C]-O-H': "me3_C_oh",
    '(CHx)3-[C]-O-CHy': "me3_C_or",
    '(CHx)3-[C]-CHx': "me3_C_c",
    '(CH3)3-[C]-CH-O-P': "me3_C_cop",
    '(CHx)3-[C]-SH': "me3_C_sh",
    
    'CH+[C](-NO2)+CH': "c+C(no2)+c",
    'CH+[C](-CHx)+CH': "c+C(c)+c",
    'CH+[C](+CH)+CH': "c+C(+c)+c",

    'CHx-O-[C]=O': "roC=o",
    '[C]=O': "C=o",
}

type_to_label["O"] = {
    'CHx-[O]-CHy': "chx_O_chx",             # ether (generic)
    'CH2-[O]-CH2-O': "ch2_O_ch2_o",         # ether (double)
    'CH2-[O]-CH2-CH2': "ch2_O_ch2_ch2",     # ether (c chain)
    'CH2-[O]-CH2': "ch2_O_ch2",             # ether

    'CH=[O]': "O=ch",                       # aldehyde
    'C=[O]': "O=c",                         # carbonyl
    'CHx-O-C=[O]': "O=co",                  # carbonyl acid
    'CHx-[O]-C=O': "cO_c=o",                # ester
    'P-[O]-CH3': "pOme",                    # pOme
    'CHx-[O]-H': "hO_c",                    # alcohol
    
    'CHx-N[O]2': "Ono_c",                   # nitro
    'P-[O]-CH': "pO_ch",                    # pOr
    '[O]=P-F': "O=pf",                      # fp=O
    '[O]=P(-CH3)-(OCH3)2': "O=pco",         # O=p(c)(oc)2

    '[O]=C=O': "O_co2",                     # co2
    '[O]=O': "O_o2",                        # o2
}

type_to_label["H"] = {
    'O-[H]': "Ho",      # alcohol H
    'S-[H]': "Hs",      # thiol S
    '[H]-NH2': "H_nh3", # NH3
    '[H]-SH': "H_h2s",  # SH2
}

type_to_label["S"] = {
    'CHx-[S]-H': "hS_chx",
    'H-[S]-H': "S_h2s",
    'CHx-[S]-CHx': "chx_S_chx",
    'CHx-[S]-S-CHx': "chx_Ss_chx",
    'CH+[S]+CH': "c+S+c"
}

type_to_label["N"] = {
    'CH+[N]+CH': "c+N+c",
    '[N]-H3': "N_nh3",
    'CHx-C#[N]': "Nc_chx",
    'CHx-[N]O2': "o2N_chx",
    '[N]#N': "N_n2"
}

type_to_label["M"] = {
    '[M]pi': "M_pi",
    '[M]center': "M_center",
    'H2S-[M]': "M_h2s",
    '[M]-NH3': "M_nh3",
    'O=[M]=O': "M_o2",
    'N#[M]#N': "M_n2"
}

type_to_label["P"] = {
    'F-[P](=O)-O-CH': "fPoome",
    'CH3-[P](=O)-(OCH3)2': "mePoome"
}


In [75]:
# Check for duplicate labels
labels = [label for ps_type in type_to_label.values() for label in ps_type.values()]
print(len(labels), len(set(labels)))

from collections import Counter
duplicates = [item for item, count in Counter(labels).items() if count > 1]
duplicates

98 98


[]

In [76]:
import pickle

with open("ps_type_to_label.pkl", "wb") as f:
    pickle.dump(type_to_label, f)

In [77]:
# Load with this:
import pickle
with open("ps_type_to_label.pkl", "rb") as f:
    type_to_label = pickle.load(f)

### Type to parameters

In [76]:
# Almost all ps_types have unique parameters, except for these:
for k, v in type_to_params.items():
    if len(v) != 1:
        print(k, v)

CH2-[CH2]-CH2 {(56.3, 3.88, 0.0), (52.5, 3.91, 0.0), (51.0, 3.89, 0.0)}
CH2-[O]-CH2 {(155.0, 2.39, -0.44), (190.0, 2.2, -0.41), (155.0, 2.39, -0.36), (29.0, 3.1, -0.42)}
O-[CH2]-O {(56.3, 3.88, 0.36), (52.5, 3.91, 0.36)}
CH2-[O]-CH2-O {(190.0, 2.2, -0.425), (155.0, 2.39, -0.395)}
O-[CH2]-CH2 {(52.5, 3.91, 0.17), (56.3, 3.88, 0.245), (56.3, 3.88, 0.16)}
O-CH2-[CH2]-CH2 {(52.5, 3.91, 0.05), (56.3, 3.88, 0.045)}
CH+[CH]+CH {(48.0, 3.74, 0.0), (50.5, 3.695, 0.0)}


In [84]:
list(atom_types.keys())

['CH4',
 'CH3',
 'CH2',
 'CH',
 'C',
 'O',
 'H',
 'S',
 'N',
 'M',
 'CF3',
 'CF2',
 'P',
 'F']

In [84]:
type_to_params

{'CH4': {(148.0, 3.73, 0.0)},
 '[CH3]-CHx': {(98.0, 3.75, 0.0)},
 'CHx-[CH2]-CHx': {(46.0, 3.95, 0.0)},
 '(CHx)2-[CH]-CHx': {(10.0, 4.68, 0.0)},
 '(CHx)3-[C]-CHx': {(0.5, 6.4, 0.0)},
 '[CH2]=CHx': {(85.0, 3.675, 0.0)},
 'CHx=[CH](sp2)-CHy(sp3)': {(47.0, 3.73, 0.0)},
 'CHx=[CH](sp2)-CHy(sp2)': {(52.0, 3.71, 0.0)},
 'CHx=[C](sp2)-CHy(sp3)': {(20.0, 3.85, 0.0)},
 'CHx=[C](sp2)-CHy(sp2)': {(22.0, 3.85, 0.0)},
 'CHx-[CH2]-O-H': {(46.0, 3.95, 0.265)},
 'CHx-[O]-H': {(93.0, 3.02, -0.7)},
 'O-[H]': {(0.0, 0.0, 0.435)},
 '[CH3]-O-CHx': {(98.0, 3.75, 0.25)},
 'CHx-[O]-CHy': {(55.0, 2.8, -0.5)},
 'CHx-[CH2]-O-CHy': {(46.0, 3.95, 0.25)},
 '(CHx)2-[CH]-O-H': {(10.0, 4.33, 0.265)},
 '(CHx)3-[C]-O-H': {(0.5, 5.8, 0.265)},
 '[CH3]-O-H': {(98.0, 3.75, 0.265)},
 'CHx-[CH2]-SH': {(46.0, 3.95, 0.171)},
 'CHx-[S]-H': {(232.0, 3.62, -0.377)},
 'S-[H]': {(0.0, 0.0, 0.206)},
 '(CHx)2-[CH]-SH': {(10.0, 4.68, 0.171)},
 '(CHx)3-[C]-SH': {(0.5, 6.4, 0.171)},
 '[CH3]-SH': {(98.0, 3.75, 0.171)},
 '(CHx)3-[C]-O-CHy'