In [None]:
# in Google Colab, uncomment this to install torch_geometric:
# !pip install torch_geometric

In [None]:
import torch
import numpy as np
import torch_geometric as tg
import pickle

In [None]:
# %%
# Create dataset
#   IN          OUT
# 0 A - B 2       0 A'- B' 2
#   |   |    -->    |   |
# 1 A - B 3       1 A"- B" 3

n = 2000
x_max = 50
sep = 10  # the minimum difference between the two values

node_attr1 = torch.randint(-x_max, x_max-sep     +1, (n,1), dtype=torch.float32)
node_attr2 = torch.randint(-x_max +sep -1, x_max +1, (n,1), dtype=torch.float32)
bools = node_attr1 > node_attr2 -sep
node_attr1[bools] += sep
node_attr2[bools] -= (sep - 1)
node_attr = torch.cat((node_attr1, node_attr2), dim=-1)
print('node_attr.shape', node_attr.shape)

# shape: nr of graphs, nr of nodes per graph, nr of features per node
node_attr = node_attr[:, [0,0,1,1]].unsqueeze(-1)
print('node_attr.shape', node_attr.shape)


In [None]:
# %%
# Create dataset
#   IN          OUT
# 0 A - A 2       0 B - C 2
#   |   |    -->    |   |
# 1 A - A 3       1 C - B 3

n = 2000
x_max = 50

node_attr1 = torch.randint(-x_max, x_max, (n,1), dtype=torch.float32)
node_attr = torch.tile(node_attr1, (1,4)).unsqueeze(-1)
print('node_attr.shape', node_attr.shape)

In [None]:

edge_index = torch.tensor([[0,1], [1,0], [1,3], [3,1], [0,2], [2,0], [2,3], [3,2]], dtype=torch.long)

edge_attr = torch.tensor([]).reshape(edge_index.shape[0], 0) # edge attributes


# add average term
# c = 0.1
# y = (node_attr + c*torch.mean(node_attr, dim=1, keepdim=True))

y = node_attr.clone()

y = torch.tile(y, dims=(1, 1, 2))
pm = torch.round(torch.rand(n)).reshape(-1, 1, 1)*2-1  # +1 or -1
y[:, [0, 3]] += (5*pm * np.array([1, -1]))
y[:, [1, 2]] -= (5*pm * np.array([1, -1]))
# shape: nr of graphs, nr of nodes per graph, nr of target features per node, nr of valid alternatives
y = y.reshape(-1, 4, 1, 2)
print('y.shape', y.shape)

In [None]:
y

In [None]:
# Compute standard deviation and mean for scaling
x_m, x_std = torch.mean(node_attr), torch.std(node_attr)
y_m, y_std = torch.mean(y), torch.std(y)

x_std = x_std.item()
x_m = x_m.item()

y_std = y_std.item()
y_m = y_m.item()


In [None]:
# Create graphs
data_list = []
for i in range(len(node_attr)):
    data_list.append(
        tg.data.Data(edge_index = edge_index.T,
                     x = ((node_attr[i]-x_m)/x_std).clone(),
                     y = ((y[i]-y_m)/y_std).clone(),
                     edge_attr = edge_attr,
                     )
    )

print(data_list[0])

In [None]:
# %%
# Split in training and testing/validation data
tr_frac = 0.7  # fraction of data for training
temp = int(tr_frac*len(data_list))
data_tr = data_list[:temp]
data_te = data_list[temp:]


In [None]:
with open('../data/FourNodeGraph_data.pkl', 'wb') as f:
    pickle.dump({'data_tr': data_tr,
                 'data_te': data_te,
                 'x_m': x_m, 'x_std': x_std,
                 'y_m': y_m, 'y_std': y_std,
                 }, f)