In [1]:
import torch

In [4]:
class Graph:
    def __init__(self, num_vertices, embedding_dim):
        self.adjacency_list = {}
        self.num_vertices = num_vertices
        self.embedding_dim = embedding_dim
        self.embeddings = torch.nn.Embedding(num_vertices, embedding_dim)

    def add_vertex(self, vertex):
        """
        Add a new vertex to the graph.
        """
        if vertex not in self.adjacency_list:
            self.adjacency_list[vertex] = []

    def add_edge(self, vertex1, vertex2):
        """
        Add a new edge between two vertices.
        """
        if vertex1 in self.adjacency_list and vertex2 in self.adjacency_list:
            self.adjacency_list[vertex1].append(vertex2)
            self.adjacency_list[vertex2].append(vertex1)

    def get_embeddings(self):
        """
        Get the learnable vertex embeddings.
        """
        return self.embeddings.weight

    def forward(self, edge_index):
        """
        Compute the forward pass of the graph neural network.
        edge_index: a list of edges in the graph
        """
        row, col = edge_index
        node_embeddings = self.embeddings(torch.arange(self.num_vertices))
        msg = node_embeddings[row] + node_embeddings[col]
        # Perform message passing and update embeddings
        # ...

        return node_embeddings

    def print_adjacency_list(self):
        """
        Print the graph in an adjacency list representation.
        """
        print("Adjacency List:")
        for vertex in self.adjacency_list:
            print(f"{vertex}: {self.adjacency_list[vertex]}")

    def print_embeddings(self):
        """
        Print the learnable vertex embeddings.
        """
        embeddings = self.get_embeddings()
        print("Vertex Embeddings:")
        for i, embedding in enumerate(embeddings):
            print(f"Vertex {i}: {embedding.tolist()}")

    def print_edges(self):
        """
        Print the edges in the graph.
        """
        edges = []
        for vertex, neighbors in self.adjacency_list.items():
            for neighbor in neighbors:
                edges.append((vertex, neighbor))
        print("Edges:")
        for edge in edges:
            print(edge)

Adjacency List:
0: [1, 3]
1: [0, 2]
2: [1, 3]
3: [2, 0]
Edges:
(0, 1)
(0, 3)
(1, 0)
(1, 2)
(2, 1)
(2, 3)
(3, 2)
(3, 0)
Vertex Embeddings:
Vertex 0: [-0.8037563562393188, -0.8256404399871826]
Vertex 1: [-0.489565908908844, 0.039530087262392044]
Vertex 2: [-0.03903724625706673, -0.02241910807788372]
Vertex 3: [0.4732538163661957, -0.5184754729270935]
Updated Vertex Embeddings:
tensor([[-0.8038, -0.8256],
        [-0.4896,  0.0395],
        [-0.0390, -0.0224],
        [ 0.4733, -0.5185]], grad_fn=<EmbeddingBackward0>)


In [None]:
def main():
    # Create a graph with 4 vertices and 2-dimensional embeddings
    graph = Graph(4, 2)

    # Add vertices to the graph
    graph.add_vertex(0)
    graph.add_vertex(1)
    graph.add_vertex(2)
    graph.add_vertex(3)

    # Add edges to the graph
    graph.add_edge(0, 1)
    graph.add_edge(1, 2)
    graph.add_edge(2, 3)
    graph.add_edge(3, 0)

    # Print the graph
    graph.print_adjacency_list()
    graph.print_edges()

    # Print the learnable vertex embeddings
    graph.print_embeddings()

    # Perform a forward pass of the graph neural network
    edge_index = torch.tensor([[0, 1, 2, 3], [1, 2, 3, 0]])
    node_embeddings = graph.forward(edge_index)
    print("Updated Vertex Embeddings:")
    print(node_embeddings)
if __name__ == "__main__":
    main()