In [1]:
# pip install -r requirements.txt

# pip install --upgrade numpy

# conda install -c conda-forge nglview 
import nglview as nv



In [113]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import yaml
import json

import tensorflow as tf

from pathlib import Path
from pymatgen.core import Structure
from sklearn.model_selection import train_test_split
from megnet.models import MEGNetModel
from megnet.data.crystal import CrystalGraph

import pymatgen
from collections import defaultdict

In [115]:
def read_pymatgen_dict(file):
    with open(file, "r") as f:
        d = json.load(f)
    return Structure.from_dict(d)


def prepare_dataset(dataset_path):
    dataset_path = Path(dataset_path)
    targets = pd.read_csv(dataset_path / "targets.csv", index_col=0)
    struct = {
        item.name.strip(".json"): read_pymatgen_dict(item)
        for item in (dataset_path / "structures").iterdir()
    }

    data = pd.DataFrame(columns=["structures"], index=struct.keys())
    data = data.assign(structures=struct.values(), targets=targets)

    return data

In [116]:
def decompose(structure):
    result = defaultdict(int)
    for site in structure.sites:
        result[site.species.formula] += 1
    return result

In [117]:
data = prepare_dataset('data/dichalcogenides_public/')
data['decomposition'] = data.structures.apply(decompose)
data['len_of_decomposition'] = data['decomposition'].apply(len)

data.head()

Unnamed: 0,structures,targets,decomposition,len_of_decomposition
6142031bee0a3fd43fb47e23,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,0.2754,"{'Mo1': 63, 'Se1': 2, 'S1': 126}",3
6141d46031cf3ef3d4a9eee8,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,0.2839,"{'Mo1': 63, 'Se1': 1, 'S1': 126}",3
614211354e27a1844a5f05b4,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,1.145,"{'Mo1': 63, 'W1': 1, 'Se1': 1, 'S1': 126}",4
614346254e27a1844a5f0a14,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,1.1405,"{'Mo1': 63, 'W1': 1, 'Se1': 1, 'S1': 126}",4
6141e2eb9cbada84a8676ab7,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,1.8092,"{'Mo1': 63, 'W1': 1, 'Se1': 2, 'S1': 126}",4


In [6]:
n = 100
sample = data[data['decomposition'] == defaultdict(int, {'Mo1': 63, 'W1': 1, 'S1': 126})].sample(n)

In [7]:
i = 2
print(sample.targets[i])
sample.structures[i]

1.1041


Structure Summary
Lattice
    abc : 25.5225256 25.5225256 14.879004
 angles : 90.0 90.0 119.99999999999999
 volume : 8393.668021812642
      A : 25.5225256 0.0 1.5628039641098191e-15
      B : -12.761262799999994 22.10315553833868 1.5628039641098191e-15
      C : 0.0 0.0 14.879004
PeriodicSite: Mo (0.0000, 1.8419, 3.7198) [0.0417, 0.0833, 0.2500]
PeriodicSite: Mo (-1.5952, 4.6048, 3.7198) [0.0417, 0.2083, 0.2500]
PeriodicSite: Mo (-3.1903, 7.3677, 3.7198) [0.0417, 0.3333, 0.2500]
PeriodicSite: Mo (-4.7855, 10.1306, 3.7198) [0.0417, 0.4583, 0.2500]
PeriodicSite: Mo (-6.3806, 12.8935, 3.7198) [0.0417, 0.5833, 0.2500]
PeriodicSite: Mo (-7.9758, 15.6564, 3.7198) [0.0417, 0.7083, 0.2500]
PeriodicSite: Mo (-9.5709, 18.4193, 3.7198) [0.0417, 0.8333, 0.2500]
PeriodicSite: Mo (-11.1661, 21.1822, 3.7198) [0.0417, 0.9583, 0.2500]
PeriodicSite: Mo (3.1903, 1.8419, 3.7198) [0.1667, 0.0833, 0.2500]
PeriodicSite: Mo (1.5952, 4.6048, 3.7198) [0.1667, 0.2083, 0.2500]
PeriodicSite: Mo (0.0000, 7.3677, 3

In [27]:
i = 3
print(sample.targets[i])
sample.structures[i]

0.7035


Structure Summary
Lattice
    abc : 25.5225256 25.5225256 14.879004
 angles : 90.0 90.0 119.99999999999999
 volume : 8393.668021812642
      A : 25.5225256 0.0 1.5628039641098191e-15
      B : -12.761262799999994 22.10315553833868 1.5628039641098191e-15
      C : 0.0 0.0 14.879004
PeriodicSite: Mo (0.0000, 1.8419, 3.7198) [0.0417, 0.0833, 0.2500]
PeriodicSite: Mo (-1.5952, 4.6048, 3.7198) [0.0417, 0.2083, 0.2500]
PeriodicSite: Mo (-3.1903, 7.3677, 3.7198) [0.0417, 0.3333, 0.2500]
PeriodicSite: Mo (-4.7855, 10.1306, 3.7198) [0.0417, 0.4583, 0.2500]
PeriodicSite: Mo (-6.3806, 12.8935, 3.7198) [0.0417, 0.5833, 0.2500]
PeriodicSite: Mo (-7.9758, 15.6564, 3.7198) [0.0417, 0.7083, 0.2500]
PeriodicSite: Mo (-9.5709, 18.4193, 3.7198) [0.0417, 0.8333, 0.2500]
PeriodicSite: Mo (-11.1661, 21.1822, 3.7198) [0.0417, 0.9583, 0.2500]
PeriodicSite: Mo (3.1903, 1.8419, 3.7198) [0.1667, 0.0833, 0.2500]
PeriodicSite: Mo (1.5952, 4.6048, 3.7198) [0.1667, 0.2083, 0.2500]
PeriodicSite: Mo (0.0000, 7.3677, 3

In [69]:
list(map(lambda x: x.species.formula, sample.structures[3].sites)).index('W1')

63

In [None]:
# def get_W1(sites_list):
#     pos = list(map(lambda x: x.species.formula, sample.structures[i].sites)).index('W1')
#     return sites_list[pos]

# Посмотрим, как выглядят структуры {'Mo1' 63, 'W1' 1, 'S1' 126}

In [27]:
i = 12
print(sample.targets[i])
view = nv.show_pymatgen(sample.structures[i])
view.add_unitcell()
view

1.0839


NGLWidget()

In [28]:
i = 3
print(sample.targets[i])
view = nv.show_pymatgen(sample.structures[i])
view.add_unitcell()
view

0.6527999999999999


NGLWidget()

In [30]:
i = 5
print(sample.targets[i])
view = nv.show_pymatgen(sample.structures[i])
view.add_unitcell()
view

1.039


NGLWidget()

In [41]:
i = 14
print(sample.targets[i])
view = nv.show_pymatgen(sample.structures[i])
view.add_unitcell()
view

0.9602


NGLWidget()

##  Вывод: есть синие и отсутствующие желтые (2), хочется их вытащить и по ним строить модель

Позиций ограниченно +  2 cимметрии, энергия же не поменяется

# Построим идеальную структуру

In [55]:
from tqdm.notebook import tqdm

In [56]:
sample = data[data['decomposition'] == defaultdict(int, {'Mo1': 63, 'W1': 1, 'S1': 126})].copy()

ideal_structure_sites = []
def extract(structure):
    for site in structure.sites:
        if site.species.formula != 'W1' and site not in ideal_structure_sites:
            ideal_structure_sites.append(site)

for structure in tqdm(sample['structures'].values):
    extract(structure)
    
ideal_structure = Structure.from_sites(ideal_structure_sites)

  0%|          | 0/380 [00:00<?, ?it/s]

In [137]:
len(ideal_structure_sites), len(sample.structures[0].sites)

(192, 190)

In [132]:
decompose(ideal_structure)

defaultdict(int, {'Mo1': 64, 'S1': 128})

In [131]:
decompose(sample.structures[0])

defaultdict(int, {'Mo1': 63, 'W1': 1, 'S1': 126})

In [136]:
view = nv.show_pymatgen(ideal_structure)
view.add_unitcell()
view

NGLWidget()

In [135]:
# отстутствую одна желтая на боку и одна немного правее
view = nv.show_pymatgen(sample.structures[0])
view.add_unitcell()
view

NGLWidget()

# Получение датасета

In [153]:
def find_differ_sites(structure_sites, ideal_structure_sites):
    site_W1 = structure_sites[63] # вроде как W всегда 63, но если что можно воспользоваться функцией получения позиции W выше
    missed_sites_coords = [site_W1]
    for x in ideal_structure_sites:
        if x not in structure_sites and (x.coords != site_W1.coords).any():
            missed_sites.append(x)
    return missed_sites

In [154]:
find_differ_sites(sample.structures[0].sites, ideal_structure_sites)

[PeriodicSite: W (-1.5952, 10.1306, 3.7198) [0.1667, 0.4583, 0.2500],
 PeriodicSite: S (12.7613, 3.6839, 5.2846) [0.5833, 0.1667, 0.3552],
 PeriodicSite: S (17.5467, 0.9210, 5.2846) [0.7083, 0.0417, 0.3552]]

## Как составить датасет?  S не отличимы, а W отличима

In [157]:
sample = data[data['decomposition'] == defaultdict(int, {'Mo1': 63, 'W1': 1, 'S1': 126})].copy()


sample['differ_sites'] = sample['structures'].apply(lambda x: find_differ_sites(x.sites, ideal_structure_sites))

for i in range(3):
    sample['W_coords_' + str(i)] = sample['differ_sites'].apply(lambda x: x[0].coords[i])
    
for i in range(3):
    sample['W_frac_coords_' + str(i)] = sample['differ_sites'].apply(lambda x: x[0].frac_coords[i])
    
    
sample.head()

Unnamed: 0,structures,targets,decomposition,len_of_decomposition,differ_sites,W_coords_0,W_coords_1,W_coords_2,W_frac_coords_0,W_frac_coords_1,W_frac_coords_2
61428c13baaf234b35290702,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,1.0033,"{'Mo1': 63, 'W1': 1, 'S1': 126}",3,"[[-1.59515772 10.13061288 3.719751 ] W, [12....",-1.595158,10.130613,3.719751,0.166667,0.458333,0.25
614282ffbaaf234b352906f4,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,0.653,"{'Mo1': 63, 'W1': 1, 'S1': 126}",3,"[[-9.57094697 18.41929621 3.719751 ] W, [14....",-9.570947,18.419296,3.719751,0.041667,0.833333,0.25
6141f6aabaaf234b3529052a,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,1.1048,"{'Mo1': 63, 'W1': 1, 'S1': 126}",3,"[[-1.59515772 21.18219065 3.719751 ] W, [ 9....",-1.595158,21.182191,3.719751,0.416667,0.958333,0.25
6141d415baaf234b352902f6,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,0.6528,"{'Mo1': 63, 'W1': 1, 'S1': 126}",3,"[[15.95157863 12.89350732 3.719751 ] W, [ 4....",15.951579,12.893507,3.719751,0.916667,0.583333,0.25
61422120baaf234b35290614,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,1.1063,"{'Mo1': 63, 'W1': 1, 'S1': 126}",3,"[[12.76126293 18.41929621 3.719751 ] W, [ 3....",12.761263,18.419296,3.719751,0.916667,0.833333,0.25


In [158]:
from sklearn.model_selection import cross_validate
from sklearn.neighbors import KNeighborsRegressor

from sklearn.metrics import make_scorer

In [159]:
def my_custom_loss_func(y_true, y_pred):
    return np.mean(np.abs(y_true - y_pred) < 0.02)

score = make_scorer(my_custom_loss_func, greater_is_better=True)

In [162]:
knn = KNeighborsRegressor(n_neighbors=2)

cv_results = cross_validate(knn, sample[sample.columns[-6:]], sample['targets'], cv=3, scoring=score)
cv_results['test_score']

array([0.40944882, 0.44094488, 0.37301587])