In [1]:
import itertools
import collections
import random
import pickle

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import tqdm
import pymatgen.core
import pymatgen.io.ase
import pymatgen.analysis.dimensionality
from pymatgen.analysis.local_env import JmolNN


RANDOM_SEED = 1234

pd.options.mode.chained_assignment = None
tqdm.tqdm.pandas()
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

In [2]:
raw_df = pd.read_pickle("../DigitalEcosystem/raw_data/2d_mat_dataset_raw.pkl")
cols_to_keep = ["bandgap (eV)", "atoms_object (unitless)"]
df = raw_df[cols_to_keep]

In [3]:
df["pymatgen_structure (unitless)"] = df["atoms_object (unitless)"].apply(pymatgen.io.ase.AseAtomsAdaptor.get_structure)

In [4]:
symbols_cols = collections.Counter()
bond_cols = collections.Counter()
angle_cols = collections.Counter()

neighbor_finder = JmolNN()

with tqdm.tqdm(total=len(df)) as pbar:
    for struct in df["pymatgen_structure (unitless)"]:
        symbols_cols.update(struct.symbol_set)
        
        for index, site in enumerate(struct.sites):
            connected = [i['site'] for i in neighbor_finder.get_nn_shell_info(struct, index, 1)]
            
            # Bond counts
            for vertex in connected:
                start, end = sorted([site.specie, vertex.specie])
                bond = f"{start}-{end}"
                bond_cols[bond] += 0.5
                
            # Angles
            for angle_start, angle_end in map(sorted, itertools.combinations(connected,2)):
                angle = f"{angle_start.specie}-{site.specie}-{angle_end.specie}"
                angle_cols[angle] += 1
        pbar.update(1)

100%|██████████| 6351/6351 [15:06:58<00:00,  8.57s/it]       


In [5]:
for filename, obj in (("symbols.pkl", symbols_cols),
                      ("bonds.pkl", bond_cols),
                      ("angles.pkl", angle_cols)):
    with open(filename, "wb") as outp:
        pickle.dump(obj, outp)

In [6]:
def testmap(foo):
    foo[random.choice(["A", "B"])] = 1
    return foo
test = df.head().copy()
test.apply(testmap, axis=1)

Unnamed: 0,A,B,atoms_object (unitless),bandgap (eV),pymatgen_structure (unitless)
0,,1.0,"(Atom('Ir', [0.0, 0.0, 0.0], index=0), Atom('F...",0.0,"[[0. 0. 0.] Ir, [ 1.75964097 1.01592735 22.08..."
1,1.0,,"(Atom('Ba', [2.476683476681, 1.429910903420999...",0.0,"[[ 2.47668348 1.4299109 19.09028155] Ba, [-2..."
2,1.0,,"(Atom('Tl', [2.63896615613751, 10.292177253854...",0.9814,"[[ 2.63896616 10.29217725 11.08346956] Tl, [3...."
3,1.0,,"(Atom('Mo', [1.5833675, 2.687975714894, 2.6388...",0.0,"[[1.5833675 2.68797571 2.63881737] Mo, [ 0.00..."
4,1.0,,"(Atom('Ir', [0.0, 0.0, 0.0], index=0), Atom('O...",0.0,"[[0. 0. 0.] Ir, [-1.24446100e-06 1.82188091e+..."


In [7]:
all_symbols = set(symbols_cols.keys())
all_bonds = set(bond_cols.keys())
all_angles = set(angle_cols.keys())

def featurize(data):
    symbol_units = "atoms"
    bond_units = "bonds"
    angle_units = "angles"
    struct = data["pymatgen_structure (unitless)"]
    
    present_symbols = collections.Counter(struct.symbol_set)
    present_bonds = collections.Counter()
    present_angles = collections.Counter()
    
    # Record and Count Symbols
    for symbol, count in present_symbols.items():
        data[f"{symbol} ({symbol_units})"] = count
    data[f"Total Atoms ({symbol_units})"] = sum(present_symbols.values())
    
    for index, site in enumerate(struct.sites):
        connected = [i['site'] for i in neighbor_finder.get_nn_shell_info(struct, index, 1)]
        
        # Count Bonds
        for vertex in connected:
            start, end = sorted([site.specie, vertex.specie])
            bond = f"{start}-{end}"
            present_bonds[bond] += 0.5
            
        # Count Angles
        for angle_start, angle_end in map(sorted, itertools.combinations(connected, 2)):
            angle = f"{angle_start.specie}-{site.specie}-{angle_end.specie}"
            present_angles[angle] += 1
            
    # Record Bonds
    for bond, count in present_bonds.items():
        data[f"{bond} ({bond_units})"] = count
    data[f"Total Bonds ({bond_units})"] = sum(present_bonds.values())
            
    # Record Angles
    for angle, count in present_angles.items():
        data[f"{angle} ({angle_units})"] = count
    data[f"Total Angles ({angle_units})"] = sum(present_angles.values())
    
    return data

all_data_features = df.progress_apply(featurize, axis=1).fillna(0)
all_data_features

100%|██████████| 6351/6351 [41:35<00:00,  2.54it/s]  


Unnamed: 0,Ac (atoms),Ac-Se (bonds),Ac-Se-Ac (angles),Ag (atoms),Ag-Ag (bonds),Ag-Ag-Ag (angles),Ag-Ag-Bi (angles),Ag-Ag-Br (angles),Ag-Ag-C (angles),Ag-Ag-Cl (angles),...,Zr-Zr-Sb (angles),Zr-Zr-Se (angles),Zr-Zr-Si (angles),Zr-Zr-Sn (angles),Zr-Zr-Te (angles),Zr-Zr-Zn (angles),Zr-Zr-Zr (angles),atoms_object (unitless),bandgap (eV),pymatgen_structure (unitless)
0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,"(Atom('Ir', [0.0, 0.0, 0.0], index=0), Atom('F...",0.0000,"[[0. 0. 0.] Ir, [ 1.75964097 1.01592735 22.08..."
1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,"(Atom('Ba', [2.476683476681, 1.429910903420999...",0.0000,"[[ 2.47668348 1.4299109 19.09028155] Ba, [-2..."
2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,"(Atom('Tl', [2.63896615613751, 10.292177253854...",0.9814,"[[ 2.63896616 10.29217725 11.08346956] Tl, [3...."
3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,"(Atom('Mo', [1.5833675, 2.687975714894, 2.6388...",0.0000,"[[1.5833675 2.68797571 2.63881737] Mo, [ 0.00..."
4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,"(Atom('Ir', [0.0, 0.0, 0.0], index=0), Atom('O...",0.0000,"[[0. 0. 0.] Ir, [-1.24446100e-06 1.82188091e+..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6346,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,"(Atom('Bi', [6.537243774895211, 2.363799214670...",0.1990,"[[ 6.53724377 2.36379921 10.57502831] Bi, [ 2..."
6347,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,"(Atom('Pt', [1.7596322562971602, 1.01592211741...",0.0000,"[[ 1.75963226 1.01592212 21.08758921] Pt, [-1..."
6348,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,"(Atom('Pt', [1.84215348881831, 1.0635655531000...",0.0000,"[[ 1.84215349 1.06356555 21.02069095] Pt, [-1..."
6349,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,"(Atom('Pt', [1.8200943700922698, 1.05082975916...",0.0000,"[[ 1.82009437 1.05082976 20.73570262] Pt, [-1..."


In [9]:
all_data_features.to_pickle("all_data_features.pkl")