In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch_geometric as tg
import pickle

In [2]:
# %%
# Create dataset
#   IN          OUT
# A -- B  -->  A-d/2 -- B-d/2
#              A-d/2 -- B+d/2
#              A+d/2 -- B+d/2
# with d = B - A
# E.g.:
# 1 -- 3  --> 0 -- 2
#             0 -- 4
#             2 -- 4
n = 2000
max_node_attr = 50

node_attr = torch.FloatTensor(n,2,1).uniform_(-max_node_attr, max_node_attr)
node_attr.shape

torch.Size([2000, 2, 1])

In [3]:
%matplotlib qt
plt.scatter(*node_attr[..., 0].T, s=1)
plt.gca().set_aspect('equal')

In [4]:
d = torch.diff(node_attr, dim=1)

y = torch.tile(node_attr, dims=(1,1,3))
# Option 1
y[..., [0]] -= d/2
# Option 2
y[:, [0], 1] -= d.reshape(-1, 1)/2
y[:, [1], 1] += d.reshape(-1, 1)/2
# Option 3
y[..., [2]] += d/2

In [5]:
# shuffle output options
ind = torch.randint(0,3, (n,1))
temp = torch.randint(1,3, (n,1))
ind2 = (ind + temp) % 3
ind3 = (ind - temp) % 3

inds = torch.cat((ind, ind2, ind3), dim=-1)
y = torch.take_along_dim(y, inds.reshape(-1,1,3), axis=-1)
y

tensor([[[ 52.0897,  52.0897,  24.5854],
         [ 24.5854,  -2.9190,  -2.9190]],

        [[  0.7180,  -7.9264,  -7.9264],
         [  9.3624,   9.3624,   0.7180]],

        [[-20.9834, -20.9834, -31.4212],
         [-41.8590, -31.4212, -41.8590]],

        ...,

        [[  9.2456,  78.9647,  78.9647],
         [-60.4736,   9.2456, -60.4736]],

        [[ -9.5224,  -8.6127,  -9.5224],
         [ -7.7030,  -7.7030,  -8.6127]],

        [[ 55.0310,   1.5804,  55.0310],
         [-51.8702, -51.8702,   1.5804]]])

In [6]:
# %%
# Check if the generated node attributes and targets make sense
for i in range(3):
    print(f'{str(node_attr[i]):30}\n{y[i]}')
    print('')

tensor([[38.3375],
        [10.8332]])
tensor([[52.0897, 52.0897, 24.5854],
        [24.5854, -2.9190, -2.9190]])

tensor([[-3.6042],
        [ 5.0402]])
tensor([[ 0.7180, -7.9264, -7.9264],
        [ 9.3624,  9.3624,  0.7180]])

tensor([[-26.2023],
        [-36.6401]])
tensor([[-20.9834, -20.9834, -31.4212],
        [-41.8590, -31.4212, -41.8590]])



In [7]:
# Compute standard deviation and mean for scaling
x_m, x_std = torch.mean(node_attr), torch.std(node_attr)
y_m, y_std = torch.mean(y), torch.std(y)

y_std = y_std.item()
y_m = y_m.item()
x_std = x_std.item()
x_m = x_m.item()


In [8]:
node_attr_scaled = (node_attr - x_m) / x_std
y_scaled = (y - y_m) / y_std

In [9]:
# save as tensors
data = torch.cat([node_attr_scaled, y_scaled], dim=-1)

training_fraction = 0.8
n_training = int(n * training_fraction)

with open('../data/ThreeRoads_data.pkl', 'wb') as f:
    pickle.dump({'data_tr': data[:n_training].clone(),
                 'data_te': data[n_training:].clone(),
                 'x_m': x_m, 'x_std': x_std,
                 'y_m': y_m, 'y_std': y_std,
                 }, f)