In [8]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
cmap = matplotlib.colormaps.get('tab10').colors
import torch
import torch_geometric as tg
from gen_autoencoder_dataset import AutoencoderDataset
from torcheval.metrics import BinaryAUROC
from tqdm import trange

In [310]:
import numpy as np
from pytest import mark
rng = np.random.default_rng()
import os
import torch_geometric as tg
import torch
from tqdm import trange
from scipy.spatial import KDTree
import networkx as nx

class AutoencoderDataset(tg.data.Dataset):
    '''
    This class bundles the creation and saving as well as loading of a dataset of 3D graphs. If an instance is created, the class will 
    check in root directory if the dataset is already processed. If not, the process() method will be called. Furthermore, the
    dataset will be loaded. If the dataset shall be calculated again, the process() method must be called explicitely.
    '''
    def __init__(self, root, n_graphs_per_type=100, transform=None, pre_transform=None):
        '''
        Args:
        - root (str): The directory where the dataset should be stored, divided into processed and raw dirs
        '''
        self.root = root
        self.n_graphs_per_type = n_graphs_per_type
        super().__init__(root, transform, pre_transform)
        
    @property
    def raw_file_names(self):
        '''
        If this file exists in the raw directory, the download will be skipped. Download not implemented.
        '''
        return 'raw.txt'
    
    @property
    def processed_file_names(self):
        '''
        If this file exists in the processed directory, processing will be skipped. 
        Note: This does smh not work, therefore files are ATM recalculated every time.
        '''
        return ['data_00000.pt']
    
    def download(self):
        '''
        Download not implemented.
        '''
        pass
    
    def len(self):
        '''
        Returns the number of graphs in the dataset.
        '''
        return len([f for f in os.listdir(os.path.join(self.root, 'processed')) if f.startswith('data')])
    
    def get(self, idx):
        '''
        Returns the graph at index idx. 
        '''
        data = torch.load(os.path.join(self.processed_dir, 'data_{:05d}.pt'.format(idx)))
        return data
    
    def process(self):
        '''
        Here creation, processing and saving of the dataset happens. 
        '''
        
        lattice_types = {
            0: {'name': 'mP', 'nodes': self._get_P_nodes, 'binding_angle':   0, 'scale': [0, 0]},
            1: {'name': 'oP', 'nodes': self._get_P_nodes, 'binding_angle':  90, 'scale': [0, 0]},
            2: {'name': 'oC', 'nodes': self._get_C_nodes, 'binding_angle':  90, 'scale': [0, 0]},
            3: {'name': 'tP', 'nodes': self._get_P_nodes, 'binding_angle':  90, 'scale': [1, 1]},
            4: {'name': 'hP', 'nodes': self._get_P_nodes, 'binding_angle': 120, 'scale': [1, 1]}
        }

        
        for n in trange(self.n_graphs_per_type * 14):
            # Get graph features:
            pos, edge_index, label = self._process_lattice(lattice_types[n % 14])
            label = np.expand_dims(label, axis=1)
            node_attr = self._get_node_attr(pos, edge_index)
            # node_attr = pos
            #edge_attr = self._get_edge_attr(pos, edge_index)
            
            # Create data object:
            data = tg.data.Data(x          = torch.tensor(node_attr, dtype=torch.float), 
                                edge_index = torch.tensor(edge_index, dtype=torch.int64), 
                                y          = torch.tensor(label, dtype=torch.float), 
                                pos        = torch.tensor(pos, dtype=torch.float))
            # Save data object:
            torch.save(data, os.path.join(self.processed_dir, 'data_{:05d}.pt'.format(n)))

class planeGraphGenerator():
    def __init__(self, size=(6,6)):
        '''
        Args:
        - size (tuple): The size of the lattice in x,y direction. The lattice will be size[0] x size[1] large.
        '''
        self.size = size
        self.lattice_types = {
            0: {'name': 'mP', 'nodes': self._get_P_nodes, 'binding_angle':   0, 'scale': [0, 0]},
            1: {'name': 'oP', 'nodes': self._get_P_nodes, 'binding_angle':  90, 'scale': [0, 0]},
            2: {'name': 'oC', 'nodes': self._get_C_nodes, 'binding_angle':  90, 'scale': [0, 0]},
            3: {'name': 'tP', 'nodes': self._get_P_nodes, 'binding_angle':  90, 'scale': [1, 1]},
            4: {'name': 'hP', 'nodes': self._get_P_nodes, 'binding_angle': 120, 'scale': [1, 1]}
        }
        
    def _get_P_nodes(self, angle=90):
        '''
        Get the nodes of a primitive lattice.
        '''
        scaling = np.sin(np.radians(angle))
        vec1 = np.arange(0,self.size[0])
        vec2 = np.arange(0,self.size[1])*scaling
        a, b = np.meshgrid(vec1,vec2)
        nodes = np.stack([a,b],axis=-1) # Stack them in a new axis
        nodes = np.reshape(nodes, (-1, 2)) # Reshape to an arr of nodes with shape (#nodes, 2)
        return nodes

    def _get_C_nodes(self):
        '''
        Get the nodes of a centred lattice.
        '''
        P = self._get_P_nodes()
        extra = P + np.array([0.5,0.5])
        return np.row_stack((P, extra))
    

    def _process_lattice(self, arg_dict):
        '''
        Method that processes a lattice of a given type. The method is called with a dictionary holding parameters for one of the lattice types. It contains the following keys:
            - name: The name of the lattice type
            - nodes: The method to get the fitting fundamental lattice nodes
            - binding_angle: binding angle alpha of the lattice type. angle are in degrees. 0° means to generate a independent random angle (0,180)°
            - scale: A list of scaling factors [x,y] for the lattice type. 0 means to generate a random scaling factor (0,2)
        '''
        # Get lattice angles
        angle = arg_dict['binding_angle']
        if angle == 0:
            angle = rng.uniform(46,89,1)
            
        # Get the fundamental lattice nodes
        if arg_dict['name'] in ['hP']:
            # For hP lattice we need to give the angle to the nodes method so that sheared connections are equally long
            nodes = arg_dict['nodes'](angle)
        else:
            nodes = arg_dict['nodes']()
        nodes = self._shear_nodes(nodes, angle)
        # Find random scale and apply gaussian noise to the lattice accordingly
        scale = np.array(arg_dict['scale'])
        scale = np.where(scale == 0, rng.uniform(0.3,3,2), scale)
        noise_level = 0.05 / scale  # At this step we scale the noise down, so that the scaling later on does not affect the noise level
        #nodes += rng.normal(0, noise_level, nodes.shape)
        
        nodes, labels = self._displace_node(nodes)
        # Find the connections between the nodes in a given radius
        cons= self._get_cons_in_radius(nodes, 1.3+np.mean(noise_level))
        # Apply the saved scaling
        nodes *= scale
        
        # Add defects to the lattice
        #nodes, cons, labels = self._add_defects(nodes, cons, labels)
        return nodes, cons, labels

    def _displace_node(self, nodes):
        '''
        Method that dislaces one random node in the lattice by a random amount. Returns the new nodes and the label for classification. 
        The label is a one hot encoded array of shape (len(nodes)) where 1 markes the index od the displaced node.
        '''
        # Get random node and displacement
        node_ind = rng.integers(0, len(nodes))
        displacement = rng.normal(0, 1, 2)
        # Assert that the displacement does not move the node out of the lattice (roughly)
        while np.any(nodes[node_ind] + displacement < np.min(nodes, axis=0)) or np.any(nodes[node_ind] + displacement > np.max(nodes, axis=0)):
            displacement = rng.normal(0, 1, 2)
        # Displace node, get label
        nodes[node_ind] += displacement
        labels = np.zeros(len(nodes))
        labels[node_ind] = 1
        return nodes, labels
    
    def _get_cons_in_radius(self, nodes, radius):
        '''
        Get the connections in a radius as well as the total number of cons for each node.
        '''
        tree = KDTree(nodes)
        cons = tree.query_pairs(radius, output_type='ndarray', p=2)
        cons = cons.T
        cons = np.column_stack((cons, cons[::-1])) # Add the reverse connections
        return cons

    def _shear_nodes(self, nodes, binding_angle):
        '''
        Shear nodes by binding angle.
        '''
        delta = np.tan(np.radians(binding_angle))
        assert not np.any(delta == 0), 'Binding angle cannot be 0'
        nodes = nodes.astype(float)
        nodes = nodes + np.stack((nodes[:,1]/delta, np.zeros_like(nodes[:,1])), axis=1)
        return nodes

    def _add_defects(self, nodes, edge_index, labels):
        '''
        Method that adds up to 10% of random defects (i.e. missing nodes) to the lattice. Should be called after _get_*_graph() but before
        _get_edge_attr() and _get_node_attr().
        '''
        # Draw up to 10% of unique random indices for nodes to be removed
        drop_indices = rng.choice(np.arange(len(nodes)), rng.integers(len(nodes)//10), replace=False)
        # Remove the nodes and labels
        nodes = np.delete(nodes, drop_indices, axis=0)
        labels = np.delete(labels, drop_indices, axis=0)
        # Delete every connection that refers to a removed node
        edge_index = np.delete(edge_index, np.where(np.isin(edge_index, drop_indices))[1], axis=1)
        
        # As edge_index refers to the original node indices, we need to adjust the indices of most connections
        # For this we create a mapping from old indices to new indices
        old_to_new = np.arange(len(nodes) + len(drop_indices))  # Start with an array of original indices; [0,1,2,3,4,5,...]
        old_to_new[drop_indices] = -1  # Mark the indices of the nodes to be deleted; eg. drop_indices = [1,3] -> [0,-1,2,-1,4,5,...]
        old_to_new = np.cumsum(old_to_new != -1) - 1  # Create a cumulative sum array; cumsum([True, False, True, False, True, True,...]) -1 -> [1,1,2,2,3,4,...] -1 -> [0,0,1,1,2,3,...]
        
        # # Update edge indices to reflect new node indices through broadcasting
        edge_index = old_to_new[edge_index]
        return nodes, edge_index, labels
        
    def _get_node_attr(self,nodes,cons):
        '''
        Method that returns the node attributes for each node in the graph. Should be called after creating the graph and adding defects.
        The node attributes have the shape (num_nodes, num_node_features). For each node, the node features are the following:
        - The bond orientational order parameters for l=4,6,8,10 (4 features)
        '''
        G = nx.Graph()
        G.add_edges_from(cons.T)
        boo_arr = np.zeros((len(nodes), 4), dtype=complex)
        
        for i in range(len(nodes)):
            # iterate over all nodes, get their neighbors
            neighbors = list(G.neighbors(i))
            if len(neighbors) == 0:
                # no neighbors, return 0
                boo_arr[i] = np.zeros(4)
                continue
            # calculate the bond orientational order parameters
            boo_ = np.zeros(4, dtype=complex)
            for n in neighbors:
                angle = np.arctan2(nodes[n,1] - nodes[i,1], nodes[n,0] - nodes[i,0])
                boo_ += np.array([np.exp(1j*4*angle), np.exp(1j*6*angle), np.exp(1j*8*angle), np.exp(1j*10*angle)])   
            boo_arr[i] = np.abs(boo_ / len(neighbors)) 
            
        return boo_arr.astype(float)
        
    
    def _get_edge_attr(self,nodes,cons):
        '''
        Method that returns the edge attributes for each edge in the graph. Should be called after creating the graph and adding defects.
        Returns an array of shape (len(edge_index[0])= #Edges, 2) with the entries [dx,dy] for each edge.
        '''
        # Get the edge vectors for each edge
        edge_vectors = nodes[cons[0]] - nodes[cons[1]]
        return edge_vectors
    
    


def plot(nodes, cons, labels=None):
    '''
    Plot the graph.
    '''
    fig, ax = plt.subplots()
    ax.scatter(nodes[:,0], nodes[:,1], c='black')
    for con in cons.T:
        ax.plot(nodes[con,0], nodes[con,1], c='black')
    if labels is not None:
        ax.plot(nodes[labels==1,0], nodes[labels==1,1], c='red', markersize=10, marker='o', linestyle='None')
    ax.set_aspect('equal')
    ax.grid()

In [311]:
g = planeGraphGenerator()
nodes, cons, labels = g._process_lattice(g.lattice_types[0])
g._get_node_attr(nodes, cons)

  return boo_arr.astype(float)


array([[0.96871547, 0.36738165, 0.87681934, 0.58670755],
       [0.96762905, 0.12756135, 0.87520135, 0.23486685],
       [0.96762905, 0.12756135, 0.87520135, 0.23486685],
       [0.96762905, 0.12756135, 0.87520135, 0.23486685],
       [0.96762905, 0.12756135, 0.87520135, 0.23486685],
       [0.95814037, 0.26706692, 0.83794194, 0.17952011],
       [0.95392717, 0.44758342, 0.82073028, 0.38441885],
       [0.95814037, 0.26706692, 0.83794194, 0.17952011],
       [0.95814037, 0.26706692, 0.83794194, 0.17952011],
       [0.95814037, 0.26706692, 0.83794194, 0.17952011],
       [0.95814037, 0.26706692, 0.83794194, 0.17952011],
       [0.95392717, 0.44758342, 0.82073028, 0.38441885],
       [0.95392717, 0.44758342, 0.82073028, 0.38441885],
       [0.95814037, 0.26706692, 0.83794194, 0.17952011],
       [0.96859722, 0.13280761, 0.88001772, 0.02134662],
       [0.95983062, 0.23354653, 0.84354343, 0.30554217],
       [0.95814037, 0.26706692, 0.83794194, 0.17952011],
       [0.95392717, 0.44758342,