In [1]:
from pathlib import Path
import numpy as np
import pandas as pd
import timeit
from pymatgen.io.cif import CifParser
from matminer.featurizers.site import GaussianSymmFunc, SiteElementalProperty,AGNIFingerprints
import os

BOND_MAX_DIST = 3.0  # Max distance for a bond in angstroms


In [2]:
def featurize_dataset(cifs: list, verbose=False, saveto: str = "features.csv") -> pd.DataFrame:
    """Featurize crystal structures using elemetal, geometric, and chemical descriptors for local environments.

    :params cifs: list of paths to crystal structure in cif format
    :params verbos: prints each step of the processing
    :params saveto: filename to save the generated features
    """
    
    ## Process Input Files
    if verbose: print("Parsing CIFs")
    features = {}
    for cif in cifs:
        structure = CifParser(cif).get_structures()[0]
        structure_name = Path(cif).name
        features[structure_name] = {}
        features[structure_name]["structure"] = structure
        features[structure_name]["structure_name"] = Path(cif).name
        features[structure_name]["structure_path"] = str(Path(cif).parent)
    data = pd.DataFrame.from_dict(features).T
    
    ### SITE PROPERTIES ###
    # These will be paired as features
    ## 1. Initialize the dictionary for each site
    #  TODO: Combine these into one big loop with easily expandable features
    if verbose: print("Assembling site property dictionary")
    site_features = {}
    for index, row in data.iterrows():
        structure = row["structure"]
        for atomidx in range(structure.num_sites):
            site_name = "%s_%i" % (index, atomidx)
            site_features[site_name] = {}
            site_features[site_name] = {"structure_name": row["structure_name"]}
            site_features[site_name].update({"structure_path": row["structure_path"]})
    
    ## 1. Site Elemental Property
    if verbose: print("site elemental properties")
    property_list = ("Number", "AtomicWeight", "Row", "Column", "Electronegativity", "CovalentRadius")
    SEP = SiteElementalProperty(properties=property_list)
    colnames = SEP._generate_column_labels(multiindex=False, return_errors=False)
    for index, row in data.iterrows():
        structure = row["structure"]
        if verbose: print(index)
        for atomidx in range(structure.num_sites):
            feat = SEP.featurize(structure, idx=atomidx)
            site_name = "%s_%i" % (index, atomidx)
            site_features[site_name].update(dict(zip(colnames, feat)))

    ## 2. AGNI
    if verbose: print("AGNI")
    property_list = ("Number", "AtomicWeight", "Row", "Column", "Electronegativity", "CovalentRadius")
    AGNI = AGNIFingerprints(cutoff=5, directions=[None])
    colnames = AGNI._generate_column_labels(multiindex=False, return_errors=False)
    for index, row in data.iterrows():
        structure = row["structure"]
        if verbose: print(index)
        for atomidx in range(structure.num_sites):
            feat = AGNI.featurize(structure, idx=atomidx)
            site_name = "%s_%i" % (index, atomidx)
            site_features[site_name].update(dict(zip(colnames, feat)))

    ## 3. Gaussian Symmetry Functions 
    if verbose: print("GSF")
    GSF = GaussianSymmFunc(cutoff=5)
    colnames = GSF._generate_column_labels(multiindex=False, return_errors=False)
    for index, row in data.iterrows():
        structure = row["structure"]
        if verbose: print(index)
        for atomidx in range(structure.num_sites):
            feat = GSF.featurize(structure, idx=atomidx)
            site_name = "%s_%i" % (index, atomidx)
            site_features[site_name].update(dict(zip(colnames, feat)))

 
    ### BOND PAIRS AND BOND PROPERTIES ###
    if verbose: print("Generating bond library")
    structures_bonds = {}  # Store bond pairs
    bond_properties = {}  # Store bond properties
    for index, row in data.iterrows():
        if verbose: print(index)
        structure = row["structure"]
        structures_bonds[index] = []
        bond_properties[index] = []
        neighbors = structure.get_neighbor_list(BOND_MAX_DIST)  # (center_indices, points_indices, offset_vectors, distances)
        for bond in range(len(neighbors[0])):
            if neighbors[0][bond] < neighbors[1][bond]:  # Don't double count bonds
                # Bonded indices
                structures_bonds[index].append((neighbors[0][bond], neighbors[1][bond]))
                # Bond properties (coord-num, bond-len)
                coord_num = list(neighbors[0]).count(neighbors[0][bond])
                bond_properties[index].append((coord_num, neighbors[3][bond]))

    # Build Dataframe by bonds
    if verbose: print("Copying over data to final dataframe")
    delta_properties = ["site Electronegativity", "site AtomicWeight"]  # For these properties, take the difference as a feature
    bond_features = {}  # Final dictionary for saving features format: bond_features['material_bond#']["feature_name"] = data
    for index, row in data.iterrows():
        bond_len_sum = 0
        if verbose: print(index)
        for bond_idx in range(len(structures_bonds[index])):
            bond = structures_bonds[index][bond_idx]
            bond_name = "%s_Atom%i_Bond%i" % (index, bond[0], bond_idx)
            bond_features[bond_name] = {}
            site1_name = "%s_%i" % (index, bond[0])
            site2_name = "%s_%i" % (index, bond[1])
            
            # Add Site features to dictionary
            # Order putting heavier element first
            # TODO: this works but is not very efficient, save data directly to final dataframe in the end?
            site_feat_labels = site_features[site1_name].keys()
            site_feat_labels = [k for k in site_feat_labels if k not in ["structure_path", "structure_name"]]
            bond_features[bond_name]["structure_name"] = site_features[site1_name]["structure_name"]
            bond_features[bond_name]["structure_path"] = site_features[site1_name]["structure_path"]
            if site_features[site1_name]["site AtomicWeight"] > site_features[site2_name]["site AtomicWeight"]:
                for k in site_feat_labels:
                    if k in delta_properties:
                        bond_features[bond_name][k+"_diff"] = site_features[site1_name][k] - site_features[site2_name][k]
                    bond_features[bond_name][k+"_atom1"] = site_features[site1_name][k]
                    bond_features[bond_name][k+"_atom2"] = site_features[site2_name][k]
            else:
                for k in site_feat_labels:
                    if k in delta_properties:
                        bond_features[bond_name][k+"_diff"] = site_features[site2_name][k] - site_features[site1_name][k]
                    bond_features[bond_name][k+"_atom1"] = site_features[site2_name][k]
                    bond_features[bond_name][k+"_atom2"] = site_features[site1_name][k]
                    
            # Insert bond properties        
            coord_num, bond_len = bond_properties[index][bond_idx]
            bond_features[bond_name]["coordination_number"] = coord_num
            bond_features[bond_name]["bond_length"] = bond_len
            bond_len_sum += bond_len
            
        # Now add each bond's fraction of lattice volume
        for bond_idx in range(len(structures_bonds[index])):
            bond = structures_bonds[index][bond_idx]
            bond_name = "%s_Atom%i_Bond%i" % (index, bond[0], bond_idx)
            _, bond_len = bond_properties[index][bond_idx]
            bond_features[bond_name]["volume_fraction"] = bond_len/bond_len_sum
    
    ### SAVE FILE
    df_features = pd.DataFrame.from_dict(bond_features).T
    if os.path.isfile(saveto):  # Append
        df_features.to_csv(saveto, mode='a', header=False)
    else:  # New file
        df_features.to_csv(saveto)

    return df_features


## Test Featurizing with a single file

In [3]:
featurize_dataset(['supercells_data/15284_super.cif'], saveto='test_feat.csv')

Unnamed: 0,structure_name,structure_path,site Number_atom1,site Number_atom2,site AtomicWeight_diff,site AtomicWeight_atom1,site AtomicWeight_atom2,site Row_atom1,site Row_atom2,site Column_atom1,...,G4_0.005_1.0_1.0_atom2,G4_0.005_1.0_-1.0_atom1,G4_0.005_1.0_-1.0_atom2,G4_0.005_4.0_1.0_atom1,G4_0.005_4.0_1.0_atom2,G4_0.005_4.0_-1.0_atom1,G4_0.005_4.0_-1.0_atom2,coordination_number,bond_length,volume_fraction
15284_super.cif_Atom0_Bond0,15284_super.cif,supercells_data,57.0,8.0,122.90607,138.90547,15.9994,6.0,2.0,3.0,...,7.58149,1.915664,2.622812,4.123305,3.895691,0.084964,0.327758,12,2.757009,0.035407
15284_super.cif_Atom0_Bond1,15284_super.cif,supercells_data,57.0,8.0,122.90607,138.90547,15.9994,6.0,2.0,3.0,...,7.58149,1.915664,2.622812,4.123305,3.895691,0.084964,0.327758,12,2.757009,0.035407
15284_super.cif_Atom0_Bond2,15284_super.cif,supercells_data,57.0,8.0,122.90607,138.90547,15.9994,6.0,2.0,3.0,...,7.58149,1.915664,2.622812,4.123305,3.895691,0.084964,0.327758,12,2.757009,0.035407
15284_super.cif_Atom0_Bond3,15284_super.cif,supercells_data,57.0,8.0,122.90607,138.90547,15.9994,6.0,2.0,3.0,...,7.58149,1.915664,2.622812,4.123305,3.895691,0.084964,0.327758,12,2.757009,0.035407
15284_super.cif_Atom0_Bond4,15284_super.cif,supercells_data,57.0,8.0,122.90607,138.90547,15.9994,6.0,2.0,3.0,...,7.58149,1.915664,2.622812,4.123305,3.895691,0.084964,0.327758,12,2.757009,0.035407
15284_super.cif_Atom0_Bond5,15284_super.cif,supercells_data,57.0,8.0,122.90607,138.90547,15.9994,6.0,2.0,3.0,...,7.58149,1.915664,2.622812,4.123305,3.895691,0.084964,0.327758,12,2.757009,0.035407
15284_super.cif_Atom0_Bond6,15284_super.cif,supercells_data,57.0,8.0,122.90607,138.90547,15.9994,6.0,2.0,3.0,...,7.58149,1.915664,2.622812,4.123305,3.895691,0.084964,0.327758,12,2.757009,0.035407
15284_super.cif_Atom0_Bond7,15284_super.cif,supercells_data,57.0,8.0,122.90607,138.90547,15.9994,6.0,2.0,3.0,...,7.58149,1.915664,2.622812,4.123305,3.895691,0.084964,0.327758,12,2.757009,0.035407
15284_super.cif_Atom0_Bond8,15284_super.cif,supercells_data,57.0,8.0,122.90607,138.90547,15.9994,6.0,2.0,3.0,...,7.58149,1.915664,2.622812,4.123305,3.895691,0.084964,0.327758,12,2.757009,0.035407
15284_super.cif_Atom0_Bond9,15284_super.cif,supercells_data,57.0,8.0,122.90607,138.90547,15.9994,6.0,2.0,3.0,...,7.58149,1.915664,2.622812,4.123305,3.895691,0.084964,0.327758,12,2.757009,0.035407


## Featurize all data in target folder in a batch system to manage memory better
Note: be sure to delete any previous features.csv files as the function will append if the file exists

In [None]:
# Batching files to reduce memory use
BATCH_SIZE = 5

# Load all CIF files in directory
file_type = "_super.cif"  # Use files with this ending in input_dir
input_dir = "supercells_data/"  # Input data directory
output_dir = "features/"  # Output directory
filename = "features.csv"  # Output filename for features

if not os.path.isdir(output_dir):
    os.mkdir(output_dir)
elif os.path.isdir(output_dir+filename):
    os.remove(output_dir+filename)  # Remove existing file

files = os.listdir(input_dir)
cif_files = [input_dir+file for file in files if file.endswith(file_type)]

# Featurize all sturctures
n_batches = int(np.ceil(len(cif_files)/BATCH_SIZE))
if os.path.isfile(filename):  # Clean up any previous runs
    os.remove(filename)
    
print("{} Batches Total: ".format(n_batches))
for b in range(n_batches):
    print("Starting batch ", b)
    # Define which files to 
    idx_start = int(b*BATCH_SIZE)
    idx_end = int(min((b+1)*BATCH_SIZE, len(cif_files)))
    start = timeit.default_timer()
    data_frame = featurize_dataset(cif_files[idx_start:idx_end], saveto=output_dir+filename)
    print("Time elapsed: ", timeit.default_timer() - start)

print("Files processed: ", len(cif_files))

109 Batches Total: 
Starting batch  0




Time elapsed:  20.856205199845135
Starting batch  1




Time elapsed:  47.80316250026226
Starting batch  2




Time elapsed:  59.030629500281066
Starting batch  3




Time elapsed:  67.9972182996571
Starting batch  4




Time elapsed:  56.73128689965233
Starting batch  5




Time elapsed:  6.015937899705023
Starting batch  6




Time elapsed:  145.4959821999073
Starting batch  7




Time elapsed:  37.22448859969154
Starting batch  8




Time elapsed:  29.74734769994393
Starting batch  9




Time elapsed:  1.9083008002489805
Starting batch  10




Time elapsed:  185.4771854998544
Starting batch  11




Time elapsed:  56.13609299995005
Starting batch  12




Time elapsed:  8.526850999798626
Starting batch  13
Time elapsed:  18.775653299875557
Starting batch  14




Time elapsed:  23.64539139997214
Starting batch  15




Time elapsed:  6.346043900121003
Starting batch  16




Time elapsed:  24.374488700181246
Starting batch  17




Time elapsed:  33.59329480025917
Starting batch  18




Time elapsed:  5.493805999867618
Starting batch  19
Time elapsed:  66.1142989997752
Starting batch  20




Time elapsed:  88.58087999979034
Starting batch  21
Time elapsed:  146.08759950008243
Starting batch  22
Time elapsed:  5.398895199876279
Starting batch  23




Time elapsed:  133.94540790002793
Starting batch  24




Time elapsed:  44.863856999669224
Starting batch  25




Time elapsed:  0.4100621002726257
Starting batch  26




Time elapsed:  1.599288999568671
Starting batch  27
Time elapsed:  7.2087146998383105
Starting batch  28




Time elapsed:  69.9626413998194
Starting batch  29
Time elapsed:  0.3762924997135997
Starting batch  30




Time elapsed:  5.541683999821544
Starting batch  31




Time elapsed:  1.940760999917984
Starting batch  32




Time elapsed:  3.7302537001669407
Starting batch  33
Time elapsed:  0.4993479000404477
Starting batch  34




Time elapsed:  2.3459218996576965
Starting batch  35




Time elapsed:  0.5682883998379111
Starting batch  36




Time elapsed:  93.98562220018357
Starting batch  37




Time elapsed:  7.509525400120765
Starting batch  38




Time elapsed:  5.192387399729341
Starting batch  39




Time elapsed:  6.293244400061667
Starting batch  40




Time elapsed:  3.3400213001295924
Starting batch  41




Time elapsed:  9.860672499984503
Starting batch  42




Time elapsed:  3.5662551997229457
Starting batch  43




Time elapsed:  14.059765900019556
Starting batch  44




Time elapsed:  23.603835800196975
Starting batch  45




Time elapsed:  3.2147532999515533
Starting batch  46




Time elapsed:  10.556929100304842
Starting batch  47




Time elapsed:  2.508479599840939
Starting batch  48
Time elapsed:  0.63538320036605
Starting batch  49




Time elapsed:  10.749301200266927
Starting batch  50




Time elapsed:  6.564462899696082
Starting batch  51




Time elapsed:  16.131281400099397
Starting batch  52




Time elapsed:  148.86897530034184
Starting batch  53




Time elapsed:  50.85999820008874
Starting batch  54




Time elapsed:  22.025266599841416
Starting batch  55




Time elapsed:  66.78832860011607
Starting batch  56




Time elapsed:  8.56251489976421
Starting batch  57




Time elapsed:  16.320733800064772
Starting batch  58




Time elapsed:  35.33846260001883
Starting batch  59
Time elapsed:  19.85931109962985
Starting batch  60




Time elapsed:  37.710362900048494
Starting batch  61




Time elapsed:  11.647609300445765
Starting batch  62




Time elapsed:  30.26305399974808
Starting batch  63




Time elapsed:  14.046133500058204
Starting batch  64




Time elapsed:  5729.689133600332
Starting batch  65




In [None]:
   ## 1. Bag of Bonds
"""
print("bag of bonds")
BB = BagofBonds()
for index, row in data.iterrows():
    structure = row["structure"]
    if verbos:
        print(index)
    BB.fit([structure])
    feat = BB.bag(structure)
    print(feat)
    site = list(feat.keys())
    print(site[0])
    print(structure[site])

print("bond fraction")
BF = BondFractions()
for index, row in data.iterrows():
    structure = row["structure"]
    if verbos:
        print(index)
    feat = BF.fit_transform([structure])
    #print(feat)
    #print(BF.feature_labels())
"""
## 5. site difference stats 
"""
print("LPD")
LPD = LocalPropertyStatsNew(properties=property_list)
colnames = LPD._generate_column_labels(multiindex=False, return_errors=False)
for index, row in data.iterrows():
    structure = row["structure"]
    if verbos:
        print(index)
    for atomidx in range(structure.num_sites):
        feat = LPD.featurize(structure, idx=atomidx)
        site_name = "%s_%i" % (index, atomidx)
        site_features[site_name].update(dict(zip(colnames, feat)))
"""
    