In [1]:
import torch
import torch.nn as nn
from torch_geometric.nn import MessagePassing
from typing import List

In [14]:
class MLP(nn.Module):

    def __init__(self,
                 input_dim: int,
                 output_dim: int,
                 num_neurons: List[int] = [64, 32],
                 hidden_act: str = 'ReLU',
                 out_act: str = 'Identity',
                 input_norm: str = 'None',
                 output_norm: str = 'None'):
        super(MLP, self).__init__()

        assert input_norm in ['Batch', 'Layer', 'None']
        assert output_norm in ['Batch', 'Layer', 'None']

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_neurons = num_neurons
        self.hidden_act = getattr(nn, hidden_act)()
        self.out_act = getattr(nn, out_act)()

        input_dims = [input_dim] + num_neurons
        output_dims = num_neurons + [output_dim]

        self.lins = nn.ModuleList()
        for i, (in_dim, out_dim) in enumerate(zip(input_dims, output_dims)):
            self.lins.append(nn.Linear(in_dim, out_dim))

        self.input_norm = self._get_norm_layer(input_norm, input_dim)
        self.output_norm = self._get_norm_layer(output_norm, output_dim)

    def forward(self, xs):
        xs = self.input_norm(xs)
        for i, lin in enumerate(self.lins[:-1]):
            xs = lin(xs)
            xs = self.hidden_act(xs)
        xs = self.lins[-1](xs)
        xs = self.out_act(xs)
        xs = self.output_norm(xs)
        return xs

    @staticmethod
    def _get_norm_layer(norm_method, dim):
        if norm_method == 'Batch':
            in_norm = nn.BatchNorm1d(dim)
        elif norm_method == 'Layer':
            in_norm = nn.LayerNorm(dim)
        elif norm_method == 'None':
            in_norm = nn.Identity()  # kinda placeholder
        else:
            raise RuntimeError("Not implemented normalization layer type {}".format(norm_method))
        return in_norm

    def _get_act(self, is_last):
        return self.out_act if is_last else self.hidden_act

In [15]:
class INLayer(MessagePassing):

    def __init__(self,
                 edge_indim: int,
                 edge_outdim: int,
                 node_indim: int,
                 node_outdim: int,
                 node_aggregator: str = 'add',
                 residual: bool = True,
                 **mlp_params):
        super(INLayer, self).__init__(aggr=node_aggregator)

        self.edge_model = MLP(input_dim=edge_indim + 2 * node_indim,
                              output_dim=edge_outdim,
                              **mlp_params)
        self.node_model = MLP(input_dim=edge_outdim + node_indim,
                              output_dim=node_outdim,
                              **mlp_params)

        self.residual = residual

    def forward(self,
                nf: torch.tensor,
                ef: torch.tensor,
                edge_idx: torch.tensor):
        nf_residual, ef_residual = nf, ef
        uef = self.edge_update(nf, ef, edge_idx)
        unf = self.propagate(edge_index=edge_idx, x=nf, edge_features=uef)
        if self.residual:
            unf, uef = unf + nf_residual, uef + ef_residual
        return unf, uef

    def edge_update(self, nf, ef, edge_index):
        row, col = edge_index
        x_i, x_j = nf[row], nf[col]
        uef = self.edge_model(torch.cat([x_i, x_j, ef], dim=-1))
        return uef

    def message(self, edge_features: torch.tensor):
        return edge_features

    def update(self,
               aggr_msg: torch.tensor,
               x: torch.tensor):
        unf = self.node_model(torch.cat([x, aggr_msg], dim=-1))
        return unf

In [16]:
edge_index = torch.tensor([[0, 1, 1, 2],
                            [1, 0, 2, 1]], dtype=torch.long)

n_nodes, n_edges = 3, 4

nf = torch.randn(n_nodes, 5)
ef = torch.randn(n_edges, 7)

layer = INLayer(7, 7, 5, 5)
unf, uef = layer(nf, ef, edge_index)
print(nf, ef)
print(unf, uef)

tensor([[ 0.8237,  1.2862,  0.2866,  0.5154, -1.2957],
        [ 1.3282, -0.0202, -1.6885, -0.0882,  1.2830],
        [-0.1396,  2.2594, -1.4746,  1.2366,  1.2765]]) tensor([[ 1.0289, -0.6426,  0.2979, -0.6827, -0.1849,  2.2004, -0.5025],
        [ 0.6252, -0.2165,  1.1382, -1.0760, -0.4326, -2.5266,  0.4149],
        [-1.1183, -0.4361, -0.9389, -0.8873, -0.9246,  0.1235, -0.3848],
        [-0.7542, -0.9807, -1.0409, -0.1864, -1.1971,  0.6327, -1.5552]])
tensor([[ 0.9691,  1.2410,  0.1860,  0.6028, -1.4554],
        [ 1.4674, -0.0973, -1.8262,  0.0207,  1.0966],
        [ 0.0510,  2.2055, -1.7267,  1.4641,  1.0188]], grad_fn=<AddBackward0>) tensor([[ 1.3906, -0.5732,  0.1154, -0.7796, -0.0753,  2.1897, -0.6814],
        [ 0.7467, -0.1715,  1.1247, -0.9585, -0.3087, -2.5348,  0.2288],
        [-0.9195, -0.4060, -1.0766, -0.9250, -0.8202,  0.0387, -0.5024],
        [-0.6658, -0.8893, -1.0313, -0.1558, -0.9605,  0.6137, -1.6716]],
       grad_fn=<AddBackward0>)


In [17]:
from torch_geometric.data import DataLoader
from torch_geometric.data import Data
nf = torch.randn(1024, n_nodes, 5)
ef = torch.randn(1024, n_edges, 7)
temp = DataLoader([Data(x=x, edge_index=edge_index, edge_attr=edge_attr, num_nodes=n_nodes) for x, edge_attr in zip(nf, ef)], batch_size=32)
data = next(iter(temp))
print(data)

DataBatch(x=[96, 5], edge_index=[2, 128], edge_attr=[128, 7], num_nodes=96, batch=[96], ptr=[33])




In [26]:
from torch_geometric.data import Batch
nf = torch.randn(32, 20, 2)
ef = torch.randn(32, 380, 1)
data_list = [Data(x=x, edge_index=edge_index, edge_attr=edge_attr) for x, edge_attr in zip(nf, ef)]
data_batch = Batch.from_data_list(data_list)
print(data_batch)
print(data_batch.x.shape)
print(data_batch.edge_index.shape)
print(data_batch.edge_attr.shape)

DataBatch(x=[640, 2], edge_index=[2, 128], edge_attr=[12160, 1], batch=[640], ptr=[33])
torch.Size([640, 2])
torch.Size([2, 128])
torch.Size([12160, 1])


In [28]:
print(edge_index.size())

torch.Size([2, 4])


In [27]:
layer = INLayer(7, 7, 5, 5)
unf, uef = layer(data.x, data.edge_attr, data.edge_index)
print(unf.size(), uef.size())

torch.Size([96, 5]) torch.Size([128, 7])


In [8]:
adj_matrix = torch.ones(20, 20)
adj_matrix.fill_diagonal_(0) # No self-loops
edge_index = torch.nonzero(adj_matrix)
print(edge_index.size())

torch.Size([380, 2])


In [11]:
node_feature = torch.randn((32, 20, 2))
edge_feature = torch.norm(node_feature[:, edge_index[:, 0], :] - node_feature[:, edge_index[:, 1], :], dim=-1, keepdim=True)
print(edge_feature.size())

torch.Size([32, 380])


In [6]:
print(torch.norm(torch.ones((1, 2)) - torch.zeros((1, 2)), p=2, dim=-1, keepdim=True))

tensor([[1.4142]])


In [35]:
a = torch.randn((32, 32, 20, 2))
print(a.size()[:2])
batch_size = a.size()[:-2]
shape_size = a.size()[-2:]
b = torch.rand((32 * 32, 20, 2))
print(b.view(*batch_size, *shape_size).size())

torch.Size([32, 32])
torch.Size([32, 32, 20, 2])
