In [1]:
# import generator from '../data/generator.py'
import sys
import os
from generator import *


2024-03-07 16:25:44.266929: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


# Graph Generation

In [2]:
n = 6
p = 0.3

dataset = RandomGraphDataset(root='./data', gen_num_graph=10, n=n, p=p)

In [3]:
import torch
from torch import nn
from encoder import Encoder
from decoder import Decoder
from mpnn import MPNN

class Network(nn.Module):
    def __init__(self, latent_dim=128):
        super(Network, self).__init__()
        self.encoder = Encoder(2, latent_dim)
        self.processor = MPNN(latent_dim*2, latent_dim)
        self.decoder = Decoder(latent_dim, 1)

    def forward(self, x, edge_index):
        z = self.encoder(x)
        h = torch.zeros(x.size(0), 128)
        processor_input = torch.cat([z, h], dim=1)
        x = self.processor(processor_input, edge_index)
        x = self.decoder(x)
        return x

In [5]:
input = torch.stack((dataset[0].pos, dataset[0].s), dim=1)

In [6]:
net = Network()
input = input.to(list(net.parameters())[0].dtype)

In [7]:
Network().forward(input, dataset[0].edge_index)

tensor([[-0.0262],
        [-0.0408],
        [-0.0350],
        [-0.0424],
        [-0.0466],
        [-0.0516]], grad_fn=<AddmmBackward0>)

In [177]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim=128):
        super(Encoder, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.lin = nn.Linear(input_dim, hidden_dim)

    def forward(self, x):
        if x.dim() == 1:
            x = x.unsqueeze(-1)
        return self.lin(x)

In [122]:
encoder = Encoder(2)
z = encoder(input)

In [124]:
h = torch.zeros(input.size(0), 128)

In [125]:
z_ = torch.cat([z, h], dim=1)
z_.size()

torch.Size([6, 256])

In [159]:
import torch
from torch_geometric.nn import MessagePassing
from torch.nn import Linear

class MPNN(MessagePassing):
  def __init__(self, in_channels, hidden_channels, activation=None):
    super(MPNN, self).__init__(aggr='max') #  "Max" aggregation.
    self.in_channels = in_channels
    self.hidden_channels = hidden_channels
    self.messages = Linear(self.in_channels * 2, self.hidden_channels)
    self.update_fn = Linear(self.in_channels + self.hidden_channels, self.hidden_channels)
    self.activation = activation

    self.mlp = torch.nn.Sequential(
        Linear(hidden_channels, hidden_channels),
        torch.nn.ReLU(),
        Linear(hidden_channels, self.hidden_channels)
    )
    
  def forward(self, x, edge_index):
    out = self.propagate(edge_index, x=x)
    out = self.mlp(out)
    if self.activation is not None:
      out = self.activation(out)
    return out
    
  def message(self, x_i, x_j):
    # x_i has shape [E, in_channels]
    # x_j has shape [E, in_channels]
    #print('MPNN => xi, xj', x_i.size(), x_j.size())
    tmp = torch.cat([x_i, x_j], dim=1)  # tmp has shape [E, 2 * in_channels]
    #print('MPNN => messages IN', tmp.size())
    m = self.messages(tmp)
    #print('MPNN => messages OUT', m.size())
    return m
  
  def update(self, aggr_out, x):
    # aggr_out has shape [N, out_channels]
    # x has shape [N, in_channels]
    #print(f'MPNN => x_i', x.size(), ' aggr_out ', aggr_out.size())
    tmp = torch.cat([x, aggr_out], dim=1)
    #print(f'MPNN => tmp', tmp.size())
    return self.update_fn(tmp)

In [160]:
processor = MPNN(256, 128)

In [161]:
hi = processor(z_, dataset[0].edge_index)

In [162]:
hi.size()

torch.Size([6, 128])