In [1]:
import numpy as np
import torch, queue
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch.utils.data import Dataset
from src.transforms import *
from torch_geometric.data import Data, HeteroData

In [31]:
class GNN_VN_Model(torch.nn.Module):
    """
    GNN model that customizes the torch_geometric.graphgym.models.gnn.GNN
    to support specific handling of new conv layers.
    """
    def __init__(self, input=3, output=20, hidden=20, layers=2, 
                 layer_type='GATConv', activation='LeakyReLU', batches=False, **kwargs):
        super(GNN_VN_Model, self).__init__()

        torch.manual_seed(1234567)
        # Initialize the first layer
        graph_layer = globals()[layer_type]
        self.initial = graph_layer(input, hidden)
        
        # Initialize the subsequent layers
        self.module_list = nn.ModuleList([graph_layer(hidden, hidden) for _ in range(layers - 1)])
        
        # Output layer
        self.output = graph_layer(hidden, output)

        # activation function
        self.activation = globals()[activation]()

        # added by Chen; VN part.
        self.virtualnode_embedding = torch.nn.Embedding(1, hidden)
        self.mlp_virtualnode_list = torch.nn.ModuleList()
        torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0)
        for layer in range(layers  - 1):
            if batches:
                self.mlp_virtualnode_list.append(
                    torch.nn.Sequential(torch.nn.Linear(hidden, hidden), torch.nn.BatchNorm1d(hidden), torch.nn.ReLU(), \
                                        torch.nn.Linear(hidden, hidden), torch.nn.BatchNorm1d(hidden), torch.nn.ReLU()))
            else:
                self.mlp_virtualnode_list.append(
                    torch.nn.Sequential(torch.nn.Linear(hidden, hidden), torch.nn.ReLU(), \
                                        torch.nn.Linear(hidden, hidden), torch.nn.ReLU()))
                
    def forward(self, x, edge_index, h_blocks, h_levels, h_num):
        out = self.initial(x, edge_index)
        vn_direct = self.virtualnode_embedding(torch.zeros(h_num).to(edge_index.dtype).to(edge_index.device))
        vn_root = self.virtualnode_embedding(torch.zeros(1).to(edge_index.dtype).to(edge_index.device))

        for layer in self.module_list:
            # Get information from virtual nodes
            out = out + vn_direct[h_blocks]
            out = layer(out, edge_index)

            # Get information from real nodes + root virtual node
            vn_direct = global_add_pool(out, h_blocks) + vn_direct
            vn_direct = vn_direct + vn_root

            # Root VN gets information from vn_direct
            vn_root = global_add_pool(vn_direct, None, size=1) + vn_root
            
            for mlp_layer in self.mlp_virtualnode_list:
                vn_direct = mlp_layer(vn_direct)
                vn_root = mlp_layer(vn_root)
            
            
        return out

class TerrainHeteroData(HeteroData):
    def __inc__(self, key, value, *args, **kwargs):
        if key == 'src':
            return self['real'].x.size(0)
        if key == 'tar':
            return self['real'].x.size(0)
        return super().__inc__(key, value, *args, **kwargs)

class TerrainPatchesData(Data):
    def __inc__(self, key, value, *args, **kwargs):
        if key == 'src':
            return self.x.size(0)
        if key == 'tar':
            return self.x.size(0)
        return super().__inc__(key, value, *args, **kwargs)

In [32]:
x1 = torch.randn(3, 16)  
ei_1 = torch.tensor([
    [0, 1, 0, 2],
    [1, 0, 2, 0],
])

h_blocks = torch.tensor([0, 0, 1])
h_num = 2

gnn = GNN_VN_Model(input=16)

gnn(x2, ei_1, h_blocks, 1, h_num)


tensor([[-0.5097, -1.4739,  0.3424,  1.5051,  0.4448,  1.0967, -0.4457,  1.0000,
          0.2805, -1.1216, -0.5548, -0.0211, -0.7743, -0.3636, -0.3101,  0.2360,
         -0.9018, -0.2084,  0.0796, -1.0028],
        [-0.3301, -1.1256,  0.3964,  1.0767,  0.3273,  0.8585, -0.2303,  0.6894,
          0.2389, -0.5163, -0.7548, -0.1384, -0.6945, -0.3687, -0.1754,  0.2559,
         -0.8168, -0.1085,  0.2097, -0.8211],
        [-0.6196, -1.7179,  0.2739,  1.7983,  0.5208,  1.2735, -0.5614,  1.1988,
          0.3197, -1.5458, -0.4141,  0.0691, -0.8276, -0.3641, -0.4109,  0.2128,
         -0.9859, -0.2710, -0.0163, -1.1102]], grad_fn=<AddBackward0>)

In [15]:
x1 = torch.randn(3, 16)  
ei_1 = torch.tensor([
    [0, 1, 0, 2],
    [1, 0, 2, 0],
])

x2 = torch.randn(3, 16)  
ei_2 = torch.tensor([
    [0, 1, 0, 2],
    [1, 0, 2, 0],
])

z

In [16]:
data0 = TerrainPatchesData(x=x1, edge_index = ei_1, src =1, tar=2)
data1 = TerrainPatchesData(x=x2, edge_index = ei_2, src = 0, tar=1)
lst = [data0, data1]


In [22]:
gnn = GNN_VN_Model(input=16)

loader = DataLoader(lst, batch_size=2, follow_batch=[ 'src', 'tar'])
batch = next(iter(loader))
print(batch.src)

out = gnn(batch.x, batch.edge_index, batch=batch)
print(out.size())

tensor([1, 3])
tensor([0, 0, 0, 1, 1, 1])
vn_emb, batch.batch torch.Size([6, 20])
torch.Size([6, 20])


In [43]:
class TerrainHeteroData(HeteroData):
    def __inc__(self, key, value, *args, **kwargs):
        if key == 'src':
            return self['real'].x.size(0)
        if key == 'tar':
            return self['real'].x.size(0)
        return super().__inc__(key, value, *args, **kwargs)

class TerrainPatchesData(Data):
    def __inc__(self, key, value, *args, **kwargs):
        if key == 'src':
            return self.x.size(0)
        if key == 'tar':
            return self.x.size(0)
        return super().__inc__(key, value, *args, **kwargs)
    

In [44]:
def add_virtual_node(data):
    hetero_data = TerrainHeteroData()
    sz_features = data.x.size()[1]
    hetero_data.src = data.src
    hetero_data.tar = data.tar
    hetero_data['real'].x = data.x.double()
    hetero_data['real', 'e1', 'real'].edge_index = data.edge_index

    vn = torch.zeros(size = (1, sz_features), dtype=torch.double )
    hetero_data['vn'].x = vn
    vn_edge_index = [[], []]
    for i in range(data.x.size()[0]):
        vn_edge_index[0].append(0)
        vn_edge_index[1].append(i)
    hetero_data['vn', 'e2', 'real'].edge_index = torch.tensor(vn_edge_index, dtype=torch.long)

    return hetero_data

In [45]:
x1 = torch.randn(3, 16)  
ei_1 = torch.tensor([
    [0, 1, 0, 2],
    [1, 0, 2, 0],
])

x2 = torch.randn(3, 16)  
ei_2 = torch.tensor([
    [0, 1, 0, 2],
    [1, 0, 2, 0],
])

In [46]:
data0 = TerrainPatchesData(x=x1, edge_index = ei_1, src =1, tar=2)
data1 = TerrainPatchesData(x=x2, edge_index = ei_2, src = 0, tar=1)
vn1 = add_virtual_node(data0)
vn2 = add_virtual_node(data1)
lst = [data0, data1]
lst_vn = [vn1, vn2]
print(vn1)

TerrainHeteroData(
  src=1,
  [1mreal[0m={ x=[3, 16] },
  [1mvn[0m={ x=[1, 16] },
  [1m(real, e1, real)[0m={ edge_index=[2, 4] },
  [1m(vn, e2, real)[0m={ edge_index=[2, 3] }
)


In [47]:
loader = DataLoader(lst_vn, batch_size=2, follow_batch = ['src'])
batch = next(iter(loader))
print(batch['vn', 'e2', 'real'])
print(batch['src'])

{'edge_index': tensor([[0, 0, 0, 1, 1, 1],
        [0, 1, 2, 3, 4, 5]])}
tensor([1, 3])


In [29]:
data0 = TerrainPatchesData(x = data[0].x, edge_index = data[0].edge_index, src=np.int64(1), tar=np.int64(2))
data1 = TerrainPatchesData(x = data[1].x, edge_index = data[1].edge_index, src=np.int64(0), tar=np.int64(0))
lst = [data0, data1]

In [30]:
loader = DataLoader(lst, batch_size=2, follow_batch=[ 'src', 'tar'])
batch = next(iter(loader))
print(batch.src)

[1, 0]
