In [48]:
import pickle
import os

import numpy as np
from torch_geometric.data import Data, Batch

import networkx as nx

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch.nn.utils.rnn import pad_sequence

In [64]:
def preprocess_graph(G):
    node_features = []
    for _, data in G.nodes(data=True):
        features = [
            data['x'], data['y'],
            data['dynamic_object_exist_probability'],
            data['dynamic_object_position_X'], data['dynamic_object_position_Y'],
            data['dynamic_object_velocity_X'], data['dynamic_object_velocity_Y'],
            data['nearest_traffic_light_detection_probability']
        ]
        node_features.append(features)
    
    edge_index = []
    for edge in G.edges():
        edge_index.append([edge[0], edge[1]])
        edge_index.append([edge[1], edge[0]])  # Add reverse edge for undirected graph
    
    return torch.tensor(node_features, dtype=torch.float), torch.tensor(edge_index, dtype=torch.long).t().contiguous()

def prepare_batch(batch_graphs, max_nodes):
    x_seq, edge_index_seq = [], []
    x_last, edge_index_last = [], []
    y = []
    batch_last = []
    seq_lengths = []

    for batch_idx, graphs in enumerate(batch_graphs):
        seq_x, seq_edge_index = [], []
        for i, G in enumerate(graphs):
            x, edge_index = preprocess_graph(G)

            # Pad the graph with additional nodes if necessary
            if x.shape[0] < max_nodes:
                padding = torch.zeros((max_nodes - x.shape[0], x.shape[1]))
                x = torch.cat([x, padding], dim=0)

            if i < 3:
                seq_x.append(x)
                seq_edge_index.append(edge_index)
            else:  # Last graph
                x_last.append(x)
                edge_index_last.append(edge_index + len(torch.cat(x_last)))
                y.append(x[:, 2:])  # Target features
                batch_last.extend([batch_idx] * x.shape[0])  # Add batch index for each node

        x_seq.append(seq_x)
        edge_index_seq.append(seq_edge_index)
        seq_lengths.append(len(seq_x))

    # Pad x_seq
    padded_x_seq = []
    for batch in zip(*x_seq):
        padded_batch = pad_sequence(batch, batch_first=True)
        padded_x_seq.append(padded_batch)

    padded_x_seq = torch.stack(padded_x_seq, dim=1)  # [batch_size, seq_len, max_nodes, features]

    # Process edge_index_seq
    max_edges = max(edge_index.shape[1] for batch in edge_index_seq for edge_index in batch)  # Maximum number of edges in the batch
    processed_edge_index_seq = []
    for batch in edge_index_seq:
        batch_edge_index = []
        for edge_index in batch:
            # Pad edge_index if necessary
            if edge_index.shape[1] < max_edges:
                padding = torch.zeros((2, max_edges - edge_index.shape[1]), dtype=edge_index.dtype)
                edge_index = torch.cat([edge_index, padding], dim=1)
            batch_edge_index.append(edge_index)
        processed_edge_index_seq.append(torch.stack(batch_edge_index))

    edge_index_seq = torch.stack(processed_edge_index_seq)

    # Concatenate all x_last and edge_index_last
    x_last = torch.cat(x_last, dim=0)
    edge_index_last = torch.cat(edge_index_last, dim=1)
    y = torch.cat(y, dim=0)
    batch_last = torch.tensor(batch_last, dtype=torch.long)
    seq_lengths = torch.tensor(seq_lengths, dtype=torch.long)

    return padded_x_seq, edge_index_seq, x_last, edge_index_last, y, batch_last, seq_lengths

def load_sequence_data(input_folder, batch_size=32):
    all_sequences = []
    max_nodes = 0

    for file_name in os.listdir(input_folder):
        file_path = os.path.join(input_folder, file_name)
        print(f"Processing file: {file_name}")
        with open(file_path, 'rb') as f:
            sequences = pickle.load(f)

            # Check if the sequences list is not empty
            if sequences:
                all_sequences.extend(sequences)

                # Update the maximum number of nodes
                max_nodes = max(max_nodes, max(G.number_of_nodes() for graphs in sequences for G in graphs))
                print(f"Max number of nodes: {max_nodes}")

    # Shuffle the sequences
    np.random.shuffle(all_sequences)

    def batch_generator():
        for i in range(0, len(all_sequences), batch_size):
            batch = all_sequences[i:i+batch_size]

            # Process the graphs in the batch
            processed_batch = []
            for graphs in batch:
                processed_graphs = []
                for G in graphs:
                    processed_graphs.append(G)  # Append the graph object instead of the tuple
                processed_batch.append(processed_graphs)

            yield processed_batch

    return batch_generator(), max_nodes

class GCNLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GCNLayer, self).__init__()
        self.conv = GCNConv(in_channels, out_channels)
    
    def forward(self, x, edge_index):
        return F.relu(self.conv(x, edge_index))

class GraphSequenceNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GraphSequenceNN, self).__init__()
        self.gcn1 = GCNLayer(input_dim, hidden_dim)
        self.gcn2 = GCNLayer(hidden_dim, hidden_dim)
        self.rnn = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, data):
        x_seq, edge_index_seq, x_last, edge_index_last, batch_last, seq_lengths = data
        
        batch_size, seq_len, max_nodes, _ = x_seq.size()
        
        gcn_out_seq = []
        for i in range(seq_len):
            x = x_seq[:, i]
            edge_index = edge_index_seq[:, i]
            
            out = self.gcn1(x.reshape(-1, x.size(-1)), edge_index.reshape(2, -1))
            out = self.gcn2(out, edge_index.reshape(2, -1))
            
            # Global mean pooling
            out = global_mean_pool(out, torch.arange(batch_size).repeat_interleave(max_nodes).to(out.device))
            gcn_out_seq.append(out)
        
        gcn_out_seq = torch.stack(gcn_out_seq, dim=1)  # [batch_size, seq_len, hidden_dim]
        
        # RNN layer
        packed_input = nn.utils.rnn.pack_padded_sequence(gcn_out_seq, seq_lengths.cpu(), batch_first=True, enforce_sorted=False)
        _, h_n = self.rnn(packed_input)
        
        # Process last graph
        out_last = self.gcn1(x_last, edge_index_last)
        out_last = self.gcn2(out_last, edge_index_last)
        
        # Global mean pooling for last graph
        out_last = global_mean_pool(out_last, batch_last)
        
        # Combine RNN output with last graph output
        combined = h_n.squeeze(0) + out_last
        
        # Final prediction
        pred = self.fc(combined)
        return pred

In [65]:
input_folder = "Sequence_Dataset"

In [66]:
# Hyperparameters
input_dim = 8  # Number of node features
hidden_dim = 64
output_dim = 6  # Number of output features
learning_rate = 0.001
num_epochs = 100
batch_size = 32

# Model, loss, and optimizer
model = GraphSequenceNN(input_dim, hidden_dim, output_dim)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    batch_generator, max_nodes = load_sequence_data(input_folder, batch_size)
    
    for batch in batch_generator:
        data = prepare_batch(batch, max_nodes)
        x_seq, edge_index_seq, x_last, edge_index_last, y, batch_last, seq_lengths = data
        
        optimizer.zero_grad()
        output = model((x_seq, edge_index_seq, x_last, edge_index_last, batch_last, seq_lengths))
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss:.4f}")

# After training, you can use the model for predictions
model.eval()
#with torch.no_grad():
    # Prepare your test data
#    test_batch_generator, _ = load_sequence_data(test_input_folder, batch_size)
#    for test_batch in test_batch_generator:
#        test_data = prepare_batch(test_batch, max_nodes)
#        predictions = model(test_data)
        # Process predictions as needed

Processing file: Cleaned_Sequence_1.pkl
Max number of nodes: 874
Processing file: Cleaned_Sequence_10.pkl
Max number of nodes: 874
Processing file: Cleaned_Sequence_11.pkl
Processing file: Cleaned_Sequence_12.pkl
Max number of nodes: 874
Processing file: Cleaned_Sequence_13.pkl
Max number of nodes: 874
Processing file: Cleaned_Sequence_14.pkl
Max number of nodes: 874
Processing file: Cleaned_Sequence_15.pkl
Max number of nodes: 874
Processing file: Cleaned_Sequence_2.pkl
Max number of nodes: 874
Processing file: Cleaned_Sequence_3.pkl
Max number of nodes: 874
Processing file: Cleaned_Sequence_4.pkl
Max number of nodes: 874
Processing file: Cleaned_Sequence_5.pkl
Max number of nodes: 874
Processing file: Cleaned_Sequence_6.pkl
Max number of nodes: 874
Processing file: Cleaned_Sequence_7.pkl
Max number of nodes: 874
Processing file: Cleaned_Sequence_8.pkl
Max number of nodes: 874
Processing file: Cleaned_Sequence_9.pkl
Max number of nodes: 874


RuntimeError: index 27970 is out of bounds for dimension 0 with size 27968