In [1]:
#import pandas as pd
import torch
import torch_geometric
from torch_geometric.data import Dataset, Data
import numpy as np 
import os

In [2]:
import xml.etree.ElementTree as ET
import pprint
pp = pprint.PrettyPrinter(indent=4)
tree = ET.parse('../data/raw/all2.xml')
root = tree.getroot()
#print(root)

In [3]:
xml_models = root[0]
model_classes = []

for m in xml_models:
    model_classes.append(m.attrib["modeltype"])
    
model_classes = list(set(model_classes))

num_model_classes = len(model_classes)

def get_model_class(model):
    return model_classes.index(model.attrib["modeltype"])


#pp.pprint(model_classes)
#print(len(model_classes))

In [4]:
xml_models = root[0]
model_data = []
node_classes = []
edge_classes = []

for m in xml_models:
    instances = [el for el in m if el.tag == "INSTANCE"]
    connectors = [el for el in m if el.tag == "CONNECTOR"]
    
    for instance in instances:
        node_class = instance.attrib["class"]
        node_classes.append(node_class)
    
    for connector in connectors:
        edge_type = next(filter(lambda attr: attr.get("name") == "Type", connector.findall("ATTRIBUTE"))).text
        if(edge_type is None):
            edge_type = "none"
        edge_classes.append(edge_type.lower())
        
    
node_classes = list(set(node_classes))
edge_classes = list(set(edge_classes))

num_node_classes = len(node_classes)
num_edge_classes = len(edge_classes)

# pp.pprint(edge_classes)
# pp.pprint(num_edge_classes)

# pp.pprint(edge_classes)

In [5]:
class EnterpriseModelDatasetLP(Dataset):
    def __init__(self, root, filename, test=False, transform=None, pre_transform=None):
        """
        root = Where the dataset should be stored. This folder is split
        into raw_dir (downloaded dataset) and processed_dir (processed data). 
        """
        self.test = test
        self.filename = filename
        self.num_classes = num_node_classes
        super(EnterpriseModelDatasetLP, self).__init__(root, transform, pre_transform)
        
    @property
    def raw_file_names(self):
        """ If this file exists in raw_dir, the download is not triggered.
            (The download func. is not implemented here)  
        """
        return self.filename

    @property
    def processed_file_names(self):
        """ If these files are found in raw_dir, processing is skipped"""
        return "unimplemented.pt"
        if self.test:
            return [f'data_test_{i}.pt' for i in range(len(xml_models))]
        else:
            return [f'data_{i}.pt' for i in range(len(xml_models))]

    def download(self):
        pass

    def process(self):
        xml_models = root[0]
        index = 0
        
        for m in xml_models:
            model = {}
            nodes = []
            edges = []
            adjacency_list = []
            y = []

            nodes_data = []
            edges_data = []

            instances = [el for el in m if el.tag == "INSTANCE"]
            connectors = [el for el in m if el.tag == "CONNECTOR"]

        
            for instance in instances:
                node = {}
                node_class = instance.attrib["class"]
                node_name = instance.attrib["name"]
                node["class"] = node_class
                node["name"] = node_name
                nodes_data.append(node)
            
                # Dataset relevant
                nodes.append([node_classes.index(node_class)])
                
                
        
            for connector in connectors:
                edge = {}
                edge_type = next(filter(lambda attr: attr.get("name") == "Type", connector.findall("ATTRIBUTE"))).text
                if(edge_type is None):
                    edge_type = "none"
                else:
                    edge_type = edge_type.lower()
                edge["type"] = edge_type

                connector_from = connector.find("FROM").get("instance")
                connector_to = connector.find("TO").get("instance")
                edge["from"] = connector_from
                edge["to"] = connector_to
                edges_data.append(edge)

                from_index = [node_data["name"] for node_data in nodes_data].index(connector_from)
                to_index = [node_data["name"] for node_data in nodes_data].index(connector_to)
                
                # Dataset relevant
                adjacency_list.append([from_index, to_index])
                edges.append([edge_classes.index(edge_type)])
                edge_y = list(0 for i in range(0, num_edge_classes))
                edge_y[edge_classes.index(edge_type)] = 1
                y.append(edge_y)
                # nicht Kantentype identifizieren, lediglich ob Kante existiert oder nicht
                
        
            model["nodes"] = torch.tensor(nodes, dtype=torch.float)
            model["edges"] = torch.tensor(edges, dtype=torch.float)
            model["adjacency"] = torch.tensor(adjacency_list, dtype=torch.int64)
            model["y"] = torch.tensor(y, dtype=torch.float)

            model["nodes_data"] = nodes_data
            model["edges_data"] = edges_data
            
            
         # Create data object
            data = Data(x=model["nodes"], 
                        edge_index= model["adjacency"].t().contiguous(),
                        edge_attr=model["edges"],
                        y=model["y"],
                        ) 
            if self.test:
                torch.save(data, 
                    os.path.join(self.processed_dir, 
                                 f'data_test_{index}.pt'))
            else:
                torch.save(data, 
                    os.path.join(self.processed_dir, 
                                 f'data_{index}.pt'))
            
            index += 1

        
    def len(self):
        return len(xml_models)

    def get(self, idx):
        """ - Equivalent to __getitem__ in pytorch
            - Is not needed for PyG's InMemoryDataset
        """
        if self.test:
            data = torch.load(os.path.join(self.processed_dir, 
                                 f'data_test_{idx}.pt'))
        else:
            data = torch.load(os.path.join(self.processed_dir, 
                                 f'data_{idx}.pt'))   
        return data

In [6]:
dataset_ep = EnterpriseModelDatasetLP(root="../data/ep_data", filename="../raw/all2.xml")

Processing...
Done!


In [7]:
print("DATASET LOADED")
print(dataset_ep[2])

DATASET LOADED
Data(x=[26, 1], edge_index=[2, 26], edge_attr=[26, 1], y=[26, 19])
