In [1]:
import torch
from torch import cdist
from torch.nn.functional import pdist
from tqdm.auto import tqdm
import numpy as np
from ase import Atoms
from ase.io import read, write
from ase.build import make_supercell
from ase.build.tools import sort as ase_sort
from ase.visualize import view
import warnings
from mendeleev import element
from mendeleev.fetch import fetch_table

In [2]:
def get_default_atoms(
            atom_type: str, 
            output_type: str ='number'
            ):
            """
            Get the default atoms based on the atom type and output type.

            Parameters:
            - atom_type (str): The type of atoms to retrieve. Accepts either "metal" or "ligand".
            - output_type (str): The type of output to retrieve. Accepts either "number" or "symbol". Defaults to "number".

            Returns:
            - atoms (list): The list of default atoms based on the atom type and output type.

            Raises:
            - ValueError: If the atom_type is not "metal" or "ligand".
            - ValueError: If the output_type is not "number" or "symbol".
            """
            if atom_type == 'metal':
                if output_type == 'number':
                    atoms = [
                        3, 4, 5, 11, 12, 13, 14, 19, 20, 21, 22, 23, 24, 
                        25, 26, 27, 28, 29, 30, 31, 32, 33, 37, 38, 39, 
                        40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 
                        52, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 
                        66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 
                        78, 79, 80, 81, 82, 83, 88
                    ]
                elif output_type == 'symbol':
                    atoms = [
                        'Li', 'Be', 'B', 'Na', 'Mg', 'Al', 'Si', 'K', 
                        'Ca', 'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 
                        'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Rb', 'Sr', 
                        'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 
                        'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'Cs', 'Ba', 
                        'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 
                        'Tb', 'Dy', 'Ho', 'Er', 'Tm', 'Yb', 'Lu', 'Hf', 
                        'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 
                        'Tl', 'Pb', 'Bi', 'Ra'
                    ]
                else:
                    raise ValueError('FAILED: Invalid output_type, accepts only "number" or "symbol"')
            elif atom_type == 'ligand':
                if output_type == 'number':
                    atoms = [
                        1, 6, 7, 8, 9, 15, 16, 17, 34, 35, 53
                    ]
                elif output_type == 'symbol':
                    atoms = [
                        'H', 'C', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Se', 'Br', 'I'
                    ]
                else:
                    raise ValueError('FAILED: Invalid output_type, accepts only "number" or "symbol"')
            else:
                raise ValueError('FAILED: Invalid atom_type, accepts only "metal" or "ligand"')
            return atoms

In [16]:
#structure_path = '../InOrgMatDataset/Dataset/CIFs/SimulatedTest/AntiFluorite_Fe2O.cif'
structure_path = '../InOrgMatDataset/Dataset/CIFs/COD_subset_cleaned/1008022.cif'
radii = [5, 10, 15]#, 20, 25]
metals = 'Default'
ligands = 'Default'
device = 'cuda'
_lightweight_mode = False
sort_atoms = False
disable_pbar = False

# Fetch node features and replace NaNs with 0.0
atomic_size_table = fetch_table('elements')[['atomic_number', 'atomic_radius']]
atomic_size_table['atomic_radius'] = atomic_size_table['atomic_radius'] / 100 # Convert pm to Å

In [17]:
# Handle metals and ligands
if metals == 'Default':
    metals = get_default_atoms('metal', output_type='number')
elif isinstance(metals, list):
    if isinstance(metals[0], str):
        try:
            from mendeleev import element
            metals = [element(elm).atomic_number for elm in metals]
        except ImportError:
            raise ImportError('FAILED: Please install mendeleev to use element symbols')
else:
    raise ValueError('FAILED: Please provide valid metals for generation of nanoparticles')

if ligands == 'Default':
    ligands = get_default_atoms('ligand', output_type='number')
elif isinstance(ligands, list):
    if isinstance(ligands[0], str):
        try:
            from mendeleev import element
            ligands = [element(elm).atomic_number for elm in ligands]
        except ImportError:
            raise ImportError('FAILED: Please install mendeleev to use element symbols')

# Read the input unit cell structure
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    unit_cell = read(structure_path)
cell_dims = np.array(unit_cell.cell.cellpar()[:3])
r_max = np.amax(radii)

# Create a supercell to encompass the entire range of nanoparticles and center it
size_check = np.array([False, False, False])
padding = np.array([0,0,0])
print(padding)
while not all(size_check):
    print(size_check)
    padding[~size_check] += 2 # Symmetric padding to ensure the particle does not exceed the supercell boundary
    print(padding)
    supercell_matrix = np.diag((np.ceil(r_max / cell_dims)) * 2 + padding)
    cell = make_supercell(prim=unit_cell, P=supercell_matrix)
    size_check = cell.get_positions().max(axis=0) >= (r_max * 2 + 5) # Check if the supercell is larger than diameter of largest particle + 5 Angstroms of padding
            
cell.center(about=0.)

atomic_numbers = cell.get_atomic_numbers()

# Find atomic radii
atomic_radii = torch.tensor(np.array([
    atomic_size_table.loc[atom-1].values
    for atom in atomic_numbers
    ], dtype='float'), device=device)

# Convert positions to torch and send to device
positions = torch.from_numpy(cell.get_positions()).to(dtype = torch.float32, device = device)

if _lightweight_mode:
    center_dists = torch.norm(positions, dim=1)
else:
    # Find all metals and center around the nearest metal
    metal_filter = torch.BoolTensor([a in metals for a in atomic_numbers]).to(device = device)
    ligand_filter = torch.BoolTensor([a in ligands for a in atomic_numbers]).to(device = device)
    center_dists = torch.norm(positions, dim=1)
    positions -= positions[metal_filter][torch.argmin(center_dists[metal_filter])]
    center_dists = torch.norm(positions, dim=1)
    ## Update the cell positions
    cell.positions = positions.cpu()

# Calculate distance matrix
cell_dists = cdist(positions, positions)

# Create mask of threshold for bonds
bond_threshold = torch.zeros_like(cell_dists, device=device)
for i, r1 in enumerate(atomic_radii[:,1]):
    bond_threshold[i,:] = (r1 + atomic_radii[:,1]) * 1.25
bond_threshold.fill_diagonal_(0.)

# Find edges
direction = torch.argwhere(cell_dists < bond_threshold).T

# Handle case with no edges
if len(direction[0]) == 0:
    min_dist = torch.amin(cell_dists[cell_dists > 0])
    direction = torch.argwhere(cell_dists < min_dist * 1.1).T

# Initialize nanoparticle lists and progress bar
nanoparticle_list = []
nanoparticle_sizes = []
pbar = tqdm(desc=f'Generating nanoparticles in range: [{np.amin(radii)},{np.amax(radii)}]', leave=False, total=len(radii), disable=disable_pbar)

# Generate nanoparticles for each radius
for r in sorted(radii, reverse=True):
    if _lightweight_mode:
        # Mask all atoms within radius
        incl_mask = (center_dists <= r) 
        
        # Modify objects based on mask
        cell = cell[incl_mask.cpu()]
        center_dists = center_dists[incl_mask]
        
    else:      
        # Mask all metal atoms within radius
        excl_mask = (center_dists > r) & metal_filter
        incl_mask = (center_dists <= r) & metal_filter

        excl_indices = torch.nonzero(excl_mask).flatten()
        incl_indices = torch.nonzero(incl_mask).flatten()

        included_edges = direction[:,(torch.isin(direction[0], incl_indices) + torch.isin(direction[1], incl_indices))]

        included_edges = included_edges[:,~(torch.isin(included_edges[0], excl_indices) + torch.isin(included_edges[1], excl_indices))]
        
        included_atoms = included_edges.unique()

        np_dists = center_dists[included_atoms]

        np_cell = cell[included_atoms.cpu()]

    
    # Determine NP size
    nanoparticle_size = torch.amax(np_dists) * 2

    # Sort the atoms
    if sort_atoms:
        sorted_cell = ase_sort(np_cell)
        if sorted_cell.get_atomic_numbers()[0] in ligands:
            sorted_cell = sorted_cell[::-1]

        # Append nanoparticle
        nanoparticle_list.append(sorted_cell)
    else:
        # Append nanoparticle
        nanoparticle_list.append(np_cell)

    # Append size
    nanoparticle_sizes.append(nanoparticle_size.item())

    pbar.update(1)
pbar.close()

[0 0 0]
[False False False]
[2 2 2]


Generating nanoparticles in range: [5,15]:   0%|          | 0/3 [00:00<?, ?it/s]

In [18]:
n_atoms_unit_cell = len(unit_cell)
for i, np_size in enumerate(nanoparticle_sizes):
    unit_vol = unit_cell.get_volume()
    np_vol = 4/3 * np.pi * (np_size/2)**3
    est_atoms = n_atoms_unit_cell * np_vol / unit_vol
    n_np_atoms = len(nanoparticle_list[i])
    n_atoms_error = (est_atoms - n_np_atoms) / n_np_atoms
    # Print results
    print(f'Radius: {(np_size//2)//5*5} Å,\tNP size: {np_size:.2f} Å,\tEstimated atoms: {est_atoms:.0f},\tActual atoms: {n_np_atoms},\tError: {n_atoms_error:.2%}')

Radius: 10.0 Å,	NP size: 27.82 Å,	Estimated atoms: 94,	Actual atoms: 97,	Error: -3.44%
Radius: 5.0 Å,	NP size: 17.60 Å,	Estimated atoms: 24,	Actual atoms: 27,	Error: -12.24%
Radius: 0.0 Å,	NP size: 0.00 Å,	Estimated atoms: 0,	Actual atoms: 1,	Error: -100.00%


In [21]:
view(nanoparticle_list[1], viewer='ngl')

HBox(children=(NGLWidget(), VBox(children=(Dropdown(description='Show', options=('All', 'Mo'), value='All'), D…

In [15]:
#write('test_np.xyz', nanoparticle_list[0])