In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pymatgen.core.structure import Structure, Lattice
from pymatgen.analysis.graphs import StructureGraph
from pymatgen.analysis.local_env import CutOffDictNN
import networkx as nx
import hashlib
from torch_geometric.utils.convert import to_networkx
from torch_geometric.data import Data
import numpy as np
import torch
import os
import yaml
from glob import glob
import random


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from mofgraph2vec.graph.mof2doc import MOF2doc

In [4]:
doc = MOF2doc(cif_path= "../data/cifs/", wl_step = 5)

In [5]:
data = doc.get_documents()

100%|██████████| 3260/3260 [00:11<00:00, 273.57it/s]


In [6]:
k_foldes = 5
random.shuffle(data)
num = int(len(data)/k_foldes)

In [32]:
slices = []
for k in range(k_foldes):
    if k == (k_foldes -1):
        slices.append(data[k*num:])
    else:
        slices.append(data[k*num: (k+1)*num])

In [34]:
len(slices)

5

In [43]:
indexes = np.arange(k_foldes)

for k in range(k_foldes):
    valid_data = slices[k]
    training_data = [d for id in indexes if id != k for d in slices[id]]
    print(len(valid_data), len(training_data))

652 2608
652 2608
652 2608
652 2608
652 2608


In [42]:
training_data[0]

['RSM3663']

In [3]:
with open(os.path.join("../src/mofgraph2vec/graph/tunedvesta.yml"), "r", encoding="utf8") as handle:
    _VESTA_CUTOFFS = yaml.load(handle, Loader=yaml.UnsafeLoader) 

def _get_distance(
    lattice, frac_coords_0, frac_coords_1, jimage
):
    """Get the distance between two fractional coordinates taking into account periodic boundary conditions.
    Parameters
    ----------
    lattice : Lattice
        pymatgen Lattice object
    frac_coords_0 : np.array
        fractional coordinates of the first atom
    frac_coords_1 : np.array
        fractional coordinates of the second atom
    jimage : Tuple[int, int, int]
        image of the second atom
    Returns
    -------
    float
        Distance between the two atoms
    """
    jimage = np.array(jimage)
    mapped_vec = lattice.get_cartesian_coords(jimage + frac_coords_1 - frac_coords_0)
    return np.linalg.norm(mapped_vec)

In [4]:
strategy = CutOffDictNN(cut_off_dict=_VESTA_CUTOFFS)

In [5]:
structure = Structure.from_file("../data/cifs/RSM0001.cif").get_primitive_structure()
sg = StructureGraph.with_local_env_strategy(structure, strategy)

In [52]:
path = "../data/cifs/RSM0001.cif"

In [54]:
meta_path = path.replace("cifs", "meta")
meta_path = meta_path.replace(".cif", ".pt")

In [6]:
sg.graph.edges

OutMultiEdgeView([(0, 10, 0), (0, 12, 0), (0, 11, 0), (0, 8, 0), (0, 7, 0), (0, 9, 0), (1, 4, 0), (2, 5, 0), (3, 6, 0), (4, 9, 0), (4, 8, 0), (5, 11, 0), (5, 12, 0), (6, 7, 0), (6, 10, 0)])

In [9]:
    def _get_node_features(structure: Structure):
        x = [site.specie.Z for site in structure]
        return np.vstack(x)

    def _get_edge_index_and_lengths(sg: StructureGraph):
        edge_idx = []
        distances = []

        lattice = sg.structure.lattice
        structure = sg.structure
        for edge in sg.graph.edges(keys=True, data=True):
            fc_0 = structure.frac_coords[edge[0]]
            fc_1 = structure.frac_coords[edge[1]]
            d = _get_distance(lattice, fc_0, fc_1, edge[-1]["to_jimage"])
            distances.append(d)
            edge_idx.append([edge[0], edge[1]])

        return (
            np.array(edge_idx).T,
            np.array(distances)
        )

In [10]:
x = _get_node_features(structure)
edge_idx, edge_attr = _get_edge_index_and_lengths(sg)

(0, 10, 0, {'to_jimage': (0, 0, 0)})
(0, 12, 0, {'to_jimage': (0, 1, -1)})
(0, 11, 0, {'to_jimage': (0, 1, 0)})
(0, 8, 0, {'to_jimage': (1, 0, -1)})
(0, 7, 0, {'to_jimage': (1, 0, 0)})
(0, 9, 0, {'to_jimage': (1, 1, -1)})
(1, 4, 0, {'to_jimage': (1, 0, 0)})
(2, 5, 0, {'to_jimage': (-1, 0, 0)})
(3, 6, 0, {'to_jimage': (0, 0, 0)})
(4, 9, 0, {'to_jimage': (0, 0, -1)})
(4, 8, 0, {'to_jimage': (0, 0, -1)})
(5, 11, 0, {'to_jimage': (0, 1, 0)})
(5, 12, 0, {'to_jimage': (0, 1, 0)})
(6, 7, 0, {'to_jimage': (0, 0, 0)})
(6, 10, 0, {'to_jimage': (0, 0, 0)})


In [12]:
edge_idx.shape

(2, 15)

In [13]:
data = Data(
            x=x, 
            edge_index=torch.Tensor(edge_idx),
            edge_attr=edge_attr
        )

In [14]:
graph = to_networkx(data)

In [23]:
features = data.x.flatten()
features

array([64,  1,  1,  1,  6,  6,  6,  8,  8,  8,  8,  8,  8])

In [24]:
features[0]

64

In [30]:
new_features = []

for node in graph.nodes:
    nebs = graph.neighbors(node)
    degs = [features[int(neb)] for neb in nebs]
    neb_features = sorted([str(deg) for deg in degs])
    nn_features = [str(features[node])]+sorted([str(deg) for deg in degs if len(degs)>0])
    nn_features = "_".join(nn_features)
    hash_object = hashlib.md5(nn_features.encode())
    hashing = hash_object.hexdigest()
    new_features.append(hashing)

In [49]:
data.x.flatten()

array([64,  1,  1,  1,  6,  6,  6,  8,  8,  8,  8,  8,  8])

In [50]:
features_to_WL = {}
for i, item in enumerate(data.x.flatten()):
    features_to_WL.update({i: item})

In [51]:
features_to_WL

{0: 64,
 1: 1,
 2: 1,
 3: 1,
 4: 6,
 5: 6,
 6: 6,
 7: 8,
 8: 8,
 9: 8,
 10: 8,
 11: 8,
 12: 8}

In [44]:
class WeisfeilerLehmanMachine:
    """
    Weisfeiler Lehman feature extractor class.
    """
    def __init__(self, graph, features, iterations):
        """
        Initialization method which also executes feature extraction.
        :param graph: The Nx graph object.
        :param features: Feature hash table.
        :param iterations: Number of WL iterations.
        """
        self.iterations = iterations
        self.graph = graph
        self.features = features
        self.nodes = self.graph.nodes()
        self.extracted_features = [str(v) for k, v in features.items()]
        self.do_recursions()

    def do_a_recursion(self):
        """
        The method does a single WL recursion.
        :return new_features: The hash table with extracted WL features.
        """
        new_features = {}
        for node in self.nodes:
            nebs = self.graph.neighbors(node)
            degs = [self.features[neb] for neb in nebs]
            features = [str(self.features[node])]+sorted([str(deg) for deg in degs])
            features = "_".join(features)
            hash_object = hashlib.md5(features.encode())
            hashing = hash_object.hexdigest()
            new_features[node] = hashing
        self.extracted_features = self.extracted_features + list(new_features.values())
        return new_features

    def do_recursions(self):
        """
        The method does a series of WL recursions.
        """
        for _ in range(self.iterations):
            self.features = self.do_a_recursion()
            print(self.features)

In [45]:
machine = WeisfeilerLehmanMachine(graph, features_to_WL, 4)

{0: 'b687ea24da315fef802670d6e98e7453', 1: '291dd475d0224126a68550d7c406f3b1', 2: '291dd475d0224126a68550d7c406f3b1', 3: '291dd475d0224126a68550d7c406f3b1', 4: '29b8c03f8d52a63d11887d6a70570940', 5: '29b8c03f8d52a63d11887d6a70570940', 6: '29b8c03f8d52a63d11887d6a70570940', 7: 'c9f0f895fb98ab9159f51fd0297e236d', 8: 'c9f0f895fb98ab9159f51fd0297e236d', 9: 'c9f0f895fb98ab9159f51fd0297e236d', 10: 'c9f0f895fb98ab9159f51fd0297e236d', 11: 'c9f0f895fb98ab9159f51fd0297e236d', 12: 'c9f0f895fb98ab9159f51fd0297e236d'}
{0: '3fa24e606490fe53d6b0f1ce52c9bc2d', 1: '2f72d8013c247d7dc696ad360ec04f96', 2: '2f72d8013c247d7dc696ad360ec04f96', 3: '2f72d8013c247d7dc696ad360ec04f96', 4: '4d2b25f2a9de9d1b45ec81e0f23804a3', 5: '4d2b25f2a9de9d1b45ec81e0f23804a3', 6: '4d2b25f2a9de9d1b45ec81e0f23804a3', 7: '815e6212def15fe76ed27cec7a393d59', 8: '815e6212def15fe76ed27cec7a393d59', 9: '815e6212def15fe76ed27cec7a393d59', 10: '815e6212def15fe76ed27cec7a393d59', 11: '815e6212def15fe76ed27cec7a393d59', 12: '815e6212def15

In [112]:
from glob import glob
import random

In [117]:
from pathlib import Path
from typing import Callable, Dict, List, Tuple, Union
from typing_extensions import TypeAlias
PathType: TypeAlias = Union[Path, str]

In [118]:
cifs = glob("../data/cifs/*.cif")

In [61]:
name = cifs[0].split("/")[-1].rstrip(".cif")

In [62]:
name

'RSM0008'

In [66]:
base = os.path.basename("../../../data/vec/embedding.csv")

In [67]:
os.path.splitext(base)

('embedding', '.csv')

In [75]:
sims = [('g_RSM1099', 0.9754330515861511), ('g_RSM0566', 0.9705005884170532), ('g_RSM0390', 0.9640724658966064)]

In [78]:
[docid for docid, sim in sims].index('g_RSM0390')

2

In [98]:
data = [0,1,0,3,4,0,1,1,0]

In [99]:
from collections import Counter

In [100]:
dis = Counter(data)
dis

Counter({0: 4, 1: 3, 3: 1, 4: 1})

In [101]:
times_count = [dis[word] for idx, word in enumerate(dis)]

In [102]:
times_count

[4, 3, 1, 1]

In [107]:
np.sum(np.array(times_count)<3)/len(times_count)

0.5

In [27]:
from mofgraph2vec.graph.mof2doc import MOF2doc

In [42]:
doc = MOF2doc(
    "../data/cifs/",
    1,
    1000
)

In [43]:
documents = doc.get_documents()

2023-01-30 12:27:14.109 | INFO     | mofgraph2vec.graph.mof2doc:get_documents:31 - Converting graphs to tokens. 
100%|██████████| 1000/1000 [00:03<00:00, 277.08it/s]


In [44]:
doc.distribution_analysis(4)

0.32193176783340716