# Perovデータセットの形式を確かめて、MEGNetデータセットに変換する

In [6]:
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from jarvis.db.figshare import data as jdata
from jarvis.core.atoms import Atoms
from sklearn.model_selection import train_test_split
from pymatgen.core import Element as Element_pmg
from IPython.display import display, HTML, clear_output

from pymatgen.core import Structure
from pymatgen.analysis.structure_matcher import StructureMatcher

from sklearn.model_selection import StratifiedShuffleSplit


%matplotlib inline

# Perovデータセットの形式を確かめる

In [2]:
perov_data_path = '../data/perov_5/test.csv'
df = pd.read_csv(perov_data_path)
display(df.head())
print(df.columns)

Unnamed: 0.1,Unnamed: 0,material_id,cif,formula,heat_all,heat_ref,dir_gap,ind_gap
0,996,3961,# generated using pymatgen\ndata_TiOsNOF\n_sym...,TiOsOFN,1.16,1.438331,0.0,0.0
1,13356,11922,# generated using pymatgen\ndata_TlRuO2F\n_sym...,RuTlO2F,1.66,1.604896,0.0,0.0
2,17389,6694,# generated using pymatgen\ndata_BiPtN2O\n_sym...,PtBiON2,1.56,1.569679,0.0,0.0
3,5951,3335,# generated using pymatgen\ndata_VRhNOF\n_symm...,VRhOFN,1.56,1.566713,0.0,0.0
4,10279,18565,# generated using pymatgen\ndata_CuAsN3\n_symm...,CuAsN3,1.96,1.957102,0.0,0.0


Index(['Unnamed: 0', 'material_id', 'cif', 'formula', 'heat_all', 'heat_ref',
       'dir_gap', 'ind_gap'],
      dtype='object')


In [3]:
cif_str = df.cif.loc[0]
with open("test.cif", "w") as f:
    f.write(cif_str)
from pymatgen.core import Structure

structure = Structure.from_file("test.cif")
print(structure)

Full Formula (Ti1 Os1 N1 O1 F1)
Reduced Formula: TiOsNOF
abc   :   4.056322   4.056322   4.056322
angles:  90.000000  90.000000  90.000000
pbc   :       True       True       True
Sites (5)
  #  SP           a    b    c
---  ----  --------  ---  ---
  0  Ti    0.609211  0    0
  1  Os    0.776199  0.5  0.5
  2  N     0.742001  0.5  0
  3  O     0.231433  0.5  0.5
  4  F     0.415284  0    0.5


In [7]:
from pymatgen.core import Structure, Lattice
from pymatgen.io.cif import CifWriter

def create_cif_string(data_atom_dict):
    # 入力データ
    structure_dict = {
        'lattice_mat': data_atom_dict['lattice_mat'],
        'coords': data_atom_dict['coords'],
        'elements': data_atom_dict['elements'],
        'cartesian': data_atom_dict['cartesian']
    }

    # Latticeオブジェクトの作成
    lattice = Lattice(structure_dict['lattice_mat'])

    # Structureの作成
    structure = Structure(
        lattice,
        structure_dict['elements'],
        structure_dict['coords'],
        coords_are_cartesian=structure_dict['cartesian']
    )

    # CIFファイルとして保存
    cif_writer = CifWriter(structure)
    return str(cif_writer)

In [8]:
from collections import Counter

def elements_to_formula(elements):
    counts = Counter(elements)
    # 元素記号の順序を維持（リスト順）
    unique_order = []
    for e in elements:
        if e not in unique_order:
            unique_order.append(e)
    # 化学式を構築
    formula = ''
    for elem in unique_order:
        count = counts[elem]
        formula += elem
        if count > 1:
            formula += str(count)
    return formula

# JARVIS Superconductorをダウンロードして、perov datasetの形式と合わせる
- 文献URL: https://www.nature.com/articles/s41524-022-00933-1
- 80:10:10 training-validation-testing data split

In [56]:
sc_data_path = '../data/supercon'
os.makedirs(sc_data_path, exist_ok=True)

In [22]:
dataset = jdata("supercon_3d")
data = dataset[0]
data.keys()

Obtaining supercond. Tc dataset 1058...
Reference:https://www.nature.com/articles/s41524-022-00933-1


100%|██████████| 8.83M/8.83M [00:10<00:00, 823kiB/s] 


Loading the zipfile...
Loading completed.


dict_keys(['stability', 'jid', 'atoms', 'cfid', 'wlog', 'lamb', 'Tc', 'a2F', 'a2F_original_x', 'a2F_original_y', 'press'])

In [57]:
data_dict_list = []
for data_i, data in enumerate(dataset):
    cif_string = create_cif_string(data['atoms'])
    formula = elements_to_formula(data['atoms']['elements'])
    data_dict_list.append({
        'material_id' :data_i,
        'cif': cif_string,
        'formula': formula,
        'tc':data['Tc'],
    })
df = pd.DataFrame(data_dict_list)

In [58]:
# dfをシャッフルし、80:10:10にtrain, val, testに分割する
# 80%をtrain、10%をval、10%をtestに分割
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
val_df, test_df = train_test_split(test_df, test_size=0.5, random_state=42)
train_df.to_csv(os.path.join(sc_data_path, 'train.csv'), index=False)
val_df.to_csv(os.path.join(sc_data_path, 'val.csv'), index=False)
test_df.to_csv(os.path.join(sc_data_path, 'test.csv'), index=False)

# MEGNetをダウンロードして、perov datasetの形式と合わせる
- ALIGNNにおけるデータ分割の比率は、60000–5000–4239

In [3]:
dataset = jdata("megnet")
data = dataset[0]
data.keys()

Obtaining MEGNET-3D CFID dataset 69k...
Reference:https://pubs.acs.org/doi/10.1021/acs.chemmater.9b01294
Loading the zipfile...
Loading completed.


dict_keys(['id', 'desc', 'formula', 'e_hull', 'gap pbe', 'mu_b', 'elastic anisotropy', 'bulk modulus', 'shear modulus', 'atoms', 'e_form'])

In [4]:
megnet_data_path = '../data/megnet'
os.makedirs(megnet_data_path, exist_ok=True)

### perovskiteデータを探索し、ラベルをつける

In [9]:
import copy

matcher = StructureMatcher(stol=0.1, angle_tol=5, primitive_cell=True, scale=True)

In [10]:
perov_coords = [
    [0, 0, 0],         # B
    [0.5, 0.5, 0.5],   # A
    [0.5, 0, 0.5],     # X
    [0, 0.5, 0.5],     # X
    [0.5, 0.5, 0],     # X
]

def check_perovskite_unitcell(data, perov_coords):
    # ペロブスカイト構造っぽい候補の座標を、reference構造の座標と比較する
    data_dict = copy.deepcopy(data['atoms'])
    data_dict['elements'] = ['H']*5 # 座標の比較のみを行うため、元素はHに置き換え
    # --- Structureオブジェクトに変換 ---
    lattice = Lattice(data_dict['lattice_mat'])
    structure1 = Structure(lattice, data_dict['elements'], data_dict['coords'], coords_are_cartesian=data_dict['cartesian'])

    # 参照となる座標を持つ辞書を作成
    reference_pereovskie_dict = {
        'lattice_mat': data_dict['lattice_mat'],
        'coords':perov_coords,
        'elements': ['H']*5,
        'abc': data_dict['abc'],
        'angles': data_dict['angles'],
        'cartesian': data_dict['cartesian'],
        'props': data_dict['props'],
    }
    # --- Structureオブジェクトに変換 ---
    lattice = Lattice(reference_pereovskie_dict['lattice_mat'])
    structure2 = Structure(lattice, reference_pereovskie_dict['elements'], reference_pereovskie_dict['coords'], coords_are_cartesian=reference_pereovskie_dict['cartesian'])
    # 構造の比較
    return matcher.fit(structure1, structure2)


In [13]:
from typing import List, Tuple, Optional
from smact import Element
from tqdm.notebook import tqdm
from pymatgen.core import Element as Element_pmg
from collections import Counter
from typing import List, Tuple, Optional
from collections import Counter
import itertools

def element_count_grouped(species_list):
    counts = Counter(species_list)
    return [[counts[elem]] for elem in counts]

def elec_neutral_check_SUPER_COMMON(num_i:int, total:int, elements: List[str], stoichs: List[List[int]], return_all_ox_states:bool=False) -> Tuple[bool, Optional[Tuple[int]]]:
    """
    Check for electrical neutrality using PyMatGen icsd_oxidation_states method by evaluating possible oxidation states combinations.

    Args:
        num_i (int): Index of the structure (for tqdm display)
        total (int): Total number of structures. (for tqdm display)
        elements (List[str]): List of element symbols.
        stoichs (List[List[int]]): List of lists containing stoichiometries.
        return_all_ox_states (bool): Whether to return all possible oxidation states combinations.

    Returns:
        Tuple[bool, Optional[Tuple[int]]]: A tuple where the first element is a boolean indicating 
                                           whether the input is electrically neutral, and the second 
                                           element is a tuple of oxidation states that make it neutral 
                                           (or None if no neutral combination is found).

    Examples:
        >>> elec_neutral_check_SUPER_COMMON(5, 10, elements=['Ti', 'O'], stoichs=[[1], [2]])
        (True, , ['Ti', 'O', 'O'], (4, -2, -2)))
        >>> elec_neutral_check_SUPER_COMMON(5, 10, elements = ['Ti', 'Al', 'O'], stoichs = [[1],[1],[1]])
        (False, ['Ti', 'Al', 'O'], None)
        >>> elec_neutral_check_SUPER_COMMON(5, 10, elements=['He', 'O'], stoichs=[[1], [2]])
        (False, ['H', 'O', 'O'], None)
    """   
    all_elements = []
    for elem, stoi in zip(elements, stoichs):
        assert len(stoi) == 1
        all_elements.extend([elem]*stoi[0])
    ox_combos = [
        list(set(Element_pmg(elem).icsd_oxidation_states) & set(Element_pmg(elem).oxidation_states) & set(Element_pmg(elem).ionic_radii.keys()) & set(Element_pmg(elem).common_oxidation_states))
        for elem in all_elements    
    ]

    # check excluding non-oxidation state elements
    if any([len(ox) == 0 for ox in ox_combos]):
        return False, all_elements, None

    lengths = np.array([len(sublist) for sublist in ox_combos])
    product_of_lengths = np.prod(lengths)

    if return_all_ox_states:
        all_neutral_ox_states = []
        for ox_states in tqdm(itertools.product(*ox_combos), total=product_of_lengths,leave=False, desc=f"neutral check ({num_i+1}/{total}) by PMG"):
            if sum(ox_states) == 0:
                all_neutral_ox_states.append(ox_states)
        return len(all_neutral_ox_states)>0, all_elements, all_neutral_ox_states

    else:
        for ox_states in tqdm(itertools.product(*ox_combos), total=product_of_lengths,leave=False, desc=f"neutral check ({num_i+1}/{total}) by PMG"):
            if sum(ox_states) == 0:
                return True, all_elements, ox_states
            
        return False, all_elements, None

In [14]:


# 1分弱かかる(40s)
megnet_data_dict_list = []
for data in dataset:
    cif_string = create_cif_string(data['atoms'])
    formula = elements_to_formula(data['atoms']['elements'])
    num_data = len(data['atoms']['elements'])
    megnet_data_dict = {
        'gap': data['gap pbe'],
        'e_form':data['e_form'],
        'cif': cif_string,
        'formula': formula,
        '100more':num_data > 100,
        'tolerance': 0,
    }
    # 角度が90度に近いものを探す
    if (np.abs(np.array(data['atoms']['angles'])-90) < 5).all() and len(data['atoms']['coords']) == 5 and (data['atoms']['cartesian']==False):
        # perovskiteの候補
        if check_perovskite_unitcell(data, perov_coords):
            # 電気的中性を確認
            is_neutral, all_elements, ox_states = elec_neutral_check_SUPER_COMMON(0, 1, data['atoms']['elements'], element_count_grouped(data['atoms']['elements']), return_all_ox_states=True)
            if is_neutral:
                print(f"Found a candidate for perovskite with elements: {all_elements}, oxidation states: {ox_states}")
                # ここで、見つかった構造を保存するなどの処理を行うことができます
                ox_states_set = [set(ox) for ox in ox_states]
                if {-2, 2, 4} in ox_states_set:
                    idx = ox_states_set.index({-2, 2, 4}) 
                    Asite_ox, Bsite_ox, Xsite_ox = 2, 4, -2
                elif {-1, 1, 2} in ox_states_set:
                    idx = ox_states_set.index({-1, 1, 2})
                    Asite_ox, Bsite_ox, Xsite_ox = 1, 2, -1
                else:
                    continue
                    raise Exception('Perovskiteではない？')
                perov_ox_states = ox_states[idx]
                
                Asite_id = perov_ox_states.index(Asite_ox)
                Bsite_id = perov_ox_states.index(Bsite_ox)
                X_site_ids = list(range(5))
                X_site_ids.remove(Asite_id)
                X_site_ids.remove(Bsite_id)


                Asite_radii = Element_pmg(data['atoms']['elements'][Asite_id]).ionic_radii[Asite_ox]
                Bsite_radii = Element_pmg(data['atoms']['elements'][Bsite_id]).ionic_radii[Bsite_ox]
                Xsite1_radii = Element_pmg(data['atoms']['elements'][X_site_ids[0]]).ionic_radii[Xsite_ox]
                Xsite2_radii = Element_pmg(data['atoms']['elements'][X_site_ids[1]]).ionic_radii[Xsite_ox]
                Xsite3_radii = Element_pmg(data['atoms']['elements'][X_site_ids[2]]).ionic_radii[Xsite_ox]
                Xsite_radii = np.mean([Xsite1_radii, Xsite2_radii, Xsite3_radii])
                tolerance_val = (Asite_radii + Xsite_radii)/(np.sqrt(2)*(Bsite_radii + Xsite_radii))
                if 0.8 <= tolerance_val < 1.0:
                    megnet_data_dict['tolerance'] = 1.0

                
            else:
                print(f"Not electrically neutral: {all_elements}")
            clear_output(wait=True)

    megnet_data_dict_list.append(megnet_data_dict)


neutral check (1/1) by PMG:   0%|          | 0/1 [00:00<?, ?it/s]

Found a candidate for perovskite with elements: ['Ba', 'Ti', 'O', 'O', 'O'], oxidation states: [(2, 4, -2, -2, -2)]


In [15]:
df = pd.DataFrame(megnet_data_dict_list)
df['material_id'] = df.index
df['100more'] = df['100more'].astype(float)
df[['material_id','cif','gap', 'e_form', 'formula', '100more', 'tolerance']]

Unnamed: 0,material_id,cif,gap,e_form,formula,100more,tolerance
0,0,# generated using pymatgen\ndata_As\n_symmetry...,0.0000,0.107405,As2,0.0,0.0
1,1,# generated using pymatgen\ndata_Hf\n_symmetry...,0.0000,0.181111,Hf,0.0,0.0
2,2,# generated using pymatgen\ndata_BaTe\n_symmet...,1.5930,-1.790168,TeBa,0.0,0.0
3,3,# generated using pymatgen\ndata_Hf2S\n_symmet...,0.0000,-1.253224,S2Hf4,0.0,0.0
4,4,# generated using pymatgen\ndata_Nb4CoSi\n_sym...,0.0064,-0.343178,Si2Co2Nb8,0.0,0.0
...,...,...,...,...,...,...,...
69197,69197,# generated using pymatgen\ndata_WO3\n_symmetr...,1.1162,-2.182536,W2O6,0.0,0.0
69198,69198,# generated using pymatgen\ndata_NiO2\n_symmet...,1.5195,-0.549193,NiO2,0.0,0.0
69199,69199,# generated using pymatgen\ndata_P\n_symmetry_...,0.0000,0.141298,P4,0.0,0.0
69200,69200,# generated using pymatgen\ndata_WO3\n_symmetr...,1.3694,-2.184021,W8O24,0.0,0.0


In [16]:
stratify_label = df['100more'].astype(str) + '_' + df['tolerance'].astype(str)
sss = StratifiedShuffleSplit(n_splits=1, test_size=(df.shape[0]-60000)/df.shape[0], random_state=42)
for train_idx, test_idx in sss.split(df, stratify_label):
    train_df = df.iloc[train_idx].reset_index(drop=True)
    test_val_df = df.iloc[test_idx].reset_index(drop=True)

sss = StratifiedShuffleSplit(n_splits=1, test_size=(test_val_df.shape[0]-5000)/test_val_df.shape[0], random_state=42)
for val_idx, test_idx in sss.split(test_val_df, stratify_label.iloc[test_idx]):
    val_df = test_val_df.iloc[val_idx].reset_index(drop=True)
    test_df = test_val_df.iloc[test_idx].reset_index(drop=True)
print(train_df.shape, val_df.shape, test_df.shape)

(60000, 7) (5000, 7) (4202, 7)


In [17]:
# train, val, testに分割してデータを保存
train_df.to_csv(os.path.join(megnet_data_path, 'train.csv'), index=False)
val_df.to_csv(os.path.join(megnet_data_path, 'val.csv'), index=False)
test_df.to_csv(os.path.join(megnet_data_path, 'test.csv'), index=False)

In [18]:
print("tolerance が0.8 ~ 1.0のペロブスカイト構造の割合",(train_df['tolerance']==1.0).sum()/train_df.shape[0])

tolerance が0.8 ~ 1.0のペロブスカイト構造の割合 0.00125


In [19]:
perov_aug_train_df = pd.concat([
    train_df,pd.concat([train_df[train_df['tolerance']==1.0]]*350, axis=0)
],axis=0)
perov_aug_train_df.material_id = range(perov_aug_train_df.shape[0])
print("tolerance が0.8 ~ 1.0のペロブスカイト構造の割合",(perov_aug_train_df['tolerance']==1.0).sum()/perov_aug_train_df.shape[0])

tolerance が0.8 ~ 1.0のペロブスカイト構造の割合 0.30521739130434783


In [20]:
megnet_perov_data_path = '../data/megnet_perov'
os.makedirs(megnet_perov_data_path, exist_ok=True)
perov_aug_train_df.to_csv(os.path.join(megnet_perov_data_path, 'train.csv'), index=False)
val_df.to_csv(os.path.join(megnet_perov_data_path, 'val.csv'), index=False)
test_df.to_csv(os.path.join(megnet_perov_data_path, 'test.csv'), index=False)