In [1]:
import pickle
import os

import numpy as np
from torch_geometric.data import Data, Batch

import networkx as nx

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool, GATConv
from torch.nn.utils.rnn import pad_sequence

import torch.optim as optim

import matplotlib.pyplot as plt
import matplotlib.cm as cm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from math import sqrt

In [82]:
import matplotlib.pyplot as plt

import networkx as nx

def plot_graph_and_data(G, filename):
    pos = {node: (data['x'], data['y']) for node, data in G.nodes(data=True)}
    Dyna_pos = {node: [(data['dynamic_object_position_X'][i], data['dynamic_object_position_Y'][i])
                       for i in range(len(data['dynamic_object_position_X']))
                       if data['dynamic_object_exist_probability'][i] == 1] for node, data in G.nodes(data=True)}

    node_types = {'map_node': [], 'dynamic_object_node': [], 'traffic_light_node': []}
    node_colors = {}
    cmap = cm.get_cmap('YlOrRd')

    for node, data in G.nodes(data=True):
        if data['type'] == 'map_node':
            node_types['map_node'].append(node)
        if 1 in data['dynamic_object_exist_probability']:
            node_types['dynamic_object_node'].append(node)
            node_colors[node] = cmap(max(data['dynamic_object_exist_probability']))
        if 1 in data['traffic_light_detected']:
            node_types['traffic_light_node'].append(node)

    fig, ax = plt.subplots(figsize=(12, 8))

    nx.draw_networkx_nodes(G, pos, nodelist=node_types['map_node'], node_color='black', node_size=2, label='Map Nodes')
    nx.draw_networkx_nodes(G, pos, nodelist=node_types['dynamic_object_node'], node_color=[node_colors[node] for node in node_types['dynamic_object_node']], node_size=5, label='Dynamic Object Nodes')
    nx.draw_networkx_nodes(G, pos, nodelist=node_types['traffic_light_node'], node_color='green', node_size=5, label='Traffic Light Nodes')

    for i, node in enumerate(node_types['dynamic_object_node']):
        x_values = [pos[node][0]] + [p[0] for p in Dyna_pos[node]]
        y_values = [pos[node][1]] + [p[1] for p in Dyna_pos[node]]
        # Calculate alpha based on i, ensuring it's between 0 and 1
        alpha = max(0, 1 - (i / len(node_types['dynamic_object_node'])))
        ax.scatter(x_values, y_values, color='blue', alpha=alpha, s=3)

    edges = [(u, v) for u, v in G.edges() if u != v]
    nx.draw_networkx_edges(G, pos, edgelist=edges, edge_color='gray')

    x_values = [data['x'] for node, data in G.nodes(data=True)]
    y_values = [data['y'] for node, data in G.nodes(data=True)]
    plot_padding = 0
    plt.xlim(min(x_values) - plot_padding, max(x_values) + plot_padding)
    plt.ylim(min(y_values) - plot_padding, max(y_values) + plot_padding)

    plt.legend(loc='upper right')
    plt.savefig(filename)
    plt.clf()
    plt.close()



In [83]:
def establish_node_correspondence(sequence):
    correspondence = {}
    complete_graph = nx.Graph()

    max_length = len(sequence)

    for i, graph in enumerate(sequence):
        for node, data in graph.nodes(data=True):
            position = (data['x'], data['y'])
            if position not in correspondence:
                correspondence[position] = len(correspondence)
                complete_graph.add_node(correspondence[position], 
                                        x=data['x'], y=data['y'], 
                                        nearest_traffic_light_detection_probability=data['nearest_traffic_light_detection_probability'], 
                                        traffic_light_detected=[0] * max_length, 
                                        **{k: [0] * max_length for k in data.keys() if k not in ['x', 'y', 'nearest_traffic_light_detection_probability']})
            else:
                complete_graph.nodes[correspondence[position]]['nearest_traffic_light_detection_probability'] = data['nearest_traffic_light_detection_probability']
                complete_graph.nodes[correspondence[position]]['traffic_light_detected'][i] = int(data['nearest_traffic_light_detection_probability'] and data['dynamic_object_exist_probability'])
            for k, v in data.items():
                if k not in ['x', 'y', 'nearest_traffic_light_detection_probability']:
                    complete_graph.nodes[correspondence[position]][k][i] = v

    for graph in sequence:
        for u, v, data in graph.edges(data=True):
            u_position = (graph.nodes[u]['x'], graph.nodes[u]['y'])
            v_position = (graph.nodes[v]['x'], graph.nodes[v]['y'])
            u_complete = correspondence[u_position]
            v_complete = correspondence[v_position]
            if not complete_graph.has_edge(u_complete, v_complete):
                complete_graph.add_edge(u_complete, v_complete, **data)

    return correspondence, complete_graph


In [84]:
def process_sequence_for_gnn(sequence):
    # 1. Node Correspondence 
    correspondence, complete_graph = establish_node_correspondence(sequence)
    #complete_graph = connect_graphs(complete_graph)
    plot_graph_and_data(complete_graph, 'graph.png')

    # 2. Temporal Features
    #temporal_graph = create_temporal_graph(window, correspondence) 

    # 3. Feature Engineering
    #enhanced_graph = add_engineered_features(temporal_graph) 

    #processed_data.append(enhanced_graph) 

    #return processed_data
    return 0


def load_data(input_folder, batch_size=32):
    all_sequences = []

    for file_name in os.listdir(input_folder):
        file_path = os.path.join(input_folder, file_name)
        #print(f"Processing file: {file_name}")
        with open(file_path, 'rb') as f:
            sequences = pickle.load(f)
            # Check if the sequences list is not empty
            if sequences:
                all_sequences.extend(sequences)
    # Shuffle the sequences
    np.random.shuffle(all_sequences)

    def batch_generator():
        for i in range(0, len(all_sequences), batch_size):
            batch = all_sequences[i:i+batch_size]
            
            # Process the graphs in the batch
            processed_batch = []
            for sequence in batch:
                processed_sequence = process_sequence_for_gnn(sequence)          
                processed_batch.append(processed_sequence)
                

            yield processed_batch

    return batch_generator()

In [86]:
train_input_folder = "Training Dataset1/Sequence_Dataset"
batch_generator = load_data(train_input_folder)
for batch in batch_generator:
    print(len(batch))
    break

  cmap = cm.get_cmap('YlOrRd')


32
