In [77]:
import os
import glob
import pickle

import pandas as pd
import numpy as np

from pymatgen.io.cif import CifParser

In [2]:
with open("datasets/space_group.pkl", "rb") as f:
    space_group = pickle.load(f)

In [95]:
cif_list = glob.glob("datasets/cif/*.cif")

def to_metadata(cif:str)->list[str,str,str]:
    sample = cif.replace("datasets/cif\\", "").split("_")[:2]
    sample.reverse()
    sample.append(cif.replace("datasets/cif\\", "datasets/cif/"))
    return sample

cif_meta_df = list(map(to_metadata, cif_list))
cif_meta_df = sorted(cif_meta_df, key=lambda x: (x[0].split("-")[0] ,int(x[0].split("-")[1])))
cif_meta_df = pd.DataFrame(cif_meta_df, columns=["material_id", "formula_pretty", "cif_path"])
cif_meta_df

Unnamed: 0,material_id,formula_pretty,cif_path
0,mp-1,Cs,datasets/cif/Cs_mp-1_computed.cif
1,mp-2,Pd,datasets/cif/Pd_mp-2_computed.cif
2,mp-3,Cs,datasets/cif/Cs_mp-3_computed.cif
3,mp-4,Nd,datasets/cif/Nd_mp-4_computed.cif
4,mp-7,S,datasets/cif/S_mp-7_computed.cif
...,...,...,...
124186,mvc-16821,CaCr2O4,datasets/cif/CaCr2O4_mvc-16821_computed.cif
124187,mvc-16832,V2ZnO4,datasets/cif/V2ZnO4_mvc-16832_computed.cif
124188,mvc-16833,AlV2O4,datasets/cif/AlV2O4_mvc-16833_computed.cif
124189,mvc-16834,MgMn2O4,datasets/cif/MgMn2O4_mvc-16834_computed.cif


In [4]:
with open(r"C:\Users\alien\env\ml\datasets\mat_full_data_2022-08-14.pkl", "rb") as f:
    data = pickle.load(f)

In [38]:
col = ['material_id','formula_pretty','nsites','nelements','volume','density',"cbm","vbm",'density_atomic','uncorrected_energy_per_atom','energy_per_atom','formation_energy_per_atom','energy_above_hull','is_stable','equilibrium_reaction_energy_per_atom','band_gap','efermi','is_gap_direct','is_metal','is_magnetic','total_magnetization','total_magnetization_normalized_vol','total_magnetization_normalized_formula_units','num_magnetic_sites','num_unique_magnetic_sites','k_voigt','k_reuss','k_vrh','g_voigt','g_reuss','g_vrh','universal_anisotropy','homogeneous_poisson','weighted_surface_energy_EV_PER_ANG2','weighted_surface_energy','weighted_work_function','surface_anisotropy','e_total','e_ionic','e_electronic','e_ij_max','shape_factor']

In [41]:
data_df = []
for i, (key, value) in enumerate(data.items()):
    summary = value["summary"]
    if isinstance(summary, str):
        continue
    summary = summary.dict()
    data_df.append(list(map(lambda x: summary.get(x, None), col)))

data_df = sorted(data_df, key=lambda x: (x[0].split("-")[0] ,int(x[0].split("-")[1])))
data_df = pd.DataFrame(data_df, columns=col)
data_df

Unnamed: 0,material_id,formula_pretty,nsites,nelements,volume,density,cbm,vbm,density_atomic,uncorrected_energy_per_atom,...,homogeneous_poisson,weighted_surface_energy_EV_PER_ANG2,weighted_surface_energy,weighted_work_function,surface_anisotropy,e_total,e_ionic,e_electronic,e_ij_max,shape_factor
0,mp-1,Cs,1,1,114.051805,1.935039,,,114.051805,-0.856633,...,0.381721,0.003839,0.061508,2.033364,0.053543,,,,,5.202531
1,mp-2,Pd,1,1,15.490302,11.408077,,,15.490302,-5.179882,...,0.377313,0.090531,1.450470,5.062258,0.056446,,,,,5.054431
2,mp-3,Cs,2,1,246.245437,1.792477,,,123.122718,-0.799005,...,1.822111,,,,,,,,,
3,mp-4,Nd,1,1,35.307869,6.783742,,,35.307869,-4.628184,...,0.368649,,,,,,,,,
4,mp-7,S,6,1,189.199136,1.688544,2.8851,0.3738,31.533189,-4.071949,...,0.204456,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
85625,mp-1405450,Ca5(CrN3)2,13,3,214.655613,3.004769,3.7774,2.9937,16.511970,-6.933645,...,,,,,,16.407425,9.995456,6.411969,0.968516,
85626,mp-1408285,TiO2,12,2,149.481845,3.548800,3.8415,1.3954,12.456820,-8.793056,...,,,,,,19.633246,14.797667,4.835578,0.992636,
85627,mp-1443513,SbF4,10,2,140.922341,4.660405,3.3434,0.8231,14.092234,-4.462696,...,,,,,,101.706192,98.726060,2.980132,8.488323,
85628,mp-1443834,Ca2TaWO6,10,4,130.201792,6.898926,6.3635,4.1046,13.020179,-8.002111,...,,,,,,21.644732,16.768746,4.875986,2.674938,


In [67]:
data_df.merge(cif_meta_df[cif_meta_df["material_id"].str.contains("mp-")], on=["material_id", "formula_pretty"])

Unnamed: 0,material_id,formula_pretty,nsites,nelements,volume,density,cbm,vbm,density_atomic,uncorrected_energy_per_atom,...,weighted_surface_energy_EV_PER_ANG2,weighted_surface_energy,weighted_work_function,surface_anisotropy,e_total,e_ionic,e_electronic,e_ij_max,shape_factor,cif_path
0,mp-1,Cs,1,1,114.051805,1.935039,,,114.051805,-0.856633,...,0.003839,0.061508,2.033364,0.053543,,,,,5.202531,datasets/cif/Cs_mp-1_computed.cif
1,mp-2,Pd,1,1,15.490302,11.408077,,,15.490302,-5.179882,...,0.090531,1.450470,5.062258,0.056446,,,,,5.054431,datasets/cif/Pd_mp-2_computed.cif
2,mp-3,Cs,2,1,246.245437,1.792477,,,123.122718,-0.799005,...,,,,,,,,,,datasets/cif/Cs_mp-3_computed.cif
3,mp-4,Nd,1,1,35.307869,6.783742,,,35.307869,-4.628184,...,,,,,,,,,,datasets/cif/Nd_mp-4_computed.cif
4,mp-7,S,6,1,189.199136,1.688544,2.8851,0.3738,31.533189,-4.071949,...,,,,,,,,,,datasets/cif/S_mp-7_computed.cif
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
81591,mp-1287570,CaMn2O4,14,3,180.925463,3.927309,3.7481,2.5192,12.923247,-7.272676,...,,,,,,,,,,datasets/cif/CaMn2O4_mp-1287570_computed.cif
81592,mp-1291802,Zn2FeWO6,20,4,223.000097,6.947445,4.6971,2.4363,11.150005,-6.130564,...,,,,,,,,,,datasets/cif/Zn2FeWO6_mp-1291802_computed.cif
81593,mp-1296431,Ca2VWO6,20,4,245.829211,5.551592,5.6186,5.6500,12.291461,-7.361888,...,,,,,,,,,,datasets/cif/Ca2VWO6_mp-1296431_computed.cif
81594,mp-1306369,Ca(NiO2)2,14,3,164.369613,4.474634,3.5520,2.9515,11.740687,-5.246323,...,,,,,,,,,,datasets/cif/Ca(NiO2)2_mp-1306369_computed.cif


In [71]:
data_df[data_df["formula_pretty"]=="SbO2"]

Unnamed: 0,material_id,formula_pretty,nsites,nelements,volume,density,cbm,vbm,density_atomic,uncorrected_energy_per_atom,...,homogeneous_poisson,weighted_surface_energy_EV_PER_ANG2,weighted_surface_energy,weighted_work_function,surface_anisotropy,e_total,e_ionic,e_electronic,e_ij_max,shape_factor
174,mp-230,SbO2,24,2,325.188676,6.281215,7.3967,5.507,13.549528,-6.054466,...,0.251782,,,,,19.61389,14.232513,5.381377,0.690277,
1327,mp-1819,SbO2,12,2,160.882551,6.348047,7.5946,5.5658,13.406879,-6.05253,...,,,,,,20.902202,15.562884,5.339318,,
7904,mp-13866,SbO2,24,2,296.579139,6.887133,8.325,7.891,12.357464,-5.71466,...,,,,,,,,,,
12335,mp-22071,SbO2,24,2,291.612145,7.004441,,,12.150506,-5.59458,...,,,,,,,,,,
58238,mp-1041974,SbO2,6,2,105.999028,4.81745,,,17.666505,-5.653786,...,,,,,,,,,,
59067,mp-1044574,SbO2,24,2,477.478952,4.277843,3.9305,1.6334,19.894956,-5.9638,...,,,,,,,,,,
59107,mp-1044716,SbO2,12,2,227.367813,4.491797,,,18.947318,-5.738578,...,,,,,,,,,,
59303,mp-1045682,SbO2,12,2,176.814879,5.776041,5.9618,5.1251,14.734573,-5.917327,...,,,,,,,,,,
59763,mp-1047300,SbO2,12,2,196.112665,5.20767,,,16.342722,-5.774261,...,,,,,,,,,,


In [68]:
data_df.merge(cif_meta_df[~cif_meta_df["material_id"].str.contains("mp-")], on=["formula_pretty"])

Unnamed: 0,material_id_x,formula_pretty,nsites,nelements,volume,density,cbm,vbm,density_atomic,uncorrected_energy_per_atom,...,weighted_surface_energy,weighted_work_function,surface_anisotropy,e_total,e_ionic,e_electronic,e_ij_max,shape_factor,material_id_y,cif_path
0,mp-230,SbO2,24,2,325.188676,6.281215,7.3967,5.5070,13.549528,-6.054466,...,,,,19.613890,14.232513,5.381377,0.690277,,mvc-6033,datasets/cif/SbO2_mvc-6033_computed.cif
1,mp-230,SbO2,24,2,325.188676,6.281215,7.3967,5.5070,13.549528,-6.054466,...,,,,19.613890,14.232513,5.381377,0.690277,,mvc-6570,datasets/cif/SbO2_mvc-6570_computed.cif
2,mp-230,SbO2,24,2,325.188676,6.281215,7.3967,5.5070,13.549528,-6.054466,...,,,,19.613890,14.232513,5.381377,0.690277,,mvc-6936,datasets/cif/SbO2_mvc-6936_computed.cif
3,mp-230,SbO2,24,2,325.188676,6.281215,7.3967,5.5070,13.549528,-6.054466,...,,,,19.613890,14.232513,5.381377,0.690277,,mvc-9477,datasets/cif/SbO2_mvc-9477_computed.cif
4,mp-230,SbO2,24,2,325.188676,6.281215,7.3967,5.5070,13.549528,-6.054466,...,,,,19.613890,14.232513,5.381377,0.690277,,mvc-9642,datasets/cif/SbO2_mvc-9642_computed.cif
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5690,mp-1397292,MoWO6,8,3,116.798353,5.342467,1.7570,0.3231,14.599794,-7.250113,...,,,,39.787059,33.506856,6.280203,10.917787,,mvc-5693,datasets/cif/MoWO6_mvc-5693_computed.cif
5691,mp-1402840,WF4,10,2,164.862637,5.234223,0.9645,-0.2005,16.486264,-5.491438,...,,,,5.414787,3.155525,2.259262,0.207780,,mvc-14582,datasets/cif/WF4_mvc-14582_computed.cif
5692,mp-1405450,Ca5(CrN3)2,13,3,214.655613,3.004769,3.7774,2.9937,16.511970,-6.933645,...,,,,16.407425,9.995456,6.411969,0.968516,,mvc-11129,datasets/cif/Ca5(CrN3)2_mvc-11129_computed.cif
5693,mp-1443834,Ca2TaWO6,10,4,130.201792,6.898926,6.3635,4.1046,13.020179,-8.002111,...,,,,21.644732,16.768746,4.875986,2.674938,,mvc-5962,datasets/cif/Ca2TaWO6_mvc-5962_computed.cif


In [195]:
from pymatgen.core.periodic_table import Element
from pymatgen.core.composition import Composition
from pymatgen.core.molecular_orbitals import MolecularOrbitals

import torch
from torch_geometric.data import Data

def count_elements(compound):
    element_counts = {str(el): 0 for el in Element}
    composition = Composition(compound)
    for element, count in composition.items():
        element_counts[str(element)] += int(count)
    return list(element_counts.values())

def get_orbital_energy(elem):
    orbital = ['1s', '2s', '2p', '3s', '3p', '3d', '4s', '4p', '4d', '4f', '5s', '5p', '5d', '6s', '6p', '5f', '6d', '7s']
    orbital = {k:0.0 for k in orbital}
    aos = MolecularOrbitals(elem).aos_as_list()
    for orb in aos:
        orbital[orb[1]] = orb[2]
    return list(orbital.values())

In [196]:
path = cif_meta_df.iloc[74440, 2]
parser = CifParser(path)
sites = parser.get_structures()[0].as_dict()['sites']

x = []
edges_idx = [[],[]]
edges_val = []
pos = []

for site in sites:
    specie = site['species'][0]['element']
    abc = site['abc']
    xyz = site['xyz']
    node_val = [*count_elements(specie), *get_orbital_energy(specie), *abc]
    node_pos = xyz
    x.append(node_val)
    pos.append(node_pos)

for i in range(len(x)):
    for j in range(len(x)):
        if i != j:
            edges_idx[0].append(i)
            edges_idx[1].append(j)
            edges_val.append(np.linalg.norm(np.array(pos[i])-np.array(pos[j])))

x = torch.tensor(x, dtype=torch.float)
edges_idx = torch.tensor(edges_idx, dtype=torch.long)
edges_val = torch.tensor(edges_val, dtype=torch.float)
pos = torch.tensor(pos, dtype=torch.float)

graph = Data(x=x, edge_index=edges_idx, edge_attr=edges_val, pos=pos)
graph

Data(x=[44, 139], edge_index=[2, 1892], edge_attr=[1892], pos=[44, 3])

In [198]:
import networkx as nx

G = nx.Graph()
G.add_nodes_from(range(graph.num_nodes))
edges = graph.edge_index.t().tolist()
G.add_edges_from(edges)

In [None]:
from matplotlib import pyplot as plt

G_pos = {i:p.numpy() for i, p in enumerate(graph.pos)}

fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
node_xyz = np.array([pos[v] for v in sorted(G)])
edge_xyz = np.array([(pos[u], pos[v]) for u, v in G.edges()])
ax.scatter(*node_xyz.T, s=100, ec="w")
for vizedge in edge_xyz:
    ax.plot(*vizedge.T, color="tab:gray", alpha=0.5)

def _format_axes(ax):
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_zlabel("z")

_format_axes(ax)
fig.tight_layout()
plt.show()

In [218]:
cif_meta_df[cif_meta_df["material_id"].str.contains("mvc")]

Unnamed: 0,material_id,formula_pretty,cif_path
121745,mvc-2,Nb2Zn2BiO8,datasets/cif/Nb2Zn2BiO8_mvc-2_computed.cif
121746,mvc-3,Nb2Zn2SbO8,datasets/cif/Nb2Zn2SbO8_mvc-3_computed.cif
121747,mvc-4,Nb2Zn2CuO8,datasets/cif/Nb2Zn2CuO8_mvc-4_computed.cif
121748,mvc-6,Nb2Zn2TeO8,datasets/cif/Nb2Zn2TeO8_mvc-6_computed.cif
121749,mvc-8,Nb2Zn2SnO8,datasets/cif/Nb2Zn2SnO8_mvc-8_computed.cif
...,...,...,...
124186,mvc-16821,CaCr2O4,datasets/cif/CaCr2O4_mvc-16821_computed.cif
124187,mvc-16832,V2ZnO4,datasets/cif/V2ZnO4_mvc-16832_computed.cif
124188,mvc-16833,AlV2O4,datasets/cif/AlV2O4_mvc-16833_computed.cif
124189,mvc-16834,MgMn2O4,datasets/cif/MgMn2O4_mvc-16834_computed.cif
