In [10]:
from pathlib import Path
import os
import sys
folder = Path(os.path.abspath(''))
if str(folder) not in sys.path:
    sys.path.insert(0, str(folder))
lattice_folder = folder.joinpath('CSPML/CSPML_latest_codes')
if str(lattice_folder) not in sys.path:
    sys.path.insert(1, str(lattice_folder))
from CSPML.CSPML_latest_codes.KmdPlus import StatsDescriptor, formula_to_composition 
import pandas as pd
import numpy as np
from pymatgen.core.composition import Composition
from mp_api.client import MPRester
from chgnet.model.model import CHGNet
from chgnet.model.dynamics import MolecularDynamics
from pymatgen.core import Structure
import warnings
warnings.filterwarnings("ignore", module="pymatgen")
warnings.filterwarnings("ignore", module="ase")
import pickle
import math
from scipy.spatial import distance_matrix
import shutil
import tensorflow as tf
from tensorflow.keras.utils import to_categorical

In [7]:
# Read all templates.
MP_stable = pd.read_pickle("CSPML/CSPML_latest_codes/data_set/MP_stable_20211107.pd.xz")

# Element-level descriptors of shape (94, 58).
element_features = pd.read_csv("CSPML/CSPML_latest_codes/data_set/element_features.csv", index_col= 0)

# Load test data (90 crystals).
test_data = pd.read_pickle("CSPML/CSPML_latest_codes/data_set/all_searching_targets_20211107_with_predictions.pd.xz")

# Load the pre-trained models.
with open("CSPML/CSPML_latest_codes/data_set/CSPML_models.xz", "rb") as f:
    models = pickle.load(f)
    
# Load stats.
cmpfgp_stable_meanstd = np.load("CSPML/CSPML_latest_codes/data_set/cmpfgp_stable_meanstd_20211107.npy") 

# Load element dissimilarity.
element_dissimilarity = np.load('CSPML/CSPML_latest_codes/data_set/element_dissimilarity.npy')

In [8]:
# ensemble models.
def ensemble_models(X, models):
    y_pred = np.array([models[i].predict(X, verbose=0)[:,1] for i in range(len(models))])
    
    return y_pred.mean(0)

# Formula to ratio label.
def formula_to_ratiolabel(formula):
    # Convert chemical formulas to compositions.
    weight = np.array([formula_to_composition(formula[i]) for i in range(len(formula))])
    
    ratio_label = []
    for i in range(len(formula)):
        sorted_weight = np.sort(weight[i])[::-1]
        comp = Composition(formula[i])
        comp_ratio = comp.num_atoms * sorted_weight
        x = [int(round(comp_ratio[j])) for j in range(len(comp))]
        gcd_x = math.gcd(*x) 
        # For collection in the case like "O2", "Na2O2".
        if gcd_x != 1:
            x = [int(round(x[k]/gcd_x)) for k in range(len(x))]
        else:
            pass
        # Get ratio label for collected x.
        label = ""
        for j in range(len(x)):
            label += f"{x[j]}:"
        # Save results.
        ratio_label.append(label[:-1])
    
    return np.array(ratio_label, dtype = "object")

# Screening for CSPML.
def Screening_candidates(query_formula, top_K, templates, cutoff = 0.5, element_features = element_features,
                        meanstd = cmpfgp_stable_meanstd, models = models):
    
    # Calculate cmpfgp for quary formula.
    query_cmpfgp = StatsDescriptor(query_formula, element_features)
    query_cmpfgp = (query_cmpfgp - meanstd[0])/meanstd[1] # scaling.
    
    # Calculate ratio label.
    query_ratiolabel = formula_to_ratiolabel(query_formula)
    
    # Make predictions.
    predictions = []

    for i in range(len(query_formula)):
        ix = np.where(templates.comp_ratio_label.values == query_ratiolabel[i])[0]

        if len(ix) < 1:
            print(f"None of the candidates had the same composition ratio as {query_formula[i]}.")
            topK_id, topK_pred, topK_formula = [], 0, []
        else:
            x = templates.iloc[ix]
            # Composition fingerprint for x.
            x_cmpfgp = x.cmpfgp.values
            x_cmpfgp = np.array([x_cmpfgp[i] for i in range(len(x_cmpfgp))])
            X = np.abs(x_cmpfgp - query_cmpfgp[i,:])
            y_pred = ensemble_models(X, models)
            topK_ix = np.argsort(y_pred)[::-1][:top_K]

            topK_id = x.materials_id.values[topK_ix]
            topK_pred = y_pred[topK_ix]
            topK_formula = x.pretty_formula.values[topK_ix]

            survived = (topK_pred > cutoff)

            if sum(survived) < 1:
                    print(f"None of the candidates had the class probabilities greater than {cutoff} at {query_formula[i]}.")
                    topK_id, topK_pred, topK_formula = [], 0, []
            else:
                topK_id, topK_pred, topK_formula = topK_id[survived], topK_pred[survived], topK_formula[survived]


        prediction_result = {"query_formula":query_formula[i],"topK_formula":topK_formula,
                                 "topK_id":topK_id,"topK_pred":topK_pred}
        predictions.append(prediction_result)
        
    return predictions

# CSPML.
def Structure_prediction(query_formula, top_K, templates, cutoff = 0.5, element_features = element_features,
            meanstd = cmpfgp_stable_meanstd, models = models, element_dissimilarity = element_dissimilarity,
            SI = False, save_cif = False, save_cif_filename = ""):
    
    # Screening top_K candidates using pre-trained model for each query formula.
    screened = Screening_candidates(query_formula, top_K, templates, cutoff, element_features,
                        meanstd, models)
    
    element_symbol = element_features.index.values
    predictions = []

    for i in range(len(query_formula)):

        predicted_structures = []
        scr_num = len(screened[i]["topK_id"])

        if scr_num == 0:
            pass

        else:
            for j in range(scr_num):

                # The ith query formula.
                vec = formula_to_composition(query_formula[i])
                N_ele = sum(vec != 0)
                comp_index = np.argsort(vec)[::-1][:N_ele]

                # Top-jth suggested formula for ith query formula.
                sug_formula = screened[i]['topK_formula'][j]
                vec_sug = formula_to_composition(sug_formula)
                comp_sug_index = np.argsort(vec_sug)[::-1][:N_ele]

                # Composition of ith fomula (quary & suggested) and it's unique composition ratio.
                comp = np.sort(vec)[::-1][:N_ele]
                keys = np.sort(list(set(comp)))[::-1]

                # Grouping composition-index(=element species) according to unique composition ratio.
                group_index = []
                group_sug_index = []
                for k in range(0, len(keys)):
                    x = (comp == keys[k])
                    group_index.append(comp_index[x])
                    group_sug_index.append(comp_sug_index[x])

                # Find out elements-replacement that minimize element-dissimilarity and make dict showing replacement.
                replacement = []
                for l in range(0, len(keys)):
                    # Replacement is unique.
                    if len(group_index[l]) == 1:
                        replacement.append(group_sug_index[l])
                    # Replacement is not unique.
                    else :
                        seq = group_sug_index[l]
                        pmt = list(itertools.permutations(seq))
                        K = len(pmt)
                        dis_sum = np.zeros(K)
                        for m in range(0, K):
                            dis_sum[m] = sum(element_dissimilarity[group_index[l], pmt[m]]) # element_dissimilarity.
                        replacement.append(np.array(pmt[np.argmin(dis_sum)]))
                rep_index = np.concatenate(replacement)
                q_ele = element_symbol[comp_index]
                rep_ele = element_symbol[rep_index]
                rep_dict = dict(zip(rep_ele,q_ele))

                # Generating top-jth candidate structure for ith query formula.
                query_str = copy.deepcopy(templates[templates.materials_id.values == screened[i]["topK_id"][j]].structure[0])
                query_str.replace_species(rep_dict)
                predicted_structures.append(query_str)

                # Save the structure object as a .cif file into dir = filename (if save_cif=True).
                if save_cif:
                    text =  f"{save_cif_filename}/{query_formula[i]}_{j+1}.cif"
                    query_str.to(filename=text)
                else:
                    pass

        predictions.append(predicted_structures)

    # Return the predicted structures (+ optionally the supplementary information of the predicted structures).
    if SI:
        return predictions, screened
    else:
        return predictions

In [None]:
def predict_structure(formulas, templates, workdir):
    predictions, screened = Structure_prediction(formulas, 5, templates,
                                                 SI = True, save_cif=True, save_cif_filename=workdir)
    chgnet = CHGNet.load()

In [None]:
# Create a directory for saving results (results should be same as cif_files_for_90crystals/predicted_structures (pre DFT)).
new_dir = "CSPML_test90"
if os.path.exists(new_dir):
    shutil.rmtree(new_dir)
os.mkdir(new_dir)

# Make CSPML prediction.
predictions, screened = Structure_prediction(test_data.pretty_formula.values, 10, MP_stable,
                    SI = True, save_cif=True, save_cif_filename=new_dir)