# Convert to Ensemble

This notebook is designed to perform a single task: it takes a dataset tarball and transform the structures inside using our ensamble model. You will fin|d the available options below, in the Variables section. Datasets generated using this notebook can be later used to generate the lmdb compatible with the Open Catalyst Project models.

*Note: For sake of simplicity, this notebook will load the complete dataset into RAM. As FG and BM datasets are small, it should not be a problem for most modern computers. However, if this notebook wants to be used with a larger dataset, implementing a batch function is strongly recommended*  

## Variables

### Paths

In [None]:
from pathlib import Path

ROOT_DIR = Path("./datasets/")              # Working directory
DS_NAME = "BM_dataset_lite"                 # Name of the Dataset
TARBALL = ROOT_DIR/f"{DS_NAME}.tar.xz"      # Location of the Tarball

# Destination tarball with the ensemble dataset
TARBALL_DEST = ROOT_DIR/f"{DS_NAME}_ensemble.tar.xz"

### Ensemble

In [None]:
VORONOI_TOLERANCE = 0.5                     # Voronoi threshold
CORDERO_SCALING_FACTOR = 1.5                # Scaling Factor of the Cordero radii
METALS = (
    "Ag", "Au", "Cd" , "Co", "Cu", "Fe"
    , "Ir", "Ni", "Os", "Pd", "Pt", "Rh"
    , "Ru", "Zn"
)                                           # Elements identified as metals during ensemble
MOL_ELEM = ("C", "H", "O", "N", "S" )       # Elements identified as non-metal during ensemble

#### Radii

In [None]:
CORDERO = {
    'Ac': 2.15, 'Al': 1.21, 'Am': 1.80, 'Sb': 1.39, 'Ar': 1.06
    , 'As': 1.19, 'At': 1.50, 'Ba': 2.15, 'Be': 0.96, 'Bi': 1.48
    , 'B' : 0.84, 'Br': 1.20, 'Cd': 1.44, 'Ca': 1.76, 'C' : 0.76
    , 'Ce': 2.04, 'Cs': 2.44, 'Cl': 1.02, 'Cr': 1.39, 'Co': 1.50
    , 'Cu': 1.32, 'Cm': 1.69, 'Dy': 1.92, 'Er': 1.89, 'Eu': 1.98
    , 'F' : 0.57, 'Fr': 2.60, 'Gd': 1.96, 'Ga': 1.22, 'Ge': 1.20
    , 'Au': 1.36, 'Hf': 1.75, 'He': 0.28, 'Ho': 1.92, 'H' : 0.31
    , 'In': 1.42, 'I' : 1.39, 'Ir': 1.41, 'Fe': 1.52, 'Kr': 1.16
    , 'La': 2.07, 'Pb': 1.46, 'Li': 1.28, 'Lu': 1.87, 'Mg': 1.41
    , 'Mn': 1.61, 'Hg': 1.32, 'Mo': 1.54, 'Ne': 0.58, 'Np': 1.90
    , 'Ni': 1.24, 'Nb': 1.64, 'N' : 0.71, 'Os': 1.44, 'O' : 0.66
    , 'Pd': 1.39, 'P' : 1.07, 'Pt': 1.36, 'Pu': 1.87, 'Po': 1.40
    , 'K' : 2.03, 'Pr': 2.03, 'Pm': 1.99, 'Pa': 2.00, 'Ra': 2.21
    , 'Rn': 1.50, 'Re': 1.51, 'Rh': 1.42, 'Rb': 2.20, 'Ru': 1.46
    , 'Sm': 1.98, 'Sc': 1.70, 'Se': 1.20, 'Si': 1.11, 'Ag': 1.45
    , 'Na': 1.66, 'Sr': 1.95, 'S' : 1.05, 'Ta': 1.70, 'Tc': 1.47
    , 'Te': 1.38, 'Tb': 1.94, 'Tl': 1.45, 'Th': 2.06, 'Tm': 1.90
    , 'Sn': 1.39, 'Ti': 1.60, 'Wf': 1.62, 'U' : 1.96, 'V' : 1.53
    , 'Xe': 1.40, 'Yb': 1.87, 'Y' : 1.90, 'Zn': 1.22, 'Zr': 1.75
}

### Special Surfaces

These molecules have an special surface energy that will not be computed from the energies file, but taken from this dict instead.

In [None]:
SPECIAL_SURF = {
    "ru-mol1": -725.4400795
    , "ru-mol2": -725.4400795
    , "ru-mol3": -725.4400795
    , "ru-mol4": -725.4400795
    , "ru-mol5": -725.4400795
}

## Auxiliary Functions

### Voronoi

In [None]:
import numpy as np
from scipy.spatial import Voronoi
from itertools import product

from pyRDTP.geomio import file_to_mol, mol_to_file
from pyRDTP.molecule import Molecule

def connectivity_search_voronoi(molecule: Molecule,
                                tolerance:float,
                                metal_rad_dict:dict,
                                center:bool=False) -> Molecule:
    if len(molecule.atoms) == 1:
        return molecule
    if center:
        cartesian_old = np.copy(molecule.coords_array('cartesian'))
        direct_old = np.copy(molecule.coords_array('direct'))
        molecule.move_to_box_center()
    coords_arr = np.copy(molecule.coords_array('direct'))
    coords_arr = np.expand_dims(coords_arr, axis=0)
    coords_arr = np.repeat(coords_arr, 27, axis=0)
    mirrors = [-1, 0, 1]
    mirrors = np.asarray(list(product(mirrors, repeat=3)))
    mirrors = np.expand_dims(mirrors, 1)
    mirrors = np.repeat(mirrors, coords_arr.shape[1], axis=1)
    corrected_coords = np.reshape(coords_arr + mirrors,
                                  (coords_arr.shape[0] * coords_arr.shape[1],
                                   coords_arr.shape[2]))
    corrected_coords = np.dot(corrected_coords, molecule.cell_p.direct)
    translator = np.tile(np.arange(coords_arr.shape[1]),
                         coords_arr.shape[0])
    vor_bonds = Voronoi(corrected_coords)
    pairs_corr = translator[vor_bonds.ridge_points]
    pairs_corr = np.unique(np.sort(pairs_corr, axis=1), axis=0)
    true_arr = pairs_corr[:, 0] == pairs_corr[:, 1]
    true_arr = np.argwhere(true_arr)
    pairs_corr = np.delete(pairs_corr, true_arr, axis=0)
    dst_d = {}
    pairs_lst = []
    for pair in pairs_corr:
        elements = [molecule.atoms[index].element for index in pair]
        fr_elements = frozenset(elements)
        if fr_elements not in dst_d:
            dst_d[fr_elements] = metal_rad_dict[elements[0]]
            dst_d[fr_elements] += metal_rad_dict[elements[1]]
            dst_d[fr_elements] += tolerance
        if dst_d[fr_elements] >= molecule.distance(*pair, system='cartesian', minimum=True):
            pairs_lst.append(pair)
            molecule.atoms[pair[0]].connection_add(molecule.atoms[pair[1]])
    molecule.pairs = np.asarray(pairs_lst)
    if center:
        molecule.coords_update(cartesian_old, 'cartesian')
        molecule.coords_update(direct_old, 'direct')
    return molecule


def mol_to_ensemble(molecule: Molecule,
                    voronoi_tolerance: float,
                    scaling_factor: float,
                    second_order: bool,
                    metals = [str],
                    mol_elem = [str],
                    radii = dict[str, float]
                   ) -> Molecule:
    elem_rad = {}
    for metal in metals:
        elem_rad[metal] = radii[metal] * scaling_factor
    for element in mol_elem:
        elem_rad[element] = radii[element]
    # 1) Define whole connectivity in the cell
    molecule = connectivity_search_voronoi(molecule, voronoi_tolerance, elem_rad)
    # 2) Create Molecule object with adsorbate and interacting metal atoms
    new_atoms = []
    non_metal_atoms = [atom for atom in molecule.atoms if atom.element not in metals]
    # 3) Collect atoms
    for atom in non_metal_atoms:
        for neighbour in atom.connections + [atom]:
            if neighbour not in new_atoms:
                new_atoms.append(neighbour)
    # 3b) Collect metal neighbours of the metal atoms directly in contact with adsorbate
    if second_order:
        for atom in new_atoms:
            if atom in metals:
                for neighbour in atom.connections + [atom]:
                    if neighbour not in new_atoms:
                        new_atoms.append(neighbour)
            else:
                pass
    new_atoms = [atom.copy() for atom in new_atoms]
    new_molecule = Molecule("")
    new_molecule.atom_add_list(new_atoms)
    new_molecule.connection_clear()
    new_molecule.cell_p_add(molecule.cell_p.copy())
    # 4) Define connectivity of the final ensemble
    new_molecule = connectivity_search_voronoi(new_molecule, voronoi_tolerance, elem_rad)
    return new_molecule

In [None]:
from pyRDTP.geomio import VaspContcar, MolObj

def vasp_str_read(s):
    contcar = VaspContcar()
    contcar.read(s)
    mol_tmp = MolObj()
    mol_tmp.universal_read(contcar.universal_convert())
    return mol_tmp.write(bulk=False)

def mol_vasp_write(m):
    mol_obj = MolObj()
    mol_obj.read(m)
    contcar = VaspContcar()
    contcar.universal_read(mol_obj.universal_convert())
    return contcar.write()
    
def to_ensemble_str(s):
    return mol_vasp_write(
        mol_to_ensemble( 
        molecule=vasp_str_read(s)
        , voronoi_tolerance=VORONOI_TOLERANCE
        , scaling_factor=CORDERO_SCALING_FACTOR
        , second_order=False
        , metals=METALS
        , mol_elem=MOL_ELEM
        , radii=CORDERO))

## Tarball

In [None]:
import re
from io import BytesIO
import tarfile

def tar_string(s, fn, tar):
    buf = BytesIO(s)
    info = tarfile.TarInfo(name=fn)
    info.size=len(s)
    tar.addfile(tarinfo=info, fileobj=buf)
    
def get_energy_from_dict(s, ener_dict, special_dict):
    if s in special_dict:
        return float(ener_dict[s]) - special_dict[s]
    match s.split('-'):
        case x if len(x) == 1: return float(ener_dict[s])
        case [x, '0000']: return float(ener_dict[s])
        case [x, *_]: return float(ener_dict[s]) - float(ener_dict[f"{x}-0000"]) 

# Regex Comp
M_RE = re.compile(".*-0000\.(contcar|poscar)")
C_RE = re.compile(".*\.*(contcar|poscar)")
E_RE = re.compile(".*energies.*dat")
def tar_file(fn, tar, tar_dest, special_dict):

    
    print(fn)
    match fn:
        case s if M_RE.match(s):
            tar_string(tar.extractfile(s).read(), s, tar_dest)
        case s if C_RE.match(s):
            byte_lines = map(
                lambda x: x.decode("utf-8")
                , tar.extractfile(s).readlines()) 
            geom = to_ensemble_str(byte_lines)
            tar_string(geom.encode("utf-8"), s, tar_dest)
        case s if E_RE.match(s): 
            energies_dict = dict(map(
                lambda x: x.split()
                , tar.extractfile(s).read().decode('utf-8').rstrip().split('\n')))
            tar_string(
                b'\n'.join(map(
                    lambda x: f'{x} {get_energy_from_dict(x, energies_dict, special_dict)}'.encode("utf-8")
                    , energies_dict.keys()))
                , s
                , tar_dest)
        case s if "groups.dat" in s: tar_string(tar.extractfile(s).read(), s, tar_dest)
        case _:
            info = tarfile.TarInfo(fn)
            info.type = tarfile.DIRTYPE
            tar_dest.addfile(info)

In [None]:
from collections import deque

tar_ds = tarfile.open(TARBALL, mode="r:xz")
tar_target = tarfile.open(TARBALL_DEST, mode="w:xz")
names = tar_ds.getnames()
deque(map( lambda s: tar_file(s, tar_ds, tar_target, SPECIAL_SURF)
          , tar_ds.getnames()))
tar_ds.close()
tar_target.close()