In [12]:
import numpy as np
from torch.nn import LSTM
import torch.nn as nn
import torch
import torchvision
from torch.optim import Adam
from sklearn.metrics import *
from torch_geometric.nn import MetaLayer
from torch.nn import Sequential as Seq, Linear as Lin, ReLU
from torch_scatter import *
# device = 0
import argparse

In [122]:
class MyLayer(torch.nn.Module):
    def __init__(self):
        super(MyLayer, self).__init__()

        self.edge_mlp = Seq(Lin(8, 20), ReLU(), Lin(20, 1))
        self.node_mlp_1 = Seq(Lin(3, 20), ReLU(), Lin(20, 2))
        self.node_mlp_2 = Seq(Lin(3, 20), ReLU(), Lin(20, 2))
        self.global_mlp = Seq(Lin(3, 20), ReLU(), Lin(20, 2))

        def edge_model(src, dest, edge_attr, u, batch):
            # source, target: [E, F_x], where E is the number of edges.
            # edge_attr: [E, F_e]
            # u: [B, F_u], where B is the number of graphs.
            # batch: [E] with max entry B - 1.
            out = torch.cat([src, dest, edge_attr, u[batch]], 1)
            return self.edge_mlp(out)

        def node_model(x, edge_index, edge_attr, u, batch):
            # x: [N, F_x], where N is the number of nodes.
            # edge_index: [2, E] with max entry N - 1.
            # edge_attr: [E, F_e]
            # u: [B, F_u]
            # batch: [N] with max entry B - 1.
            row, col = edge_index
            out = torch.cat([x[col], edge_attr], dim=1)
            out = self.node_mlp_1(out)
            out = scatter_mean(out, row, dim=0, dim_size=x.size(0))
            out = torch.cat([out, u[batch]], dim=1)
            return self.node_mlp_2(out)

        def global_model(x, edge_index, edge_attr, u, batch):
            # x: [N, F_x], where N is the number of nodes.
            # edge_index: [2, E] with max entry N - 1.
            # edge_attr: [E, F_e]
            # u: [B, F_u]
            # batch: [N] with max entry B - 1.
            out = torch.cat([u, scatter_mean(x, batch, dim=0)], dim=1)
            return self.global_mlp(out)

        self.op = MetaLayer(edge_model, node_model, global_model)

    def forward(self, x, edge_index, edge_attr, u, batch):
        return self.op(x, edge_index, edge_attr, u, batch)

In [199]:
## 输入采用的格式
x = torch.rand(10,2)
## 2 node ,3 node, 5 node
edge_index = torch.LongTensor([[0,2,3,6,6,7,7],[1,3,4,5,7,8,9]])
row,col = edge_index 
# print(row,col)
edge_attr = torch.rand(7,3)
u = torch.Tensor([[1.],[2.],[3.]])
batch = torch.LongTensor([0,0,1,1,1,2,2,2,2,2])
batch.size()

torch.Size([10])

In [124]:
x

tensor([[0.8561, 0.2686],
        [0.6355, 0.9120],
        [0.9857, 0.5438],
        [0.5853, 0.1554],
        [0.4832, 0.5185],
        [0.0823, 0.0503],
        [0.1333, 0.3632],
        [0.7997, 0.7625],
        [0.3072, 0.1134],
        [0.5853, 0.3785]])

In [125]:
layer = MyLayer()

In [126]:
layer(x,edge_index,edge_attr,u,batch)

batch:  tensor([0, 1, 1, 2, 2, 2, 2]) src:  tensor([[0.8561, 0.2686],
        [0.9857, 0.5438],
        [0.5853, 0.1554],
        [0.1333, 0.3632],
        [0.1333, 0.3632],
        [0.7997, 0.7625],
        [0.7997, 0.7625]]) dest:  tensor([[0.6355, 0.9120],
        [0.5853, 0.1554],
        [0.4832, 0.5185],
        [0.0823, 0.0503],
        [0.7997, 0.7625],
        [0.3072, 0.1134],
        [0.5853, 0.3785]]) batch:  tensor([0, 1, 1, 2, 2, 2, 2])
x_columns :  tensor([[0.6355, 0.9120],
        [0.5853, 0.1554],
        [0.4832, 0.5185],
        [0.0823, 0.0503],
        [0.7997, 0.7625],
        [0.3072, 0.1134],
        [0.5853, 0.3785]])
out.shape :  torch.Size([7, 3]) edge_attr:  torch.Size([7, 1]) x[col].shape:  torch.Size([7, 2])
torch.Size([3, 3])


(tensor([[-0.1519, -0.3969],
         [-0.1384, -0.3714],
         [-0.1712, -0.5107],
         [-0.1498, -0.4985],
         [-0.1163, -0.4825],
         [-0.1018, -0.5654],
         [-0.1397, -0.5849],
         [-0.1506, -0.5916],
         [-0.1018, -0.5654],
         [-0.1018, -0.5654]], grad_fn=<AddmmBackward>), tensor([[ 0.0497],
         [-0.0121],
         [ 0.0072],
         [-0.0170],
         [ 0.0406],
         [-0.0662],
         [-0.0637]], grad_fn=<AddmmBackward>), tensor([[ 0.0187,  0.2285],
         [-0.0181,  0.3353],
         [-0.0488,  0.4398]], grad_fn=<AddmmBackward>))

In [7]:
# layer(x,edge_index,edge_attr,u,batch)
arr = torch.rand(10,2)
arr

tensor([[0.6307, 0.8820],
        [0.7498, 0.5209],
        [0.4813, 0.4273],
        [0.0594, 0.1129],
        [0.2499, 0.0397],
        [0.1566, 0.1587],
        [0.1767, 0.1347],
        [0.4546, 0.3925],
        [0.0209, 0.0132],
        [0.6159, 0.7637]])

In [15]:
scatter_max(arr,torch.LongTensor([0,0,1,1,2,2,3,3,4,4]),dim = 0,dim_size = 10)

(tensor([[0.7498, 0.8820],
         [0.4813, 0.4273],
         [0.2499, 0.1587],
         [0.4546, 0.3925],
         [0.6159, 0.7637],
         [0.0000, 0.0000],
         [0.0000, 0.0000],
         [0.0000, 0.0000],
         [0.0000, 0.0000],
         [0.0000, 0.0000]]), tensor([[ 1,  0],
         [ 2,  2],
         [ 4,  5],
         [ 7,  7],
         [ 9,  9],
         [-1, -1],
         [-1, -1],
         [-1, -1],
         [-1, -1],
         [-1, -1]]))

In [41]:

m = nn.Linear(20, 30)
input = torch.autograd.Variable(torch.randn(128, 20))
output = m(input)
print(output.size())

torch.Size([128, 30])


In [128]:
t = torch.Tensor([[1,2],[3,4],[5,6],[7,8]]).view(2,2,2)

In [150]:
a = []

In [151]:
a.append(torch.Tensor([[1,2]]));a.append(torch.Tensor([[3,4]]))

In [164]:
torch.cat([*a]).view(1,2,2)[:,0,:]

tensor([[1., 2.]])

In [169]:
arr = torch.zeros((2,3))

In [171]:
edge_index = torch.LongTensor([src,dist])

tensor([[1.0000, 0.8820],
        [0.7498, 0.5209],
        [0.4813, 0.4273],
        [0.0594, 0.1129],
        [0.2499, 0.0397],
        [0.1566, 0.1587],
        [0.1767, 0.1347],
        [0.4546, 0.3925],
        [0.0209, 0.0132],
        [0.6159, 0.7637]])

In [191]:
x = torch.Tensor([[1], [2], [3]])
print(x.size())
torch.Size([3, 1])
x.expand(10,3,1).reshape(-1,1).shape

torch.Size([3, 1])


torch.Size([30, 1])

In [204]:
batch = torch.LongTensor([0 for i in range(10)]).cuda(0)

In [205]:
batch + 1

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')