In [1]:
# import generator from '../data/generator.py'
import sys
import os
from generator import *

2024-03-11 11:04:40.777734: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


# Graph Generation

In [144]:
n = 6
p = 0.3

dataset = RandomGraphDataset(root='./data', gen_num_graph=200, n=n, p=p)

Generating 200 graphs


100%|██████████| 200/200 [00:00<00:00, 567.93it/s]
Processing...
100%|██████████| 200/200 [00:00<00:00, 332.04it/s]
Done!


In [5]:
import torch
from torch import nn
from encoder import Encoder
from decoder import Decoder
from mpnn import MPNN

class Network(nn.Module):
    def __init__(self, latent_dim=128):
        super(Network, self).__init__()
        self.encoder = Encoder(2, latent_dim)
        self.processor = MPNN(latent_dim*2, latent_dim)
        self.decoder = Decoder(latent_dim, 1)

    def forward(self, x, edge_index):
        z = self.encoder(x)
        h = torch.zeros(x.size(0), 128)
        processor_input = torch.cat([z, h], dim=1)
        x = self.processor(processor_input, edge_index)
        x = self.decoder(x)
        return x

In [62]:
from torch.functional import F
from torch import nn
import torch

class Loss(nn.Module):
    def __init__(self):
        super(Loss, self).__init__()

    def forward(self, batch:torch.tensor, batch_pred: torch.Tensor):

        # for every batch find the predicted and true values and send them to the calculate_loss function
        loss_x = 0
        loss_h = 0
        for i in range(batch.size(0)):
            data = batch[i]
            x = data.reach_h[-1] # true output value
            x_pred = batch_pred[i][-1] # predicted output value
            h_pred = batch_pred[i][:-2] # predicted hint values
            h = data.reach_h[:len(h_pred)] # true hint values
            loss_x += F.binary_cross_entropy(x, x_pred)
            print(loss_x)
            for i in range(h.size(1)):
                loss_h += F.binary_cross_entropy(h[:, i], h_pred[:, i])

        return loss_x, loss_h

In [218]:
import torch
from torch import nn
from encoder import Encoder
from decoder import Decoder
from mpnn import MPNN
from loss import Loss
from torch.functional import F

class Network(nn.Module):
    def __init__(self, latent_dim=128):
        super(Network, self).__init__()
        self.encoder = Encoder(2, latent_dim)
        self.processor = MPNN(latent_dim*2, latent_dim)
        self.decoder = Decoder(latent_dim, 1)

    def forward(self, batch, max_iter=10):
        input = torch.stack((batch.pos, batch.s), dim=1).float()
        h = torch.zeros(input.size(0), 128) # hidden state from the processor
        hints = batch.reach_h[1:] # hints from the reachability
        true_output = batch.reach_h[-1] # true_output expected from the reachability
        max_iter = hints.size(0)+1
        predictions = torch.zeros(max_iter, input.size(0))

        for i in range(max_iter):
            z = self.encoder(input) # the encoded input
            processor_input = torch.cat([z, h], dim=1) # the input to the processor
            h = self.processor(processor_input, batch.edge_index.long()) # the output of the processor
            y = self.decoder(h).view(batch.s.size())   # decoded state from the processor
            predictions[i] = y
            input = torch.stack((batch.pos, y), dim=1).float() # we update the input with the new state
        
        loss = self.calculate_loss(hints, predictions, true_output)
        return y, loss
    
    def calculate_loss(self, hints, predictions, true_output):
        loss_x = F.binary_cross_entropy(torch.sigmoid(predictions[-1]), true_output.type(torch.float))
        loss_h = 0
        for i in range(hints.size(0)):
            loss_h += F.binary_cross_entropy(torch.sigmoid(predictions[i]), hints[i].type(torch.float))
        return loss_x, loss_h

In [230]:
def train(model, train_dataset, validation_dataset=None, optimizer=None, epochs=10, batch_size=5):
    x_loss_weight = 0.5
    h_loss_weight = 1 - x_loss_weight
    for epoch in range(epochs):
        batch_count = len(train_dataset) // batch_size
        for i in range(batch_count):
            model.train()
            cumulated_loss = 0
            for j in range(i*batch_size, (i+1)*batch_size):
                graph = train_dataset[j] 
                y, loss = model(graph)
                loss_x = loss[0] # loss for the output
                loss_hints = loss[1] # loss for the hints
                cumulated_loss += x_loss_weight * loss_x + h_loss_weight * loss_hints # we combine the two losses
            
            cumulated_loss /= batch_size
            optimizer.zero_grad()
            cumulated_loss.backward()
            optimizer.step()

        if validation_dataset:
            model.eval()
            cumulated_loss_val = 0
            with torch.no_grad():
                for k in range(len(validation_dataset)):
                    graph = validation_dataset[k]
                    y, loss = model(graph)
                    loss_x, loss_hints = loss[0], loss[1]
                    cumulated_loss_val += x_loss_weight * loss_x + h_loss_weight * loss_hints

            cumulated_loss_val /= len(validation_dataset)

            print(f'Epoch {epoch}, loss {cumulated_loss.item()}, validation loss {cumulated_loss_val}')
        else:
            print(f'Epoch {epoch}, loss {cumulated_loss.item()}')

In [231]:
# split the dataset into training and testing using the train_test_split function
from torch.utils.data import random_split
dataset = RandomGraphDataset(root='./data', gen_num_graph=250, n=n, p=p)
train_dataset, test_dataset = random_split(dataset, [200, 50])

In [232]:
optimizer = torch.optim.Adam
model = Network()

train(model=model, train_dataset=train_dataset, validation_dataset=test_dataset,
      optimizer=optimizer(model.parameters()), epochs=10)

Epoch 0, loss 0.20106768608093262, validation loss 0.1900792121887207
Epoch 1, loss 0.08652420341968536, validation loss 0.0728931725025177
Epoch 2, loss 1.7268340587615967, validation loss 1.5496764183044434
Epoch 3, loss 1.7241554260253906, validation loss 1.547053575515747
Epoch 4, loss 1.7241947650909424, validation loss 1.5469694137573242
Epoch 5, loss 1.7242014408111572, validation loss 1.5469566583633423
Epoch 6, loss 0.06705305725336075, validation loss 0.05606688931584358
Epoch 7, loss 0.06038669869303703, validation loss 0.049940671771764755
Epoch 8, loss 0.05842023342847824, validation loss 0.04803124815225601
Epoch 9, loss 0.057915251702070236, validation loss 0.047455139458179474


In [259]:
i =31
graph = test_dataset[i]
print(graph.reach_h[-1])
print(torch.sigmoid(model(graph)[0]))

tensor([1., 1., 1., 1., 1., 1.], dtype=torch.float64)
tensor([1., 1., 1., 1., 1., 1.], grad_fn=<SigmoidBackward0>)


In [238]:
model(graph)[0]

tensor([3500525.5000, 3524301.5000, 3249399.0000, 3501471.5000, 3512054.0000,
        3518177.7500], grad_fn=<ViewBackward0>)

In [10]:
input

tensor([[0.0000, 0.0000],
        [0.1667, 1.0000],
        [0.3333, 0.0000],
        [0.5000, 0.0000],
        [0.6667, 0.0000],
        [0.8333, 0.0000]])

In [9]:
Network().forward(input, dataset[0].edge_index)

tensor([[-0.0748],
        [-0.0829],
        [-0.0641],
        [-0.0551],
        [-0.0525],
        [-0.0458]], grad_fn=<AddmmBackward0>)

In [177]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim=128):
        super(Encoder, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.lin = nn.Linear(input_dim, hidden_dim)

    def forward(self, x):
        if x.dim() == 1:
            x = x.unsqueeze(-1)
        return self.lin(x)

In [122]:
encoder = Encoder(2)
z = encoder(input)

In [124]:
h = torch.zeros(input.size(0), 128)

In [125]:
z_ = torch.cat([z, h], dim=1)
z_.size()

torch.Size([6, 256])

In [159]:
import torch
from torch_geometric.nn import MessagePassing
from torch.nn import Linear

class MPNN(MessagePassing):
  def __init__(self, in_channels, hidden_channels, activation=None):
    super(MPNN, self).__init__(aggr='max') #  "Max" aggregation.
    self.in_channels = in_channels
    self.hidden_channels = hidden_channels
    self.messages = Linear(self.in_channels * 2, self.hidden_channels)
    self.update_fn = Linear(self.in_channels + self.hidden_channels, self.hidden_channels)
    self.activation = activation

    self.mlp = torch.nn.Sequential(
        Linear(hidden_channels, hidden_channels),
        torch.nn.ReLU(),
        Linear(hidden_channels, self.hidden_channels)
    )
    
  def forward(self, x, edge_index):
    out = self.propagate(edge_index, x=x)
    out = self.mlp(out)
    if self.activation is not None:
      out = self.activation(out)
    return out
    
  def message(self, x_i, x_j):
    # x_i has shape [E, in_channels]
    # x_j has shape [E, in_channels]
    #print('MPNN => xi, xj', x_i.size(), x_j.size())
    tmp = torch.cat([x_i, x_j], dim=1)  # tmp has shape [E, 2 * in_channels]
    #print('MPNN => messages IN', tmp.size())
    m = self.messages(tmp)
    #print('MPNN => messages OUT', m.size())
    return m
  
  def update(self, aggr_out, x):
    # aggr_out has shape [N, out_channels]
    # x has shape [N, in_channels]
    #print(f'MPNN => x_i', x.size(), ' aggr_out ', aggr_out.size())
    tmp = torch.cat([x, aggr_out], dim=1)
    #print(f'MPNN => tmp', tmp.size())
    return self.update_fn(tmp)

In [160]:
processor = MPNN(256, 128)

In [161]:
hi = processor(z_, dataset[0].edge_index)

In [162]:
hi.size()

torch.Size([6, 128])