In [None]:
import torch
from torch_geometric.data import Data, Batch

def create_graph_with_virtual_nodes(x_original, edge_index_original, E):
    N = x_original.size(0)
    
    # Generate virtual nodes with additional dimensions
    x_virtual = torch.randn(N, N + E)

    # Combine original and virtual node features
    x_combined = torch.cat([x_original, x_virtual], dim=0)

    # Create edges connecting virtual nodes to their corresponding original nodes
    edge_index_virtual = torch.tensor([
        list(range(N, 2 * N)),  # virtual node indices
        list(range(N))          # corresponding original node indices
    ], dtype=torch.long)

    # Create edge features that include the additional dimensions E
    edge_attr_virtual = x_virtual[:, -E:]

    # Combine original edge_index and virtual edge_index
    edge_index_combined = torch.cat([edge_index_original, edge_index_virtual], dim=1)

    # Combine original edge attributes and virtual edge attributes (if any)
    edge_attr_combined = torch.zeros(edge_index_combined.size(1), E)
    edge_attr_combined[-N:] = edge_attr_virtual

    # Create the graph data object
    data = Data(x=x_combined, edge_index=edge_index_combined, edge_attr=edge_attr_combined)
    
    # Indices of virtual nodes to be excluded during prediction
    virtual_node_indices = list(range(N, 2 * N))

    return data, virtual_node_indices

# Example of creating a batch of graphs
def create_batch_of_graphs(batch_size, N, E):
    graphs = []
    virtual_node_indices_list = []
    
    for _ in range(batch_size):
        x_original = torch.randn(N, N)  # Example node features for the original graph
        edge_index_original = torch.tensor([
            [i, (i + 1) % N] for i in range(N)
        ]).t().contiguous()  # Example edge_index (ring graph)
        
        data, virtual_node_indices = create_graph_with_virtual_nodes(x_original, edge_index_original, E)
        graphs.append(data)
        virtual_node_indices_list.append(virtual_node_indices)
    
    batch = Batch.from_data_list(graphs)
    
    return batch, virtual_node_indices_list

# Parameters
batch_size = 4
N = 10  # Number of original nodes in each graph
E = 5   # Additional dimensions for virtual nodes

# Create a batch of graphs
batch, virtual_node_indices_list = create_batch_of_graphs(batch_size, N, E)

print(batch)
print("Indices of virtual nodes in each graph in the batch:", virtual_node_indices_list)


In [None]:
import torch
from torch_geometric.data import Data, Batch

def add_virtual_nodes_to_graph(data, E):
    N = data.num_nodes // 2  # Assuming original graph is half of the current nodes after first processing

    # Separate the original node features
    x_original = data.x[:N]

    # Generate virtual nodes with additional dimensions
    x_virtual = torch.randn(N, N + E)

    # Combine original and virtual node features
    x_combined = torch.cat([x_original, x_virtual], dim=0)

    # Create edges connecting virtual nodes to their corresponding original nodes
    edge_index_virtual = torch.tensor([
        list(range(N, 2 * N)),  # virtual node indices
        list(range(N))          # corresponding original node indices
    ], dtype=torch.long)

    # Create edge features that include the additional dimensions E
    edge_attr_virtual = x_virtual[:, -E:]

    # Combine original edge_index and virtual edge_index
    edge_index_combined = torch.cat([data.edge_index, edge_index_virtual], dim=1)

    # Combine original edge attributes and virtual edge attributes (if any)
    if data.edge_attr is not None:
        edge_attr_combined = torch.cat([data.edge_attr, torch.zeros(edge_index_virtual.size(1), data.edge_attr.size(1))], dim=0)
    else:
        edge_attr_combined = torch.zeros(edge_index_combined.size(1), E)
    
    edge_attr_combined[-N:] = edge_attr_virtual

    # Update the graph data object
    data.x = x_combined
    data.edge_index = edge_index_combined
    data.edge_attr = edge_attr_combined

    # Indices of virtual nodes to be excluded during prediction
    virtual_node_indices = list(range(N, 2 * N))

    return data, virtual_node_indices

def process_existing_batch(batch, E):
    graphs = batch.to_data_list()
    modified_graphs = []
    virtual_node_indices_list = []
    
    for graph in graphs:
        modified_graph, virtual_node_indices = add_virtual_nodes_to_graph(graph, E)
        modified_graphs.append(modified_graph)
        virtual_node_indices_list.append(virtual_node_indices)
    
    modified_batch = Batch.from_data_list(modified_graphs)
    
    return modified_batch, virtual_node_indices_list

# Example usage
E = 5  # Additional dimensions for virtual nodes

# Assuming `batch` is your existing batch of graphs
# batch = ...

# Process the existing batch to add virtual nodes
modified_batch, virtual_node_indices_list = process_existing_batch(batch, E)

print(modified_batch)
print("Indices of virtual nodes in each graph in the batch:", virtual_node_indices_list)
