In [None]:
from os.path import expanduser
home = expanduser("~")
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
# from workflow_utils import exec_parallel, grouper, uniquify, flatten
from scipy.stats import gaussian_kde
from scipy.integrate import quad
from matplotlib import cm
from pymatgen.core.composition import Composition
from pymatgen.core.periodic_table import Species, Element
from adjustText import adjust_text


### CONSTANTS ####################################################################
##################################################################################
# ROOT_LABEL = 'cifs_mp_n2-5'
ROOT_LABEL = 'cifs_icsd_ionic_compounds'
MOD_LABEL = '_maxSites100'
DATASET_LABEL = f"{ROOT_LABEL}{MOD_LABEL}"
PATH_PREFIX = f'{home}/Projects/storage/datasets'
CIF_LOC = f'{PATH_PREFIX}/{ROOT_LABEL}/2_succ_rewrite_aflow{MOD_LABEL}'
COMP_LOC = f'{PATH_PREFIX}/comparison-output_{DATASET_LABEL}_match1.0'
FIG_LOC = f'../figures/{DATASET_LABEL}_match1.0'

if ROOT_LABEL == 'cifs_mp_n2-5':
    IN_STRUC = f"{COMP_LOC}/aug_data_struc_4_filtered.json"
    IN_PROTO = f"{COMP_LOC}/aug_data_proto_4_filtered.json"
if ROOT_LABEL == 'cifs_icsd_ionic_compounds':
    IN_STRUC = f"{COMP_LOC}/aug_data_struc_4_filtered.json"
    IN_PROTO = f"{COMP_LOC}/aug_data_proto_4_filtered.json"

In [None]:
### Parse data

struc_df = pd.read_json(IN_STRUC)
proto_df = pd.read_json(IN_PROTO)
print(struc_df.columns.to_numpy())
# print(proto_df.columns.to_numpy())

struc_df['formula'] = [fname.split('_')[0] for fname in struc_df.fname]
struc_df['species'] = [[Species(sym, oxi) for sym, oxi in zip(symbols, oxi_states)] for symbols, oxi_states in zip(struc_df.symbols, struc_df.oxi_states)]
if ROOT_LABEL == 'cifs_icsd_ionic_compounds': struc_df['theoretical'] = [False]*len(struc_df)

In [None]:
### Extract entries in the protos of interest

cols_to_use = ['formula','ident','label'] + \
    [col for col in struc_df.columns if col[:5] == 'type0'] +\
    [col for col in struc_df.columns if col[:5] == 'type1']

select = {'Wurtzite':['AB_hP4_186_b_b'],
          'NiAs':    ['AB_hP4_194_c_a', 'AB_hP4_194_a_c'],
          'Zinc Blende': ['AB_cF8_216_c_a', 'AB_cF8_216_a_c'],
          'Rock Salt':['AB_cF8_225_a_b', 'AB_cF8_225_b_a'],
          'CsCl':['AB_cP2_221_b_a','AB_cP2_221_a_b']
         }

select_names = sorted(select.keys())

select_strucs = {}
for name in select_names:
    subdfs = []
    for proto in select[name]:
        for struc_rep in proto_df[proto_df.proto == proto].struc_rep:
            subdf = struc_df[(struc_df.struc_rep == struc_rep) & (struc_df.theoretical==False)].copy()
            if len(subdf)>0:
                subdf['label'] = [name]*len(subdf)
                subdf = subdf[cols_to_use]
                subdfs.append(subdf)
    select_strucs[name] = pd.concat(subdfs)
    
### Check for duplicates
for proto in select_names:
    subdf = select_strucs[proto]
    formulae, counts = np.unique(subdf.formula, return_counts=True)
    for form, ct in zip(formulae, counts):
        if ct > 1:
            print(proto, form)
            
### Examine example
select_strucs['CsCl']

In [None]:
len(subdfs)

In [None]:
### Collate dataset and write
df = pd.concat([select_strucs[name] for name in select_names])
df.to_json('AB_compounds.json')