In [1]:
import numpy as np
rng = np.random.default_rng()
import matplotlib.pyplot as plt
import matplotlib

cmap = matplotlib.colormaps.get('tab10').colors
import torch
import torch_geometric as tg
from tqdm import trange
import os
from scipy.spatial import KDTree

In [2]:
import numpy as np
rng = np.random.default_rng()
import os
import torch_geometric as tg
import torch
from tqdm import trange
from scipy.spatial import KDTree

class DefectDetectionDataset(tg.data.Dataset):
    '''
    This class bundles the creation and saving as well as loading of a dataset of 3D graphs. If an instance is created, the class will 
    check in root directory if the dataset is already processed. If not, the process() method will be called. Furthermore, the
    dataset will be loaded. If the dataset shall be calculated again, the process() method must be called explicitely.
    '''
    def __init__(self, root, n_graphs_per_type=100, transform=None, pre_transform=None):
        '''
        Args:
        - root (str): The directory where the dataset should be stored, divided into processed and raw dirs
        '''
        self.root = root
        self.n_graphs_per_type = n_graphs_per_type
        super().__init__(root, transform, pre_transform)
        
    @property
    def raw_file_names(self):
        '''
        If this file exists in the raw directory, the download will be skipped. Download not implemented.
        '''
        return 'raw.txt'
    
    @property
    def processed_file_names(self):
        '''
        If this file exists in the processed directory, processing will be skipped. 
        Note: This does smh not work, therefore files are ATM recalculated every time.
        '''
        return ['data_00000.pt']
    
    def download(self):
        '''
        Download not implemented.
        '''
        pass
    
    def len(self):
        '''
        Returns the number of graphs in the dataset.
        '''
        return len([f for f in os.listdir(os.path.join(self.root, 'processed')) if f.startswith('data')])
    
    def get(self, idx):
        '''
        Returns the graph at index idx. 
        '''
        data = torch.load(os.path.join(self.processed_dir, 'data_{:05d}.pt'.format(idx)))
        return data
    
    def process(self):
        '''
        Here creation, processing and saving of the dataset happens. 
        '''
        # Some attributes for all graphs:
        self.size = np.array([5,5,5])
        lattice_types = {
             0: {'name': 'aP', 'nodes': self._get_P_nodes, 'binding_angles': [  0,   0,   0], 'scale': [0, 0, 0]},
             1: {'name': 'mP', 'nodes': self._get_P_nodes, 'binding_angles': [ 90,   0,  90], 'scale': [0, 0, 0]},
             2: {'name': 'mS', 'nodes': self._get_S_nodes, 'binding_angles': [ 90,   0,  90], 'scale': [0, 0, 0]},
             3: {'name': 'oP', 'nodes': self._get_P_nodes, 'binding_angles': [ 90,  90,  90], 'scale': [0, 0, 0]},
             4: {'name': 'oS', 'nodes': self._get_S_nodes, 'binding_angles': [ 90,  90,  90], 'scale': [0, 0, 0]},
             5: {'name': 'oI', 'nodes': self._get_I_nodes, 'binding_angles': [ 90,  90,  90], 'scale': [0, 0, 0]},
             6: {'name': 'oF', 'nodes': self._get_F_nodes, 'binding_angles': [ 90,  90,  90], 'scale': [0, 0, 0]},
             7: {'name': 'tP', 'nodes': self._get_P_nodes, 'binding_angles': [ 90,  90,  90], 'scale': [1, 1, 0]},
             8: {'name': 'tI', 'nodes': self._get_I_nodes, 'binding_angles': [ 90,  90,  90], 'scale': [1, 1, 0]},
             9: {'name': 'hR', 'nodes': self._get_P_nodes, 'binding_angles': [  0,   0,   0], 'scale': [1, 1, 1]},
            10: {'name': 'hP', 'nodes': self._get_P_nodes, 'binding_angles': [ 90,  90, 120], 'scale': [1, 1, 0]},
            11: {'name': 'cP', 'nodes': self._get_P_nodes, 'binding_angles': [ 90,  90,  90], 'scale': [1, 1, 1]},
            12: {'name': 'cI', 'nodes': self._get_I_nodes, 'binding_angles': [ 90,  90,  90], 'scale': [1, 1, 1]},
            13: {'name': 'cF', 'nodes': self._get_F_nodes, 'binding_angles': [ 90,  90,  90], 'scale': [1, 1, 1]},
        }

        
        for n in trange(self.n_graphs_per_type * 14):
            # Get graph features:
            pos, edge_index, label = self._process_lattice(lattice_types[n % 14])
            label = np.expand_dims(label, axis=1)
            node_attr = self._get_node_attr(pos, edge_index)
            edge_attr = self._get_edge_attr(pos, edge_index)
            # Create data object:
            data = tg.data.Data(x          = torch.tensor(node_attr, dtype=torch.float), 
                                edge_index = torch.tensor(edge_index, dtype=torch.int64), 
                                edge_attr  = torch.tensor(edge_attr, dtype=torch.float), 
                                y          = torch.tensor(label, dtype=torch.float), 
                                pos        = torch.tensor(pos, dtype=torch.float))
            # Save data object:
            torch.save(data, os.path.join(self.processed_dir, 'data_{:05d}.pt'.format(n)))


    def _get_P_nodes(self, angles=np.array([90,90,90])):
        '''
        Get the nodes of a primitive lattice.
        '''
        scaling = np.sin(np.radians(angles))
        vec1 = np.arange(0,self.size[0],1)
        vec2 = np.arange(0,self.size[1]*scaling[2],1*scaling[2])
        vec3 = np.arange(0,self.size[2]*scaling[1],1*scaling[1])
        a, b, c = np.meshgrid(vec1,vec2, vec3)
        nodes = np.stack([a,b,c],axis=-1) # Stack them in a new axis
        nodes = np.reshape(nodes, (-1, 3)) # Reshape to an arr of nodes with shape (#nodes, 3)
        return nodes

    def _get_S_nodes(self):
        '''
        Get the nodes of a base-centred lattice.
        '''
        P = self._get_P_nodes()
        extra = P + np.array([0.5,0.5,0])
        return np.append(P,extra,axis=0)

    def _get_I_nodes(self):
        '''
        Get the nodes of a body-centred lattice.
        '''
        P = self._get_P_nodes()
        extra = P + 0.5
        return np.append(P,extra,axis=0)

    def _get_F_nodes(self):
        '''
        Get the nodes of a face-centred lattice.
        '''
        P = self._get_P_nodes()
        extra1 = P + np.array([0.5,0.5,0])
        extra2 = P + np.array([0,0.5,0.5])
        extra3 = P + np.array([0.5,0,0.5])
        return np.row_stack((P, extra1, extra2, extra3))
    

    def _process_lattice(self, arg_dict):
        '''
        Method that processes a lattice of a given type. The method is called with a dictionary holding parameters for one of the lattice types. It contains the following keys:
            - name: The name of the lattice type
            - nodes: The method to get the fitting fundamental lattice nodes
            - binding_angles: A list of binding angles [alpha, beta, gamma] of the lattice type. Angles are in degrees. 0° means to generate a independent random angle (0,180)°
            - scale: A list of scaling factors [x,y,z] for the lattice type. 0 means to generate a random scaling factor (0,2)
        '''
        # Get lattice angles
        angles = np.array(arg_dict['binding_angles'])
        if arg_dict['name'] == 'hR':
            # Special case for hR lattice as it has 3 identical but random angles
            angles = np.where(angles == 0, rng.uniform(46,89,1), angles)
        else:
            angles = np.where(angles == 0, rng.uniform(46,89,3), angles)
            
        # Get the fundamental lattice nodes
        if arg_dict['name'] in ['hR', 'hP']:
            # For hR and hP lattices we need to give the angles to the nodes method so that sheared connections are equally long
            nodes = arg_dict['nodes'](angles)
        else:
            nodes = arg_dict['nodes']()
        nodes = self._shear_nodes(nodes, angles)
        # Find random scale and apply gaussian noise to the lattice accordingly
        scale = np.array(arg_dict['scale'])
        scale = np.where(scale == 0, rng.uniform(0.1,3,3), scale)
        noise_level = 0.05 / scale  # At this step we scale the noise down, so that the scaling later on does not affect the noise level
        nodes += rng.normal(0, noise_level, nodes.shape)
        
        nodes, labels = self._displace_node(nodes)
        # Find the connections between the nodes in a given radius
        cons= self._get_cons_in_radius(nodes, 1.3+np.mean(noise_level))
        # Apply the saved scaling
        nodes *= scale
        
        # Add defects to the lattice
        #nodes, cons, labels = self._add_defects(nodes, cons, labels)
        return nodes, cons, labels

    def _displace_node(self, nodes):
        '''
        Method that dislaces one random node in the lattice by a random amount. Returns the new nodes and the label for classification. 
        The label is a one hot encoded array of shape (len(nodes)) where 1 markes the index od the displaced node.
        '''
        # Get random node and displacement
        node_ind = rng.integers(0, len(nodes))
        displacement = rng.normal(0, 1, 3)
        # Displace node, get label
        nodes[node_ind] += displacement
        labels = np.zeros(len(nodes))
        labels[node_ind] = 1
        return nodes, labels
    
    def _get_cons_in_radius(self, nodes, radius):
        '''
        Get the connections in a radius as well as the total number of cons for each node.
        '''
        tree = KDTree(nodes)
        cons = tree.query_pairs(radius, output_type='ndarray', p=2)
        cons = cons.T
        cons = np.column_stack((cons, cons[::-1])) # Add the reverse connections
        return cons

    def _shear_nodes(self, nodes, binding_angle):
        '''
        Shear nodes. Binding angle is a 3D vector with the Binding angle in each axis.
        '''
        delta = np.tan(np.radians(np.array(binding_angle)))
        assert not np.any(delta == 0), 'Binding angle cannot be 0'
        nodes = nodes.astype(float)
        nodes = nodes + np.stack((nodes[:,2]/delta[0] + nodes[:,1]/delta[2], nodes[:,2]/delta[1] , np.zeros_like(nodes[:,1])), axis=1)
        return nodes

    def _add_defects(self, nodes, edge_index, labels):
        '''
        Method that adds up to 10% of random defects (i.e. missing nodes) to the lattice. Should be called after _get_*_graph() but before
        _get_edge_attr() and _get_node_attr().
        '''
        # Draw up to 10% of unique random indices for nodes to be removed
        drop_indices = rng.choice(np.arange(len(nodes)), rng.integers(len(nodes)//10), replace=False)
        # Remove the nodes and labels
        nodes = np.delete(nodes, drop_indices, axis=0)
        labels = np.delete(labels, drop_indices, axis=0)
        # Delete every connection that refers to a removed node
        edge_index = np.delete(edge_index, np.where(np.isin(edge_index, drop_indices))[1], axis=1)
        
        # As edge_index refers to the original node indices, we need to adjust the indices of most connections
        # For this we create a mapping from old indices to new indices
        old_to_new = np.arange(len(nodes) + len(drop_indices))  # Start with an array of original indices; [0,1,2,3,4,5,...]
        old_to_new[drop_indices] = -1  # Mark the indices of the nodes to be deleted; eg. drop_indices = [1,3] -> [0,-1,2,-1,4,5,...]
        old_to_new = np.cumsum(old_to_new != -1) - 1  # Create a cumulative sum array; cumsum([True, False, True, False, True, True,...]) -1 -> [1,1,2,2,3,4,...] -1 -> [0,0,1,1,2,3,...]
        
        # # Update edge indices to reflect new node indices through broadcasting magic
        edge_index = old_to_new[edge_index]
        return nodes, edge_index, labels
        
    def _get_node_attr(self,nodes,cons):
        '''
        Method that returns the node attributes for each node in the graph. Should be called after creating the graph and adding defects.
        Returns an array of shape (len(pos) = #Nodes) with the entries [C] for each node.
            - C: Number of connections to other nodes
        '''
        # Get the number of connections for each node
        connection_counts = np.zeros(len(nodes))
        for edge in cons[0]:
            # Iterate over all edge start points and count the connections for each node. Start points sufficient, as connections are bidirectional.
            connection_counts[edge] += 1 
            
        return np.expand_dims(connection_counts, axis=1)
    
    def _get_edge_attr(self,nodes,cons):
        '''
        Method that returns the edge attributes for each edge in the graph. Should be called after creating the graph and adding defects.
        Returns an array of shape (len(edge_index[0])= #Edges, 2) with the entries [dx,dy] for each edge.
        '''
        # Get the edge vectors for each edge
        edge_vectors = nodes[cons[0]] - nodes[cons[1]]
        return edge_vectors
    

In [28]:
n = 30
dataset = DefectDetectionDataset(root='defect_graphs', n_graphs_per_type=n)
dataset.process()
train_loader = tg.loader.DataLoader(dataset[:n*12], batch_size=16, shuffle=True)
test_loader = tg.loader.DataLoader(dataset[n*12:], batch_size=16, shuffle=True)
train_loader.dataset[2]


100%|██████████| 420/420 [00:02<00:00, 158.52it/s]
  data = torch.load(os.path.join(self.processed_dir, 'data_{:05d}.pt'.format(idx)))


Data(x=[250, 1], edge_index=[2, 3564], edge_attr=[3564, 3], y=[250, 1], pos=[250, 3])

In [None]:
in_shape = dataset[0].num_features
#out_shape = dataset[0].y.shape[1]
#print(out_shape)
class GINEConv(torch.nn.Module):
    def __init__(self):
        super().__init__()
        #torch.manual_seed(12345)
        self.conv1 = tg.nn.GINEConv(torch.nn.Linear(in_shape, 10), edge_dim=3)
        self.conv2 = tg.nn.GINEConv(torch.nn.Linear(10, 20), edge_dim=3)
        self.conv3 = tg.nn.GINEConv(torch.nn.Linear(20, 10), edge_dim=3)
        self.conv4 = tg.nn.GINEConv(torch.nn.Linear(10, 1), edge_dim=3)
    def forward(self, x, edge_index, edge_attr, batch):
        x = torch.tanh(self.conv1(x, edge_index, edge_attr))
        x = torch.tanh(self.conv2(x, edge_index, edge_attr))
        x = torch.tanh(self.conv3(x, edge_index, edge_attr))
        x = torch.tanh(self.conv4(x, edge_index, edge_attr))  
        out = x
        return out
model = GINEConv()
#model.to(device)
print(model)

optimizer = torch.optim.NAdam(model.parameters())
criterion = torch.nn.CrossEntropyLoss()

def train(loader):
  model.train()
  for data in loader:
    #data = data.to(device)
    out = model(data.x, data.edge_index, data.edge_attr, data.batch)# Perform a single forward pass.
    loss = criterion(out, data.y.argmax(dim=1))  # Compute the loss.
    #print(loss)
    loss.backward()  # Derive gradients.
    optimizer.step()  # Update parameters based on gradients.
    optimizer.zero_grad() # Reset grads.
def test(loader):
    model.eval()
    correct = 0
    for data in loader:
        #data = data.to(device)
        out = model(data.x, data.edge_index, data.edge_attr, data.batch)
        pred = out.argmax(dim=1)  # Use the class with highest probability.
        true = data.y.argmax(dim=1)
        #print(true)
        correct += int((true == pred).sum())
    return correct / len(loader.dataset)

GINEConv(
  (conv1): GINEConv(nn=Linear(in_features=1, out_features=10, bias=True))
  (conv2): GINEConv(nn=Linear(in_features=10, out_features=20, bias=True))
  (conv3): GINEConv(nn=Linear(in_features=20, out_features=10, bias=True))
  (conv4): GINEConv(nn=Linear(in_features=10, out_features=1, bias=True))
)


  data = torch.load(os.path.join(self.processed_dir, 'data_{:05d}.pt'.format(idx)))


In [36]:
accs = []
acc = test(test_loader)
accs.append(acc)
for epoch in range(1, 30):
    train(train_loader)
    test_acc = test(test_loader)
    print(test_acc)
    accs.append(test_acc)
    np.savetxt('accs.txt', accs)
fig,ax = plt.subplots()
ax.plot(accs, label='GINEConv')
ax.legend()

  data = torch.load(os.path.join(self.processed_dir, 'data_{:05d}.pt'.format(idx)))


1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0


KeyboardInterrupt: 