## Kabsch interpolation

<mark>Warning: Always check that the interpolated structures are correct - you can visualise the generated structures using [vesta](https://jp-minerals.org/vesta/en/) or similar.</mark>

If you use this code please consider:
- [Citing the associated paper](https://arxiv.org/abs/2302.08412) (currently under review)
- [Citing the Atomic Simulation Environment](https://wiki.fysik.dtu.dk/ase/faq.html#how-should-i-cite-ase)

Please see https://github.com/NU-CEM/Kabsch_interpolation for more information.

## Imports

In [6]:
import ase
from ase import io
from ase.geometry import analysis
from ase.build import molecule
from ase import Atoms
from ase import neighborlist

from IPython.display import HTML
import numpy as np
from numpy.linalg import norm 
from scipy.spatial.transform import Rotation as R
from scipy.spatial.transform import Slerp
from scipy import sparse

import itertools
from collections import Counter
from copy import deepcopy
from tempfile import NamedTemporaryFile

## Custom functions

In [7]:
def interpolate_structures(start_atoms, end_atoms, molecular_formulas=None, number_intermediates=9, fformat="vasp", reverse=True, molecular_indices=None, translation_species = None, mic_cutoff=0.5):
    """ Uses linear interpolation along translation and rotation vectors to create intermediate 
    structures that lie between those found in start_filepath and end_filepath. Molecules are identified, with each molecular orientation described by a 
    the set of direction vectors that connect the atoms. Atoms which are not identified as part 
    of a molecule are translated without rotation. 
    
    Args:
    start_atoms (str) - ASE Atoms object for the start structure.
    end_atoms (str) - ASE Atoms object for the end structure.
    molecular_formulas (list(str))(optional) - a list of molecular formulas to which Kabsch interpolation will be applied. If not set then will use molecular_indices keyword argument to identify atoms for Kabsch interpolation. Defaults to None.
    number_intermediates (int)(optional) - number of intermediate structures. Defaults to 9.
    fformat (string)(optional) - file format for writing intermediate structures. For output options see https://wiki.fysik.dtu.dk/ase/ase/io/io.html. Defaults to "vasp".
    reverse (bool)(optional) - create additional interpolations along the same vector in the negative sense (e.g. interpolate between 0 and 52 degrees, and 0 to -52 degrees). Defaults to True.
    molecular_indices (list(arrays))(optional) - list of numpy arrays. Each array contains the indices for a molecule. If not set, molecular_indices will be found automatically using the molecular_formulas keyword. Defaults to None.
    translation_species (list(str))(optional) - a list of elementa species which will be translated (without Kabsch interpolation). If not set then all non-molecular species will be translated. Defaults to None.
    mic_cutoff (float) - cutoff distance from edge of unit cell, below which the minimum image convention is applied to any molecule."""

    # need to ensure minimum image convention between start and end structures
    start_atoms, end_atoms = start_end_mic(start_atoms, end_atoms)
        
    # need to specify either molecular_formulas (most common use case) or molecular_indices (for awkward cases where ASE neighbour analysis doesn't work)
    if molecular_formulas is None and molecular_indices is None:
        raise ValueError("either molecular_formulas or molecular_indices must be specified")
        
    # if single entry for molecular formula, convert it to a list.
    if type(molecular_formulas) is str:
        molecular_formulas = [molecular_formulas]
    
    # get index of every atom that belongs to a molecule (to apply Kabsch interpolation).
    if molecular_indices is not None:
        molecular_indices_list = molecular_indices
    else:
        molecular_indices_list = find_molecules(start_atoms, molecular_formulas)
    
    translation_indices = get_translation_indices(start_atoms, translation_species, molecular_indices_list)

    if reverse:
        iterator = range(-(number_intermediates+1),number_intermediates+2)
    else:
        iterator = range(0,number_intermediates+2)
        
    for step_index in iterator:
        
        # "amplitude" of interpolation
        interval = step_index * (1/(number_intermediates+1))
        
        # need deepcopies to avoid overwriting
        positions = deepcopy(start_atoms.get_positions())            
        atoms = deepcopy(start_atoms)       
        
        # interpolate along translation vector 
        for atom_index in translation_indices:
            atom_translation = end_atoms[atom_index].position - start_atoms[atom_index].position
            beta = atom_translation*interval
            atoms[atom_index].position = start_atoms[atom_index].position + beta
        
        # update positions with the interpolated positions
        positions = atoms.positions
        
        # kabsch interpolation along translation and rotation vectors
        for molecule_indices in molecular_indices_list: 
            
            # get atoms object to describe particular molecule
            start_molecule = start_atoms[molecule_indices] 
            end_molecule = end_atoms[molecule_indices] 
            
            # need to apply minimum image convention to any molecule that may bridge between neighbouring unit cells
            if ((np.abs(start_molecule.positions) < mic_cutoff).any() ) or ((np.abs(end_molecule.positions) < mic_cutoff).any()):                           
                start_molecule.set_positions(ase.geometry.geometry.find_mic(start_molecule.positions, start_atoms.cell)[0])
                end_molecule.set_positions(ase.geometry.geometry.find_mic(end_molecule.positions, start_atoms.cell)[0])

            # get set of vectors that describe the molecule
            start_vectors = get_molecule_vectors(start_molecule)
            end_vectors = get_molecule_vectors(end_molecule)
            
            # get rotation and translation vectors for interpolating along
            axis, angle = get_axis_angle(start_vectors, end_vectors)
            translation = get_translation(start_molecule, end_molecule)
            
            # scale vectors by the "amplitude" of the interpolation
            delta = translation * interval 
            alpha = angle * interval
        
            # apply the translation and rotation
            molecule = deepcopy(start_molecule)
            molecule.rotate(alpha, axis, center="COM")
            molecule.translate(delta)
            
            # update the positions with the interpolated positions
            positions[molecule_indices] = molecule.positions
            
        atoms.set_positions(positions)        
        ase.io.write('POSCAR_'+str(step_index).zfill(3)+"."+fformat,atoms, format=fformat)

def find_molecules(atoms, molecular_formulas):
    """ Returns a list of arrays. 
    Each array contains the index of atoms in a molecule."""

    # create a matrix summarising the connectivity of the structure
    # this follows the example in the ASE documentation: 
    # https://wiki.fysik.dtu.dk/ase/ase/neighborlist.html#ase.neighborlist.get_connectivity_matrix
    cutOff = ase.neighborlist.natural_cutoffs(atoms)
    neighborList = neighborlist.NeighborList(
        cutOff, self_interaction=False, bothways=True)
    neighborList.update(atoms)
    # matrix is a scipy sparse matrix as most atoms are not connected to one another.
    # Also allows nice analysis of the connected components
    matrix = neighborList.get_connectivity_matrix()
    # calc number of components (connected atoms) and which atom belongs to which component
    n_components, component_list = sparse.csgraph.connected_components(
       matrix)

    # create list of lists. 
    # Each sub-list contains the indices for a single component (connected atoms).
    # will contain repeat components and components which are not our target molecules.
    molIdxs_list = []
    for idx in range(len(component_list)):
        molIdx = component_list[idx]
        molIdxs = [i for i in range(len(component_list))
               if component_list[i] == molIdx]
        molIdxs_list.append(molIdxs)

    # filter out repeat entries
    molIdxs_list.sort()
    molIdxs_list = list(molIdxs_list for molIdxs_list,_ 
        in itertools.groupby(molIdxs_list))

    # standardise the input molecular formulas so that formatted ASE-style.
    # this is for comparison against those we will find in the start_atoms object.
    molecular_formulas = [ase.Atoms(formula_string).get_chemical_formula() 
        for formula_string in molecular_formulas]

    # filter out molecules which are not the ones we want to rotate
    molIdxs_list = [molIdxs for molIdxs in molIdxs_list if 
        atoms[molIdxs].get_chemical_formula() in molecular_formulas]
    
    # report what has been found
    mol_counter = Counter(atoms[x].get_chemical_formula() for x in molIdxs_list)
    for key in mol_counter:
        print("{} {} molecules have been found".format(mol_counter[key],key))
    print("The molecules have the following indices:")
    for molIdxs in molIdxs_list:
        print(molIdxs)
        
    return [np.array(molIdxs) for molIdxs in molIdxs_list]
        
def get_axis_angle(start_vectors, end_vectors):
    """ Uses Kabsch algorithm to calculate the rotation axis and angle 
    between two sets of vectors"""
  
    transform = R.align_vectors(start_vectors, end_vectors)[0]
    transform.as_rotvec()
    angle = np.degrees(norm(transform.as_rotvec()))
    unit_axis = transform.as_rotvec()/norm(transform.as_rotvec())
    
    return unit_axis, -angle

def get_translation(start_molecule, end_molecule):
    """ Returns the displacement (in Angstrom) between the molecule 
    COM in the start position and end position"""
        
    start_COM = start_molecule.get_center_of_mass()
    end_COM = end_molecule.get_center_of_mass()
    translation = end_COM - start_COM
    
    return translation

def get_molecule_vectors(molecule_atoms):
    """Returns the distance vectors for all connected atoms in a molecule"""

    cutOff = ase.neighborlist.natural_cutoffs(molecule_atoms)
    neighborList = neighborlist.NeighborList(
        cutOff, self_interaction=False, bothways=False)
    neighborList.update(molecule_atoms)
    matrix = neighborList.get_connectivity_matrix()
    rows, columns = matrix.nonzero()
    pair_indices = np.column_stack((rows,columns))

    return [molecule_atoms.get_distance(i,j,mic=True,vector=True) for i,j in pair_indices]     

def get_translation_indices(atoms, translation_species, molecular_indices_list):
    """Returns the indices of all atoms that are to be translated (only, with 
    no rotational interpolation)."""
    
    # if translation_species is not specified then translate every atom not 
    # in molecular_indices_list
    if translation_species is None:
        flat_list = [item for sublist in molecular_indices_list for item in sublist]
        translation_indices = [atom.index for atom in atoms 
            if atom.index not in flat_list]
    # if translation_species is specified then return the indices for every 
    # element in translation_species
    else:
        translation_indices = [atom.index for atom in atoms 
            if atom.symbol in translation_species]
    print("{} translation-only atoms have been found"
        .format(len(translation_indices)))
    return translation_indices

def start_end_mic(start_atoms,end_atoms):
    """In some cases an atom crosses a cell boundary during relaxation between 
    the start and end structures. Shift positions to ensure minimum image convention."""

    for i,atom in enumerate(start_atoms):
        for j,position in enumerate(atom.position):
            if position-end_atoms[i].position[j] > start_atoms.cell.cellpar()[j]/2:
                start_atoms[i].position[j] = position - start_atoms.cell.cellpar()[j]
            if end_atoms[i].position[j] - position > start_atoms.cell.cellpar()[j]/2:
                end_atoms[i].position[j] = (end_atoms[i].position[j] - 
                    start_atoms.cell.cellpar()[j])
    return start_atoms, end_atoms
    

## Example of how to generate interpolated structures

In [4]:
start_atoms = ase.io.read("./POSCAR_start.vasp")
end_atoms = ase.io.read("./POSCAR_end.vasp")
interpolate_structures(start_atoms, end_atoms, molecular_formulas = ["CNH6"], number_intermediates=9, fformat="vasp", reverse=True, molecular_indices=None, translation_species = ["Cs","Pb","I"], mic_cutoff=0.5)

14 CH6N molecules have been found
The molecules have the following indices:
[0, 14, 28, 42, 56, 70, 84, 98]
[1, 15, 29, 43, 57, 71, 85, 99]
[2, 16, 30, 44, 58, 72, 86, 100]
[3, 17, 31, 45, 59, 73, 87, 101]
[4, 18, 32, 46, 60, 74, 88, 102]
[5, 19, 33, 47, 61, 75, 89, 103]
[6, 20, 34, 48, 62, 76, 90, 104]
[7, 21, 35, 49, 63, 77, 91, 105]
[8, 22, 36, 50, 64, 78, 92, 106]
[9, 23, 37, 51, 65, 79, 93, 107]
[10, 24, 38, 52, 66, 80, 94, 108]
[11, 25, 39, 53, 67, 81, 95, 109]
[12, 26, 40, 54, 68, 82, 96, 110]
[13, 27, 41, 55, 69, 83, 97, 111]
67 translation-only atoms have been found


## Example of how to use ASE to check bond lengths

The code below is not required to generate the interpolated structures. It is an example of how to use ASE to monitor bond lengths and identify any suspect outliers in your generated structures.

In [5]:
def get_PbI_bonds(filepath):
    
    atoms = ase.io.read(filepath)
    analysis = ase.geometry.analysis.Analysis(atoms)
    PbIBonds = analysis.get_bonds('Pb', 'I', unique=True)
    print("There are {} Pb-I bonds in a supercell of 16 12-atom primitive cells.".format(len(PbIBonds[0])))
    PbIbondvalues = analysis.get_values(PbIBonds)
    print("The average Pb-I bond length is {}.".format(np.average(PbIbondvalues)))
    print("The standard deviation is {}.".format(np.std(PbIbondvalues)))
    return PbIbondvalues[0]

def summarise_Kabsch(folder,displacement):
    
    print("~~~~"+displacement+"~~~~~")
    Interpolation_PbI_bonds = get_PbI_bonds(folder+displacement)
    if Interpolation_PbI_bonds:
        print("The maximum bond is:", max(Interpolation_PbI_bonds))
        print("The minimum bond is:", min(Interpolation_PbI_bonds))
    
folder="./"
summarise_Kabsch(folder,"POSCAR_000.vasp")
summarise_Kabsch(folder,"POSCAR_002.vasp")
summarise_Kabsch(folder,"POSCAR_004.vasp")
summarise_Kabsch(folder,"POSCAR_006.vasp")
summarise_Kabsch(folder,"POSCAR_008.vasp")
summarise_Kabsch(folder,"POSCAR_010.vasp")

~~~~POSCAR_000.vasp~~~~~
There are 94 Pb-I bonds in a supercell of 16 12-atom primitive cells.
The average Pb-I bond length is 3.191361380622213.
The standard deviation is 0.060971403380010926.
The maximum bond is: 3.4118278587172934
The minimum bond is: 3.0724504463270645
~~~~POSCAR_002.vasp~~~~~
There are 95 Pb-I bonds in a supercell of 16 12-atom primitive cells.
The average Pb-I bond length is 3.1914473405310537.
The standard deviation is 0.06050571815005559.
The maximum bond is: 3.418126583958739
The minimum bond is: 3.0759907227815773
~~~~POSCAR_004.vasp~~~~~
There are 96 Pb-I bonds in a supercell of 16 12-atom primitive cells.
The average Pb-I bond length is 3.1919412052555316.
The standard deviation is 0.060983701887589385.
The maximum bond is: 3.4256673070080095
The minimum bond is: 3.0702041001002494
~~~~POSCAR_006.vasp~~~~~
There are 96 Pb-I bonds in a supercell of 16 12-atom primitive cells.
The average Pb-I bond length is 3.1907099415973335.
The standard deviation is 0.059