## Import part

In [1]:
import pickle
import networkx.algorithms.isomorphism as iso
import networkx as nx
from tqdm.notebook import trange

import numpy as np
import pymatgen.core as core

import os
import re
from pymatgen.analysis.structure_matcher import StructureMatcher
from pymatgen.io.ase import AseAtomsAdaptor
from ase.visualize.plot import plot_atoms
from pymatgen.core.periodic_table import DummySpecies, Element, Species, get_el_sp
import matplotlib.pyplot as plt
from crystals.datastructures import VariableStructure
from copy import deepcopy

from crystals.datastructures import State
from crystals.matchers import GraphIsomorphismMatcher
from crystals.selectors import DefectsNoChangeSelector, DefectLocalitySelector
from crystals.graphs import GraphManager
import pickle
from tqdm.autonotebook import tqdm
import networkx.algorithms.isomorphism as iso
from networkx import similarity
import networkx as nx
from crystals import datastructures as ds, datasets, fp
import copy


sm = StructureMatcher(primitive_cell=False)


  from tqdm.autonotebook import tqdm


In [None]:
# %load_ext autoreload
# %autoreload 0

In [2]:
import matplotlib.pylab as plt
from pymatgen.core.periodic_table import DummySpecies
from pymatgen.io.ase import AseAtomsAdaptor
from ase.visualize.plot import plot_atoms
from pymatgen.core.structure import Structure
from ase.utils.structure_comparator import SymmetryEquivalenceCheck
from pymatgen.util import coord
comp = SymmetryEquivalenceCheck()

In [3]:
def switch_group(group_name, dataset):
    dataset.processed_data = None
    dataset.filter_group(group_name)
    return dataset.prepare()  # return np.array of States
    
def load_data(group_name=None):
    dataset = datasets.MoS2Dataset(group_name)
    if group_name:
        switch_group(group_name, dataset)
    return dataset

In [4]:
dataset = load_data()

In [5]:
switch_group('X5_diff', dataset)
len(dataset)

15

In [7]:
two_def_groups = ['S2', 'S4', 'V2', 'V4', 'X2_diff', 'X2_same', 'X4_diff', 'X4_same', 'X5_diff', 'X5_same']
three_def_groups = ['S3_diff', 'S3_same', 'S5_diff', 'S5_same', 'S6_diff', 'S6_same',
                    'V3_diff', 'V3_same', 'V5_diff', 'V5_same', 'V6_diff', 'V6_same',]
# check all items of two_def_groups are in dataset.groups.keys()
assert all([x in dataset.groups.keys() for x in two_def_groups])

In [8]:
from scipy.spatial import transform
def perturb_shift(state: State, shift=None):
    new_state = deepcopy(state)
    if shift is None:
        shift = np.random.randint(0, 8, 3) * 0.125
        shift[2] = 0
    new_state.initial_structure.translate_sites(range(len(new_state.initial_structure)), shift)

    for d in new_state.defects:
        fc = np.mod(d.frac_coords + shift, 1)
        d.frac_coords = fc
    return new_state

def perturb_rotate(state: State, theta=None):
    new_state = deepcopy(state)
    if theta is None:
        theta = (np.random.randint(2) + 1) * 2 * np.pi / 3
    new_state.initial_structure = ds.rotate_structure(new_state.initial_structure, theta)
    new_state.defects = ds.rotate_sites(new_state.defects, theta)
    return new_state

def pertrube_state(state: State, count: int=10, kind: str=None):
    new_state = state
    for _ in range(count):
        if kind is None:
            kind = np.random.choice(['shift', 'rotate'])
        if kind == 'shift':
            new_state = (perturb_shift(new_state))
        elif kind == 'rotate':
            new_state = (perturb_rotate(new_state))
        else:
            assert False, 'Unknown pertrubation kind'
    return new_state

def shuffle_defects(state: State):
    # state.initial_structure.sort(key=lambda x: np.random.random())
    state.defects.sort(key=lambda x: np.random.random())
    return state


def generate_true_false_pairs(data, false_count=None, true_ratio=1.0, verbose=False):
    true_pairs = []
    false_pairs = []
    if false_count is None:
        false_count = len(data) * (len(data) - 1) / 2
    true_count = int(true_ratio * false_count)

    pert_count = max([1, int(true_count / len(data))])
    if verbose:
        print(">>>", true_count, false_count, pert_count )
    stop = 0
    pbar = tqdm(total=false_count, disable=not verbose)
    for i in range(len(data)):
        for j in range(i, len(data)):
            if i == j:
                for _ in range(pert_count):
                    true_pairs.append([data[i], shuffle_defects(pertrube_state(data[i]))])
            else:
                false_pairs.append([data[i], data[j]])
                pbar.update(1)
                if len(false_pairs) >= false_count:
                    stop = 1
                    break
        if stop:
            break

    if len(true_pairs) < true_count:
        for _ in range(true_count - len(true_pairs)):
            i = np.random.randint(len(data))
            true_pairs.append([data[i], pertrube_state(data[i])])
    return true_pairs, false_pairs

def generate_true_pairs(data, true_count, pert=10, verbose=False, seed=None):
    true_pairs = []
    if seed is not None:
        np.random.seed(seed)
    pert_per_state = int(np.round(true_count / len(data), 0))
    assert pert_per_state > 1, f"pert_per_state must be > 1 {(pert_per_state)}"
    if verbose:
        print(f">>> {true_count} ~ {pert_per_state} perturbations per state")
    pbar = tqdm(total=true_count, disable=not verbose)
    for _ in range(true_count):
        state = data[np.random.randint(len(data))]
        for _ in range(pert_per_state):
            shuffled_state = shuffle_defects(pertrube_state(state, count=pert))
            assert compare_sm(state.initial_structure, shuffled_state.initial_structure), "What???"
            true_pairs.append([state, shuffled_state])
            pbar.update(1)
            if pbar.n >= true_count:
                return true_pairs
    assert False, "how did I get here?"

In [9]:

smatcher = StructureMatcher(primitive_cell=False)
def compare_sm(_sA, _sB):
    return smatcher.fit(_sA, _sB)

In [10]:
def merge_pairs(true_pairs, false_pairs):
    pairs = []
    targets = np.zeros(len(true_pairs) + len(false_pairs))
    targets[:len(true_pairs)] = 1
    pairs = deepcopy(true_pairs)
    pairs.extend(false_pairs)
    return pairs, targets

In [11]:
def get_the_node(struct: Structure):

    # compare node Z
    # compare node z coord distribution (single one - the one)
    node_Zs = dict([(i, s.specie.Z) for i, s in enumerate(struct.sites)])
    sorted_node_Zs = sorted(node_Zs.items(), key=lambda x: x[1])
    max_Z = sorted_node_Zs[-1][1]
    if sorted_node_Zs[-2][1] != max_Z:
        return sorted_node_Zs[-1][0]
    min_Z = sorted_node_Zs[0][1]
    if sorted_node_Zs[1][1] != min_Z:
        return sorted_node_Zs[0][0]

    node_zs = dict([(i, s.frac_coords[2]) for i, s in enumerate(struct.sites)])
    sorted_node_zs = sorted(node_zs.items(), key=lambda x: x[1])
    min_z = sorted_node_zs[0][1]
    if sorted_node_zs[1][1] != min_z:
        return sorted_node_zs[0][0]
    max_z = sorted_node_zs[-1][1]
    if sorted_node_zs[-2][1] != max_z:
        return sorted_node_zs[-1][0]
    if len(struct.sites) == 2:  # pick a random one
        return 0
    return None    

def align_struc(struct: Structure, idx_center: int):

    node = struct.sites[idx_center]
    # print(">>> node before shift", node)
    struct = ds.center_at_site(struct, idx_center)
    node = struct.sites[idx_center]  # refresh after rotation
    # print(">>> node after shift", node)

    max_distance = 0
    node2 = None
    image2 = None
    idx2 = None
    for i, s in enumerate(struct.sites):
        d, image = node.distance_and_image(s)
        if d > max_distance:
            max_distance = d
            node2 = s
            image2 = image
            idx2 = i
    assert node2 is not None, "Cannot find the other one"
    # print(">>> node2", node2, image2, idx2, max_distance)

    dx_max = 0
    angle_max = 0
    angle = 2 * np.pi / 3
    struct_tmp = deepcopy(struct)
    node = struct_tmp.sites[idx_center]
    node2 = struct_tmp.sites[idx2]
    for i in range(0,3):
        dx = node2.frac_coords[0] - node.frac_coords[0]
        dy = node2.frac_coords[1] - node.frac_coords[1]
        if dx > dx_max:
            dx_max = dx
            angle_max = angle * i
        # print(f">>>==: {angle * i:.2f}, dx {dx:.2f}, dy {dy:.2f}, dx_max {dx_max:.2f}, angle_max {angle_max:.2f}" )
        ds.rotate_sites_inplace(struct_tmp.sites, angle, anchor=node.coords)
    ds.rotate_structure(struct, angle_max, anchor=node.coords)
    return struct

def normalize_struct_asymmetry(struct: Structure) -> Structure:
    center_idx = get_the_node(struct)
    struct = align_struc(struct, center_idx)
    return struct

In [17]:
s = dataset[0].initial_structure


In [18]:
from pymatgen.io.ase import AseAtomsAdaptor

a = AseAtomsAdaptor.get_atoms(s)
# save to file
from ase.io import write
write('test.xyz', a)


## Generate true pairs dataset

In [12]:
from collections import OrderedDict
group_pair_lens = OrderedDict()
pairs_same_group = []
true_pairs_dict = {}

#check if variable is defined
try:
    pair_groups
except NameError:
    pair_groups = {}

# export_groups = ['S3_same',]
export_groups = three_def_groups

for group in tqdm(export_groups):
    filename = f'data/data_pairs_{group}_true.pkl'
    if os.path.exists(filename):
        print("File exists: ", filename)
        continue
    pair_groups[group] = switch_group(group, dataset)
    print(f'group {group} has {len(dataset)} items')
    true_count = len(dataset) * 5
    true_pairs = generate_true_pairs(pair_groups[group], true_count=true_count, pert=10, verbose=True, seed=0)
    true_pairs_dict[group] = true_pairs
    pairs_same_group.extend(true_pairs)
    group_pair_lens[group] = len(true_pairs)
    with open(filename, 'wb') as f:
        pickle.dump(true_pairs, f)
    print("Saved to ", filename)
print(f"Generated {len(pairs_same_group)} pairs")

  0%|          | 0/12 [00:00<?, ?it/s]

File exists:  data/data_pairs_S3_diff_true.pkl
File exists:  data/data_pairs_S3_same_true.pkl
File exists:  data/data_pairs_S5_diff_true.pkl
File exists:  data/data_pairs_S5_same_true.pkl
File exists:  data/data_pairs_S6_diff_true.pkl
File exists:  data/data_pairs_S6_same_true.pkl
File exists:  data/data_pairs_V3_diff_true.pkl
File exists:  data/data_pairs_V3_same_true.pkl
File exists:  data/data_pairs_V5_diff_true.pkl
File exists:  data/data_pairs_V5_same_true.pkl
File exists:  data/data_pairs_V6_diff_true.pkl
File exists:  data/data_pairs_V6_same_true.pkl
Generated 0 pairs


## Check different matchers/FPs

In [None]:
pairs = copy.copy(pairs_same_group)
targets = pairs_targets

In [None]:
len(pairs), len(targets)

In [None]:
class ThreeDefectsFP(fp.TwoDefectsFP):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)


    def get_key(self, struct: Structure) -> tuple:
        assert len(struct) == 3, "Only works for three defects"
        # struct = ds.normalize_struct(struct)
        result = []
        sum_images = np.zeros(3)
        for i in range(3):
            site1 = struct.sites[i]
            site2 = struct.sites[(i + 1) % 3]
            start, end, is_directed = get_directed_edge(site1, site2)
            dZ = end.specie.Z - start.specie.Z
            dz = (end.coords[2] - start.coords[2]) * np.sign(dZ)
            proj = fp.rotate_vector_x_proj(struct.lattice, start.coords, end.coords, n_turns=self.n_turns)
            # _, image = struct.lattice.get_distance_and_image(start.frac_coords, end.frac_coords)
            # sum_images += image 
            # print(">>", image)
            # image = np.dot(image, np.array([4,16,16*4]))
            if not is_directed:
                proj = np.abs(proj)
            result.append(np.array([np.min(proj), np.max(proj), dz]))
            # result.append(np.array([np.min(proj), np.max(proj)]))

        result.sort(key=lambda x: x[0]*100 + x[1])
        span = np.pad(np.array(sum_images[0:2]).reshape(1, 2), [(0, 2), (0, 0)])
        print(">>> span", span[0, :])
        arr_result = np.concatenate([np.array(result) , span], axis=1)
        # arr_result = np.concatenate(np.array(result), np.ones(3) * sum_images)
        return arr_result

    def fp_distance(self, struct1: Structure, struct2: Structure) -> float:
            proj1 = self.get_key(struct1)
            proj2 = self.get_key(struct2)
            return np.linalg.norm(proj1 - proj2)

class DescriptorFP(fp.FingerprintAlgorithm):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def match(self, state1: State, state2: State, **kwargs) -> bool:
        return state1.idx == state2.idx



In [None]:
idx = 504
_s1_d = pairs[idx][0].defect_structure()
_s2_d = pairs[idx][1].defect_structure()
print(_s1_d.sites, "\n", _s2_d.sites)
matcherThreeDefects = ThreeDefectsFP(selector)
matcherThreeDefects.get_key(normalize_struct_asymmetry(_s1_d)), matcherThreeDefects.get_key(normalize_struct_asymmetry(_s2_d))

In [None]:
get_the_node(_s1_d), get_the_node(_s2_d)

In [None]:
_s1_d.sites, _s1_d.distance_matrix, _s2_d.sites, _s2_d.distance_matrix

In [None]:
true_indices = np.where(pairs_targets == 1)[0]
false_indices = np.where(pairs_targets == 0)[0]



In [None]:
def distances(struct):
    dm = struct.distance_matrix
    return np.sort(np.array([dm[0, 1], dm[0, 2], dm[1, 2]]))

for j in true_indices[:9]:
    i = true_indices[j]
    if i > len(pairs):
        print("break")
        break
    _s1_d = pairs[i][0].defect_structure()
    _s2_d = pairs[i][1].defect_structure()
    print(_s1_d.sites, "\n", _s2_d.sites)
    _s1_d = normalize_struct_asymmetry(_s1_d)
    _s2_d = normalize_struct_asymmetry(_s2_d)
    k1, k2 = (matcherThreeDefects.get_key(_s1_d), matcherThreeDefects.get_key(_s2_d))
    d1 = distances(_s1_d)
    d2 = distances(_s2_d)
    span1 = k1[0, 3:]
    span2 = k2[0, 3:]
    span1_sum = np.sum(np.abs(span1))
    span2_sum = np.sum(np.abs(span2))
    same_dist = np.allclose(d1, d2)
    if span1_sum != span2_sum:
        print(i, "...", span1, span2, same_dist, targets[i], "\n", k1, "\n", k2)

In [None]:
class TwoDefectsFP(fp.FingerprintAlgorithm):
    default_threshold: float = 1e-2

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.n_turns = 3

    def _build_key(self, proj, is_directed: bool, unify: bool = False):
        key = np.round(proj, 5)
        if not is_directed:
            key = np.abs(key)
        if unify:
            argmax_idx = np.argmax(key)
            pmax = key[argmax_idx]
            if key[(argmax_idx + 2) % 3] == pmax:
                argmax_idx = (argmax_idx + 2) % 3
            key = np.roll(key, -argmax_idx)
        return key

    def get_key(self, struct: Structure, unify: bool = True) -> tuple:
        assert len(struct) == 2, "Only works for two defects"
        start, end, is_directed = get_directed_edge(struct.sites[0], struct.sites[1])
        proj = rotate_vector_x_proj(struct.lattice, start.coords, end.coords, n_turns=self.n_turns)
        key = self._build_key(proj, is_directed, unify)
        return key
    
    def match(self, state1: State, state2: State, **kwargs) -> bool:
        struct1 = state1.defect_structure()
        struct2 = state2.defect_structure()
        assert len(struct1) == 2, "Only works for two defects"
        assert len(struct2) == 2, "Only works for two defects"
        if np.abs(struct1[0].distance(struct1[1]) - struct2[0].distance(struct2[1])) > self.default_threshold:
            return False
        key1 = self.get_key(struct1)
        key2 = self.get_key(struct2)
        return np.linalg.norm(key1 - key2) < self.default_threshold
    
    def get_phase(self, key1, key2) -> int:
        for i in range(3):
            if np.allclose(key1, key2):
                return i
            key2 = np.roll(key2, -1)
        return None

        
def get_directed_edge(site0, site1):
    delta_z = np.round(site0.z - site1.z, 5)
    if site0.specie.Z > site1.specie.Z:
        return site0, site1, True
    elif site0.specie.Z < site1.specie.Z:
        return site1, site0, True
    elif delta_z < 0: 
        return site0, site1, True
    elif delta_z > 0:
        return site1, site0, True
    return site0, site1, False

In [None]:
class DistancePlusMatcher(matchers.CrystalMatcher):
    def __init__(self, selector, *args, **kwargs):
        super().__init__(selector, *args, **kwargs)
        self.selector = selector
        self.fp = TwoDefectsFP(selector)
        self.iso = matchers.GraphIsomorphismMatcher(selector)


    def match(self, state1: State, state2: State, **kwargs) -> bool:
        if not self.iso.match(state1, state2):
            return False
        return self.fp.match(state1, state2)

In [None]:
from sklearn import metrics
from crystals import matchers, fp, selectors, structure

sm = structure.StructureManager(dataset)
selector = selectors.DefectsNoChangeSelector(sm)


matchers_dict = dict(
    matcher2def = TwoDefectsFP(selector),
    matcherDistancePlus = DistancePlusMatcher(selector),
    # matcher3def = ThreeDefectsFP(selector),
    # matcherDescriptor = DescriptorFP(selector),
    # matcherSpaceGroup = fp.SpaceGroupFP(selector),
    matcherIsomLocality = matchers.GraphIsomorphismMatcher(selector)
    )

In [None]:
# check GI with locality
# selector = selectors.DefectLocalitySelector(sm)

# gmA = GraphManager(dataset)
# graphA = dataset.graphs

# matchers_dict = dict(
#     matcherIsomLocality = matchers.GraphIsomorphismMatcher(selector)
#     )

In [None]:

preds = dict()
metric = dict()
denergies = np.zeros(len(pairs))
for key, matcher in matchers_dict.items():
    preds[key] = np.zeros(len(pairs))
    pred = preds[key]
    for i in trange(len(pairs), desc=key):
        pred[i] = matcher.match(pairs[i][0], pairs[i][1])
        denergies[i] = np.abs(pairs[i][0].energy - pairs[i][1].energy)
    fpr, tpr, thresholds = metrics.roc_curve(targets, pred)
    accuracy = metrics.accuracy_score(targets, pred)
    precision = metrics.precision_score(targets, pred)
    recall = metrics.recall_score(targets, pred)
    f1 = metrics.f1_score(targets, pred)
    metric[key] = dict(fpr=fpr, tpr=tpr, thresholds=thresholds, accuracy=accuracy, precision=precision, recall=recall, f1=f1)


In [None]:
for key, m in metric.items():
    if key == 'matcherIsom':
        continue
    fpr, tpr, thresholds = m['fpr'], m['tpr'], m['thresholds']
    accuracy = m['accuracy']
    precision = m['precision']
    recall = m['recall']
    f1 = m['f1']
    print(f'{key}: Accuracy: {accuracy:.3f}, Precision: {precision:.3f}, Recall: {recall:.3f}, F1: {f1:.3f}')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.plot(fpr, tpr, label=f'{key.replace("matcher", "")} ROC curve (area = {metrics.auc(fpr, tpr):.3f})')
plt.legend();

In [None]:
algo = 'matcher2def'
# algo = 'matcherIsomLocality'
indices_pred_true = np.argwhere(preds[algo] > 0.5).flatten()
indices_pred_false = np.argwhere(preds[algo] < 0.5).flatten()
indices_false = np.argwhere(targets < 0.5).flatten()
indices_true = np.argwhere(targets > 0.5).flatten()
indices_fp = np.intersect1d(indices_pred_true, indices_false)
indices_fn = np.intersect1d(indices_pred_false, indices_true)

denergies_fp = denergies[indices_fp]
distances_fp = preds[algo][indices_fp]

plt.scatter(distances_fp, denergies_fp, alpha=0.2)
plt.title(f"Delta Energy for {algo.upper()} FP");

In [None]:
dict(FP_len=len(indices_fp), FPs=indices_fp[:10], FN_len=len(indices_fn), FNs=indices_fn[:10])

In [None]:
denergies_true = denergies[np.argwhere(targets > 0.5).flatten()]
denergies_false = denergies[np.argwhere(targets < 0.5).flatten()]
denergies_fp = denergies[indices_fp]
# plt.hist([denergies_false, denergies_fp], bins=30, alpha=0.5, label=['false', 'FP']);
plt.hist([denergies_fp], bins=30, alpha=0.5, label=[ 'FP']);
plt.xlabel('Delta Energy (eV)')
plt.legend()
dict(   
    mean_denergy_false=format(np.mean(denergies_false), '.4f'), 
    mean_denergy_fp=format(np.mean(denergies_fp), '.5f'),
    len_fp=len(denergies_fp),
)

## Examine FPs/FNs

In [None]:
def pair2structs(pairs, idx, group_pair_lens):
    s1 = deepcopy(pairs[idx][0].initial_structure)
    s2 = deepcopy(pairs[idx][1].initial_structure)
    s1_d = Structure.from_sites(pairs[idx][0].defects)
    s2_d = Structure.from_sites(pairs[idx][1].defects)
    for s in s1_d.sites:
        nl = s1.get_neighbor_list(0.9, [s,])[0]
        if len(nl) == 0:
            s1.append(s.specie, s.coords, coords_are_cartesian=True)
    for s in s2_d.sites:
        nl = s2.get_neighbor_list(0.9, [s,])[0]
        if len(nl) == 0:
            s2.append(s.specie, s.coords, coords_are_cartesian=True)
    assert len(s1.sites) == 192, f"not right size: {len(s1.sites)}"
    assert len(s2.sites) == 192, f"not right size: {len(s2.sites)}"
    #     s1.remove_sites([i for i in range(192, len(s1.sites))])
    # if len(s2.sites) > 192:
    #     s2.remove_sites([i for i in range(192, len(s2.sites))])
    delta = s1_d.frac_coords[0][2] - s1_d.frac_coords[1][2]
    group_idx = ([0,] + [i + 1 for (i, k) in enumerate(np.cumsum(list(group_pair_lens.values()))) if idx > k])[0]
    group_name = list(group_pair_lens.keys())[group_idx]
    was_set1 = set([s.properties['was'] for s in s1_d.sites])
    was_set2 = set([s.properties['was'] for s in s2_d.sites])
    print(f"{idx}, {s1_d.distance_matrix[1,0]:.3}, {s1_d.species}, {was_set1}, {group_name}")
    # print(f"{s2_d.distance_matrix[1,0]:.3}, {s2_d.species}, {was_set2}, {two_def_groups[group_idx]}")
    return s1, s2, s1_d, s2_d

In [None]:
bad_idx =   931
matcher = TwoDefectsFP(selector)
s1, s2, s1_d, s2_d = pair2structs(pairs, bad_idx, group_pair_lens)
# key1, key2 = matchers_dict['matcher2def'].get_key(s1_d), matchers_dict['matcher2def'].get_key(s2_d)
# key1, key2 = matchers_dict['matcher3def'].get_key(s1_d), matchers_dict['matcher3def'].get_key(s2_d)
# key1, key2 = matcher.get_key(s1_d), matcher.get_key(s2_d)
print(s1_d.distance_matrix, "\n", s2_d.distance_matrix)
targets[bad_idx], preds['matcher2def'][bad_idx]

In [None]:
matcher.get_key(s1_d), matcher.get_key(s2_d), matcher.get_key(s1_d, unify=False), matcher.get_phase(matcher.get_key(s1_d), matcher.get_key(s1_d, unify=False))

In [None]:
get_directed_edge(s1_d[0], s1_d[1])

In [None]:
s1_d_r = ds.rotate_structure(s1_d)
matcher.get_key(s1_d), matcher.get_key(s1_d_r)

In [None]:
fig = plt.figure(figsize=(22, 14))
(ax1, ax2) = fig.subplots(1, 2)

ax1.axis('off')
ax2.axis('off')

ds.plot_struct(s1, ax=ax1, defects_only=True, normalize=True)
ds.plot_struct(s2, ax=ax2, defects_only=True, normalize=True)


## here I am

In [None]:
fig = plt.figure(figsize=(22, 14))
(ax1, ax2) = fig.subplots(1, 2)


ds.plot_struct(normalize_struct_asymmetry(s1_d), ax=ax1, defects_only=True, normalize=False)
ds.plot_struct(normalize_struct_asymmetry(s2_d), ax=ax2, defects_only=True, normalize=False)


In [None]:
normalize_struct_asymmetry(s1_d).sites, normalize_struct_asymmetry(s2_d).sites  

In [None]:
normalize_struct_asymmetry(s1_d).sites, normalize_struct_asymmetry(s2_d).sites  

In [None]:
# format numpy array for printing
np.set_printoptions(precision=3, suppress=True)
s1_d.distance_matrix

In [None]:
%matplotlib inline
from ipywidgets import fixed, interactive
import matplotlib.pyplot as plt
from IPython.display import display
from pymatgen.core.structure import Structure

    
def fp(x, y, theta, phi, initial_structure, defects_only=False):
    t = np.array([x, y, 0])
    assert isinstance(initial_structure, Structure), "invalid type"

    ts = deepcopy(initial_structure)
    if defects_only:
        ts.remove_species(["S", ])
        ts.replace_species({Element("Mo"): Element("H")})
        ts.replace_species({DummySpecies(): 'O'})
    ts.translate_sites(list(range(len(ts))), vector=t)
    ts = ds.rotate_structure(ts, theta=-theta)
    ts = ds.rotate_structure(ts, theta=-phi, axis=[1, 0, 0], anchor=[0, 0, 0])
    ase_atoms = AseAtomsAdaptor.get_atoms(ts)
    # defects = ts.get_defects(from_struct=True)
    # coords = coord.pbc_diff(defects.frac_coords[1], defects.frac_coords[0])[0:2]
    # angles = angles_debug(defects)

    # print(coord.pbc_shortest_vectors(ts.lattice, defects.frac_coords[0], defects.frac_coords))
    
    for i in range(len(ts)):
        site_i = ts[i]
        if site_i.specie.symbol == 'H':
            continue
        for j in range(i+1, len(ts)):
            site_j = ts[j]
            if site_j.specie.symbol == 'H' or i == j:
                continue
            print(i, site_i.specie.symbol, j, site_j.specie.symbol, coord.pbc_diff(site_i.frac_coords, site_j.frac_coords)[0:2])
    plt.figure(2, figsize=(5, 3))
    ax = plt.subplot()
    # ax.set_title(", ".join([f"{n}: {a:.3f}" for n, a in zip([r"$\alpha$", r"$\beta$", r"$\gamma$", r"$\sum$"], angles)]))
    ax.xaxis.set_visible(False)
    ax.yaxis.set_visible(False)
    plot_atoms(ase_atoms, ax=ax, radii=0.5, rotation=('12x, 0y, 0z'), show_unit_cell=False)

interactive_plot = interactive(fp, 
                               x=(-0.5, 0.5, 0.125), y=(-0.5, 0.5, 0.125), 
                               theta=(-2*np.pi, 2*np.pi, 2/3*np.pi),
                               phi=fixed(0),
                            #    phi=(-np.pi, np.pi, np.pi),
                            #    initial_structure=fixed(structA))
                               initial_structure=fixed(s1),
                               defects_only=fixed(True))
output = interactive_plot.children[-1]
output.layout.height = '300px'
display(interactive_plot)


In [None]:
interactive_plot = interactive(fp, 
                               x=(-0.5, 0.5, 0.125), y=(-0.5, 0.5, 0.125), 
                               theta=(-2*np.pi, 2*np.pi, 2/3*np.pi),
                               phi=fixed(0),
                            #    phi=(-np.pi, np.pi, np.pi),
                            #    initial_structure=fixed(structA))
                               initial_structure=fixed(s2),
                               defects_only=fixed(True))
output = interactive_plot.children[-1]
output.layout.height = '300px'
display(interactive_plot)

In [None]:
np.mod([0.5, 0.7, -0.5], 1) 

In [None]:
comp = SymmetryEquivalenceCheck()
def compare_ase(_sA, _sB):
    return comp.compare(AseAtomsAdaptor.get_atoms(_sA), AseAtomsAdaptor.get_atoms(_sB))


In [None]:
pairs[bad_idx][0].energy - pairs[bad_idx][1].energy

In [None]:
bad_idx = 1519
s1, s2, s1_d, s2_d = pair2structs(pairs, bad_idx, group_pair_lens)
    
vs = VariableStructure(s2, s2_d.sites)
# vs.add_replacement_defect(33, 'O')
# for i in range(5):
#     vs.add_replacement_defect(np.random.randint(192), 'O')

fig = plt.figure(figsize=(22, 14))
(ax1, ax2, ax3) = fig.subplots(1, 3)

ax1.axis('off')
ax2.axis('off')
ax3.axis('off')

s1 = ds.translate_structure(s1, [0, -3 * 0.125, 0])
ds.plot_struct(s1, ax=ax1, defects_only=True, normalize=False)
ds.plot_struct(ds.rotate_structure(s1, theta=-2*np.pi/3), ax=ax2, defects_only=True, normalize=False)
ds.plot_struct(ds.rotate_structure(s1, theta=-4*np.pi/3), ax=ax3, defects_only=True, normalize=False)


In [None]:
def x_proj_fp2(s: Structure, n_turns: int=3, agg=np.sort):
    assert len(s) == 2, "FP2, invalid structure"
    return x_proj_edge(s[0], s[1], agg=agg, n_turns=n_turns)
    
def x_proj_fpN(s: Structure, n_turns: int=3, agg=np.sort):
    # assert len(s) == 2, "FP2, invalid structure"
    proj_dict = dict()
    for i, site in enumerate(s.sites):
        for j, site2 in enumerate(s.sites):
            if i == j: continue
            key = str(sorted([i, j]))
            if key in proj_dict: continue
            # print(">> proj for: " , s[i], s[j])
            proj = x_proj_edge(s[i], s[j], agg=agg, n_turns=n_turns)
            # print(key, proj)
            proj_dict[key] = np.array(proj[0:3])
    # print(proj_dict.keys())
    proj_fp = np.sort(np.concatenate(list(proj_dict.values()), axis=0))
    return proj_fp
    

def rotate_vector_x_proj(lattice, coords_start, coords_end, theta=-2*np.pi/3, n_turns=3):
    from scipy.linalg import expm

    axis = np.array([0, 0, 1])
    theta %= 2 * np.pi
    proj = np.zeros(n_turns)

    rm = expm(np.cross(np.eye(3), axis) * theta)
    frac_start = lattice.get_fractional_coords(coords_start)
    for i in range(n_turns):
        frac_end = lattice.get_fractional_coords(coords_end)
        x_proj, y_proj, z_proj = coord.pbc_diff(frac_end, frac_start)
        # print(">>>", x_proj, y_proj, z_proj, coords_start)
        proj[i] = x_proj if x_proj >= -0.49 else 0.5
        coords_end = ((np.dot(rm, np.array(coords_end - coords_start).T)).T + coords_start).ravel()
    return proj
    

def x_proj_edge(site1, site2, n_turns=3, agg=np.sort):
    length = site1.distance(site2)
    if site1.specie.Z >= site2.specie.Z:
        start, end = site1, site2
    else:
        start, end = site2, site1
    proj = rotate_vector_x_proj(site1.lattice, start.coords, end.coords, n_turns=n_turns)
    # if site1.specie.Z == site2.specie.Z:
    #     proj = np.abs(proj)
    return agg(proj)

In [None]:
s1_d.sites

In [None]:
x_proj_edge(*s1_d.sites)

In [None]:
from pymatgen.core.lattice import Lattice
a0 = Structure(Lattice.hexagonal(2, 0.5), ['O', 'H'], [[0, 0, 0], [0.5, 0.0, 0]], to_unit_cell=True)
x_proj_edge(*a0.sites)

In [None]:
ds.plot_struct(vs * (2, 2, 1), defects_only=True, normalize=False)


In [None]:
vs2 = ds.rotate_structure(vs, theta=2 * np.pi / 3)
ds.plot_struct(vs2 * (2, 2, 1), defects_only=True, normalize=False)

In [None]:
def defect_distances(struct, specie):
    sd = struct * (2, 2, 1)
    sd.remove_species(["S", "H", "Mo"])
    sd.replace_species({DummySpecies(): 'O'})
    defect = [d for d in sd.sites if d.specie.symbol == specie][0]
    distances = []
    for s in sd.sites:
        print(s.specie, defect.distance_and_image_from_frac_coords(s.frac_coords)[0])
    return sd.distance_matrix

dm0 = defect_distances(vs, 'O')


In [None]:
dm1= defect_distances(vs2, 'O')


In [None]:
fig = plt.figure(figsize=(22, 14))
(ax1, ax2) = fig.subplots(1, 2)

ax1.axis('off')
ax2.axis('off')

ds.plot_struct(vs, ax=ax1, defects_only=True, normalize=False)
ds.plot_struct(vs2, ax=ax2, defects_only=True, normalize=False)

In [None]:
import networkx as nx

G1 = nx.from_numpy_matrix(dm0, create_using=nx.Graph)
G2 = nx.from_numpy_matrix(dm1, create_using=nx.Graph)

edge_comp = iso.numerical_edge_match("weight", 0.0)
node_comp = iso.numerical_node_match('Z', 0)
nx.is_isomorphic(G1, G2, edge_match=edge_comp, node_match=node_comp)