In [6]:
from data.generator import *

In [107]:
n = 6
p = 0.3

dataset = RandomGraphDataset(root='./data', gen_num_graph=10, n=n, p=p)

Generating 10 graphs


100%|██████████| 10/10 [00:00<00:00, 1385.81it/s]
Processing...
100%|██████████| 10/10 [00:00<00:00, 404.36it/s]
Done!


In [108]:
dataset[0]

CLRSData(edge_index=[2, 10], pos=[6], length=4, s=0, pi=[10], reach_h=[4, 6], pi_h=[4, 10], hints=[2], inputs=[2], outputs=[1])

In [134]:
import torch
from torch_geometric.nn import MessagePassing
from torch.nn import Linear

class MPNN(MessagePassing):
  def __init__(self, in_channels, hidden_channels, activation=None):
    super(MPNN, self).__init__(aggr='max') #  "Max" aggregation.
    self.in_channels = in_channels
    self.hidden_channels = hidden_channels
    self.messages = Linear(self.in_channels * 2, self.hidden_channels)
    self.update_fn = Linear(self.in_channels + self.hidden_channels, self.hidden_channels)
    self.activation = activation

    self.mlp = torch.nn.Sequential(
        Linear(2 * in_channels, hidden_channels),
        torch.nn.ReLU(),
        Linear(hidden_channels, hidden_channels)
    )
    
  def forward(self, x, edge_index):
    out = self.propagate(edge_index, x=x)
    out = self.mlp(out)
    if self.activation is not None:
      out = self.activation(out)
    return out
    
  def message(self, x_i, x_j):
    # x_i has shape [E, in_channels]
    # x_j has shape [E, in_channels]
    #print('MPNN => xi, xj', x_i.size(), x_j.size())
    tmp = torch.cat([x_i, x_j], dim=1)  # tmp has shape [E, 2 * in_channels]
    #print('MPNN => messages IN', tmp.size())
    m = self.messages(tmp)
    #print('MPNN => messages OUT', m.size())
    return m
  
  def update(self, aggr_out, x):
    # aggr_out has shape [N, out_channels]
    # x has shape [N, in_channels]
    #print(f'MPNN => x_i', x.size(), ' aggr_out ', aggr_out.size())
    tmp = torch.cat([x, aggr_out], dim=1)
    #print(f'MPNN => tmp', tmp.size())
    return self.update_fn(tmp)

In [122]:
from models.encoder import Encoder

In [123]:
dataset[0].pi

tensor([1., 1., 0., 1., 1., 0., 0., 0., 1., 0.], dtype=torch.float64)

In [124]:
# create an encoder class
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim=128):
        super(Encoder, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.lin = nn.Linear(1, hidden_dim)

    def forward(self, x):
        if x.dim() == 1:
            x = x.unsqueeze(-1)
        return self.lin(x)

In [125]:
d1['reach_h'][-1]

tensor([1., 1., 1., 1., 1., 1.], dtype=torch.float64)

In [126]:
encoder = Encoder(1, 128)

In [127]:
d1 = dataset[0]
# concat pos and d1.reach[-1]
x = torch.cat([d1.pos, d1['reach_h'][1]], dim=-1)


In [128]:
x.shape

torch.Size([12])

In [129]:
enc = encoder(x.to(encoder.lin.weight.dtype))
enc.shape

torch.Size([12, 128])

In [135]:
mpnn = MPNN(128, 128)

mpnn(enc, d1.edge_index)

tensor([[-0.2359, -0.0652, -0.5170,  ..., -0.1245, -0.3441,  0.3432],
        [-0.1635, -0.1216, -0.4853,  ..., -0.1552, -0.3783,  0.3189],
        [-0.0921, -0.1402, -0.3366,  ..., -0.1429, -0.3408,  0.3960],
        ...,
        [-0.2694,  0.0019, -0.3031,  ..., -0.0201, -0.2309,  0.4045],
        [-0.2694,  0.0019, -0.3031,  ..., -0.0201, -0.2309,  0.4045],
        [-0.2694,  0.0019, -0.3031,  ..., -0.0201, -0.2309,  0.4045]],
       grad_fn=<AddmmBackward0>)

In [102]:
len(d1.edge_index[0])

2966