In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import yaml
import json

import tensorflow as tf

from pathlib import Path
from pymatgen.core import Structure
from sklearn.model_selection import train_test_split
from megnet.models import MEGNetModel
from megnet.data.crystal import CrystalGraph

import pymatgen
from collections import defaultdict
from tqdm.notebook import tqdm

In [2]:
def read_pymatgen_dict(file):
    with open(file, "r") as f:
        d = json.load(f)
    return Structure.from_dict(d)


def prepare_dataset(dataset_path):
    dataset_path = Path(dataset_path)
    targets = pd.read_csv(dataset_path / "targets.csv", index_col=0)
    struct = {
        item.name.strip(".json"): read_pymatgen_dict(item)
        for item in (dataset_path / "structures").iterdir()
    }

    data = pd.DataFrame(columns=["structures"], index=struct.keys())
    data = data.assign(structures=struct.values(), targets=targets)

    return data

In [3]:
def decompose(structure):
    result = defaultdict(int)
    for site in structure.sites:
        result[site.species.formula] += 1
    return result

In [4]:
data = prepare_dataset('data/dichalcogenides_public/')
data['decomposition'] = data.structures.apply(decompose)
data['len_of_decomposition'] = data['decomposition'].apply(len)

data.head()

Unnamed: 0,structures,targets,decomposition,len_of_decomposition
6142031bee0a3fd43fb47e23,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,0.2754,"{'Mo1': 63, 'Se1': 2, 'S1': 126}",3
6141d46031cf3ef3d4a9eee8,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,0.2839,"{'Mo1': 63, 'Se1': 1, 'S1': 126}",3
614211354e27a1844a5f05b4,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,1.145,"{'Mo1': 63, 'W1': 1, 'Se1': 1, 'S1': 126}",4
614346254e27a1844a5f0a14,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,1.1405,"{'Mo1': 63, 'W1': 1, 'Se1': 1, 'S1': 126}",4
6141e2eb9cbada84a8676ab7,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,1.8092,"{'Mo1': 63, 'W1': 1, 'Se1': 2, 'S1': 126}",4


In [5]:
sample = data[data['decomposition'] == defaultdict(int, {'Mo1': 63, 'W1': 1, 'S1': 126})].copy()

ideal_structure_sites = []
def extract(structure):
    for site in structure.sites:
        if site.species.formula != 'W1' and site not in ideal_structure_sites:
            ideal_structure_sites.append(site)

for structure in tqdm(sample['structures'].values):
    extract(structure)
    
ideal_structure = Structure.from_sites(ideal_structure_sites)

  0%|          | 0/380 [00:00<?, ?it/s]

In [6]:
def find_differ_sites(structure_sites, ideal_structure_sites):
    differ_sites = []
    for x in ideal_structure_sites:
        flag = 0
        for y in structure_sites:
            if (x.coords == y.coords).all():
                flag = 1
                if x.species.formula != y.species.formula:
                    differ_sites.append(y) # find another site on this position
                break
        if flag == 0: # didn't find site
            differ_sites.append(x)

    return differ_sites

In [7]:
data['representative'] = data['structures'].apply(lambda x: find_differ_sites(x.sites, ideal_structure_sites))
data['representative_str'] = data['representative'].apply(lambda x: Structure.from_sites(x))
data.head()

Unnamed: 0,structures,targets,decomposition,len_of_decomposition,representative,representative_str
6142031bee0a3fd43fb47e23,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,0.2754,"{'Mo1': 63, 'Se1': 2, 'S1': 126}",3,"[[12.76126293 1.84192955 3.719751 ] Mo, [12...","[[12.76126293 1.84192955 3.719751 ] Mo, [12..."
6141d46031cf3ef3d4a9eee8,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,0.2839,"{'Mo1': 63, 'Se1': 1, 'S1': 126}",3,"[[ 3.19031583 12.89350732 3.719751 ] Mo, [ 6...","[[ 3.19031583 12.89350732 3.719751 ] Mo, [ 6..."
614211354e27a1844a5f05b4,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,1.145,"{'Mo1': 63, 'W1': 1, 'Se1': 1, 'S1': 126}",4,"[[-3.19031561 7.36771851 3.719751 ] W, [-3....","[[-3.19031561 7.36771851 3.719751 ] W, [-3...."
614346254e27a1844a5f0a14,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,1.1405,"{'Mo1': 63, 'W1': 1, 'Se1': 1, 'S1': 126}",4,"[[11.16610508 4.604824 3.719751 ] W, [20....","[[11.16610508 4.604824 3.719751 ] W, [20...."
6141e2eb9cbada84a8676ab7,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,1.8092,"{'Mo1': 63, 'W1': 1, 'Se1': 2, 'S1': 126}",4,"[[ 7.97578938 10.13061288 3.719751 ] W, [7.9...","[[ 7.97578938 10.13061288 3.719751 ] W, [7.9..."


In [11]:
# Проверка, что ничего не испортилось и классы не потерялись и не перемешались (количества точек не изменились):

# data['representative_decomposition'] = data['representative_str'].apply(decompose)
# data['representative_decomposition'].value_counts()
# data['decomposition'].value_counts()

In [9]:
mask1 = data['decomposition'] == defaultdict(int, {'Mo1': 63, 'W1': 1, 'Se1': 1, 'S1': 126})
mask2 = data['decomposition'] == defaultdict(int, {'Mo1': 63, 'Se1': 1, 'S1': 126})
mask3 = data['decomposition'] == defaultdict(int, {'Mo1': 63, 'W1': 1, 'S1': 126})
mask4 = data['decomposition'] == defaultdict(int, {'Mo1': 63, 'Se1': 2, 'S1': 126})
mask5 = data['decomposition'] == defaultdict(int, {'Mo1': 63, 'W1': 1, 'Se1': 2, 'S1': 126})