In [2]:
import torch.nn as nn
import torch
import torch.functional as F


class MLP(nn.Module):
    """MLP with linear output"""
    def __init__(self, num_layers, input_dim, hidden_dim, output_dim):

        super().__init__()
        self.linear_or_not = True  # default is linear model
        self.num_layers = num_layers
        self.output_dim = output_dim
        self.input_dim = input_dim

        if num_layers < 1:
            raise ValueError("number of layers should be positive!")
        elif num_layers == 1:
            # Linear model
            self.linear = nn.Linear(input_dim, output_dim)
        else:
            # Multi-layer model
            self.linear_or_not = False
            self.linears = torch.nn.ModuleList()
            self.batch_norms = torch.nn.ModuleList()

            self.linears.append(nn.Linear(input_dim, hidden_dim))
            for layer in range(num_layers - 2):
                self.linears.append(nn.Linear(hidden_dim, hidden_dim))
            self.linears.append(nn.Linear(hidden_dim, output_dim))

            for layer in range(num_layers - 1):
                self.batch_norms.append(nn.BatchNorm1d((hidden_dim)))

    def forward(self, x):
        if self.linear_or_not:
            # If linear model
            return self.linear(x)
        else:
            # If MLP
            h = x
            for i in range(self.num_layers - 1):
                h = F.relu(self.batch_norms[i](self.linears[i](h)))
            return self.linears[-1](h)


In [12]:
import torch_geometric.nn as pyg_nn

class GINEConvESLapPE(pyg_nn.conv.MessagePassing):
    """GINEConv Layer with EquivStableLapPE implementation.

    Modified torch_geometric.nn.conv.GINEConv layer to perform message scaling
    according to equiv. stable PEG-layer with Laplacian Eigenmap (LapPE):
        ICLR 2022 https://openreview.net/pdf?id=e95i1IHcWj
    """
    def __init__(self, nn, eps=0., train_eps=False, edge_dim=None, **kwargs):
        kwargs.setdefault('aggr', 'add')
        super().__init__(**kwargs)
        self.nn = nn
        self.initial_eps = eps
        if train_eps:
            self.eps = torch.nn.Parameter(torch.Tensor([eps]))
        else:
            self.register_buffer('eps', torch.Tensor([eps]))
        if edge_dim is not None:
            if hasattr(self.nn[0], 'in_features'):
                in_channels = self.nn[0].in_features
            else:
                in_channels = self.nn[0].in_channels
            self.lin = pyg_nn.Linear(edge_dim, in_channels)
        else:
            self.lin = None

        if hasattr(self.nn[0], 'in_features'):
            out_dim = self.nn[0].out_features
        else:
            out_dim = self.nn[0].out_channels
        print("out dimension: ", out_dim)
        # Handling for Equivariant and Stable PE using LapPE
        # ICLR 2022 https://openreview.net/pdf?id=e95i1IHcWj
        self.mlp_r_ij = torch.nn.Sequential(
            torch.nn.Linear(1, out_dim), torch.nn.ReLU(),
            torch.nn.Linear(out_dim, 1),
            torch.nn.Sigmoid())
        self.reset_parameters()

    def reset_parameters(self):
        pyg_nn.inits.reset(self.nn)
        self.eps.data.fill_(self.initial_eps)
        if self.lin is not None:
            self.lin.reset_parameters()
        pyg_nn.inits.reset(self.mlp_r_ij)

    def forward(self, x, edge_index, edge_attr=None, pe_LapPE=None, size=None):
        # if isinstance(x, Tensor):
        #     x: OptPairTensor = (x, x)

        # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)
        out = self.propagate(edge_index, x=x, edge_attr=edge_attr,
                             PE=pe_LapPE, size=size)

        x_r = x[1]
        if x_r is not None:
            out += (1 + self.eps) * x_r

        return self.nn(out)

    def message(self, x_j, edge_attr, PE_i, PE_j):
        if self.lin is None and x_j.size(-1) != edge_attr.size(-1):
            raise ValueError("Node and edge feature dimensionalities do not "
                             "match. Consider setting the 'edge_dim' "
                             "attribute of 'GINEConv'")

        if self.lin is not None:
            edge_attr = self.lin(edge_attr)

        # Handling for Equivariant and Stable PE using LapPE
        # ICLR 2022 https://openreview.net/pdf?id=e95i1IHcWj
        r_ij = ((PE_i - PE_j) ** 2).sum(dim=-1, keepdim=True)
        r_ij = self.mlp_r_ij(r_ij)  # the MLP is 1 dim --> hidden_dim --> 1 dim

        return ((x_j + edge_attr).relu()) * r_ij

    def __repr__(self):
        return f'{self.__class__.__name__}(nn={self.nn})'

In [13]:
from torch_geometric.nn import Linear as Linear_pyg
import torch.nn.functional as F
hidden_dim = 64
gineconvlayer = GINEConvESLapPE(nn.Sequential(Linear_pyg(hidden_dim, hidden_dim),
                                   nn.ReLU(), 
                                   Linear_pyg(hidden_dim, hidden_dim)))
total_params = sum(p.numel() for p in gineconvlayer.parameters())
print(f"Number of parameters: {total_params}")

out dimension:  64
Number of parameters: 8513


In [15]:
from torch_geometric.nn import Linear as Linear_pyg
import torch.nn.functional as F
class VGNLayer(nn.Module):
    def __init__(self, num_clusters, hidden_dim):
        super().__init__()
        self.num_clusters = num_clusters
        self.model = nn.ModuleList()
        for _ in range(num_clusters):
            gin_nn = nn.Sequential(Linear_pyg(hidden_dim, hidden_dim),
                                   nn.ReLU(), 
                                   Linear_pyg(hidden_dim, hidden_dim))
            self.model.append(GINEConvESLapPE(gin_nn))

    def forward(self, x, masks, edge_index, edge_attr=None, pe_LapPE=None, size=None):
        x_in = x
        for cluster in range(self.num_clusters): 
            x_tmp = x.clone()
            x = self.model[cluster](x, edge_index, edge_attr, pe_LapPE)
            x = torch.einsum('i,ij->ij', masks[cluster], x) + x_tmp
        
            if self.batch_norm:
                x = self.bn[cluster](x)
        
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        if self.residual:
            x = x_in + x
        
        return x


In [16]:
model = VGNLayer(3, 64)
total_params = sum(p.numel() for p in model.parameters())
print(f"Number of parameters: {total_params}")


out dimension:  64
out dimension:  64
out dimension:  64
Number of parameters: 25539


In [1]:
import torch
import torch.nn.functional as F

mlp_output = torch.randn(20, 3)
# raw_masks = F.softmax(mlp_output, dim=-1)
print("MLP output: ", mlp_output)
# print("MLP square: ", torch.abs(mlp_output))
raw_masks = torch.nn.functional.normalize(mlp_output, p=1, dim=1)
print("Raw masks: ", raw_masks)
print(torch.sum(raw_masks))

node_feature = torch.ones(20, 5)

masks = torch.transpose(raw_masks, 0, 1)
print("Masks: ", masks)
updated_feature = torch.einsum('i,ij->ij', masks[0], node_feature)
print("updated features: ", updated_feature)

MLP output:  tensor([[ 0.7243,  1.6338,  0.6300],
        [-0.3052, -0.4431, -0.1276],
        [-0.4478, -0.8402, -0.3304],
        [ 1.2718, -1.4993, -0.6587],
        [ 0.6263,  0.0116, -0.2667],
        [-0.6787, -0.1957,  0.9593],
        [-0.5144, -1.3066, -1.5799],
        [ 0.0691, -0.4715, -1.1053],
        [-0.7641, -2.1179, -0.1819],
        [ 1.5850,  0.8496,  1.8045],
        [ 1.2271, -1.3980,  1.3648],
        [-0.3439, -0.3490, -0.1845],
        [ 0.9314, -1.0427, -0.2660],
        [-1.6444, -0.7975, -0.5241],
        [-1.1393, -1.0675, -0.4113],
        [ 0.5542, -0.7812,  0.8467],
        [ 1.4799,  0.6455, -0.5353],
        [ 1.1460,  0.8029,  1.8106],
        [ 0.4462,  1.3766, -1.2718],
        [-0.3623,  1.0443, -1.1626]])
Raw masks:  tensor([[ 0.2424,  0.5468,  0.2108],
        [-0.3485, -0.5059, -0.1457],
        [-0.2767, -0.5192, -0.2042],
        [ 0.3708, -0.4371, -0.1920],
        [ 0.6924,  0.0128, -0.2949],
        [-0.3701, -0.1067,  0.5232],
        [-0.

In [15]:
a = torch.Tensor([1,2,3])
torch.nn.functional.normalize(a, dim=0)


tensor([0.2673, 0.5345, 0.8018])

In [6]:
import torch
a = torch.randn(5, 3)
for i in range(5):
    a[i][0] = 2
b = torch.nn.functional.normalize(a, dim=0)
print(a)
print(b)

tensor([[ 2.0000, -0.5828, -0.1285],
        [ 2.0000, -0.1852,  1.4881],
        [ 2.0000, -0.0696, -1.0252],
        [ 2.0000, -0.2703,  0.5161],
        [ 2.0000,  0.5234,  0.5290]])
tensor([[ 0.4472, -0.6841, -0.0657],
        [ 0.4472, -0.2173,  0.7606],
        [ 0.4472, -0.0817, -0.5240],
        [ 0.4472, -0.3173,  0.2638],
        [ 0.4472,  0.6143,  0.2704]])
