In [23]:
import os
import time

from multiprocessing import Manager
import tqdm
import tqdm.contrib.concurrent

import pickle
import scipy
import random
import numpy as np
from scipy.spatial import ConvexHull
from scipy.spatial.distance import pdist, squareform

import ase
import ase.io
from ase.visualize import view, plot
import matplotlib.pyplot as plt

from ppafm import io

from typing import List, Tuple, Dict

elements = ['H' , 'He',
            'Li', 'Be',  'B',  'C',  'N',  'O',  'F', 'Ne', 
            'Na', 'Mg', 'Al', 'Si',  'P',  'S', 'Cl', 'Ar',
             'K', 'Ca', 
            'Sc', 'Ti',  'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn',
                        'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr',
            'Rb', 'Sr',
             'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd',
                        'In', 'Sn', 'Sb', 'Te',  'I', 'Xe'
]

In [16]:
def return_number_of_keys_in_dict(d: dict) -> int:
    total = 0
    for k in d:
        total += len(d[k])
    return total

# https://stackoverflow.com/a/8453514
def random_unit_vector():
    vec = np.random.normal(size=3)
    vec /= np.linalg.norm(vec)
    return vec

def plane_from_3points(points):
    a = points[1] - points[0]
    b = points[2] - points[0]
    n = np.cross(a, b)
    n /= np.linalg.norm(n)
    d = -np.dot(n, points[0])
    return np.array([n[0], n[1], n[2], d])

def plane_from_2points(points):
    p0 = points[0] + np.array([5, 5, 5])
    a = points[0] - p0
    b = points[1] - p0
    n = np.cross(a, b)
    n /= np.linalg.norm(n)
    d = -np.dot(n, points[0])
    return np.array([n[0], n[1], n[2], d])

def zyz_rotation(alpha, beta, gamma):
    '''
    Extrinsic rotation of angle alpha around z, beta around y, and gamma around z.
    '''

    def R_y(x):
        return [[np.cos(x), 0, np.sin(x)], [0, 1, 0], [-np.sin(x), 0, np.cos(x)]]
    def R_z(x):
        return [[np.cos(x), -np.sin(x), 0], [np.sin(x), np.cos(x), 0], [0, 0, 1]]

    return np.dot(R_z(gamma), np.dot(R_y(beta), R_z(alpha)))

def cart_to_sph(vec):
    r = np.linalg.norm(vec)
    phi = np.arctan2(vec[1], vec[0])
    theta = np.arccos(vec[2]/r)
    return r, phi, theta

def get_convex_hull_eqs(xyz, angle_tolerance=5):
    '''
    Get coefficients for equations of planes of convex hull of point cloud. If plane normals
    are within angle_tolerance, they are considered to be the same plane and the other plane
    is ignored.
    '''
    xyz = xyz[:,:3]
    hull = ConvexHull(xyz)
    cosines = 1 - pdist(hull.equations[:,:3], 'cosine')
    angles = np.arccos(cosines) / np.pi * 180
    angles = squareform(angles)
    bad_inds = []
    for i, angle in enumerate(angles):
        if i in bad_inds:
            continue
        inds = np.where(angle < 5)[0]
        for ind in inds:
            if ind == i:
                continue
            if ind not in bad_inds:
                bad_inds.append(ind)
    eqs = np.delete(hull.equations, bad_inds, axis=0)
    return eqs, hull

def find_planar_segments(xyz, eqs, dist_tol=0.2, num_atoms=10):
    '''
    Find planar segments on the surface of a molecule. A planar segment is
    a surface plane that contains at least num_atoms atoms within dist_tol of
    the plane.
    '''
    planar_seg_eqs = []
    planar_seg_inds = []
    for i, eq in enumerate(eqs):
        eq = eq / np.linalg.norm(eq[:3])
        dist = np.abs(np.dot(xyz[:,:3], eq[:3]) + eq[-1])
        if len(np.where(dist <= dist_tol)[0]) >= num_atoms:
            planar_seg_eqs.append(eq)
            planar_seg_inds.append(i)
    return planar_seg_eqs, planar_seg_inds

def get_plane_elements(xyz, plane_eqs, dist_tol=0.5):
    '''
    Find and count the number of atoms of different elements in a molecule near planes.
    '''
    plane_elems = []
    for eq in plane_eqs:
        eq = eq / np.linalg.norm(eq[:3])
        dist = np.abs(np.dot(xyz[:,:3], eq[:3]) + eq[-1])
        plane_elems.append(set(xyz[dist <= dist_tol, -1].astype(int)))
    return plane_elems

def _convert_elemements(element_dict):
    element_dict_ = {}
    for i, e in enumerate(element_dict):
        if isinstance(e, str):
            element_dict_[elements.index(e)+1] = element_dict[e]
        else:
            element_dict_[e] = element_dict[e]
    return element_dict_

def choose_rotations_bias(
    xyz,
    flat,
    plane_bias=None,
    random_bias=None,
    angle_tolerance=5,
    elem_dist_tol=0.7,
    flat_dist_tol=0.1,
    flat_num_atoms=10
):

    n_vecs = []

    if plane_bias is not None:
        plane_bias = _convert_elemements(plane_bias)
    if random_bias is not None:
        random_bias = _convert_elemements(random_bias)
    
    if len(xyz) > 3:
        try:
            eqs, hull = get_convex_hull_eqs(xyz, angle_tolerance=angle_tolerance)
            vertices = hull.vertices
        except scipy.spatial.qhull.QhullError:
            print(f'A problematic molecule encountered.')
            return []
    elif len(xyz) == 3:
        eqs = plane_from_3points(xyz[:,:3])[None]
        vertices = np.array([0, 1, 2])
    elif len(xyz) == 2:
        eqs = plane_from_2points(xyz[:,:3])[None]
        vertices = np.array([0, 1])
    else:
        print(xyz)
        raise RuntimeError('Molecule with less than two atoms.')
    
    if flat:
        planar_seg_eqs, planar_seg_inds = find_planar_segments(xyz, eqs, dist_tol=flat_dist_tol, num_atoms=flat_num_atoms)
        for eq in planar_seg_eqs:
            n_vecs.append(eq[:3])
        eqs = np.delete(eqs, planar_seg_inds, axis=0)

    if plane_bias:
        plane_elems = get_plane_elements(xyz, eqs, dist_tol=elem_dist_tol)
        for eq, elems in zip(eqs, plane_elems):
            for e, p in plane_bias.items():
                if e in elems and (np.random.rand() <= p):
                    n_vecs.append(eq[:3])
                    break

    if random_bias:
        elems = set(xyz[vertices,-1].astype(int))
        for e in random_bias:
            if e not in elems:
                continue
            while random_bias[e] > 0:
                if random_bias[e] < 1 and np.random.rand() > random_bias[e]:
                    break
                while True:
                    n = random_unit_vector()
                    _, phi, theta = cart_to_sph(n)
                    new_xyz = xyz.copy()
                    new_xyz[:,:3] = np.dot(new_xyz[:,:3], zyz_rotation(-phi, -theta, 0).T)
                    eq = np.array([0, 0, 1, -new_xyz[:,2].max()])
                    elems = get_plane_elements(new_xyz, [eq], dist_tol=0.7)
                    if e in elems[0]:
                        break
                if len(n_vecs) > 0:
                    n_vecs_np = np.stack(n_vecs, axis=0)
                    angles = np.arccos(np.dot(n_vecs_np, n)/np.linalg.norm(n_vecs_np, axis=1)) / np.pi * 180
                    if all(angles > angle_tolerance):
                        n_vecs.append(n)
                        random_bias[e] -= 1
                else:
                    n_vecs.append(n)
                    random_bias[e] -= 1

    rotations = []
    for vec in n_vecs:
        _, phi, theta = cart_to_sph(vec)
        rotations.append(zyz_rotation(-phi, -theta, 0))

    return rotations

In [17]:
#database_path = '/Users/kurkil1/data/mol_database/'
database_path = '/l/mol_database/'
filenames = [
    os.path.join(database_path, f)
    for f in os.listdir(database_path)
    if f.endswith('.xyz')
]

In [18]:
def rotate_xyz(xyz, R):
    new_xyz = xyz.copy()
    new_xyz[:,:3] = np.dot(new_xyz[:,:3], R.T)
    return new_xyz

In [19]:
def return_rotations(
    filenames: List[str],
    valid_elements: List[str],
    flat: bool = True,
    plane_bias = None,
    random_bias = None,
    angle_tolerance: float = 5,
    elem_dist_tol: float = 0.7,
    flat_dist_tol: float = 0.1,
    flat_num_atoms: int = 10
) -> List[np.ndarray]:
    rotations = {}
    for filename in tqdm.tqdm(filenames):
        xyz, zs, qs, comment = io.loadXYZ(filename)
        cid = comment.split()[-1]
        if np.any(~np.isin(zs, valid_elements)):
            continue
        rots = choose_rotations_bias(
            xyz,
            flat=flat,
            plane_bias=plane_bias,
            random_bias=random_bias,
            angle_tolerance=angle_tolerance,
            elem_dist_tol=elem_dist_tol,
            flat_dist_tol=flat_dist_tol,
            flat_num_atoms=flat_num_atoms
        )
        if len(rots) == 0:
            continue
        rotations[cid] = rots
    return rotations

In [104]:
flat = True
flat_dist_tol = 0.3
elem_dist_tol = 0.3
angle_tolerance = 5
plane_bias = {
    'H' : 0.0,
    'C' : 0.0,
    'N' : 0.1,
    'O' : 0.0,
    'F' : 1,
    'Cl': 1,
    'Br': 1,
}
random_bias = {
    'H' : 0.1,
    'C' : 0.0,
    'N' : 0.4,
    'O' : 0.1,
    'F' : 1,
    'Cl': 1,
    'Br': 1,
}
flat_num_atoms = 6
valid_elements = np.array([1, 6, 7, 8, 9, 17, 35])

# Set random seeds for reproducibility
random.seed(0)
np.random.seed(0)

In [105]:
rotations = return_rotations(
    filenames[:10000],
    valid_elements,
    flat=flat,
    plane_bias=plane_bias,
    #random_bias=random_bias,
    angle_tolerance=angle_tolerance,
    elem_dist_tol=elem_dist_tol,
    flat_dist_tol=flat_dist_tol,
    flat_num_atoms=flat_num_atoms
)

  except scipy.spatial.qhull.QhullError:
 37%|███▋      | 3726/10000 [00:01<00:02, 2226.41it/s]

A problematic molecule encountered.


 49%|████▊     | 4855/10000 [00:02<00:02, 2185.68it/s]

A problematic molecule encountered.
A problematic molecule encountered.


 64%|██████▍   | 6390/10000 [00:02<00:01, 2067.61it/s]

A problematic molecule encountered.


 79%|███████▉  | 7932/10000 [00:03<00:00, 2158.97it/s]

A problematic molecule encountered.


 93%|█████████▎| 9261/10000 [00:04<00:00, 2217.54it/s]

A problematic molecule encountered.


100%|██████████| 10000/10000 [00:04<00:00, 2167.62it/s]

A problematic molecule encountered.





In [106]:
print(return_number_of_keys_in_dict(rotations))

14474


In [107]:
os.makedirs('rotations', exist_ok=True)
for counter, (mol_id, rots) in enumerate(rotations.items()):
    if len(rots) == 0:
        continue
    # Load molecule
    filename = os.path.join(database_path, f'{mol_id}.xyz')
    xyz, zs, qs, comment = io.loadXYZ(filename)
    original_mol = ase.Atoms(numbers=zs, positions=xyz[:, :3])
    original_mol.cell = np.eye(3) * 20
    original_mol.cell[2, 2] = 5.0
    original_mol.center()

    # Plot the original molecule and the rotated molecules
    fig = plt.figure(figsize=(10, 5))
    sf1, sf2 = fig.subfigures(1, 2, width_ratios=[1, len(rots)], wspace=0.1)
    
    # Plot original molecule
    axs = sf1.subplots(2, 1)
    ax = axs[0]
    plot.plot_atoms(original_mol, ax=ax, show_unit_cell=2)
    ax.set_title('Original')

    # Plot side-view of original molecule
    ax = axs[1]
    plot.plot_atoms(original_mol, ax=ax, show_unit_cell=2, rotation='-90x')

    axs = sf2.subplots(2, len(rots))
    if len(rots) == 1:
        axs = axs[:, np.newaxis]
    for i, R in enumerate(rots):
        new_xyz = rotate_xyz(xyz, R)
        
        new_mol = ase.Atoms(numbers=zs, positions=new_xyz[:, :3])
        new_mol.cell = np.eye(3) * 20
        new_mol.cell[2, 2] = 5.0
        new_mol.center()

        # Plot from top-view
        ax = axs[0,i]
        plot.plot_atoms(new_mol, ax=ax, show_unit_cell=2)

        # Plot from side-view
        ax = axs[1,i]
        plot.plot_atoms(new_mol, ax=ax, show_unit_cell=2, rotation='-90x')

    plt.tight_layout()
    plt.savefig(f'rotations/{counter}.png')
    plt.close()

    if counter > 100:
        break


In [99]:
def analyze_rotations(rotations: Dict[str, List[np.ndarray]]):
    '''
    Analyze the rotations to see the distributions of atom types near the top atom in rotated molecules.
    '''

    element_counts = {}

    for mol_id, rots in rotations.items():
        if len(rots) == 0:
            continue
        filename = os.path.join(database_path, f'{mol_id}.xyz')
        xyz, zs, qs, comment = io.loadXYZ(filename)
        for R in rots:
            xyz = rotate_xyz(xyz, R)

            # Filter out atoms that are too far from the top atom
            mask = xyz[:,2] > xyz[:,2].max() - 1.0
            xyz = xyz[mask]
            zs = zs[mask]

            for z in zs:
                if z not in element_counts:
                    element_counts[z] = 0
                element_counts[z] += 1

    # Divide each value by total number of counts
    total = sum(element_counts.values())
    for k in element_counts:
        element_counts[k] /= total

    return element_counts

In [100]:
new_counts = analyze_rotations(rotations)
# Print the counts in the increasing order of keys
for k in sorted(new_counts.keys()):
    print(f'{elements[k-1]:2s}: {new_counts[k]:.3f}')

H : 0.418
C : 0.436
N : 0.059
O : 0.070
F : 0.005
Cl: 0.010
Br: 0.002


In [61]:
counts = analyze_rotations(rotations)
# Print the counts in the increasing order of keys
for k in sorted(counts.keys()):
    print(f'{elements[k-1]:2s}: {counts[k]:.3f}')

H : 0.391
C : 0.458
N : 0.063
O : 0.070
F : 0.005
Cl: 0.010
Br: 0.003
