In [498]:
import torch
import torch.nn as nn
from torch import inverse

import torch_geometric as tg
import torch_geometric.nn as gnn
from torch_geometric.data import Data, Batch, Dataset, DataLoader


from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree, to_dense_adj

In [27]:
# edge_index: [2,N]
# x: [N, d]


In [160]:
edges = torch.tensor([[0,1,2,3,4],[1,2,3,4,1]])
x = torch.tensor([[10,20,30,40,50]]).T
x.shape

torch.Size([5, 1])

In [161]:
G = Data(x = x, edge_index = edges)
G

Data(edge_index=[2, 5], x=[5, 1])

In [128]:
print(G.num_edges)
print(G.num_nodes)
print(G.num_node_features)
print(G.keys)
print(G['x'])

5
5
1
['x', 'edge_index']
tensor([[10],
        [20],
        [30],
        [40],
        [50]])


In [129]:
batch = Batch.from_data_list([G, G, G])
batch

Batch(batch=[15], edge_index=[2, 15], ptr=[4], x=[15, 1])

In [130]:
print(batch[0])
print(batch.num_graphs)
print(batch.to_data_list())

Data(edge_index=[2, 5], x=[5, 1])
3
[Data(edge_index=[2, 5], x=[5, 1]), Data(edge_index=[2, 5], x=[5, 1]), Data(edge_index=[2, 5], x=[5, 1])]


In [206]:
def total_degree(edge_index, num_nodes):
    """
    Arguments
    ---------
    edge_index : torch.Tensor [2, num_edges]
    num_nodes : int
    
    Return
    ------
    total_degree : torch.Tensor [num_nodes]
    """
    return out_degree(edge_index, num_nodes) \
         + in_degree(edge_index, num_nodes)

In [437]:
def out_degree(edge_index, num_nodes):
    """
    Arguments
    ---------
    edge_index : torch.Tensor [2, num_edges]
    num_nodes : int
    
    Returns
    -------
    out_degree : torch.Tensor [num_nodes]
    """
    return degree(edge_index[0], num_nodes)

In [436]:
def in_degree(edge_index, num_nodes):
    """
    Arguments
    ---------
    edge_index : torch.Tensor [2, num_edges]
    num_nodes : int
    
    Returns
    -------
    in_degree : torch.Tensor [num_nodes]
    """
    return degree(edge_index[1], num_nodes)

In [238]:
class GNN(tg.nn.MessagePassing):
    def __init__(self, d_x=1, d_m=1):
        super(GNN, self).__init__(aggr = 'add', flow = 'source_to_target', node_dim = -2)
        
        self.fc = nn.Linear(d_x, d_m)
        
    def message(self, x_i, x_j, x, edge_index):
        """phi and square
        
        Arguments
        ---------
        everything that was passed to "propagate"
        propagate automatically extracts x_j and x_i for every node i
        so this is an individual message
        
        x_i : torch.Tensor [E, d_x]
        x_j : torch.Tensor [E, d_x]
        
        Returns
        -------
        m_ij : torch.Tensor [E, d_m]
            message from an single neighbour j to node i
        """
        return x
    
    def update(self, aggr_out):
        """gamma
        
        Arguments
        ---------
        aggr_out : torch.Tensor [N, d_m]
            output of the aggregation function
        
        Returns
        -------
        x_k : torch.Tensor [N, d_x]
        """
        return aggr_out
    
    def forward(self, x, edge_index):
        """Prepare data, input to propagate and return
        
        Arguments
        ---------
        x_k-1 : torch.Tensor [N, d_x]
        
        Returns
        -------
        x_k : torch.Tensor [N, d_x]
        """
        
        return self.propagate(x)

In [376]:
class GNN(tg.nn.MessagePassing):
    def __init__(self, d):
        super(GNN, self).__init__(aggr = 'mean', 
                                  flow = 'source_to_target', 
                                  node_dim = -2)
        
        self.W_j = nn.Linear(d, d)
        self.W_i = nn.Linear(d, d)
        self.act = nn.ReLU()
    
    def message(self, x_j):
        """
        Arguments
        ---------
        x_j : torch.Tensor [E, d]
        
        Returns
        -------
        m_j : torch.Tensor [E, d]
        """
        m_j = self.act(self.W_j(x_j))
        return m_j
        

    def update(self, aggr_out, x):
        """
        Arguments
        ---------
        x : torch.Tensor [N, d]
        aggr_out : torch.Tensor [N, d]
        
        Returns
        -------
        x_k : torch.Tensor [N, d]
        """
        m = aggr_out
        
        x_k = m + self.W_i(x)
        return x_k
    
    def forward(self, edge_index, x):
        """
        Arguments
        ---------
        edge_index : torch.tensor [2, E]
        x : torch.tensor [N, d]
        
        Returns
        -------
        x_k : torch.Tensor [N, d]
        """
        x_k = self.propagate(edge_index, x=x)
        return x_k

In [527]:
x = torch.tensor([[10, 20, 30]]).T.float()
x

tensor([[10.],
        [20.],
        [30.]])

In [528]:
edge_index = torch.tensor([[0,0,2],[1,2,1]])
edge_index

tensor([[0, 0, 2],
        [1, 2, 1]])

In [529]:
GNN(1)(edge_index, x)

tensor([[ 8.0344],
        [18.6940],
        [25.1396]], grad_fn=<AddBackward0>)

In [664]:
GCN(1)(edge_index, x).shape

torch.Size([3, 1])

In [665]:
class GCN(tg.nn.MessagePassing):
    def __init__(self, d):
        super(GCN, self).__init__(aggr = 'add', flow = 'source_to_target')
        self.W = nn.Linear(d,d)
        self.act = nn.Softmax(-1)
        
    def normalize_edges(self, edge_index, num_nodes):
        """
        Arguments
        ---------
        edge_index : torch.tensor [2, E]
        num_nodes : int

        Returns
        -------
        norm : torch.tensor [E, 1]
        """
        d_inv_sqrt = in_degree(edge_index, num_nodes = num_nodes).pow(-0.5)
        out_nodes, in_nodes = edge_index
        norm = d_inv_sqrt[out_nodes] * d_inv_sqrt[in_nodes]
        return norm.view(-1,1)
        
    def message(self, edge_index, x, norm, x_j):
        """
        Arguments
        ---------
        edge_index : torch.tensor [2, E]
        x : torch.tensor [N, d]
        x_j : torch.tensor [E, d]
        
        Returns
        -------
        m_j : torch.tensor [E, d]
        """
        norm = self.normalize_edges(edge_index, x.shape[-2])
        m_j = x_j * norm
        return m_j
    
    def update(self, aggr_out):
        """
        Arguments
        ---------
        m : torch.tensor [N, d]
        
        Returns
        -------
        x_k : torch.tensor [N, d]
        """
        m = aggr_out
        x_k = self.act(m)
        return m
        
        
    def forward(self, edge_index, x):
        x = self.W(x)
        edge_index = add_self_loops(edge_index, num_nodes = x.shape[-2])[0]
        return self.propagate(edge_index, x = x, norm = norm)

In [807]:
edge_index

tensor([[0, 0, 2],
        [1, 2, 1]])

In [862]:
class SingleHeadAttention(tg.nn.MessagePassing):
    def __init__(self, d_x, d_h):
        super(SingleHeadAttention, self).__init__()
        
        self.W = nn.Linear(d_x, d_h)
        self.a = nn.Linear(2*d_h, 1)
        
    def attention(self, x_i, x_j):
        """
        Arguments
        ---------
        x_i : torch.tensor [E, d_x]
        x_j : torch.tensor [E, d_x]

        Returns
        -------
        e_ij : torch.tensor [E, 1]
        """
        e_ij = nn.LeakyReLU()(self.a(torch.cat([self.W(x_i), self.W(x_j)], dim = -1)))
        return e_ij
    
    
    def alpha(self, edge_index, x_i, x_j):
        """
        Arguments
        ---------
        edge_index : torch.tensor [E, 2]
        x_i : torch.tensor [E, d_x]
        x_j : torch.tensor [E, d_x]
        
        Returns
        -------
        a_ij : torch.tensor [E, 1]
        """
        neighbours, node = edge_index
        e_ij = self.attention(x_i, x_j)
        a_ij = torch.zeros_like(e_ij)
        for i in range(x.shape[-2]):
            i_edges = node == i
            a_ij[i_edges] = nn.Softmax(-2)(e_ij[i_edges])
        return a_ij
    
    
    def message(self, edge_index, x_i, x_j):
        """
        Arguments
        ---------
        edge_index : torch.tensor [E, 2]
        x_i : torch.tensor [E, d_x]
        x_j : torch.tensor [E, d_x]
        
        Returns
        -------
        m_j : torch.tensor [E, d_h]
        """
        a_ij = self.alpha(edge_index, x_i, x_j)
        m_j = a_ij * self.W(x_j)
        return m_j
    
    def update(self, aggr_out):
        """
        Arguments
        ---------
        aggr_out : torch.tensor [N, d_h]
        
        Returns
        -------
        h : torch.tensor [N, d_h]
        """
        m = aggr_out
        h = nn.Softmax(-2)(m)
        return h
    
    
    def forward(self, edge_index, x):
        """
        Arguments
        ---------
        edge_index : torch.tensor [2, E]
        x : torch.tensor [N, d_x]
        
        Returns
        -------
        h : torch.tensor [N, d_h]
        """
        edge_index,_ = add_self_loops(edge_index, num_nodes = x.shape[-2])
        h = self.propagate(edge_index = edge_index, x=x)
        return h

In [855]:
class MultiHeadAttention(nn.Module):
    def __init__(self, K, d_x, d_h):
        super(MultiHeadAttention, self).__init__()
        self.heads = nn.ModuleList([
            SingleHeadAttention(d_x, d_h)
            for _ in range(K)
        ])
        
        self.W = nn.Linear(K*d_h, d_h)
        
    def forward(self, edge_index, x):
        """
        Arguments
        ---------
        edge_index : torch.tensor [2, E]
        x : torch.tensor [N, d_x]
        
        Returns
        -------
        h : torch.tensor [N, d_h]
        """
        h_k = [head(edge_index, x) for head in self.heads]
        h = self.W(torch.cat(h_k, dim = -1))
        return h

In [776]:
def undirected_message(self, edge_index, x_i, x_j):
    """ 
    For an undirected graph, the MessagePassing class can't be used,
    because the attention weights must sum up for every node, which means, 
    that they will be different for every node.
    Therefore we can't just pass 'x_j' to the aggregation function, 
    because that would imply, that the edges will have the same value for every node, but they don't.
    For a directed graph everything is fine, because edges are not shared.
    """
    neighbours, node = edge_index
    e_ij = self.attention(x_i, x_j)

    messages = list()
    for i in range(x.shape[-2]):
        i_edges = node == i
        a_ij = nn.Softmax(-2)(e_ij[i_edges])
        x_j = x[neighbours[i_edges]] 
        m_i = (a_ij * self.W(x_j)).sum(0)
        messages.append(m_i)
    m = torch.stack(messages, 0)
    return m