In [5]:
from ase.io import read, write
from ase.build import make_supercell
import numpy as np
import os
import random


In [6]:
def generate_supercells(cif_files, input_dir, output_dir, N_MAX=4, verbose=False):
    # Convert Each CIF File (lots of gross text processing, beware!)
    dead_files = 0  # Files with no data
    for f in cif_files:
        f_name = f.split('.cif')[0]
        print("Name: ", f_name)
        # Alter some formatting to make the file load better into ASE, copy into output folder
        occupation = dict()  # For each site, what elements exist and how many
        site_tag_dict = dict()  # Dict with element tags and site name values
        # We replace the actual element with one of these so we can find the site later
        site_tags = ['H','He','Li','Be','B','C','N','O','F','Ne',  
                     'Na','Mg','Al','Si','P','S','Cl','Ar','K', 'Ca',
                     'Sc', 'Ti', 'V','Cr', 'Mn', 'Fe', 'Co', 'Ni',
                     'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr',
                     'Rb', 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru',
                     'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te',
                     'I', 'Xe','Cs', 'Ba','La', 'Ce', 'Pr', 'Nd', 'Pm',
                     'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er', 'Tm',
                     'Yb', 'Lu', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir',
                     'Pt', 'Au', 'Hg', 'Tl', 'Pb', 'Bi', 'Po', 'At', 'Rn',
                     'Fr', 'Ra', 'Ac', 'Th', 'Pa', 'U', 'Np', 'Pu', 'Am',
                     'Cm', 'Bk', 'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr',
                     'Rf', 'Db', 'Sg', 'Bh','Hs', 'Mt', 'Ds', 'Rg', 'Cn',
                     'Nh', 'Fl', 'Mc', 'Lv', 'Ts', 'Og']
        for k in range(2,7):  # Make it extra long
            site_tags += [k*i for i in site_tags]
        print(len(site_tags))

        with open(input_dir+f, "r") as file:
            contents = file.read().split('\n')
            fh  = open(output_dir+f, 'w') # Re-saves a better formatted file
            format_start = False
            write_start = False
            num_loops = 0  # number of times '_loop' is seen
            data_lookup = dict()
            for line in contents:
                # Record Occupation
                if num_loops>=2 and len(line)>1 and '#' not in line:  # Data stored in second _loop
                   if line[1] == '_':  # Add to data labels 
                        data_lookup[line] = len(data_lookup)  # Index for that data label
                   elif len(line.split()) > 4 :  # Data probably listed
                       spl = line.split()
                       try:
                           site_label = spl[data_lookup[' _atom_site_fract_x']] + " " + spl[data_lookup[' _atom_site_fract_y']] + " " + spl[data_lookup[' _atom_site_fract_z']] 
                       except: raise Exception("NO SITE DESCRIPTORS FOUND")
                        
                       if site_label not in occupation.keys():
                            occupation[site_label] = {}
                            tag = site_tags[len(site_tag_dict)]
                            site_tag_dict[tag] = site_label  # Remember this site is a given element tag
                       else:
                           tag = list(site_tag_dict.keys())[list(site_tag_dict.values()).index(site_label)]
                       #  Add element occupancy to
                       if ' _atom_site_occupancy' in data_lookup.keys():
                           occupation[site_label][spl[data_lookup[' _atom_site_type_symbol']]] = float(spl[data_lookup[' _atom_site_occupancy']])
                       else:  # No element occupancy data found
                           occupation[site_label][spl[data_lookup[' _atom_site_type_symbol']]] = 1
                        
                       spl[data_lookup[' _atom_site_type_symbol']] = tag  # Replace real element with tag
                       line = ' '+'\t'.join(spl)
               # Format
                if len(line)<1:  # End formatting
                    format_start = False
                if 'data' in line or '_cell' in line:
                    write_start = True
                # Copy Over text
                if format_start and "'" not in line:
                    spl = line.split()
                    fh.write(spl[0]+" '"+' '.join(spl[1:])+"' \n")
                elif write_start:
                    fh.write(line+'\n')
                if '_symmetry_equiv_pos_as_xyz' in line:
                    format_start = True
                if 'loop_' in line:
                    num_loops += 1
            fh.close()    
            file.close()
        
        # Throw error if file is empty
        if len(occupation) < 1:
            print("NO DATA FOUND FOR: ", f, "! Skipping...")
            dead_files += 1
            continue
          
        print("Occupation: ", occupation)
        # Load the original CIF file
        if verbose:  
            with(open(output_dir+f, "r") as file):
                contents = file.read().split('\n')
                for line in contents:
                    print(line)
            file.close()
            
        original_structure = read(output_dir+f) # Open formatted file copied into output dir
        
        # Define the supercell size adaptively
        # Find minimum element fraction for any given element
        x_min = 1
        for k1 in occupation.values():
            for k2 in k1.values():
                x_min = min(k2, x_min)
        n = np.ceil((5/x_min)**(1/3))  # Adaptive cell size to be <10% element representation error
        if x_min == 1:  # Use base cell if no mixed occupation lattice site
            supercell_dim = 1
        else:
            supercell_dim = min(n, N_MAX)  # Use a big enough supercell to limit representative error, max size N_MAX
        print("Supercell dim: ", supercell_dim)
        supercell_size = supercell_dim*np.eye(3) 
                
        # Create the supercell
        superlattice = make_supercell(original_structure, supercell_size)
        
        # Keep track of where each atom is in the superlattice
        site_list = list() # which site is at each atomic location in superlattice
        site_count = dict() # how many site of each exist
        for k in occupation.keys():
            site_count[k] = 0
        for atom in range(len(superlattice)):
            tag = superlattice[atom].symbol  # Name of tag element
            site = site_tag_dict[tag]  # Name of site this represents
            site_list.append(site)
            site_count[site] += 1

        if len(site_list) != len(superlattice):
            print(site_list)
            print(superlattice)
            raise Exception("Not every atom in super-lattice accounted for in occupation dictionary!")
        
        # Create list of atoms to draw from for each site
        site_sampler = dict()
        for site in occupation.keys():
            site_sampler[site] = list()
            # Make representative list of atoms to draw from
            for atom in occupation[site].keys():
                n_atoms = round( occupation[site][atom]*site_count[site] )
                site_sampler[site] = site_sampler[site] + [atom]*n_atoms
            # List too long, pop random element
            while len(site_sampler[site]) > site_count[site]: 
                random.shuffle(site_sampler[site])
                site_sampler[site].pop()
            # List too short, fill with vacancies
            while len(site_sampler[site]) < site_count[site]:
                site_sampler[site].append('')
    
        # Replace atoms based on occupation fraction
        vacancies = list()
        # Shuffle each site in the sampler dictionary
        for site in occupation.keys():
            random.shuffle(site_sampler[site])
        # For every atom in the lattice, draw an element to replace the atom with
        for atom in range(len(superlattice)):
            atom_site = site_list[atom]
            element_draw = site_sampler[atom_site][0]
            del site_sampler[atom_site][0]
            if element_draw == "":
                vacancies.append(atom)
            else:
                superlattice[atom].symbol = element_draw
        # Create oxygen vacancies in lattice by deletion
        for i in range(len(vacancies)):
            del superlattice[vacancies[i]-i]
   
        # Save the superlattice structure in a new CIF file
        write(output_dir+f_name+'_super.cif', superlattice)
    print("Number of dead files: ", dead_files)
    print("Finished: processed "+ str(len(cif_files)) +" files")

In [8]:

## Script Parameters
input_dir = "data_linear/"  # CIF files source directory
output_dir = "supercells_data/"  # CIF supercell output file directory

# Load all CIF files in
file_type = ".cif"
files = os.listdir(input_dir)
cif_files = [file for file in files if file.endswith(file_type)]

# Supercell max dimension
N_MAX = 4  # maximum size allowed for supercell (NxNxN unit cells)

# Make Output Directory if needed
if not os.path.isdir(output_dir):
    os.mkdir(output_dir)
# Atom('O', [27.765269999999994, 34.01027, 9.945936], index=1231)
generate_supercells(cif_files, input_dir, output_dir, N_MAX=N_MAX)

Name:  10004
3776
Occupation:  {'0.1759 0.0471 0.2942': {'O': 1.0}, '0.4243 0.2500 0.0146': {'Sm': 1.0}, '0.5276 0.2500 0.5966': {'O': 1.0}, '0.0000 0.0000 0.0000': {'Mn': 1.0}}
Supercell dim:  1
Name:  100055
3776
NO DATA FOUND FOR:  100055.cif ! Skipping...
Name:  100096
3776
Occupation:  {'0.5000 0.5000 0.2384': {'Sr': 1.0}, '0.0000 0.0000 0.2703': {'O': 0.94}, '0.0000 0.5000 0.0000': {'O': 0.85}, '0.0000 0.5000 0.5000': {'O': 1.0}, '0.0000 0.0000 0.5000': {'Co': 0.93, 'Mo': 0.07}, '0.0000 0.0000 0.0000': {'Co': 0.97, 'Mo': 0.03}}
Supercell dim:  4
Name:  100147
3776
Occupation:  {'0.2433 0.0000 0.0000': {'O': 0.77}, '0.2500 0.2500 0.2500': {'Ba': 1.0}, '0.5000 0.5000 0.5000': {'Co': 1.0}, '0.0000 0.0000 0.0000': {'Co': 0.56, 'Bi': 0.22, 'Sc': 0.22}}
Supercell dim:  3.0
Name:  100216
3776
Occupation:  {'0.5000 0.0000 0.0000': {'O': 0.917}, '0.5000 0.5000 0.5000': {'Sr': 1.0}, '0.0000 0.0000 0.0000': {'Co': 0.5, 'Ti': 0.5}}
Supercell dim:  3.0
Name:  100406
3776
Occupation:  {'0.2107

In [34]:
"""
# Math to convert cartesian to fractional coord
a_cell, b_cell, c_cell, alpha, beta, gamma = superlattice.get_cell_lengths_and_angles()
a_cell /= supercell_dim
b_cell /= supercell_dim
c_cell /= supercell_dim
alpha = np.deg2rad(alpha)
beta = np.deg2rad(beta)
gamma = np.deg2rad(gamma)
n_cell = (np.cos(alpha)-np.cos(gamma)*np.cos(beta))/np.sin(gamma)
M = np.array([[a_cell, 0, 0], [b_cell*np.cos(gamma), b_cell*np.sin(gamma), 0], [c_cell*np.cos(beta), c_cell*n_cell, c_cell*np.sqrt(np.sin(beta)**2-n_cell**2)]])
M_inv = np.linalg.inv(M)

frac_atom_pos = np.matmul(np.array(atom_pos), M_inv)
print("Frac pos: ", frac_atom_pos)
shifted_pos = np.mod(frac_atom_pos, np.ones((1,3))/supercell_dim)*supercell_dim  # shifted across supercell to original cell
"""

'\n# Math to convert cartesian to fractional coord\na_cell, b_cell, c_cell, alpha, beta, gamma = superlattice.get_cell_lengths_and_angles()\na_cell /= supercell_dim\nb_cell /= supercell_dim\nc_cell /= supercell_dim\nalpha = np.deg2rad(alpha)\nbeta = np.deg2rad(beta)\ngamma = np.deg2rad(gamma)\nn_cell = (np.cos(alpha)-np.cos(gamma)*np.cos(beta))/np.sin(gamma)\nM = np.array([[a_cell, 0, 0], [b_cell*np.cos(gamma), b_cell*np.sin(gamma), 0], [c_cell*np.cos(beta), c_cell*n_cell, c_cell*np.sqrt(np.sin(beta)**2-n_cell**2)]])\nM_inv = np.linalg.inv(M)\n\nfrac_atom_pos = np.matmul(np.array(atom_pos), M_inv)\nprint("Frac pos: ", frac_atom_pos)\nshifted_pos = np.mod(frac_atom_pos, np.ones((1,3))/supercell_dim)*supercell_dim  # shifted across supercell to original cell\n'