In [4]:
import sumolib
import networkx as nx
import matplotlib.pyplot as plt
import xml.etree.ElementTree as ET
import random

import torch
import torch_geometric
import torch_geometric.data as Data
import torch_geometric.utils as pyg_utils

import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
import numpy as np




First let's read in the graph file and see if it works properly

## Data handling

In [5]:
#Function to read in the network we work with into a networkx object with the nodes and edges, no features yet
def read_sumo_net(filename):
    net = sumolib.net.readNet(filename)
    G = nx.Graph()
    
    # Add nodes
    for node in net.getNodes():
        G.add_node(node.getID(), pos=(node.getCoord()))
    # Add edges
    for edge in net.getEdges():
        for lane in edge.getLanes():
            edge_id = lane.getEdge().getID()
            starting_node_id = net.getEdge(edge_id).getFromNode().getID()
            ending_node_id = net.getEdge(edge_id).getToNode().getID()
            G.add_edge(starting_node_id, ending_node_id, edge_id = edge_id)
    return G

#Function to add the features to the network graph we created already

def add_edge_features_from_xml(G, xml_filename, interval_begin):
    # Parse the XML file
    tree = ET.parse(xml_filename)
    root = tree.getroot()
    #Find the interval corresponding to the interval_begin time 
    interval = root.find(f'.//interval[@begin="{interval_begin}"]')
    #Extract all the features of the edges
    edges = interval.findall('.//edge')
    for edge in edges:
        edge_id = edge.get('id')
        edge_features = {}
        edge_features['left'] = edge.get('left')
        #We can add other features here
        #Iterate through the edges in the existing NetworkX graph
        for xml_edge_id, xml_edge_data in G.edges.items():
            if G.get_edge_data(xml_edge_id[0],xml_edge_id[1])['edge_id'] == edge_id:
                G.edges[xml_edge_id].update(edge_features)
    return G

def nx_to_pyg(graph):
    # Convert NetworkX graph to PyTorch Geometric Data object
    pyg_data = Data.Data()
    #We have to number the nodes, because that is how Data object works
    # Mapping between string node IDs and numerical indices
    node_id_to_index = {node_id: i for i, node_id in enumerate(graph.nodes)}

    # Set node features
    num_nodes = graph.number_of_nodes()
    node_features = np.zeros((num_nodes, 2))  # Assuming num_features is known, this is important to change, if we want to change something, altough I do not think that will be the case for us
    for i, (node, features) in enumerate(graph.nodes(data=True)):
        node_features[i] = [features['pos'][0], features['pos'][1]]  # Add node features accordingly, this case the coordinates
    pyg_data.x = torch.tensor(node_features, dtype=torch.float)

    # Set edge features and edge indices
    edge_index = []
    edge_features = []
    for u, v, features in graph.edges(data=True):
        # Map string node IDs to numerical indices
        u_index = node_id_to_index[u]
        v_index = node_id_to_index[v]
        edge_index.append([u_index, v_index])
        edge_features.append([float(features['left'])])  # Add edge features accordingly, if we add more features, we have to change this line

    pyg_data.edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    pyg_data.edge_attr = torch.tensor(edge_features, dtype=torch.float)

    return pyg_data


#Function to plot the graph
def plot_graph(G):
    pos = nx.get_node_attributes(G, 'pos')
    nx.draw(G, pos, with_labels=False, node_size=10)
    plt.show()
Graph = read_sumo_net('s_gyor.net.xml')

In [7]:
G1 = read_sumo_net('s_gyor.net.xml')
G2 = add_edge_features_from_xml(G1,'gyor_forg_15_min.xml',"0.00")
pyg_data = nx_to_pyg(G2)
print(pyg_data)


Data(x=[413, 2], edge_index=[2, 504], edge_attr=[504, 1])


In [23]:
data = nx_to_pyg(G2)

class GNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GNN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x


# Get the total number of edges in the dataset
total_num_edges = data.num_edges

# Calculate the number of edges to hide (70% of total number of edges)
num_edges_to_hide = int(total_num_edges * 0.7)
num_remaining_edges = total_num_edges - num_edges_to_hide

 # Randomly select edges to hide (70% of total edges)
edges_to_hide = random.sample(range(total_num_edges), num_edges_to_hide)
print(edges_to_hide)
print(len(edges_to_hide))


input_dim = num_remaining_edges
hidden_dim = 64
output_dim = num_edges_to_hide 

# Define the model
model = GNN(input_dim, hidden_dim, output_dim)


criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)


num_epochs = 100

for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    
   
    
    # Filter the edge features and edge indices for the selected edges
    hidden_edge_attr = data.edge_attr[edges_to_hide]
    hidden_edge_index = data.edge_index[:, edges_to_hide]
    
    output = model.forward()
    
    # Compute loss using the predicted features and the input features of the remaining edges

    remaining_edge_attr = data.edge_attr  # Assuming all edges' features are present initially
    remaining_edge_index = data.edge_index[:, ~edges_to_hide]  # Filter out the hidden edges
    loss = criterion(output, remaining_edge_attr)  # Assuming remaining_edge_attr contains the features of the remaining edges
    
    loss.backward()
    optimizer.step()
    
    if epoch % 10 == 0:
        print(f'Epoch {epoch+1}, Loss: {loss.item()}')

# NaN értékte cserélni a hiden edge-eket


[393, 158, 165, 454, 429, 24, 52, 379, 308, 311, 211, 146, 368, 236, 56, 222, 45, 134, 88, 247, 35, 461, 9, 107, 99, 306, 43, 34, 260, 28, 456, 487, 109, 53, 8, 89, 387, 501, 460, 472, 217, 144, 255, 30, 337, 188, 363, 245, 39, 318, 362, 301, 225, 20, 166, 113, 265, 292, 366, 147, 108, 502, 480, 416, 484, 61, 293, 267, 325, 198, 375, 250, 65, 115, 493, 251, 503, 261, 148, 152, 440, 241, 180, 434, 373, 409, 280, 383, 406, 27, 159, 69, 458, 330, 309, 128, 478, 405, 11, 203, 419, 66, 145, 209, 257, 433, 408, 496, 441, 233, 479, 172, 404, 300, 463, 314, 41, 126, 81, 119, 315, 313, 476, 278, 467, 279, 402, 344, 275, 398, 102, 361, 431, 477, 303, 322, 120, 352, 59, 237, 421, 498, 32, 0, 224, 149, 324, 319, 103, 3, 87, 71, 264, 273, 289, 132, 54, 316, 200, 42, 131, 57, 220, 474, 29, 329, 73, 439, 178, 38, 492, 445, 397, 355, 47, 181, 151, 101, 299, 423, 191, 6, 471, 193, 91, 112, 164, 162, 235, 190, 333, 452, 425, 199, 282, 142, 347, 64, 31, 15, 80, 79, 110, 62, 490, 407, 323, 189, 130, 155, 

TypeError: bad operand type for unary ~: 'list'