# Four-Node Graph Problem: Introduction

This test system is a simple example of bifurcation on a graph, where the permutation symmetry is broken.

The input graph is a graph of four nodes that are connected in a square. Each node has an initial node attribute $x_i$. Two of the nodes are identical, with initial node embedding $a$ and the other two are also identical to each other with initial node embedding $b$.

The output node attributes are then
$${x_i + 0.1 \frac{1}{4}\sum^4_1 x_i \pm 5}$$
The term $0.1 \frac{1}{4}\sum^4_1 x_i$ means each node attribute gets a contribution from the average value of the node attribute, which forces the nodes to consider the entire graph.

In the term $\pm 5$, connected non-identical nodes will get the same sign, but identical nodes (which are also connected by an edge) will get an opposite sign. This forces two identical nodes to diverge: one of them must increase its value by 5, the other must decrease it. The nodes need to coordinate to figure out which one picks plus and which picks minus. They must also consider the other nodes, because non-identical nodes connected by an edge must pick the same sign.

This leads to 2 different possibilities for the entire output graph.
If the input node attributes are 
$$x^{\textrm{in}}_i = [a,a,b,b]$$
then the possible outputs are 
$$x^{\textrm{out}}_i = [a+g+5, a+g-5, b+g+5, b+g-5]$$
or
$$x^{\textrm{out}}_i = [a+g-5, a+g+5, b+g-5, b+g+5]$$
with $g$ the global contribution $g = 0.1 \frac{1}{4}(a + a + b + b)$.


In [29]:
# in Google Colab, uncomment this to install torch_geometric:
# !pip install torch_geometric

In [30]:
import pickle
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch_geometric as tg

# The Dataset
Let's take a look at what the dataset actually looks like.

In [31]:
with open('../data/FourNodeGraph_data.pkl', 'rb') as f:
    data = pickle.load(f)

for key, value in data.items():
    if isinstance(value, torch.Tensor):
        print(key, value.size())
    elif isinstance(value, list):
        print(key, len(value))
    else:
        print(key, value)

data_tr = data['data_tr']
data_te = data['data_te']
x_m = data['x_m']
y_m = data['y_m']
x_std = data['x_std']
y_std = data['y_std']

data_tr 1400
data_te 600
x_m 0.6632500290870667
x_std 29.427982330322266
y_m 0.729574978351593
y_std 31.057353973388672


In [32]:
# look at the first data point of the training and validation set
print(data_tr[0])
print(data_te[0])

Data(x=[4, 1], edge_index=[2, 8], edge_attr=[8, 0], y=[4, 1, 2])
Data(x=[4, 1], edge_index=[2, 8], edge_attr=[8, 0], y=[4, 1, 2])


Both data_tr and data_te are lists of torch_geometric.data.Data objects, with 1400 and 600 data points respectively.
Each Data object represents a graph with the following quantities and their shapes:
* node features x, shape [N, 1]: the initial node value
* edge_index, shape [2, 2*N]: indices that indicate which nodes are connected by edge. Also includes the reverse edges.
* edge_attr, shape [N, 0]: these are the edge features, empty initially because there are no edge features
* final node values y, shape [N, 1, n_sol] = [N, 1, 2]: both possible solutions. If we train on only one solution, we take the first one: graph.y[:, :, 0], with graph one element from data_tr or data_te.

In [34]:
# Take a look at the node attributes
for i in range(3):
    print('Initial node attributes:')
    print(f'{str(data_tr[i].x*x_std+x_m):30}')
    print('Final node attributes (each column is one possible option):')
    print(f'{data_tr[i].y*y_std+y_m}')

Initial node attributes:
tensor([[-46.],
        [-46.],
        [ 12.],
        [ 12.]])
Final node attributes (each column is one possible option):
tensor([[[-42.7000, -52.7000]],

        [[-52.7000, -42.7000]],

        [[ 15.3000,   5.3000]],

        [[  5.3000,  15.3000]]])
Initial node attributes:
tensor([[39.],
        [39.],
        [-2.],
        [-2.]])
Final node attributes (each column is one possible option):
tensor([[[35.8500, 45.8500]],

        [[45.8500, 35.8500]],

        [[-5.1500,  4.8500]],

        [[ 4.8500, -5.1500]]])
Initial node attributes:
tensor([[-21.],
        [-21.],
        [-36.],
        [-36.]])
Final node attributes (each column is one possible option):
tensor([[[-18.8500, -28.8500]],

        [[-28.8500, -18.8500]],

        [[-33.8500, -43.8500]],

        [[-43.8500, -33.8500]]])


Above you can see that always the zeroth and first node have the same initial value, as well as the second and third node.

In [35]:
data_tr[0].edge_index

tensor([[0, 1, 1, 3, 0, 2, 2, 3],
        [1, 0, 3, 1, 2, 0, 3, 2]])

From the edge_index above, we can see that the nodes are connected in a square like this:
```<!-- language: lang-none -->
0 - 1
|   |
2 - 3
```

In the plot below, you can see that the initial values of the nodes are distributed essentially randomly, except for the fact that there is a difference of at least 10 between them.

In [36]:
%matplotlib qt
# plot the values of the zeroth and second node against each other
node_attr = np.array([graph.x.numpy()[[0, 2],0] for graph in data_tr])
node_attr = node_attr * x_std + x_m
plt.scatter(*node_attr.T, s=1)
plt.gca().set_aspect('equal')
plt.xlabel('Node 0')
plt.ylabel('Node 2')
plt.grid()

In [37]:
# create loaders for easiest access to the data
train_loader = tg.loader.DataLoader(data_tr, batch_size=64)
test_loader = tg.loader.DataLoader(data_te, batch_size=100000)

In the plot below, we plot the input against the output for each node. Similar to what we saw with the Three Roads problem, it shows that we cannot predict the final value of a node from only its initial value. This is because the nodes interact. We can also see that for each initial value, there are two clusters of possible final values. The fact that there are two is caused by the choice between +5 or -5, and the averaging term causes these two clusters to be a bit fuzzy.

In [16]:
# %%
# Plot in vs out per node
for batch in test_loader:
    batch = batch.clone()

    plt.scatter(batch.x[:,0].detach().numpy(), batch.y[:,0,0].detach().numpy(), s=1)
    plt.gca().set_aspect('equal')
    plt.xlabel('initial node attribute')
    plt.ylabel('final node attribute')
    plt.axline([0,0], [1,1], c='tab:orange')

plt.show()


The following plot shows how the node values split apart; instead of one dot (in blue) which represents two values, you now need two dots (red and orange) to represent the four different values of the nodes.

In [38]:
# %%
# Plot how ground truth bifurcates
fig, ax = plt.subplots()
fig.patch.set_facecolor("None")

for ind in range(100):
    graph = data_tr[ind]
    pos1 = graph.x[[0, 2], 0].numpy()*x_std+x_m
    pos2a = graph.y[[0, 2], 0, 0].numpy()*y_std+y_m
    pos2b = graph.y[[1, 3], 0, 0].numpy()*y_std+y_m

    ax.scatter(*pos1, c='tab:blue', s=10, label='Initial')
    ax.scatter(*pos2a, c='tab:orange', s=10, label='Final, node 0 & 2')
    ax.scatter(*pos2b, c='tab:red', s=10, label='Final, node 1 & 3')

    ax.annotate('', xy=pos2a, xytext=pos1,
                arrowprops=dict(arrowstyle='->', facecolor='black'),
                )
    ax.annotate('', xy=pos2b, xytext=pos1,
                arrowprops=dict(arrowstyle='->', facecolor='black'),
                )

handles, labels = ax.get_legend_handles_labels()
plt.legend(handles[:3], labels[:3])
ax.set_xlabel('Embedding node 0 and 1')
ax.set_ylabel('Embedding node 2 and 3')
ax.set_aspect('equal')

plt.show()