### Hierarchical generation of candidate structures from space group, composition, and lattice parameters.

**Input:** Space group, lattice parameters, number of atoms for each element-type in the unit cell, and pair-wise minimum distances allowed.

**Algorithm:**
1. Get Wyckoff sites (symmtery operations) for the space group (attention should be paid to the spg setting used in experiments)
2. Generate all possible combinations of Wyckoff sites satisfying the number of atoms (composition) in the cell. Filter out combinations that duplicate zero-degree-of-freedom Wyckoff sites.
3. Sort combinations such that those with fewer number of distinct Wyckoff sites are higher on the list.
4. For each combination on the list:
    1. Start with the Wyckoff positions of the element-type (e.g. "Pb") that has fewer number of atoms. For each Wyckoff position of this element in the combination:
        1. Generate sets of atomic positions from each Wyckoff position's symmetry operations for each point on the [x,y,z] grid with a specified step size within [0,1] (considering which dimensions are active)
        2. Remove sets of positions where positions overlap or closer than the distance threshold for the same (A-A) pairs.
    2. Repeat the overlap/distance check accross different Wyckoff sites of the same element. Skip those with atoms too close.
    3. If this is the first element-type considered (e.g. "Pb"), continue loop.
    4. If this is not the first element-type considered (e.g. "S"): Repeat the overlap/distance check between the newest element expanded in the loop (e.g. "S") and element-types already generated. (e.g. Pb). Skip those with atoms too close.
    5. Store structures that pass all distance filters.
      
**Returns:** A list of feasible sets of atomic positions that satisfy the composition, symmetry and distance constraints.
 

*TODO:* This is just an initial prototype. We should test, clean up and make this code modular/better!

In [3]:
import itertools
import numpy as np
from pymatgen import Structure
from pymatgen import Lattice
from pyxtal.symmetry import get_wyckoffs
from tqdm.notebook import tqdm

This example is focusing on Anglesite: PbSO4; which has the space group 62. Experimental reports often use a non-standard setting of 62; i.e. Pbnm. In general, care must be taken to ensure proper comparison under similar settings for structures generate and experimental diffraction data.

In [4]:
# SETUP
spg = 62
c=6.959; a=8.482; b=5.398; alpha=90; beta=90; gamma=90 
atoms = [(4,'Pb'), (16,'O'), (4,'S')]
# General ovelap threshold:
d_tol = 1.2 
# Specific thresholds:
d_mins = {'Pb': 1.5*2, 'S': 1.70*2, 'O': 2.1, 'O-Pb': 2.4, 'Pb-S': 3.0}
# Numer of steps in [0,1] for each free (x, y or z) parameter.
npoints = 10

In [5]:
lattice = Lattice.from_parameters(a,b,c,alpha,beta,gamma)
atoms = sorted(atoms)
d_tol_squared = d_tol**2 # general overlap distance threshold
d_mins_squared = {k: v**2 for k,v in d_mins.items()}
wyckoffs = get_wyckoffs(spg)
multiplicities = [len(w) for w in wyckoffs]

def get_wyckoff_candidates(pos, npoints=5, d_min_squared=None, lattice=None):
    """
    This function generates all unique sets of atomic positions from a Wyckoff position's symmetry operations,
    on a grid of allowed free-varibles (with *n* equi-distant points on [0,1]) (e.g. n x n x n grid if all a 
    site has all x, y and z as free-variables; or nx1xn if x and z are free variables). It takes into account
    a minimum allowed distance.
    Args:
        pos (list): Symmetry operations of the Wyckoff position
        npoints (int): number of equi-distant points on [0,1] for generation of the free-variable grid.
        d_min_squared (float): If given, squared distance between each atom pair in the generated set of positions 
            is compared to filter out unphysical structures.
        lattice (pymatgen.Lattice): lattice object used for distance computations.
    Returns:
        A set of *sets of atomic positions* generated satisfying the symmetry and distance constraints.

    """
    
    grid_xyz=[]
    for i in active_dim(pos):
        if i:
            grid_xyz.append(np.linspace(0,1.0,npoints)) 
        else:
            grid_xyz.append([0]) # if dimension is not active.

    candidates = []
    # forming a meshgrid for free x, y, and/or z parameters for the wyckoff site.
    for xyz in itertools.product(grid_xyz[0],
                                 grid_xyz[1],
                                 grid_xyz[2]):
        wyckoff_positions = []
        for so in pos: #apply symmetry operations of the wyckoff 
            product = so.operate(xyz) # applies both rotation and translation.
            warped = warp(product)    # make sure sites remain within the unit cells
            wyckoff_positions.append(tuple(warped))
    
        wyckoff_positions = frozenset(wyckoff_positions) # forming a set will get rid of duplciates (overlaps)
        
        if len(wyckoff_positions) == len(pos): # if no overlapping sites, store set of wyckoff positions
            skip_str = False
            if d_min_squared:
                for s1,s2 in itertools.combinations(wyckoff_positions,2):
                    if np.sum((
                                lattice.get_cartesian_coords(np.array(s1))
                                -lattice.get_cartesian_coords(np.array(s2)))**2 )< d_min_squared:
                        skip_str = True
                        break
            if skip_str:
                continue
            candidates.append(wyckoff_positions)
    return set(candidates)

def get_possible_combinations(multiplicities, target_n_atoms):
    """
    Helper function to find all possible combinatios of Wyckoff positions of a given space group
         that would satisfy the target atom count.
    """
    return [q for i in range(len(multiplicities), 0, -1) 
          for q in itertools.combinations(enumerate(multiplicities), i) if sum([k[1] for k in q]) == target_n_atoms]

def active_dim(pos):
    """Helper function that checks if the wyckoff position has free variables"""
    return pos[0].rotation_matrix.sum(axis=0) != 0  

def warp(coord):
    """
    Puts fractional coordinates that fall outside back into [0,1].
    """
    for i in range(3):
        coord[i]=coord[i]%1
    return coord

It will be useful to know which Wyckoff sites have free variables (i.e. which are not special positions)

In [6]:
all_active_wyckoffs = np.array([i for i in range(len(wyckoffs)) if sum(active_dim(wyckoffs[i]))>0])
all_active_wyckoffs

array([0, 1])

Let's take a look at multiplicities:

In [7]:
multiplicities

[8, 4, 4, 4]

Combining groups of wyckoff: sites satisfying composition requirements, and filtering
out those that would repeat wyckoffs that do not have internal degree of freedom (hence can't be occupied
by two different species.

In [8]:
filter_combinations=[]
[]
for i in itertools.product(*[get_possible_combinations(multiplicities,a[0]) for a in atoms]):
    counter = np.zeros(len(wyckoffs))
    for j in i:
        for k in j:
            counter[k[0]] +=1
    t = np.argwhere( counter> 1 ).flatten()
    if False not in np.isin(t,all_active_wyckoffs):
        filter_combinations.append(i)

In [9]:
filter_combinations

[(((1, 4),), ((1, 4),), ((0, 8), (1, 4), (2, 4))),
 (((1, 4),), ((1, 4),), ((0, 8), (1, 4), (3, 4))),
 (((1, 4),), ((1, 4),), ((0, 8), (2, 4), (3, 4))),
 (((1, 4),), ((2, 4),), ((0, 8), (1, 4), (3, 4))),
 (((1, 4),), ((3, 4),), ((0, 8), (1, 4), (2, 4))),
 (((2, 4),), ((1, 4),), ((0, 8), (1, 4), (3, 4))),
 (((3, 4),), ((1, 4),), ((0, 8), (1, 4), (2, 4)))]

Let's sort these based on number of Wyckoff positions (prioritizing lower numbers)

In [10]:
filter_combinations = sorted(filter_combinations, 
       key=lambda x: sum([len(i) for i in x]))

In [11]:
filter_combinations

[(((1, 4),), ((1, 4),), ((0, 8), (1, 4), (2, 4))),
 (((1, 4),), ((1, 4),), ((0, 8), (1, 4), (3, 4))),
 (((1, 4),), ((1, 4),), ((0, 8), (2, 4), (3, 4))),
 (((1, 4),), ((2, 4),), ((0, 8), (1, 4), (3, 4))),
 (((1, 4),), ((3, 4),), ((0, 8), (1, 4), (2, 4))),
 (((2, 4),), ((1, 4),), ((0, 8), (1, 4), (3, 4))),
 (((3, 4),), ((1, 4),), ((0, 8), (1, 4), (2, 4)))]

In [12]:
g=lambda x: sum([len(i) for i in x]) 
num_unique_wyckoffs=[g(x) for x in filter_combinations]

Normally, one can use num_unique_wyckoffs in a candidate structure to prioritize simpler ones over others.
Here they are all the same number.

In [14]:
num_unique_wyckoffs

[5, 5, 5, 5, 5, 5, 5]

Now, let's build a loop for combing through all structures and removing those that are likely to have atoms too close. This loop can take anywhere from a few seconds to few hours, primarily controlled by the npoints and distance threshold parameters.

In [16]:
filter_further = [] # We will collect all "feasible" structures in this list
top_X_combinations = -1 # If we want to focus on top X combinations, we can set this to some other value X.

counter=0
for combin in filter_combinations[:top_X_combinations]:
    print(counter, combin)
    rolling_good_base_strs = []
    for atom in range(len(combin)):
        elem_group = combin[atom]
        elem = atoms[atom][1]
        if elem in d_mins_squared:
            d_min_squared = d_mins_squared[elem]
        else:
            d_min_squared = None
        
        # FIRST WE WILL GET WYCKOFF SITE GRIDS; 
        # AND REMOVE THOSE OVERLAP ACCROSS DIFFERENT SITES FOR SAME ATOM!
        _g = []
        print('{}.{}: Elem self loop: {}'.format(counter, combin, elem))
        for site in elem_group:
            _g.append(list(get_wyckoff_candidates(wyckoffs[site[0]],npoints=npoints, 
                                                  d_min_squared=d_min_squared, lattice=lattice)))
        within_elem_group = list(itertools.product(*_g))
        
        _d_tol_squared = d_min_squared if d_min_squared else d_tol_squared
        good_strs_within_elem_group = []
        
        for struct in tqdm(within_elem_group):
            skip_str = False
            for sub_pairs in itertools.combinations(struct,2):
                for s1,s2 in itertools.product(*sub_pairs):
                    if np.sum( (
                                    lattice.get_cartesian_coords(
                                                np.array(s1))
                                            -lattice.get_cartesian_coords(
                                                np.array(s2)))**2 ) < _d_tol_squared:
                        skip_str = True
                        break
                else:
                    continue
                break
            if not skip_str:
                good_strs_within_elem_group.append([[i for sub in struct for i in sub]])
            
        if atom==0:
            rolling_good_base_strs = good_strs_within_elem_group
            
        # NOW; we will combine good_strs_within_elem_group and rolling_good_base_strs and
        # remove if any bad structures accross these.
            
        if atom>0:
            print('{}.{}:  Elem pairs loop: {}'.format(counter, combin, elem))
            good_structures_merged = []
            for structs in tqdm( itertools.product(rolling_good_base_strs, good_strs_within_elem_group), 
                               total=len(rolling_good_base_strs)*len(good_strs_within_elem_group)):
                skip_str = False
                for i in range(len(structs[0])):
                    # different atoms of previous kind
                    atomgroup1 = structs[0][i]
                    atomgroup2 = structs[1][0]
                    pair = '-'.join(sorted([atoms[i][1], atoms[atom][1]]))
                    _d_tol_squared = max( d_mins_squared.get(pair, 0), d_tol_squared)     
                    for s1,s2 in itertools.product(atomgroup1,atomgroup2):
                        if np.sum( (lattice.get_cartesian_coords(
                                                    np.array(s1))
                                                -lattice.get_cartesian_coords(
                                                    np.array(s2)))**2 ) < _d_tol_squared:
                                skip_str = True
                                break
                    if skip_str:
                        break
                if not skip_str:
                    good_structures_merged.append(structs[0]+[structs[1][0]])
            rolling_good_base_strs = good_structures_merged
    filter_further+=rolling_good_base_strs
    counter+=1

0 (((1, 4),), ((1, 4),), ((0, 8), (1, 4), (2, 4)))
0.(((1, 4),), ((1, 4),), ((0, 8), (1, 4), (2, 4))): Elem self loop: Pb


HBox(children=(FloatProgress(value=0.0, max=72.0), HTML(value='')))


0.(((1, 4),), ((1, 4),), ((0, 8), (1, 4), (2, 4))): Elem self loop: S


HBox(children=(FloatProgress(value=0.0, max=68.0), HTML(value='')))


0.(((1, 4),), ((1, 4),), ((0, 8), (1, 4), (2, 4))):  Elem pairs loop: S


HBox(children=(FloatProgress(value=0.0, max=4896.0), HTML(value='')))


0.(((1, 4),), ((1, 4),), ((0, 8), (1, 4), (2, 4))): Elem self loop: O


HBox(children=(FloatProgress(value=0.0, max=5508.0), HTML(value='')))


0.(((1, 4),), ((1, 4),), ((0, 8), (1, 4), (2, 4))):  Elem pairs loop: O


HBox(children=(FloatProgress(value=0.0, max=373464.0), HTML(value='')))


1 (((1, 4),), ((1, 4),), ((0, 8), (1, 4), (3, 4)))
1.(((1, 4),), ((1, 4),), ((0, 8), (1, 4), (3, 4))): Elem self loop: Pb


HBox(children=(FloatProgress(value=0.0, max=72.0), HTML(value='')))


1.(((1, 4),), ((1, 4),), ((0, 8), (1, 4), (3, 4))): Elem self loop: S


HBox(children=(FloatProgress(value=0.0, max=68.0), HTML(value='')))


1.(((1, 4),), ((1, 4),), ((0, 8), (1, 4), (3, 4))):  Elem pairs loop: S


HBox(children=(FloatProgress(value=0.0, max=4896.0), HTML(value='')))


1.(((1, 4),), ((1, 4),), ((0, 8), (1, 4), (3, 4))): Elem self loop: O


HBox(children=(FloatProgress(value=0.0, max=5508.0), HTML(value='')))


1.(((1, 4),), ((1, 4),), ((0, 8), (1, 4), (3, 4))):  Elem pairs loop: O


HBox(children=(FloatProgress(value=0.0, max=359632.0), HTML(value='')))


2 (((1, 4),), ((1, 4),), ((0, 8), (2, 4), (3, 4)))
2.(((1, 4),), ((1, 4),), ((0, 8), (2, 4), (3, 4))): Elem self loop: Pb


HBox(children=(FloatProgress(value=0.0, max=72.0), HTML(value='')))


2.(((1, 4),), ((1, 4),), ((0, 8), (2, 4), (3, 4))): Elem self loop: S


HBox(children=(FloatProgress(value=0.0, max=68.0), HTML(value='')))


2.(((1, 4),), ((1, 4),), ((0, 8), (2, 4), (3, 4))):  Elem pairs loop: S


HBox(children=(FloatProgress(value=0.0, max=4896.0), HTML(value='')))


2.(((1, 4),), ((1, 4),), ((0, 8), (2, 4), (3, 4))): Elem self loop: O


HBox(children=(FloatProgress(value=0.0, max=68.0), HTML(value='')))


2.(((1, 4),), ((1, 4),), ((0, 8), (2, 4), (3, 4))):  Elem pairs loop: O


HBox(children=(FloatProgress(value=0.0, max=4256.0), HTML(value='')))


3 (((1, 4),), ((2, 4),), ((0, 8), (1, 4), (3, 4)))
3.(((1, 4),), ((2, 4),), ((0, 8), (1, 4), (3, 4))): Elem self loop: Pb


HBox(children=(FloatProgress(value=0.0, max=72.0), HTML(value='')))


3.(((1, 4),), ((2, 4),), ((0, 8), (1, 4), (3, 4))): Elem self loop: S


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


3.(((1, 4),), ((2, 4),), ((0, 8), (1, 4), (3, 4))):  Elem pairs loop: S


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


3.(((1, 4),), ((2, 4),), ((0, 8), (1, 4), (3, 4))): Elem self loop: O


HBox(children=(FloatProgress(value=0.0, max=5508.0), HTML(value='')))


3.(((1, 4),), ((2, 4),), ((0, 8), (1, 4), (3, 4))):  Elem pairs loop: O


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


4 (((1, 4),), ((3, 4),), ((0, 8), (1, 4), (2, 4)))
4.(((1, 4),), ((3, 4),), ((0, 8), (1, 4), (2, 4))): Elem self loop: Pb


HBox(children=(FloatProgress(value=0.0, max=72.0), HTML(value='')))


4.(((1, 4),), ((3, 4),), ((0, 8), (1, 4), (2, 4))): Elem self loop: S


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


4.(((1, 4),), ((3, 4),), ((0, 8), (1, 4), (2, 4))):  Elem pairs loop: S


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


4.(((1, 4),), ((3, 4),), ((0, 8), (1, 4), (2, 4))): Elem self loop: O


HBox(children=(FloatProgress(value=0.0, max=5508.0), HTML(value='')))


4.(((1, 4),), ((3, 4),), ((0, 8), (1, 4), (2, 4))):  Elem pairs loop: O


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


5 (((2, 4),), ((1, 4),), ((0, 8), (1, 4), (3, 4)))
5.(((2, 4),), ((1, 4),), ((0, 8), (1, 4), (3, 4))): Elem self loop: Pb


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


5.(((2, 4),), ((1, 4),), ((0, 8), (1, 4), (3, 4))): Elem self loop: S


HBox(children=(FloatProgress(value=0.0, max=68.0), HTML(value='')))


5.(((2, 4),), ((1, 4),), ((0, 8), (1, 4), (3, 4))):  Elem pairs loop: S


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


5.(((2, 4),), ((1, 4),), ((0, 8), (1, 4), (3, 4))): Elem self loop: O


HBox(children=(FloatProgress(value=0.0, max=5508.0), HTML(value='')))


5.(((2, 4),), ((1, 4),), ((0, 8), (1, 4), (3, 4))):  Elem pairs loop: O


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




Let's see how many structures are generated with a npoints=10 and distance thresholds specified:

In [17]:
len(filter_further)

14488

Okay, now let's compare these to experimental structure to see if we have a decent match in our list. Whole workflow below is really crude -- needs lots of cleanup/refinement. And is also slow.

In [20]:
original_str = Structure.from_file('PbSO4_AMS_DATA.cif')

In [21]:
from pymatgen.analysis.structure_matcher import StructureMatcher
sm = StructureMatcher(ltol=0.6,stol=0.6, angle_tol=25)

Let's get pymatgen.Structure objects from our coordinates, and comapre those to the original structure. This will take a while...

In [24]:
species = ['Pb']*4+['S']*4+['O']*16
scores=[]
for m in range(len(filter_further)):
    sites = []
    for j in [list(i) for i in filter_further[m]]:
        for k in j:
            sites.append(k)
    s = Structure(lattice, species, sites)
    if m%100==0:
        print(m,' ', end='')
    scores.append( sm.get_rms_dist(s, original_str) )

0  100  200  300  400  500  600  700  800  900  1000  1100  1200  1300  1400  1500  1600  1700  1800  1900  2000  2100  2200  2300  2400  2500  2600  2700  2800  2900  3000  3100  3200  3300  3400  3500  3600  3700  3800  3900  4000  4100  4200  4300  4400  4500  4600  4700  4800  4900  5000  5100  5200  5300  5400  5500  5600  5700  5800  5900  6000  6100  6200  6300  6400  6500  6600  6700  6800  6900  7000  7100  7200  7300  7400  7500  7600  7700  7800  7900  8000  8100  8200  8300  8400  8500  8600  8700  8800  8900  9000  9100  9200  9300  9400  9500  9600  9700  9800  9900  10000  10100  10200  10300  10400  10500  10600  10700  10800  10900  11000  11100  11200  11300  11400  11500  11600  11700  11800  11900  12000  12100  12200  12300  12400  12500  12600  12700  12800  12900  13000  13100  13200  13300  13400  13500  13600  13700  13800  13900  14000  14100  14200  14300  14400  

In [28]:
# we should write a simple function to setup Structure from coordinates in filter_further...

In [26]:
_scores = [(scores[i][0],i) for i in range(len(scores)) if scores[i] is not None]

In [27]:
min(_scores)

(0.3055005621261954, 8364)

We can save this structure we found and inspect.

In [30]:
species = ['Pb']*4+['S']*4+['O']*16
sites = []
for j in [list(i) for i in filter_further[min(_scores)[1]]]:
    for k in j:
        sites.append(k)
s = Structure(lattice, species, sites)
s.to('poscar','best_so_far.vasp')

This structure looks reasonble, but it does not seem close enough to the actual structure. My test with npointst=12 
(for which StructureMatching, rather than generation, is the the most time-consuming step) yielded a structure that has a score 0.29 but looks much closer to the original one.