In [19]:
#from common import *
import torch_geometric.nn as gnn
#

class LinearBn(nn.Module):
    def __init__(self, in_channel, out_channel, act=None):
        super(LinearBn, self).__init__()
        self.linear = nn.Linear(in_channel, out_channel, bias=False)
        self.bn   = nn.BatchNorm1d(out_channel,eps=1e-05, momentum=0.1)
        self.act  = act

    def forward(self, x):
        x = self.linear(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.act is not None:
            x = self.act(x)
        return x


#message passing
class Net(torch.nn.Module):
    def __init__(self, node_dim=13, edge_dim=5, num_target=8):
        super(Net, self).__init__()

        self.num_message_passing = 6
        node_hidden_dim=128
        edge_hidden_dim=128

        self.preprocess = nn.Sequential(
            LinearBn(node_dim, 64),
            nn.ReLU(),
            LinearBn(64, node_hidden_dim),
        )
        edge_net = nn.Sequential(
            LinearBn(edge_dim, 32),
            nn.ReLU(),
            LinearBn(32, 64),
            nn.ReLU(),
            LinearBn(64, edge_hidden_dim),
            nn.ReLU(),
            LinearBn(edge_hidden_dim, node_hidden_dim * node_hidden_dim) # edge_hidden_dim,  node_hidden_dim *node_hidden_dim
        )

        self.conv = gnn.NNConv(node_hidden_dim, node_hidden_dim, edge_net, aggr='mean', root_weight=True) #node_hidden_dim, node_hidden_dim
        self.gru  = nn.GRU(node_hidden_dim, node_hidden_dim)
        self.set2set = gnn.Set2Set(node_hidden_dim, processing_steps=6) # node_hidden_dim

        #predict coupling constant
        self.predict = nn.Sequential(
            LinearBn(4*node_hidden_dim, 512),  #node_hidden_dim
            nn.ReLU(),
            nn.Linear(512, num_target),
        )

    def forward(self, node, edge, edge_index, node_batch_index, coupling_index, coupling_type, coupling_batch_index):

        #----
        edge_index = edge_index.t().contiguous()

        x = F.relu(self.preprocess(node))
        h = x.unsqueeze(0)

        for i in range(self.num_message_passing):
            m    = F.relu(self.conv(x, edge_index, edge))
            x, h = self.gru(m.unsqueeze(0), h)
            x = x.squeeze(0)
        #x =  num_node, node_hidden_dim

        pool = self.set2set(x, node_batch_index) # global pool
        pool = torch.index_select(
            pool,
            dim=0,
            index=coupling_batch_index
        )
        x = torch.index_select(
            x,
            dim=0,
            index=coupling_index.view(-1)
        ).reshape(len(coupling_index),-1)

        x = torch.cat([pool,x],-1)
        predict = self.predict(x)

        predict = torch.gather(predict,1,coupling_type.view(-1,1)).view(-1)
        return predict


# def criterion(predict, coupling_value):
#     predict = predict.view(-1)
#     coupling_value = coupling_value.view(-1)
#     assert(predict.shape==coupling_value.shape)
#
#     loss = F.mse_loss(predict, coupling_value)
#     return loss


def criterion(predict, coupling_value):
    predict = predict.view(-1)
    coupling_value = coupling_value.view(-1)
    assert(predict.shape==coupling_value.shape)

    loss = torch.abs(predict-coupling_value)
    loss = loss.mean()
    loss = torch.log(loss)
    return loss


##################################################################################################################

def make_dummy_data(node_dim, edge_dim, num_target, batch_size):

    #dummy data
    num_node = []
    num_edge = []

    node = []
    edge = []
    edge_index  = []
    batch_node_index = []

    coupling_index = []
    coupling_type  = []
    coupling_value = []
    batch_coupling_index = []


    for b in range(batch_size):
        node_offset = sum(num_node)
        edge_offset = sum(num_edge)

        N = np.random.choice(10)+8
        E = np.random.choice(10)+16
        node.append(np.random.uniform(-1,1,(N,node_dim)))
        edge.append(np.random.uniform(-1,1,(E,edge_dim)))

        edge_index.append(np.random.choice(N, (E,2))+node_offset)
        batch_node_index.extend([b]*N)

        #---
        C = np.random.choice(10)+1
        coupling_index.append(np.random.choice(N,(C,2))+node_offset)
        coupling_type.append(np.random.choice(num_target, C))
        coupling_value.append(np.random.uniform(-1,1, C))
        batch_coupling_index.extend([b]*C)

        #---
        num_node.append(N)
        num_edge.append(E)


    node = torch.from_numpy(np.concatenate(node)).float().cuda()
    edge = torch.from_numpy(np.concatenate(edge)).float().cuda()
    edge_index  = torch.from_numpy(np.concatenate(edge_index)).long().cuda()
    batch_node_index = torch.from_numpy(np.array(batch_node_index)).long().cuda()

    #---
    coupling_index = torch.from_numpy(np.concatenate(coupling_index)).long().cuda()
    coupling_type  = torch.from_numpy(np.concatenate(coupling_type)).long().cuda()
    coupling_value = torch.from_numpy(np.concatenate(coupling_value)).float().cuda()
    batch_coupling_index = torch.from_numpy(np.array(batch_coupling_index)).long().cuda()

    return node,edge,edge_index, batch_node_index, coupling_index,coupling_type,coupling_value,batch_coupling_index



def run_check_net():

    #dummy data
    node_dim = 5
    edge_dim = 7
    num_target = 8
    batch_size = 16
    node,edge,edge_index, batch_node_index, coupling_index,coupling_type,coupling_value,batch_coupling_index = \
        make_dummy_data(node_dim, edge_dim, num_target, batch_size)

    print('batch_size ', batch_size)
    print('----')
    print('node',node.shape)
    print('edge',edge.shape)
    print('edge_index',edge_index.shape)
    print('batch_node_index',batch_node_index.shape)
    print(batch_node_index)
    print('----')

    print('coupling_index',coupling_index.shape)
    print('coupling_type',coupling_type.shape)
    print('coupling_value',coupling_value.shape)
    print('batch_coupling_index',batch_coupling_index.shape)
    print(batch_coupling_index)
    print('')

    #---
    net = Net(node_dim=node_dim, edge_dim=edge_dim, num_target=num_target).cuda()
    net = net.eval()

    predict = net(node,edge,edge_index, batch_node_index, coupling_index, coupling_type, batch_coupling_index )

    print('predict: ', predict.shape)
    print(predict)
    print('')



def run_check_train():

    node_dim = 15
    edge_dim = 5
    num_target =12
    batch_size = 64
    node,edge,edge_index, batch_node_index, coupling_index,coupling_type,coupling_value,batch_coupling_index = \
        make_dummy_data(node_dim, edge_dim, num_target, batch_size)


    net = Net(node_dim=node_dim, edge_dim=edge_dim, num_target=num_target).cuda()
    net = net.eval()


    predict = net(node,edge,edge_index, batch_node_index, coupling_index, coupling_type, batch_coupling_index )
    loss = criterion(predict, coupling_value)


    print('*loss = %0.5f'%( loss.item(),))
    print('')

    print('predict: ', predict.shape)
    print(predict)
    print(coupling_value)
    print('')

    # dummy sgd to see if it can converge ...
    optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()),
                      lr=0.01, momentum=0.9, weight_decay=0.0001)

    print('--------------------')
    print('[iter ]  loss       ')
    print('--------------------')

    i=0
    optimizer.zero_grad()
    while i<=500:
        net.train()
        optimizer.zero_grad()

        predict = net(node,edge,edge_index, batch_node_index, coupling_index,coupling_type,batch_coupling_index)
        loss = criterion(predict, coupling_value)

        loss.backward()
        optimizer.step()

        if i%10==0:
            print('[%05d] %8.5f  '%(
                i,
                loss.item(),
            ))
        i = i+1
    print('')

    #check results
    print(predict[:5])
    print(coupling_value[:5])
    print('')





# main #################################################################
if __name__ == '__main__':
    print( '%s: calling main function ... ' % os.path.basename(__file__))


    #run_check_net()
    run_check_train()

    print('\nsucess!')


ModuleNotFoundError: No module named 'torch_geometric'

In [4]:

def make_dummy_data(node_dim, edge_dim, num_target, batch_size):

    #dummy data
    num_node = []
    num_edge = []

    node = []
    edge = []
    edge_index  = []
    batch_node_index = []

    coupling_index = []
    coupling_type  = []
    coupling_value = []
    batch_coupling_index = []


    for b in range(batch_size):
        node_offset = sum(num_node)
        edge_offset = sum(num_edge)

        N = np.random.choice(10)+8
        E = np.random.choice(10)+16
        node.append(np.random.uniform(-1,1,(N,node_dim)))
        edge.append(np.random.uniform(-1,1,(E,edge_dim)))

        edge_index.append(np.random.choice(N, (E,2))+node_offset)
        batch_node_index.extend([b]*N)

        #---
        C = np.random.choice(10)+1
        coupling_index.append(np.random.choice(N,(C,2))+node_offset)
        coupling_type.append(np.random.choice(num_target, C))
        coupling_value.append(np.random.uniform(-1,1, C))
        batch_coupling_index.extend([b]*C)

        #---
        num_node.append(N)
        num_edge.append(E)


    node = torch.from_numpy(np.concatenate(node)).float().cuda()
    edge = torch.from_numpy(np.concatenate(edge)).float().cuda()
    edge_index  = torch.from_numpy(np.concatenate(edge_index)).long().cuda()
    batch_node_index = torch.from_numpy(np.array(batch_node_index)).long().cuda()

    #---
    coupling_index = torch.from_numpy(np.concatenate(coupling_index)).long().cuda()
    coupling_type  = torch.from_numpy(np.concatenate(coupling_type)).long().cuda()
    coupling_value = torch.from_numpy(np.concatenate(coupling_value)).float().cuda()
    batch_coupling_index = torch.from_numpy(np.array(batch_coupling_index)).long().cuda()

    return node,edge,edge_index, batch_node_index, coupling_index,coupling_type,coupling_value,batch_coupling_index


In [13]:
import numpy as np
import torch
node,edge,edge_index, batch_node_index, coupling_index,coupling_type,coupling_value,batch_coupling_index = make_dummy_data(node_dim=5, edge_dim=5, num_target=5, batch_size=2)

In [14]:
node

tensor([[ 0.4723, -0.8067,  0.3258, -0.6208,  0.5560],
        [-0.9595,  0.7951, -0.9843,  0.8704, -0.7179],
        [-0.5064, -0.4117,  0.7831, -0.9848,  0.1357],
        [ 0.5371, -0.2131,  0.6461,  0.6492,  0.0830],
        [ 0.2731, -0.0661, -0.3318, -0.8232,  0.3908],
        [-0.4480,  0.1785, -0.7515,  0.2236, -0.5094],
        [ 0.4396, -0.5053, -0.3029, -0.9537,  0.0693],
        [-0.8159,  0.3805,  0.9136,  0.0340,  0.1214],
        [-0.9596, -0.6928,  0.1985, -0.6410, -0.4084],
        [-0.5625, -0.7217, -0.1158,  0.9381,  0.8925],
        [-0.8264, -0.3035, -0.4280, -0.4427, -0.2724],
        [ 0.6564,  0.8210, -0.6512,  0.4082, -0.8642],
        [ 0.0358,  0.4467,  0.7292, -0.2749, -0.6451],
        [ 0.8957,  0.8138, -0.9596,  0.9689,  0.4266],
        [ 0.8417, -0.2831,  0.3199, -0.3972, -0.5284],
        [-0.7090, -0.4546, -0.2164, -0.1280,  0.0338],
        [-0.9664, -0.5584, -0.8256, -0.9712, -0.8293],
        [-0.4981, -0.3515, -0.8003,  0.2347, -0.2453],
        [-

In [15]:
edge

tensor([[ 0.7957,  0.3686, -0.3072, -0.5834,  0.0216],
        [ 0.8477,  0.1664,  0.1391, -0.9147,  0.8251],
        [ 0.2829,  0.8254,  0.3780,  0.3306, -0.7319],
        [ 0.2086,  0.0928, -0.7925, -0.6176, -0.1256],
        [-0.9840,  0.3911, -0.6141,  0.5109,  0.6674],
        [ 0.6545,  0.6847, -0.8481,  0.1086,  0.8280],
        [-0.8996, -0.2289,  0.0357,  0.7494, -0.4473],
        [-0.8205,  0.9878,  0.3918,  0.0305,  0.5299],
        [ 0.9179, -0.6688, -0.4919,  0.6886, -0.7483],
        [ 0.6495, -0.0775,  0.5435, -0.4826,  0.5591],
        [-0.6250,  0.6022, -0.9049, -0.3680, -0.9929],
        [ 0.9259,  0.0506,  0.1670, -0.0649, -0.3554],
        [ 0.5866,  0.5458, -0.2946, -0.1616,  0.2676],
        [ 0.2844, -0.5653,  0.2541, -0.0351, -0.1361],
        [ 0.7393, -0.7709, -0.5858, -0.6075,  0.6806],
        [-0.1572,  0.7897, -0.5434, -0.9317,  0.8409],
        [ 0.1016, -0.1637,  0.0260, -0.9999, -0.7324],
        [ 0.6425, -0.7445,  0.3698, -0.8365, -0.2793],
        [-

In [16]:
edge_index

tensor([[ 1, 11],
        [ 5,  4],
        [13,  5],
        [ 2, 10],
        [ 9,  6],
        [ 4,  2],
        [10,  9],
        [ 4, 10],
        [ 5,  4],
        [ 5,  5],
        [ 7,  0],
        [ 7, 13],
        [ 5,  3],
        [ 6,  0],
        [13,  9],
        [13, 11],
        [ 3, 10],
        [29, 20],
        [19, 23],
        [23, 17],
        [28, 28],
        [15, 30],
        [18, 30],
        [19, 29],
        [30, 17],
        [29, 30],
        [14, 24],
        [19, 24],
        [26, 22],
        [21, 20],
        [23, 17],
        [16, 27],
        [16, 23],
        [19, 29],
        [19, 29],
        [22, 30],
        [25, 15],
        [22, 30],
        [20, 20]], device='cuda:0')

In [17]:
coupling_value

tensor([-7.4636e-01, -2.2672e-01, -2.7217e-01, -7.2136e-04,  7.2412e-01,
        -2.7352e-01], device='cuda:0')

In [18]:
coupling_index

tensor([[ 4,  0],
        [11,  0],
        [29, 25],
        [16, 14],
        [19, 22],
        [15, 17]], device='cuda:0')