In [1]:
import pickle
import gzip
import json
import pandas as pd

import numpy as np
import networkx as nx
from tqdm import tqdm
from joblib import Parallel, delayed

In [2]:
!pip freeze

argon2-cffi @ file:///tmp/build/80754af9/argon2-cffi_1613037499734/work
ase==3.22.1
async-generator @ file:///home/ktietz/src/ci/async_generator_1611927993394/work
attrs @ file:///opt/conda/conda-bld/attrs_1642510447205/work
backcall @ file:///home/ktietz/src/ci/backcall_1611930011877/work
bleach @ file:///opt/conda/conda-bld/bleach_1641577558959/work
certifi==2021.10.8
cffi @ file:///opt/conda/conda-bld/cffi_1642701102775/work
charset-normalizer==2.0.11
cycler==0.11.0
Cython==0.29.27
debugpy @ file:///tmp/build/80754af9/debugpy_1637091799509/work
decorator @ file:///opt/conda/conda-bld/decorator_1643638310831/work
defusedxml @ file:///tmp/build/80754af9/defusedxml_1615228127516/work
entrypoints==0.3
fonttools==4.29.1
future==0.18.2
idna==3.3
importlib-metadata @ file:///tmp/build/80754af9/importlib-metadata_1638542770237/work
ipykernel @ file:///tmp/build/80754af9/ipykernel_1633534655931/work/dist/ipykernel-6.4.1-py3-none-any.whl
ipython @ file:///tmp/build/80754af

In [2]:
from pymatgen.transformations.site_transformations import (
    ReplaceSiteSpeciesTransformation,
    InsertSitesTransformation,
    RemoveSitesTransformation
)
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer, SpacegroupOperations
from pymatgen.core.sites import PeriodicSite



In [3]:
with gzip.open('data.pickle.gz', 'rb') as fh:
    data = pickle.load(fh, )

In [4]:
data.shape

(5933, 11)

In [5]:
data["idx"] = np.arange(len(data))

In [6]:
data.head()

Unnamed: 0_level_0,descriptor_id,energy,energy_per_atom,fermi_level,homo,lumo,initial_structure,defect_representation,formation_energy,formation_energy_per_site,band_gap,idx
_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
6141cf0efbfd4bd9ab2c2f7e,6141cf0efbfd4bd9ab2c2f7c,-1391.3404,-7.284505,-0.199707,-0.6754,0.4698,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,[[-7.98855051 17.50569919 5.28204642] X0+],2.6457,2.6457,1.1452,0
6141cf0f51c1cbd9654b8870,6141cf0e51c1cbd9654b886e,-1384.5528,-7.28712,-0.220627,-0.6852,0.3991,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,"[[14.34365939 6.45412142 2.15745558] X0+, [9...",5.3063,2.65315,1.0843,1
6141cf0fe689ecc4c43cdd4b,6141cf0fe689ecc4c43cdd49,-1397.1961,-7.277063,-0.183537,-0.6931,1.1102,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,[[ 4.78547342 17.49833154 2.15486663] Se],0.279,0.279,1.8033,2
6141cf10b842c2e72e2f2d44,6141cf10b842c2e72e2f2d42,-1396.2576,-7.272175,-0.179802,-0.6916,1.1179,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,"[[ 9.57094697 20.26122598 2.15486663] Se, [20...",0.5795,0.28975,1.8095,3
6141cf1051c1cbd9654b8872,6141cf0e51c1cbd9654b886e,-1384.5327,-7.287014,-0.21319,-0.6718,0.4384,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,"[[ 7.96302799 17.50569919 2.15745558] X0+, [-...",5.3264,2.6632,1.1102,4


In [7]:
data.energy_per_atom.min()

-7.297261052631578

In [8]:
def get_element(defect_site):
    return defect_site.as_dict()["species"][0]["element"]

In [9]:
def get_layer(defect_site, layers_coords, atol=1e-02):
    """
    layers_coords - array/list of len 3
    output:
    -1 - S (bottom)
     0 - Mo
     1 - S (top)
    """
    z = defect_site.frac_coords[2]
    is_close = np.isclose(layers_coords, z, atol)
    assert np.max(is_close) != 0
    
    return np.argmax(np.isclose(layers_coords, z, atol)) - 1

In [10]:
def build_layers_coords(structure):
    zs = [site.frac_coords[2] for site in structure.sites]
    layer_coords = np.zeros(3)
    layer_coords[0], layer_coords[2] = np.min(zs), np.max(zs)
    layer_coords[1] = (layer_coords[0] + layer_coords[2]) / 2
    return layer_coords

In [11]:
s0 = data.iloc[10].initial_structure

In [12]:
layers_coords = build_layers_coords(s0)
# [0.1448, 0.2500, 0.3552]
layers_coords

array([0.144826, 0.25    , 0.355174])

In [13]:
data.iloc[19].defect_representation

Structure Summary
Lattice
    abc : 25.5225256 25.5225256 14.879004
 angles : 90.0 90.0 119.99999999999999
 volume : 8393.668021812642
      A : 25.5225256 0.0 1.5628039641098191e-15
      B : -12.761262799999994 22.10315553833868 1.5628039641098191e-15
      C : 0.0 0.0 14.879004
PeriodicSite: X0+ (14.3437, 0.9283, 2.1575) [0.5830, 0.0420, 0.1450]
PeriodicSite: Se (3.1903, 20.2612, 2.1549) [0.5833, 0.9167, 0.1448]

In [14]:
get_layer(data.iloc[19].defect_representation.sites[1], layers_coords)

-1

In [20]:
def build_suffix(layer):
    return "bottom" if layer == -1 else "top"

def build_suffix_Svac_Ssub(elements, layers):
    """
    both arguments of len 2
    determine S-layers where vacancy and substition are
    """
    assert len(elements) == 2
    assert len(layers) == 2
    
    vac_idx = elements.index("X")
    suffix_vac = build_suffix(layers[vac_idx])
    suffix_sub = build_suffix(layers[1-vac_idx])
    return f"X_{suffix_vac}_Se_{suffix_sub}"

def build_suffix_2_defects(elements, layers):
    """
    both arguments of len 2
    determine S-defects are on the same layer and which one
    """
    assert len(elements) == 2
    assert len(layers) == 2
    
    if layers[0] == layers[1]:
        return "same"
    return "diff" 
    
def classify_X(elements, layers):
    number, suffix = "", ""
    if len(elements) == 1:
        number = 3 if elements[0] == "Se" else 1
    else:
        suffix = build_suffix_2_defects(elements, layers)
        if "Se" not in elements:
            number = 2
        elif "X" not in elements:
            number = 4         
        else: 
            number = 5
    return number, suffix
def classify_V_or_S(elements, layers):
    number, suffix = "", ""
    if len(elements) == 1:
        number = 1
    else:
        mo_w_idx = layers.index(0)
        number, suffix = classify_X(elements[:mo_w_idx] + elements[mo_w_idx+1:],
                                    layers[:mo_w_idx] + layers[mo_w_idx+1:])
        number += 1
    return number, suffix

In [21]:
def classify(defect_representation, layers_coords, atol=1e-02):
    elements = [get_element(site) for site in defect_representation.sites]
    layers = [get_layer(site, layers_coords) for site in defect_representation.sites]
    prefix = ""
    if 0 not in layers:
        prefix = "X"
        number, suffix = classify_X(elements, layers)
    elif "W" in elements:
        prefix = "S"
        number, suffix = classify_V_or_S(elements, layers)
    else:
        number,suffix = classify_V_or_S(elements, layers)
        prefix = "V"
    return f"{prefix}{number}_{suffix}" if suffix else f"{prefix}{number}"

In [22]:
classify(data.iloc[19].defect_representation, layers_coords)

'X5_same'

In [23]:
classify(data.iloc[1].defect_representation, layers_coords)

'X2_diff'

In [24]:
data['subgroup'] = data.defect_representation.apply(lambda x: classify(x, layers_coords))

In [25]:
data.head()

Unnamed: 0_level_0,descriptor_id,energy,energy_per_atom,fermi_level,homo,lumo,initial_structure,defect_representation,formation_energy,formation_energy_per_site,band_gap,idx,subgroup
_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
6141cf0efbfd4bd9ab2c2f7e,6141cf0efbfd4bd9ab2c2f7c,-1391.3404,-7.284505,-0.199707,-0.6754,0.4698,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,[[-7.98855051 17.50569919 5.28204642] X0+],2.6457,2.6457,1.1452,0,X1
6141cf0f51c1cbd9654b8870,6141cf0e51c1cbd9654b886e,-1384.5528,-7.28712,-0.220627,-0.6852,0.3991,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,"[[14.34365939 6.45412142 2.15745558] X0+, [9...",5.3063,2.65315,1.0843,1,X2_diff
6141cf0fe689ecc4c43cdd4b,6141cf0fe689ecc4c43cdd49,-1397.1961,-7.277063,-0.183537,-0.6931,1.1102,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,[[ 4.78547342 17.49833154 2.15486663] Se],0.279,0.279,1.8033,2,X3
6141cf10b842c2e72e2f2d44,6141cf10b842c2e72e2f2d42,-1396.2576,-7.272175,-0.179802,-0.6916,1.1179,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,"[[ 9.57094697 20.26122598 2.15486663] Se, [20...",0.5795,0.28975,1.8095,3,X4_diff
6141cf1051c1cbd9654b8872,6141cf0e51c1cbd9654b886e,-1384.5327,-7.287014,-0.21319,-0.6718,0.4384,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,"[[ 7.96302799 17.50569919 2.15745558] X0+, [-...",5.3264,2.6632,1.1102,4,X2_diff


In [26]:
counts = data.subgroup.value_counts()
counts.sort_index()

S1           1
S2          15
S3_diff    379
S3_same    364
S4          15
S5_diff    379
S5_same    364
S6_diff    715
S6_same    700
V1           1
V2          15
V3_diff    379
V3_same    364
V4          15
V5_diff    379
V5_same    364
V6_diff    715
V6_same    700
X1           1
X2_diff     10
X2_same      9
X3           1
X4_diff     10
X4_same      9
X5_diff     15
X5_same     14
Name: subgroup, dtype: int64

In [27]:
correct_groups = ["X1", "X2", "X3","X4","X5",
                  "V1", "V2", "V3", "V4", "V5", "V6", 
                  "S1", "S2", "S3", "S4", "S5", "S6"]
correct_groups_count = [1, 19, 1, 19, 29,
                        1, 15, 743, 15, 743, 1415,
                        1, 15, 743, 15, 743, 1415]
correct_groups_series = pd.Series(correct_groups_count, correct_groups)

In [28]:
data['group'] = data.subgroup.apply(lambda x: x[:2])

In [29]:
pd.testing.assert_series_equal(correct_groups_series.sort_index(),
                               data.group.value_counts().sort_index(),
                               check_names=False)

In [30]:
data.head(10)

Unnamed: 0_level_0,descriptor_id,energy,energy_per_atom,fermi_level,homo,lumo,initial_structure,defect_representation,formation_energy,formation_energy_per_site,band_gap,idx,subgroup,group
_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
6141cf0efbfd4bd9ab2c2f7e,6141cf0efbfd4bd9ab2c2f7c,-1391.3404,-7.284505,-0.199707,-0.6754,0.4698,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,[[-7.98855051 17.50569919 5.28204642] X0+],2.6457,2.6457,1.1452,0,X1,X1
6141cf0f51c1cbd9654b8870,6141cf0e51c1cbd9654b886e,-1384.5528,-7.28712,-0.220627,-0.6852,0.3991,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,"[[14.34365939 6.45412142 2.15745558] X0+, [9...",5.3063,2.65315,1.0843,1,X2_diff,X2
6141cf0fe689ecc4c43cdd4b,6141cf0fe689ecc4c43cdd49,-1397.1961,-7.277063,-0.183537,-0.6931,1.1102,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,[[ 4.78547342 17.49833154 2.15486663] Se],0.279,0.279,1.8033,2,X3,X3
6141cf10b842c2e72e2f2d44,6141cf10b842c2e72e2f2d42,-1396.2576,-7.272175,-0.179802,-0.6916,1.1179,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,"[[ 9.57094697 20.26122598 2.15486663] Se, [20...",0.5795,0.28975,1.8095,3,X4_diff,X4
6141cf1051c1cbd9654b8872,6141cf0e51c1cbd9654b886e,-1384.5327,-7.287014,-0.21319,-0.6718,0.4384,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,"[[ 7.96302799 17.50569919 2.15745558] X0+, [-...",5.3264,2.6632,1.1102,4,X2_diff,X2
6141cf10b842c2e72e2f2d46,6141cf10b842c2e72e2f2d42,-1396.2563,-7.272168,-0.180389,-0.6915,1.1178,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,"[[-6.38063149 14.73543703 2.15486663] Se, [20...",0.5808,0.2904,1.8093,5,X4_same,X4
6141cf11cc0e69a0cf28ab35,6141cf10cc0e69a0cf28ab33,-1390.4044,-7.279604,-0.200633,-0.6706,0.4778,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,"[[ 3.19031561 14.73543703 2.15486663] Se, [4....",2.9437,1.47185,1.1484,6,X5_diff,X5
6141cf11b842c2e72e2f2d48,6141cf10b842c2e72e2f2d42,-1396.2481,-7.272126,-0.179317,-0.691,1.1158,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,"[[14.35642052 0.92096489 5.28463537] Se, [17...",0.589,0.2945,1.8068,7,X4_same,X4
6141cf11ae4fb853db2e3f14,6141cf11ae4fb853db2e3f12,-1380.0584,-7.225437,-0.310205,-0.4902,-0.1302,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,[[12.77402406 12.88613968 3.719751 ] X0+],7.1215,7.1215,0.36,8,V1,V1
6141cf11cc0e69a0cf28ab37,6141cf10cc0e69a0cf28ab33,-1390.4006,-7.279584,-0.200289,-0.6694,0.4781,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,"[[1.59515772 6.44675377 2.15486663] Se, [-6.39...",2.9475,1.47375,1.1475,9,X5_diff,X5


In [27]:
group0 = counts.index[0]
data0 = data[data.group == group0]

In [28]:
group0

'V5'

In [29]:
data0.iloc[1].defect_representation

Structure Summary
Lattice
    abc : 25.5225256 25.5225256 14.879004
 angles : 90.0 90.0 119.99999999999999
 volume : 8393.668021812642
      A : 25.5225256 0.0 1.5628039641098191e-15
      B : -12.761262799999994 22.10315553833868 1.5628039641098191e-15
      C : 0.0 0.0 14.879004
PeriodicSite: X0+ (15.9643, 7.3604, 3.7198) [0.7920, 0.3330, 0.2500]
PeriodicSite: Se (6.3806, 14.7354, 5.2846) [0.5833, 0.6667, 0.3552]
PeriodicSite: Se (20.7371, 6.4468, 5.2846) [0.9583, 0.2917, 0.3552]

In [31]:
j = s0.to_json()

In [32]:
with open("s0.json", "w") as fh:
    fh.write(j)

In [33]:
import nglview



In [34]:
# https://gist.github.com/lan496/3f60b6474750a6fd2b4237e820fbfea4
def plot3d(structure, spacefill=True, show_axes=True):
    from itertools import product
    from pymatgen.core import Structure
    from pymatgen.core.sites import PeriodicSite
    
    eps = 1e-8
    sites = []
    for site in structure:
        species = site.species
        frac_coords = np.remainder(site.frac_coords, 1)
        for jimage in product([0, 1 - eps], repeat=3):
            new_frac_coords = frac_coords + np.array(jimage)
            if np.all(new_frac_coords < 1 + eps):
                new_site = PeriodicSite(species=species, coords=new_frac_coords, lattice=structure.lattice)
                sites.append(new_site)
    structure_display = Structure.from_sites(sites)
    
    view = nglview.show_pymatgen(structure_display)
    view.add_unitcell()
    
    if spacefill:
        view.add_spacefill(radius_type='vdw', radius=0.5, color_scheme='element')
        view.remove_ball_and_stick()
    else:
        view.add_ball_and_stick()
        
    if show_axes:
        view.shape.add_arrow([-4, -4, -4], [0, -4, -4], [1, 0, 0], 0.5, "x-axis")
        view.shape.add_arrow([-4, -4, -4], [-4, 0, -4], [0, 1, 0], 0.5, "y-axis")
        view.shape.add_arrow([-4, -4, -4], [-4, -4, 0], [0, 0, 1], 0.5, "z-axis")
        
    view.camera = "perspective"
    return view

In [35]:
plot3d(s0)

NGLWidget()

In [36]:
d10 = data.iloc[10].defect_representation
s10 = data.iloc[10].initial_structure

In [37]:
s10

Structure Summary
Lattice
    abc : 25.5225256 25.5225256 14.879004
 angles : 90.0 90.0 119.99999999999999
 volume : 8393.668021812642
      A : 25.5225256 0.0 1.5628039641098191e-15
      B : -12.761262799999994 22.10315553833868 1.5628039641098191e-15
      C : 0.0 0.0 14.879004
PeriodicSite: Mo (0.0000, 1.8419, 3.7198) [0.0417, 0.0833, 0.2500]
PeriodicSite: Mo (-1.5952, 4.6048, 3.7198) [0.0417, 0.2083, 0.2500]
PeriodicSite: Mo (-3.1903, 7.3677, 3.7198) [0.0417, 0.3333, 0.2500]
PeriodicSite: Mo (-4.7855, 10.1306, 3.7198) [0.0417, 0.4583, 0.2500]
PeriodicSite: Mo (-6.3806, 12.8935, 3.7198) [0.0417, 0.5833, 0.2500]
PeriodicSite: Mo (-7.9758, 15.6564, 3.7198) [0.0417, 0.7083, 0.2500]
PeriodicSite: Mo (-9.5709, 18.4193, 3.7198) [0.0417, 0.8333, 0.2500]
PeriodicSite: Mo (-11.1661, 21.1822, 3.7198) [0.0417, 0.9583, 0.2500]
PeriodicSite: Mo (3.1903, 1.8419, 3.7198) [0.1667, 0.0833, 0.2500]
PeriodicSite: Mo (1.5952, 4.6048, 3.7198) [0.1667, 0.2083, 0.2500]
PeriodicSite: Mo (0.0000, 7.3677, 3

In [38]:
def get_x_y_dist(structure, size=(8,8)):
    ls = structure.lattice.lengths
    return ls[0] / size[0], ls[1] / size[1]

In [39]:
def sites_equal(s1, s2, atol=1e-2):
    return np.allclose(s1.coords, s2.coords, atol)

def find_site(structure, site, atol=1e-2):
    for i, s in enumerate(structure):
        if sites_equal(s, site, atol):
            return i
    raise ValueError(f"{site} not found in structure with atol={atol}")
    
def safe_find_site(sites, site, atol=1e-2):
    for i, s in enumerate(sites):
        if sites_equal(s, site, atol):
            return True, i
    return False, -1

In [40]:
def get_nn(structure, site, atol=1e-1):
    """
    returns indicies of neighbouring sites in structure for given site.
    """
    layer_coords = build_layers_coords(structure)
    layer = get_layer(site, layer_coords)
    nearest_neighbor_ids = []
    dx, dy = get_x_y_dist(structure)
    for i,s in enumerate(structure):
        if max(np.isclose(site.distance(s), [dx, dy], atol=atol)) == 1 and get_layer(s, layer_coords) == layer:
            nearest_neighbor_ids.append(i)    
    return nearest_neighbor_ids

In [69]:
def save_adj_dict(adj_dict, name):
    with open(f"{name}.json", "w") as json_file:
        json.dump(adj_dict, json_file)
        
def restore_adj_dict(group_name):
    def jsonKeys2int(x):
        if isinstance(x, dict):
            return {int(k):v for k,v in x.items()}
        return x

    adj_dct = {}
    with open(f"{group_name}.json", "r") as json_file:
         adj_dct = json.load(json_file, object_hook=jsonKeys2int)
    return adj_dct

In [42]:
def swap_sites(structure, defect_site, idx):
    defect_species = defect_site.species_string
    site_species = structure[idx].species_string
    if defect_species == site_species:
        return structure.copy(), defect_site
    new_defect_site = PeriodicSite(defect_site.species, 
                                   structure[idx].frac_coords, 
                                   structure[idx].lattice)
    if defect_species != "X0+":
        defect_idx = find_site(structure, defect_site)
        indices_species_map = {idx: defect_species, defect_idx: site_species}
        transformation = ReplaceSiteSpeciesTransformation(indices_species_map)
        return transformation.apply_transformation(structure), new_defect_site
    insert_transf = InsertSitesTransformation([site_species],
                                              [defect_site.frac_coords])
    remove_transf = RemoveSitesTransformation([idx])
    removed_s = remove_transf.apply_transformation(structure)
    return insert_transf.apply_transformation(removed_s), new_defect_site

In [43]:
swap_sites(s10, d10[0], get_nn(s10, d10[0])[3])[1]

PeriodicSite: Se (-4.7855, 17.4983, 2.1549) [0.2083, 0.7917, 0.1448]

In [44]:
d10[0]

PeriodicSite: Se (-3.1903, 14.7354, 2.1549) [0.2083, 0.6667, 0.1448]

In [45]:
s0 = data.iloc[0].initial_structure

In [46]:
d0 = data.iloc[0].defect_representation

In [47]:
def build_full_structure(s0, d0):
    site_species = "S"
    defect_site = d0[0]
    insert_transf = InsertSitesTransformation([site_species],
                                              [defect_site.frac_coords])
    return insert_transf.apply_transformation(s0)

full_structure = build_full_structure(s0, d0)

In [49]:
 data.iloc[2].defect_representation

Structure Summary
Lattice
    abc : 25.5225256 25.5225256 14.879004
 angles : 90.0 90.0 119.99999999999999
 volume : 8393.668021812642
      A : 25.5225256 0.0 1.5628039641098191e-15
      B : -12.761262799999994 22.10315553833868 1.5628039641098191e-15
      C : 0.0 0.0 14.879004
PeriodicSite: Se (4.7855, 17.4983, 2.1549) [0.5833, 0.7917, 0.1448]

In [50]:
d2 = data.iloc[2].defect_representation
d2

Structure Summary
Lattice
    abc : 25.5225256 25.5225256 14.879004
 angles : 90.0 90.0 119.99999999999999
 volume : 8393.668021812642
      A : 25.5225256 0.0 1.5628039641098191e-15
      B : -12.761262799999994 22.10315553833868 1.5628039641098191e-15
      C : 0.0 0.0 14.879004
PeriodicSite: Se (4.7855, 17.4983, 2.1549) [0.5833, 0.7917, 0.1448]

In [51]:
refl_coords = d2[0].frac_coords.copy()
refl_coords[-1] = layers_coords[2]
refl_site = PeriodicSite(d2[0].species, refl_coords, d2[0].lattice)
refl_site

PeriodicSite: Se (4.7855, 17.4983, 5.2846) [0.5833, 0.7917, 0.3552]

In [52]:
d2

Structure Summary
Lattice
    abc : 25.5225256 25.5225256 14.879004
 angles : 90.0 90.0 119.99999999999999
 volume : 8393.668021812642
      A : 25.5225256 0.0 1.5628039641098191e-15
      B : -12.761262799999994 22.10315553833868 1.5628039641098191e-15
      C : 0.0 0.0 14.879004
PeriodicSite: Se (4.7855, 17.4983, 2.1549) [0.5833, 0.7917, 0.1448]

In [51]:
checker.are_symmetrically_equivalent([d2[0]], [refl_site])


NameError: name 'checker' is not defined

In [None]:
finder = SpacegroupAnalyzer(cur.initial_structure)
print("The spacegroup is {}".format(finder.get_space_group_symbol()))

In [49]:
finder = SpacegroupAnalyzer(full_structure, symprec=1e-1)
checker = finder.get_space_group_operations()

In [None]:
 %time checker.are_symmetrically_equivalent(s0, swap_sites(s0, d0[0], get_nn(s0, d0[0])[0]))

In [None]:
%time checker.are_symmetrically_equivalent(s0, s0)

In [43]:
%time checker.are_symmetrically_equivalent(d0, d10)

CPU times: user 20.2 ms, sys: 4.05 ms, total: 24.3 ms
Wall time: 19.4 ms


False

In [None]:
for i, row in data0.iterrows():
    if checker.are_symmetrically_equivalent(cur.defect_representation.sites,
                                            row.defect_representation.sites):
        print(i)

In [None]:
s0[70].species_string

In [None]:
plot3d(s0)

In [None]:
plot3d(swap_sites(s0, d0[0], get_nn(s0, d0[0])[4]))

In [50]:
def generate_neighbouring_structures(structure, defects):
    """
    Only one swap.
    """
    result_structures = []
    defect_reprs = []
    for i, defect_site in enumerate(defects):
        nns = get_nn(structure, defect_site)
        for n_idx in nns:
            new_structure, new_defect_site = swap_sites(structure, 
                                                        defect_site, 
                                                        n_idx)
            new_defect_repr = defects.copy()
            new_defect_repr[i] = new_defect_site
            
            # check if swapped with another defect
            found, idx = safe_find_site(defects, new_defect_site)
            if found:
                swapped_defect_site = PeriodicSite(defects[idx].species, 
                                   defect_site.frac_coords, 
                                   defect_site.lattice)
                new_defect_repr[idx] = swapped_defect_site
            # check if got something new
            is_new = True
            for reprn in defect_reprs:
                if checker.are_symmetrically_equivalent(reprn, 
                                                        new_defect_repr, 
                                                        symm_prec=0.01):
                    is_new = False
                    break
            if is_new or len(defect_reprs) == 0:
                result_structures.append(new_structure)
                defect_reprs.append(new_defect_repr)
    return result_structures, defect_reprs

In [51]:
def get_energy_idx(defect_representation, group_df, checker):
    for i, row in group_df.iterrows():
        if checker.are_symmetrically_equivalent(defect_representation,
                                                row.defect_representation,
                                                symm_prec=0.01):
            return row.energy_per_atom, row.idx
    raise ValueError(f"Not found symmetrically equivalent structure for {defect_representation}")

In [97]:
def get_energy_idx_parallel(defect_representation, group_df, checker):
    def get_item(i, row):
        if checker.are_symmetrically_equivalent(defect_representation,
                                                row.defect_representation,
                                                symm_prec=0.01):
            return row.energy_per_atom, row.idx
        return None
    result = Parallel(n_jobs=-3)(
        delayed(get_item)(i, row) for i, row in group_df.iterrows()
    )
    for entry in result:
        if entry is not None:
            return entry
        
    raise ValueError(f"Not found symmetrically equivalent structure for {defect_representation}")

In [None]:
len(generate_neighbouring_structures(s0, d0))

In [101]:
group10 = data.iloc[10].group

In [103]:
ns24 = generate_neighbouring_structures(s10, d10)

In [115]:
data[data.group=='V5']

Unnamed: 0_level_0,descriptor_id,energy,energy_per_atom,fermi_level,homo,lumo,initial_structure,defect_representation,formation_energy,formation_energy_per_site,band_gap,idx,subgroup,group
_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
6141cf14ee0a3fd43fb479d7,6141cf13ee0a3fd43fb479d5,-1378.2011,-7.215713,-0.315632,-0.5232,-0.1069,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,"[[3.20307696 7.36035079 3.719751 ] X0+, [23.9...",7.7028,2.567600,0.4163,24,V5_diff,V5
6141cf15ee0a3fd43fb479d9,6141cf13ee0a3fd43fb479d5,-1378.2003,-7.215708,-0.314099,-0.5316,-0.0991,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,"[[15.96433976 7.36035079 3.719751 ] X0+, [ ...",7.7036,2.567867,0.4325,27,V5_same,V5
6141cf15ee0a3fd43fb479db,6141cf13ee0a3fd43fb479d5,-1378.1731,-7.215566,-0.295020,-0.4712,-0.1161,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,"[[ 1.60791911 21.17482301 3.719751 ] X0+, [1...",7.7308,2.576933,0.3551,35,V5_same,V5
6141cf16ee0a3fd43fb479dd,6141cf13ee0a3fd43fb479d5,-1378.1841,-7.215624,-0.301147,-0.4767,-0.1264,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,"[[15.96433976 12.88613968 3.719751 ] X0+, [1...",7.7198,2.573267,0.3503,41,V5_diff,V5
6141cf17ee0a3fd43fb479df,6141cf13ee0a3fd43fb479d5,-1378.1780,-7.215592,-0.297514,-0.4760,-0.1182,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,"[[7.98855051 4.59745635 3.719751 ] X0+, [14.3...",7.7259,2.575300,0.3578,49,V5_same,V5
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
61437dc1ee0a3fd43fb47f9b,6141cf13ee0a3fd43fb479d5,-1378.1807,-7.215606,-0.302052,-0.4813,-0.1231,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,"[[0.01276126 7.36035079 3.719751 ] X0+, [17.5...",7.7232,2.574400,0.3582,5583,V5_diff,V5
61437e12ee0a3fd43fb47f9d,6141cf13ee0a3fd43fb479d5,-1378.1831,-7.215618,-0.302585,-0.4783,-0.1269,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,"[[ 4.79823481 15.64903412 3.719751 ] X0+, [1...",7.7208,2.573600,0.3514,5584,V5_same,V5
61439247ee0a3fd43fb47f9f,6141cf13ee0a3fd43fb479d5,-1378.1778,-7.215591,-0.298805,-0.4787,-0.1187,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,"[[-4.77271229 15.64903412 3.719751 ] X0+, [4...",7.7261,2.575367,0.3600,5611,V5_same,V5
61442d83ee0a3fd43fb47fa1,6141cf13ee0a3fd43fb479d5,-1378.1719,-7.215560,-0.296089,-0.4765,-0.1157,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,"[[-7.96302799 15.64903412 3.719751 ] X0+, [4...",7.7320,2.577333,0.3608,5746,V5_same,V5


In [None]:
get_energy_idx(ns10[1][10], data[data.group==group10], checker)

In [108]:
%time
get_energy_idx(ns10[1][1], data[data.group==group10], checker)

CPU times: user 3 µs, sys: 1 µs, total: 4 µs
Wall time: 9.54 µs


(-7.272177604166667, 231)

In [114]:
len(data[data.group==group10])

19

In [112]:
%time
get_energy_idx_parallel(ns10[1][1], data[data.group=="group10"], checker)

CPU times: user 4 µs, sys: 2 µs, total: 6 µs
Wall time: 13.1 µs




(-7.272177604166667, 231)

In [111]:
get_energy_idx_parallel(generate_neighbouring_structures(s10, d10), data[data.group==group10], checker)



AttributeError: 'list' object has no attribute 'is_periodic_image'

In [None]:
get_energy(data[data.group==group10].iloc[2].defect_representation, data[data.group==group10], checker)

In [None]:
data[data.group==group10].iloc[2].energy_per_atom

In [None]:
data[data.group==group10]

In [116]:
def generate_group_adj_dict(group_df, checker, thr=0.0):
    def get_item(i, row):
        adj_dict = {}
        cur_energy = row.energy_per_atom
        ns, defect_reprs = generate_neighbouring_structures(
            row.initial_structure, row.defect_representation
        )
        adj_dict[row.idx] = []
        for defect_repr in defect_reprs:
            energy, idx = get_energy_idx(defect_repr, group_df, checker)
            if energy - cur_energy <= thr:
                adj_dict[row.idx].append(idx)
        return adj_dict
            
    result = Parallel(n_jobs=-3)(
        delayed(get_item)(i, row) for i, row in group_df.iterrows()
    )
    adj_dict = {}
    for entry in result:
        adj_dict |= entry
#     for i, row in group_df.iterrows():
#         cur_energy = row.energy_per_atom
#         ns, defect_reprs = generate_neighbouring_structures(
#             row.initial_structure, row.defect_representation
#         )
#         adj_dict[row.idx] = []
#         for defect_repr in defect_reprs:
#             energy, idx = get_energy_idx(defect_repr, group_df, checker)
#             if energy - cur_energy <= thr:
#                 adj_dict[row.idx].append(idx)
    return adj_dict

In [None]:
ad = generate_group_adj_dict(data[data.group == 'V3'], checker)
ad

In [50]:
%time generate_group_adj_dict(data[data.group==group10], checker)

CPU times: user 1min 9s, sys: 9.98 s, total: 1min 19s
Wall time: 1min 7s


{3: [10, 231],
 5: [113, 15, 23],
 7: [48, 123],
 10: [],
 15: [113, 71],
 23: [113, 133],
 32: [3, 238],
 38: [32, 58],
 48: [5, 123, 23],
 58: [3, 231],
 71: [],
 113: [71],
 123: [5],
 133: [],
 173: [10],
 211: [238, 32],
 231: [10],
 238: [3],
 635: [173]}

In [57]:
data[data.idx==71].energy_per_atom

_id
6141cf18b842c2e72e2f2d58   -7.272178
Name: energy_per_atom, dtype: float64

In [58]:
data[data.idx==133].energy_per_atom

_id
6141cf1fb842c2e72e2f2d5e   -7.272173
Name: energy_per_atom, dtype: float64

In [59]:
data[data.idx==10].energy_per_atom

_id
6141cf12b842c2e72e2f2d4a   -7.272183
Name: energy_per_atom, dtype: float64

In [51]:
len(data[data.group==group10])

19

In [76]:
adj_dict10 = generate_subgroup_adj_dict(data[data.group==group10], checker)

In [53]:
def check_reachable(adj_dict, min_idx):
    is_reachable = True
    digraph = nx.from_dict_of_lists(adj_dict, create_using=nx.DiGraph)
    for node in digraph:
        if not nx.has_path(digraph, node, min_idx):
            print(f"No path from {node} to {min_idx}!")
            is_reachable = False
    return is_reachable

In [77]:
check_reachable(adj_dict10, 10)

True

In [54]:
def check_all_groups(groups, data, checker):
    all_adjs = {}
    all_reachability = {}
    for group in tqdm(groups):
        group_df = data[data.group == group]
        all_adjs[group] = generate_group_adj_dict(group_df, checker)
        min_idx = group_df.iloc[group_df.energy_per_atom.argmin()].idx
        all_reachability[group] = check_reachable(all_adjs[group], 
                                                     min_idx)
    return all_adjs, all_reachability

In [55]:
def check_group(group_df, checker):
    adj_dict = generate_group_adj_dict(group_df, checker)
    min_idx = group_df.iloc[group_df.energy_per_atom.argmin()].idx
    check_reachable(adj_dict, min_idx)

In [99]:
ad = generate_group_adj_dict(data[data.group == 'X2'], checker)



In [100]:
ad

{1: [],
 4: [4, 292, 208, 453],
 17: [230, 137, 165],
 45: [672, 769, 230, 17],
 137: [137, 230],
 165: [137],
 208: [208, 1, 582],
 230: [230, 769],
 292: [208],
 370: [1, 582, 597, 370],
 398: [4, 292, 398],
 453: [208, 582],
 500: [4, 453],
 551: [551, 137],
 566: [230, 137, 551],
 582: [1],
 597: [],
 672: [672],
 769: [672]}

In [96]:
tmp = {}
def two_args(a, b):
    tmp[a] = b
    return tmp
    
Parallel(n_jobs=-3)(
        delayed(two_args)(a=i, b=row) for i, row in zip([1,2,3], ("a","b","c"))
    )

[{1: 'a'}, {2: 'b'}, {3: 'c'}]

In [87]:
print(*((i, row) for i, row in group_df.iterrows()))

('6141cf11cc0e69a0cf28ab35', descriptor_id                                         6141cf10cc0e69a0cf28ab33
energy                                                              -1390.4044
energy_per_atom                                                      -7.279604
fermi_level                                                          -0.200633
homo                                                                   -0.6706
lumo                                                                    0.4778
initial_structure            [[1.27612629e-07 1.84192955e+00 3.71975100e+00...
defect_representation        [[ 3.19031561 14.73543703  2.15486663] Se, [4....
formation_energy                                                        2.9437
formation_energy_per_site                                              1.47185
band_gap                                                                1.1484
idx                                                                          6
subgroup               

Name: 6141cf1dcc0e69a0cf28ab4d, dtype: object) ('6141cf20cc0e69a0cf28ab4f', descriptor_id                                         6141cf10cc0e69a0cf28ab33
energy                                                              -1390.4106
energy_per_atom                                                      -7.279637
fermi_level                                                          -0.196105
homo                                                                    -0.675
lumo                                                                     0.472
initial_structure            [[1.27612629e-07 1.84192955e+00 3.71975100e+00...
defect_representation        [[6.38063136 3.68385933 5.28463537] Se, [3.177...
formation_energy                                                        2.9375
formation_energy_per_site                                              1.46875
band_gap                                                                 1.147
idx                                                    

Name: 6141cf90cc0e69a0cf28ab6b, dtype: object) ('6141d130cc0e69a0cf28ab6d', descriptor_id                                         6141cf10cc0e69a0cf28ab33
energy                                                              -1390.3134
energy_per_atom                                                      -7.279128
fermi_level                                                           -0.19916
homo                                                                   -0.6595
lumo                                                                    0.4725
initial_structure            [[1.27612629e-07 1.84192955e+00 3.71975100e+00...
defect_representation        [[1.59515772 6.44675377 2.15486663] Se, [1.582...
formation_energy                                                        3.0347
formation_energy_per_site                                              1.51735
band_gap                                                                 1.132
idx                                                    

In [64]:
save_adj_dict(ad, 'X2')

In [70]:
new_ad = restore_adj_dict('X2')

In [71]:
new_ad

{1: [],
 4: [4, 292, 208, 453],
 17: [230, 137, 165],
 45: [672, 769, 230, 17],
 137: [137, 230],
 165: [137],
 208: [208, 1, 582],
 230: [230, 769],
 292: [208],
 370: [1, 582, 597, 370],
 398: [4, 292, 398],
 453: [208, 582],
 500: [4, 453],
 551: [551, 137],
 566: [230, 137, 551],
 582: [1],
 597: [],
 672: [672],
 769: [672]}

In [80]:
group = 'X5'
group_df = data[data.group == group]
# check_group(group_df, checker)

In [76]:
group_df = data[(data.subgroup == 'X5_X_bottom_Se_bottom') | (data.subgroup == 'X5_X_top_Se_top')]
check_group(group_df, checker)

No path from 161 to 139!
No path from 322 to 139!


In [92]:
adj_dict

{12: [97, 153, 97, 153],
 14: [263, 97],
 20: [246, 121, 121, 246],
 30: [299, 180, 299, 180],
 40: [180, 180, 66, 66],
 54: [12, 14, 97],
 66: [180],
 97: [66],
 121: [153, 246],
 153: [97, 66, 246],
 180: [299],
 246: [40, 66],
 263: [180, 66, 30, 97],
 299: [],
 420: [54, 54]}

In [80]:
%time
all_adjs, all_reachability = check_all_subgroups(data.group.unique(), data, 
                                                 checker)

CPU times: user 2 µs, sys: 0 ns, total: 2 µs
Wall time: 6.44 µs


  2%|█                                           | 1/42 [00:03<02:17,  3.36s/it]


ValueError: Not found symmetrically equivalent structure!

In [61]:
data[data.group==group10].energy_per_atom.argmin()

1

In [62]:
data[data.group==group10].iloc[1]

descriptor_id                                         6141cf10b842c2e72e2f2d42
energy                                                              -1396.2591
energy_per_atom                                                      -7.272183
fermi_level                                                          -0.177558
homo                                                                   -0.6911
lumo                                                                    1.1164
initial_structure            [[1.27612629e-07 1.84192955e+00 3.71975100e+00...
defect_representation        [[-3.19031579 14.73543703  2.15486663] Se, [17...
formation_energy                                                         0.578
formation_energy_per_site                                                0.289
band_gap                                                                1.8075
idx                                                                         10
group                                               

In [None]:
get_nn(s10, d10[0])[0]

In [None]:
di = data.iloc[0].descriptor_id
data.loc[di]

In [77]:
data.group.unique()

array(['X1', 'X2', 'X3', 'X4', 'X5', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6',
       'S1', 'S2', 'S3', 'S4', 'S5', 'S6'], dtype=object)