In [None]:
# in Google Colab, uncomment this to install torch_geometric:
# !pip install torch_geometric

In [None]:
import pickle
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch_geometric as tg
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing

# Open dataset

In [None]:
with open('../data/FourNodeGraph_data.pkl', 'rb') as f:
    data = pickle.load(f)

for key, value in data.items():
    if isinstance(value, torch.Tensor):
        print(key, value.size())
    elif isinstance(value, list):
        print(key, len(value))
    else:
        print(key, value)

data_tr = data['data_tr']
data_te = data['data_te']
x_m = data['x_m']
y_m = data['y_m']
x_std = data['x_std']
y_std = data['y_std']

In [None]:
train_loader = tg.loader.DataLoader(data_tr, batch_size=64)
test_loader = tg.loader.DataLoader(data_te, batch_size=100000)

# Choose device

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    print('cuda available')
    # mlflow.log_param('device', torch.cuda.get_device_name(device))
else:
    device = torch.device('cpu')
    print('cuda not available')
    # mlflow.log_param('device', 'cpu')

# Define model

In [None]:
class MyMessagePassingLayer(MessagePassing):
    """message passing layer that updates both node and edge embeddings.
    first, the node embedding is updated based on the previous node embedding and the message received from neighboring nodes. these messages are based on the embedding of the source node and the edge attributes.
    second, the edge embedding is updated based on the previous edge embedding and the node embeddings of the source and target node.

    parameters
    ----------
    messagepassing : [type]
        [description]
    """

    def __init__(self, node_in, edge_in, message_size, node_out, edge_out):

        """initialize layer

        parameters
        ----------
        node_in : int
            previous node embedding size
        edge_in: int
            previous edge embedding size
        message_size : int
            size of the message
        node_out : int
            node embedding size after updating
        edge_out: in
            edge embedding size after updating
        """
        super().__init__(aggr='add')
        self.mlp_message = torch.nn.Sequential(
                        torch.nn.Linear(2*node_in + edge_in, message_size),
                        torch.nn.LeakyReLU(),
                        torch.nn.Linear(message_size, message_size),
                        torch.nn.LeakyReLU())
        self.mlp_update = torch.nn.Sequential(
                        torch.nn.Linear(node_in + message_size,
                                        node_out),
                        torch.nn.LeakyReLU(),
                        torch.nn.Linear(node_out, node_out))
        self.mlp_edge = torch.nn.Sequential(
                        torch.nn.Linear(edge_in + 2*node_in, edge_out),
                        torch.nn.LeakyReLU(),
                        torch.nn.Linear(edge_out, edge_out))

        # print('__init__')
        # self.lin1 = torch.nn.linear(node_in, node_out)

    def forward(self, x, edge_index, edge_attr):
        """[summary]

        parameters
        ----------
        x : torch.tensor, shape [n, node_in]
            current node embedding for each of the n nodes
        edge_index : torch.tensor, shape [2, e]
            indices of all edges
        edge_attr : torch.tensor, shape [e, edge_in]
            edge_attributes of each of the e edges
        """
        # print('forward')
        return self.propagate(edge_index, x=x, edge_attr=edge_attr), self.edge_updater(edge_index, x=x, edge_attr=edge_attr)

    def edge_update(self, edge_index, x_i, x_j, edge_attr):
        # print('edge_updater')
        temp = torch.cat((edge_attr, x_i, x_j), dim=1)
        # print('edge_update, temp.shape:', temp.shape)

        return self.mlp_edge(temp)

    def message(self, x_i, x_j, edge_attr):
        """[summary]

        parameters
        ----------
        x_j : torch.tensor, shape [e, node_in]
            node embeddings of source nodes
        edge_attr : torch.tensor, shape [e, edge_in]
            [description]
        """
        # print('message')
        # print('x_j.shape, edge_attr.shape:', x_j.shape, edge_attr.shape)
        temp = torch.cat((x_i, x_j, edge_attr), dim=1)
        return self.mlp_message(temp)

    def update(self, aggr_out, x):
        # print('update')
        # print('x.shape, aggr_out.shape:', x.shape, aggr_out.shape)
        temp = torch.cat((x, aggr_out), dim=1)
        # temp = x
        return self.mlp_update(temp)


class MyGNN(torch.nn.Module):
    def __init__(self, use_edge_weight=False):
        super().__init__()

        self.conv1 = MyMessagePassingLayer(1, 0, 0, 32, 32)
        # self.batchnorm_x = torch.nn.BatchNorm1d(32)
        # self.batchnorm_e = torch.nn.BatchNorm1d(32)
        self.conv2 = MyMessagePassingLayer(32, 32, 32, 32, 32)
        # self.batchnorm2_x = torch.nn.BatchNorm1d(32)
        # self.batchnorm2_e = torch.nn.BatchNorm1d(32)
        self.conv3 = MyMessagePassingLayer(32, 32, 32, 1, 0)

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr

        # convolutional part
        x, edge_attr = self.conv1(x, edge_index, edge_attr=edge_attr)
        # x, edge_attr = self.batchnorm_x(x), self.batchnorm_e(edge_attr)
        x, edge_attr = F.leaky_relu(x), F.leaky_relu(edge_attr)
        x, edge_attr = self.conv2(x, edge_index, edge_attr=edge_attr)
        # x, edge_attr = self.batchnorm2_x(x), self.batchnorm2_e(edge_attr)
        x, edge_attr = F.leaky_relu(x), F.leaky_relu(edge_attr)
        x, _ = self.conv3(x, edge_index, edge_attr=edge_attr)

        return x

In [None]:
model = MyGNN().to(device)
print(model)

n_params = sum(p.numel() for p in model.parameters())
print('Total nr of parameters:', n_params)

# Evaluate GNN before training

In [None]:
model.eval()
for batch in test_loader:
    batch = batch.clone().to(device)

    real = batch.y[..., 0].cpu().detach().numpy()
    pred = model(batch).cpu().detach().numpy()

    plt.scatter(real, pred, s=1)
    plt.xlabel('real')
    plt.ylabel('predicted')

# Train

In [None]:
lr=5e-3
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
ep_lower_lr = 0

MSE_train = []
MSE_test = []
for epoch in range(1000):
    model.train()
    print(f'epoch {epoch:5}', end=' ')

    error_train = []
    for batch in train_loader:
        batch = batch.clone().to(device)
        optimizer.zero_grad()
        out = model(batch)
        if torch.isnan(out).any():
            raise ValueError('nan value in train prediction')

        loss = torch.nn.MSELoss()(out, batch.y[..., 0])
        loss.backward()

        optimizer.step()

        error_train.extend((out - batch.y[..., 0]).cpu().detach().numpy())

    MSE_train.append(np.mean(np.asarray(error_train)**2))
    print(f'train MSE {MSE_train[-1]:12.4}', end=' ')

    model.eval()
    error_test = []
    for batch in test_loader:
        batch = batch.clone().to(device)
        out = model(batch)
        error_test.extend((out - batch.y[..., 0]).cpu().detach().numpy())

    MSE_test.append(np.mean(np.asarray(error_test)**2))
    print(f'test MSE {MSE_test[-1]:12.4}')

    # ======== Early stopping =========
    # Stop training if loss did not decrease enough on validation data
    if (
        epoch > 20
        and
        (np.mean(MSE_test[-5:])
             > 0.995*np.mean(MSE_test[-10:-5]))
    ):

        print('======= !! Validation loss did not decrease enough =======')
        print('Stopping training')
        break




# Plot loss

In [None]:
plt.plot(MSE_train, label='train')
plt.plot(MSE_test, label='validation')
plt.legend()
plt.yscale('log')

If we plot the real vs the predicted final node attribute, we can see that the model does manage to predict the contribution from the global average fairly accurately; its prediction is always off by about +5 or -5, which exactly what you expect if the model is averaging the possible options.

In [None]:
model.eval()
for batch in test_loader:
    batch = batch.clone().to(device)
    real = batch.y[..., 0].cpu().detach().numpy()*y_std + y_m
    pred = model(batch).cpu().detach().numpy()*y_std + y_m
    plt.scatter(real, pred, s=1)
    plt.gca().set_aspect('equal')
    plt.xlabel('real final node attribute')
    plt.ylabel('predicted')
    plt.axline([0,0], [1,1], c='tab:orange')

To compare, we show below the simplest possible prediction: we simply predict that the node attribute does not change, so the final node attribute is equal to the initial node attribute. This shows some 'fuzziness' in the prediction, which is caused by the global average contribution.

In [None]:
for batch in test_loader:
    batch = batch.clone()
    real = batch.y[..., 0].cpu().detach().numpy()*y_std + y_m
    pred = batch.x.cpu().detach().numpy()*y_std + y_m
    plt.scatter(real, pred, s=1)
    plt.gca().set_aspect('equal')
    plt.ylabel('prediction: initial node attribute')
    plt.xlabel('real final node attribute')
    plt.axline([0,0], [1,1], c='tab:orange')

# Calculate Wasserstein distance
Since the GNN does successfully manage to predict the global average contribution but otherwise predicts the average of the possible outputs, it has an error of almost exactly 5.0 for each node. This means a single datapoint consisting of 4 nodes must move a `distance' of $\sqrt{4\cdot 5^2} = 10$ to end up in either of the correct solutions. Since half of the probability mass must shift toward one solution and the other half towards the other solution, everything must shift a distance of about 10. We therefore expect the Wasserstein distance to be very close to 10.0.

In [None]:
from scipy.stats import wasserstein_distance_nd
model.eval()

rel_dist = []
for batch in test_loader:
    batch = batch.clone().to(device)

    # create 1 predictions per node (since the prediction will always be the same, more samples is useless)
    preds = model(batch).cpu().detach().numpy().reshape(-1, 1)*y_std + y_m

    preds = preds.reshape(-1, 4, 1)
    reals = batch.y.cpu().detach().numpy().reshape(-1, 4, 2)*y_std + y_m

    print('preds.shape:', preds.shape)
    print('reals.shape:', reals.shape)
    wd = []
    for pred, real in zip(preds, reals):  # iterate over all data points (as far as I know, this cannot be batched)
        wd.append(wasserstein_distance_nd(pred.T, real.T))

        diff = pred[..., np.newaxis] - real[..., np.newaxis, :]
        dist = np.linalg.norm(diff, axis=0)
        # print(diff.shape)
        # print(dist.shape)
        rel_dist_temp = dist[..., 0] / (dist[..., 0] + dist[..., 1])
        # print(rel_dist_temp.shape)
        rel_dist.extend(rel_dist_temp.flatten().tolist())

print(np.mean(wd))

In [None]:
# histogram of rel_dist
plt.rcParams.update({'font.size': 16})
plt.figure(figsize=(5,5))
plt.hist(rel_dist,
         bins=np.linspace(0, 1, 40), density=True)
plt.xlabel(r'$\frac{\text{Distance to option 1}}{\text{Distance to option 1 + Distance to option 2}}$', fontsize=18)
# change label font size x label
plt.axvline(0.0, color='red', linestyle='--')
plt.axvline(1.0, color='red', linestyle='--')
plt.ylabel('Density')

# Calculate Wasserstein distance when predicting that final = initial

In [None]:
from scipy.stats import wasserstein_distance_nd
model.eval()
for batch in test_loader:
    preds = batch.x.cpu().detach().numpy().reshape(-1, 4, 1)*x_std + x_m
    reals = batch.y.cpu().detach().numpy().reshape(-1, 4, 2)*y_std + y_m

    print('preds.shape:', preds.shape)
    print('reals.shape:', reals.shape)
    wd = []
    for pred, real in zip(preds, reals):  # iterate over all data points (as far as I know, this cannot be batched)
        wd.append(wasserstein_distance_nd(pred.T, real.T))

print(np.mean(wd))

This shows that the Wasserstein distance is better when we use a GNN instead of just predicting that the final node value is equal to the initial node value, but not by much.