In [None]:
# in Google Colab, uncomment this to install torch_geometric:
# !pip install torch_geometric

In [None]:
import torch
import torch.nn.functional as F
import torch_geometric as tg
import matplotlib.pyplot as plt
import numpy as np
import pickle
from torch_geometric.nn import MessagePassing

# Open Dataset

In [None]:
with open('../data/FourNodeGraph_data.pkl', 'rb') as f:
    data = pickle.load(f)

for key, value in data.items():
    if isinstance(value, torch.Tensor):
        print(key, value.size())
    elif isinstance(value, list):
        print(key, len(value))
    else:
        print(key, value)

data_tr = data['data_tr']
data_te = data['data_te']
x_m = data['x_m']
y_m = data['y_m']
x_std = data['x_std']
y_std = data['y_std']

In [None]:
train_loader = tg.loader.DataLoader(data_tr, batch_size=64)
test_loader = tg.loader.DataLoader(data_te, batch_size=100000) # all test data at once

# Choose device

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    print('cuda available')
else:
    device = torch.device('cpu')
    print('cuda not available')

# Define metrics

In [None]:
def KL_loss(logsig_i, mu_i, logsig_f=torch.tensor(0.0), mu_f=torch.tensor(0.0)):
    # logsig_f and mu_f are optional, if not given, they are assumed to be 0
    # (corresponding to a standard normal distribution, with mu=0 and sig=1)
    temp = 2*(logsig_f-logsig_i) - 1 + torch.exp(logsig_i)**2/torch.exp(logsig_f)**2 + (mu_f - mu_i)**2/torch.exp(logsig_f)**2
    temp = 0.5*torch.sum(temp, axis=-1)
    return torch.mean(temp)

def MSELoss_allTargets(pred, target, batch, return_indices=False):
    """MSE loss applied on all possible targets, then taking the minimum per graph

    Parameters
    ----------
    pred : torch tensor, [N, f]
        prediction per node, N = total nr of nodes, f = nr of predicted features
    target : torch tensor, [N, f, a]
        all possible targets, N = total nr of nodes, f = nr of predicted features, a = nr of possible alternatives
    batch : torch tensor, [N, ]
        graph that each node belongs to
    return_indices : bool, optional
        whether to return the indices of the alternatives giving the minimum loss, by default False

    Returns
    -------
    torch tensor, [1,]
        MSE loss, using the closest alternative for each graph
    """
    # sum square error per node per alternative (resulting shape: [N, a])
    SE = torch.sum((pred.unsqueeze(-1) - target)**2, dim=1)
    # sum square error per graph per alternative (resulting shape: [G, a], with G the nr of graphs)
    SE = tg.nn.global_add_pool(SE, batch=batch)
    # take the minimum error per graph
    SE_min, inds = torch.min(SE, axis=-1)

    # take mean over entire batch
    MSE = torch.sum(SE_min)/(pred.numel())

    if return_indices:
        return MSE, inds
    else:
        return MSE

# Define model

In [None]:
class MyMessagePassingLayer(MessagePassing):
    """message passing layer that updates both node and edge embeddings.
    first, the node embedding is updated based on the previous node embedding and the message received from neighboring nodes. these messages are based on the embedding of the source node and the edge attributes.
    second, the edge embedding is updated based on the previous edge embedding and the node embeddings of the source and target node.

    parameters
    ----------
    messagepassing : [type]
        [description]
    """

    def __init__(self, node_in, edge_in, message_size, node_out, edge_out):

        """initialize layer

        parameters
        ----------
        node_in : int
            previous node embedding size
        edge_in: int
            previous edge embedding size
        message_size : int
            size of the message
        node_out : int
            node embedding size after updating
        edge_out: inL
            edge embedding size after updating
        """
        super().__init__(aggr='add')
        self.mlp_message = torch.nn.Sequential(
                        torch.nn.Linear(2*node_in + edge_in, message_size),
                        torch.nn.LeakyReLU(),
                        torch.nn.Linear(message_size, message_size),
                        torch.nn.LeakyReLU())
        self.mlp_update = torch.nn.Sequential(
                        torch.nn.Linear(node_in + message_size,
                                        node_out),
                        torch.nn.LeakyReLU(),
                        torch.nn.Linear(node_out, node_out))
        self.mlp_edge = torch.nn.Sequential(
                        torch.nn.Linear(edge_in + 2*node_in, edge_out),
                        torch.nn.LeakyReLU(),
                        torch.nn.Linear(edge_out, edge_out))

        # print('__init__')
        # self.lin1 = torch.nn.linear(node_in, node_out)

    def forward(self, x, edge_index, edge_attr):
        """[summary]

        parameters
        ----------
        x : torch.tensor, shape [n, node_in]
            current node embedding for each of the n nodes
        edge_index : torch.tensor, shape [2, e]
            indices of all edges
        edge_attr : torch.tensor, shape [e, edge_in]
            edge_attributes of each of the e edges
        """
        # print('forward')
        return self.propagate(edge_index, x=x, edge_attr=edge_attr), self.edge_updater(edge_index, x=x, edge_attr=edge_attr)

    def edge_update(self, edge_index, x_i, x_j, edge_attr):
        # print('edge_updater')
        temp = torch.cat((edge_attr, x_i, x_j), dim=1)
        # print('edge_update, temp.shape:', temp.shape)

        return self.mlp_edge(temp)

    def message(self, x_i, x_j, edge_attr):
        """[summary]

        parameters
        ----------
        x_j : torch.tensor, shape [e, node_in]
            node embeddings of source nodes
        edge_attr : torch.tensor, shape [e, edge_in]
            [description]
        """
        # print('message')
        # print('x_j.shape, edge_attr.shape:', x_j.shape, edge_attr.shape)
        temp = torch.cat((x_i, x_j, edge_attr), dim=1)
        return self.mlp_message(temp)

    def update(self, aggr_out, x):
        # print('update')
        # print('x.shape, aggr_out.shape:', x.shape, aggr_out.shape)
        temp = torch.cat((x, aggr_out), dim=1)
        # temp = x
        return self.mlp_update(temp)


class SimpleGNN(torch.nn.Module):
    def __init__(self, node_in, node_out):
        super().__init__()

        self.conv1 = MyMessagePassingLayer(node_in, 0, 0, 32, 32)
        self.conv2 = MyMessagePassingLayer(32, 32, 32, node_out, 0)

    def forward(self, x, edge_index, edge_attr):

        # convolutional part
        x, edge_attr = self.conv1(x, edge_index, edge_attr=edge_attr)
        x, edge_attr = F.leaky_relu(x), F.leaky_relu(edge_attr)
        x, _ = self.conv2(x, edge_index, edge_attr=edge_attr)

        return x

class ProbabilisticGNN1(torch.nn.Module):
    def __init__(self, latent_dim):
        super().__init__()

        self.latent_dim = latent_dim
        self.embeddingGNN1 = SimpleGNN(1, 32)
        self.embeddingGNN2 = SimpleGNN(2, 32)
        self.mlp1 = torch.nn.Sequential(
                        torch.nn.LeakyReLU(),
                        torch.nn.Linear(32, 32),
                        torch.nn.LeakyReLU(),
                        torch.nn.Linear(32, 2*latent_dim)
                        )
        self.mlp2 = torch.nn.Sequential(
                        torch.nn.LeakyReLU(),
                        torch.nn.Linear(32*2, 32),
                        torch.nn.LeakyReLU(),
                        torch.nn.Linear(32, 2*latent_dim)
                        )
        self.finalGNN = SimpleGNN(32 + latent_dim, 1)

    def forward(self, data, cond_on_final=False):
        x, y, edge_index, edge_attr = data.x, data.y, data.edge_index, data.edge_attr

        if y.ndim == 3:  # if multiple options are provided, take the first one
            y = y[..., 0]
        elif y.ndim == 2:
            pass
        else:
            raise ValueError(f'ground truth output y should have 2 or 3 dimensions, not {y.ndim}')

        # Apply GNN to initial
        x2 = self.embeddingGNN1(x.clone(), edge_index, edge_attr)
        if cond_on_final:
            # Apply GNN to final
            x3 = self.embeddingGNN2(torch.cat((x.clone(), y.clone()), dim=-1), edge_index, edge_attr)

        # Create logsig, mu from initial
        logsig_mu_i = self.mlp1(x2).reshape(-1, self.latent_dim, 2)
        logsig_i, mu_i = logsig_mu_i[..., 0], logsig_mu_i[..., 1]
        if cond_on_final:
            # Create logsig, mu from initial and final
            x4 = torch.cat((x2, x3), dim=-1)
            logsig_mu_f = self.mlp2(x4).reshape(-1, self.latent_dim, 2)
            logsig_f, mu_f = logsig_mu_f[..., 0], logsig_mu_f[..., 1]

        # Sample
        eps = torch.randn_like(logsig_i)
        z_i = eps*torch.exp(logsig_i) + mu_i
        if cond_on_final:
            eps = torch.randn_like(logsig_f)
            z_f = eps*torch.exp(logsig_f) + mu_f

        if cond_on_final:
            x = torch.cat((x2, z_f), dim=-1)
        else:
            x = torch.cat((x2, z_i), dim=-1)

        x = self.finalGNN(x, edge_index, edge_attr)

        if cond_on_final:
            return logsig_i, mu_i, logsig_f, mu_f, x
        else:
            return (x, )

# Define GNN

In [None]:
model = ProbabilisticGNN1(latent_dim=8).to(device=device)
print(model)

n_params = sum(p.numel() for p in model.parameters())
print('Total nr of parameters:', n_params)

# Evaluate GNN before training
To test the initialization of the network, evaluate the GNN before even training it. We use the 'MSELoss_allTargets' function, which calculates the mean square loss before the target and every prediction, and then picks the best prediction.

In [None]:
model.eval()
for batch in test_loader:
    batch = batch.clone()
    batch.to(device=device)
    pred = model(batch)[-1]
    test_loss, inds = MSELoss_allTargets(pred, batch.y, batch=batch.batch, return_indices=True)

    ground_truth = np.take_along_axis(batch.y.cpu().detach().numpy().reshape(-1,4,1,2),
                                    inds.cpu().numpy().reshape(-1,1,1,1), axis=-1)
    plt.scatter(ground_truth, pred.cpu().detach().numpy(), s=1)
    # plt.gca().set_aspect('equal')
    plt.xlabel('real')
    plt.ylabel('predicted')
    plt.title('Real vs predicted before training\n(to check initialization of network)')

# Train

In [None]:
lr=5e-3
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
ep_lower_lr = 0

losses = {
            # on training data, while conditioning on output:
            'MSE_train': [],
            'KL_train': [],
            'loss_train': [], # total loss

            # on validation data, without conditioning on output:
            'MSE_val': [],
            'MSE_val_allTargets': [],  # MSE compared to closest target

            # on validation data, but still conditioned on output:
            'KL_val2': [],
            'MSE_val2': []
        }

KL_increase_range = 50 # nr of epochs over which we increase the weight of the KL loss
for epoch in range(500):
    print(f'epoch {epoch}')

    # ====================== 1) TRAIN ======================
    KL_weight = np.clip(epoch/KL_increase_range, a_min=0.0, a_max=1.0)/2
    print(' \t KL_weight:', KL_weight)
    losses_temp = {key: [] for key in losses}

    model.train()
    for batch in train_loader:
        batch = batch.clone().to(device)
        optimizer.zero_grad()

        out = model(batch, cond_on_final=True)
        pred = out[-1]
        if torch.isnan(pred).any():
            raise ValueError('nan value in train prediction')

        MSE_loss = torch.nn.MSELoss()(pred, batch.y[..., 0])
        KL_loss1 = KL_loss(*out[:-1])

        loss = (1.0-KL_weight)*MSE_loss + KL_weight*KL_loss1
        loss.backward()

        optimizer.step()

        losses_temp['MSE_train'].append(MSE_loss.item())
        losses_temp['KL_train'].append(KL_loss1.item())
        losses_temp['loss_train'].append(loss.item())

    for loss in ['MSE_train', 'KL_train', 'loss_train']:
        losses[loss].append(np.mean(losses_temp[loss]))
        print(f'\t{loss:10} {losses[loss][-1]:8.4}')

    # =============== 2) EVALUATE WITHOUT CONDITIONING ON OUTPUT ===============
    # In evaluation mode, so not conditioned on output. Calculating KL loss is therefore not possible.
    model.eval()
    for batch in test_loader:
        batch = batch.clone().to(device)
        out = model(batch, cond_on_final=False)
        pred = out[-1]
        if torch.isnan(pred).any():
            raise ValueError('nan value in train prediction')

        MSE_loss = torch.nn.MSELoss()(pred, batch.y[..., 0])
        MSE_loss_alltarg = MSELoss_allTargets(pred, batch.y, batch=batch.batch)

        losses_temp['MSE_val'].append(MSE_loss.item())
        losses_temp['MSE_val_allTargets'].append(MSE_loss_alltarg.item())

    for loss in ['MSE_val', 'MSE_val_allTargets']:
        losses[loss].append(np.mean(losses_temp[loss]))
        print(f'\t{loss:10} {losses[loss][-1]:8.4}')

    # ================ 3) EVALUATE WITH CONDITIONING ON OUTPUT ================
    model.eval()
    for batch in test_loader:
        batch = batch.clone().to(device)
        out = model(batch, cond_on_final=True)
        pred = out[-1]
        if torch.isnan(pred).any():
            raise ValueError('nan value in train prediction')

        MSE_loss = torch.nn.MSELoss()(pred, batch.y[..., 0])
        KL_loss1 = KL_loss(*out[:-1])

        losses_temp['MSE_val2'].append(MSE_loss.item())
        losses_temp['KL_val2'].append(KL_loss1.item())

    for loss in ['MSE_val2', 'KL_val2']:
        losses[loss].append(np.mean(losses_temp[loss]))
        print(f'\t{loss:10} {losses[loss][-1]:8.4}')

    # ======== Early stopping =========
    # Stop training if both MSE and KL loss did not decrease enough on validation data
    if (
        epoch > KL_increase_range
        and
        (np.mean(losses["MSE_val_allTargets"][-5:])
             > 0.995*np.mean(losses["MSE_val_allTargets"][-10:-5]))
        and (np.mean(losses["KL_val2"][-5:])
             > 0.995*np.mean(losses["KL_val2"][-10:-5]))
    ):

        print('======= !! Validation loss did not decrease enough =======')
        print('Stopping training')
        break

# Plot losses

In [None]:
plt.figure()

for key, value in losses.items():
    plt.plot(value, label=key)
plt.legend()
plt.yscale('log')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.show()


# Compare input, output, predicted output

In [None]:
model.eval()
for batch in test_loader:
    batch = batch.clone().to(device)
    # print input
    print('============== INPUT ==============')
    temp_x = (batch.x[:6*4]*x_std + x_m).cpu().detach().numpy().astype(int)
    for asdf in temp_x.reshape(-1, 4):
        print(*asdf, sep=', ')
    # print ground truth output
    batch.y[:12]*y_std + y_m

    print('============== POSSIBLE OUTPUTS ==============')
    temp_y = (batch.y[:6*4]*y_std + y_m).cpu().detach().numpy()
    for asdf in temp_y.reshape(-1, 4, 2):
        print(*asdf.T, sep=', ')
    # print predicted output
    batch.y[:12]*y_std + y_m

    print('============== PREDICTIONS ==============')
    model.eval()
    pred = model(batch)[-1][:6*4]*y_std + y_m
    temp_y = pred.cpu().detach().numpy()
    for asdf in temp_y.reshape(-1, 4):
        print(*asdf.T, sep=', ')

# Plot real vs predicted

In [None]:
model.eval()
for batch in test_loader:
    batch = batch.clone()
    batch.to(device=device)
    pred = model(batch, cond_on_final=False)[-1]
    print('Shapes:', pred.shape, batch.y.shape)
    test_loss, inds = MSELoss_allTargets(pred, batch.y, batch=batch.batch, return_indices=True)
    print('test RMSE', torch.sqrt(test_loss).item())

    ground_truth = np.take_along_axis(batch.y.cpu().detach().numpy().reshape(-1,4,1,2),
                                    inds.cpu().numpy().reshape(-1,1,1,1), axis=-1)
    plt.scatter(ground_truth*y_std+y_m, pred.cpu().detach().numpy()*y_std+y_m, s=1)
    plt.axline([0,0], [1,1], c='tab:orange')
    plt.gca().set_aspect('equal')
    plt.xlabel('real')
    plt.ylabel('predicted')
    plt.title('Real (best match) vs predicted, not conditioned on output')

In [None]:
model.eval()
for batch in test_loader:
    batch = batch.clone()
    batch.to(device=device)
    pred = model(batch, cond_on_final=True)[-1]
    print('Shapes:', pred.shape, batch.y.shape)
    test_loss, inds = MSELoss_allTargets(pred, batch.y, batch=batch.batch, return_indices=True)
    print('test RMSE', torch.sqrt(test_loss).item())

    ground_truth = np.take_along_axis(batch.y.cpu().detach().numpy().reshape(-1,4,1,2),
                                    inds.cpu().numpy().reshape(-1,1,1,1), axis=-1)
    plt.axline([0,0], [1,1], c='tab:orange')
    plt.scatter(ground_truth*y_std + y_m, pred.cpu().detach().numpy()*y_std + y_m, s=1)
    plt.gca().set_aspect('equal')
    plt.xlabel('real')
    plt.ylabel('predicted')
    plt.title('Real (best match) vs predicted, conditioned on output')

# Bifurcation plots

In [None]:
# %%
# Plot how ground truth bifurcates
fig, ax = plt.subplots()
fig.patch.set_facecolor("None")

model.eval()
for batch in test_loader:
    for ind in range(100):
        pos1 = (batch.x[[4*ind+0, 4*ind+2],0]*x_std+x_m).cpu().detach().numpy()
        pos2a = (batch.y[[4*ind+0, 4*ind+2],0,0]*y_std+y_m).cpu().detach().numpy()
        pos2b = (batch.y[[4*ind+1, 4*ind+3],0,0]*y_std+y_m).cpu().detach().numpy()

        ax.scatter(*pos1,  c='tab:blue',   s=10, label='Initial')
        ax.scatter(*pos2a, c='tab:orange', s=10, label='Final, node 0 & 2')
        ax.scatter(*pos2b, c='tab:red',    s=10, label='Final, node 1 & 3')

        ax.annotate('', xy=pos2a, xytext=pos1,
                    arrowprops=dict(arrowstyle='->', facecolor='black'),
                    )
        ax.annotate('', xy=pos2b, xytext=pos1,
                    arrowprops=dict(arrowstyle='->', facecolor='black'),
                    )

    handles, labels = ax.get_legend_handles_labels()
    plt.legend(handles[:3], labels[:3])
    ax.set_xlabel('Embedding node 0 and 1')
    ax.set_ylabel('Embedding node 2 and 3')
    ax.set_aspect('equal')
    ax.set_title('Ground truth bifurcation')

    plt.show()



In [None]:
# Plot how result bifurcates (not conditioned on output)
fig, ax = plt.subplots()
fig.patch.set_facecolor("None")

model.eval()
for batch in test_loader:
    batch = batch.clone().to(device=device)
    pred = model(batch, cond_on_final=False)[-1]

    for ind in range(100):
        pos1 = (batch.x[[4*ind+0, 4*ind+2]]*x_std+x_m).cpu().detach().numpy()
        pos2a = (pred[[4*ind+0, 4*ind+2]]*y_std+y_m).cpu().detach().numpy()
        pos2b = (pred[[4*ind+1, 4*ind+3]]*y_std+y_m).cpu().detach().numpy()

        ax.scatter(*pos1,  c='tab:blue',   s=10, label='Initial')
        ax.scatter(*pos2a, c='tab:orange', s=10, label='Final, node 0 & 2')
        ax.scatter(*pos2b, c='tab:red',    s=10, label='Final, node 1 & 3')

        ax.annotate('', xy=pos2a, xytext=pos1,
                    arrowprops=dict(arrowstyle='->', facecolor='black'),
                    )
        ax.annotate('', xy=pos2b, xytext=pos1,
                    arrowprops=dict(arrowstyle='->', facecolor='black'),
                    )

    handles, labels = ax.get_legend_handles_labels()
    plt.legend(handles[:3], labels[:3])
    ax.set_xlabel('Embedding node 0 and 1')
    ax.set_ylabel('Embedding node 2 and 3')
    ax.set_aspect('equal')
    ax.set_title('Predicted bifurcation (not conditioned on output)')

    plt.show()

In [None]:
# Plot how result bifurcates (conditioned on output)
fig, ax = plt.subplots()
fig.patch.set_facecolor("None")

model.eval()
for batch in test_loader:
    batch = batch.clone().to(device=device)
    pred = model(batch, cond_on_final=True)[-1]

    for ind in range(100):
        pos1 = (batch.x[[4*ind+0, 4*ind+2]]*x_std+x_m).cpu().detach().numpy()
        pos2a = (pred[[4*ind+0, 4*ind+2]]*y_std+y_m).cpu().detach().numpy()
        pos2b = (pred[[4*ind+1, 4*ind+3]]*y_std+y_m).cpu().detach().numpy()

        ax.scatter(*pos1,  c='tab:blue',   s=10, label='Initial')
        ax.scatter(*pos2a, c='tab:orange', s=10, label='Final, node 0 & 2')
        ax.scatter(*pos2b, c='tab:red',    s=10, label='Final, node 1 & 3')

        ax.annotate('', xy=pos2a, xytext=pos1,
                    arrowprops=dict(arrowstyle='->', facecolor='black'),
                    )
        ax.annotate('', xy=pos2b, xytext=pos1,
                    arrowprops=dict(arrowstyle='->', facecolor='black'),
                    )

    handles, labels = ax.get_legend_handles_labels()
    plt.legend(handles[:3], labels[:3])
    ax.set_xlabel('Embedding node 0 and 1')
    ax.set_ylabel('Embedding node 2 and 3')
    ax.set_aspect('equal')
    ax.set_title('Predicted bifurcation (conditioned on output)')

    plt.show()

# Check if bifurcation at least is consistent-ish
If the predicted value of node 1 is higher than node 0, then the predicted value of node 3 should also be higher than node 2. This way we can check if the bifurcation is at least going in a correct direction, even if the actual prediction is not very accurate yet.

In [None]:
# %% check if bifurcation at least is consistent-ish
model.eval()
for batch in test_loader:
    batch = batch.clone().to(device)
    pred = model(batch, cond_on_final=False)[-1]
    pred = (pred*y_std+y_m).cpu().detach().numpy().reshape(-1, 4)

    bools = (pred[:, 0] > pred[:, 1]) == (pred[:, 2] > pred[:, 3])
    print(f'No conditioning on output: Bifurcation in the right direction {sum(bools)/len(bools)*100:.1f}% of the time')

model.eval()
for batch in test_loader:
    batch = batch.clone().to(device)
    pred = model(batch, cond_on_final=True)[-1]
    pred = (pred*y_std+y_m).cpu().detach().numpy().reshape(-1, 4)

    bools = (pred[:, 0] > pred[:, 1]) == (pred[:, 2] > pred[:, 3])
    print(f'Conditioned on output: Bifurcation in the right direction {sum(bools)/len(bools)*100:.1f}% of the time')

# Calculate Wasserstein distance

In [None]:
from scipy.stats import wasserstein_distance_nd
model.eval()
rel_dist = []
for batch in test_loader:
    batch = batch.clone().to(device)

    preds = np.zeros((len(batch.x), 100))  # create 100 predictions per node

    for i in range(100):  # make 100 predictions for each data point
        preds[:, i] = model(batch, cond_on_final=False)[-1].cpu().detach().numpy().reshape(-1)*y_std + y_m

    preds = preds.reshape(-1, 4, 100)
    reals = batch.y.cpu().detach().numpy().reshape(-1, 4, 2)*y_std + y_m

    print('preds.shape:', preds.shape)
    print('reals.shape:', reals.shape)
    wd = []
    for pred, real in zip(preds, reals):  # iterate over all data points (as far as I know, this cannot be batched)
        wd.append(wasserstein_distance_nd(pred.T, real.T))

        diff = pred[..., np.newaxis] - real[..., np.newaxis, :]
        dist = np.linalg.norm(diff, axis=0)
        print(diff.shape)
        print(dist.shape)
        rel_dist_temp = dist[..., 0] / (dist[..., 0] + dist[..., 1])
        print(rel_dist_temp.shape)
        rel_dist.extend(rel_dist_temp.flatten().tolist())

print(np.mean(wd))

In [None]:
# histogram of rel_dist
# set font size to 16
plt.rcParams.update({'font.size': 16})
plt.figure(figsize=(5,5))
plt.hist(rel_dist,
         bins=np.linspace(0, 1, 40), density=True)
plt.xlabel(r'$\frac{\text{Distance to option 1}}{\text{Distance to option 1 + Distance to option 2}}$', fontsize=18)
# change label font size x label
plt.axvline(0.0, color='red', linestyle='--')
plt.axvline(1.0, color='red', linestyle='--')
plt.ylabel('Density')

This shows that the VAE+GNN approach is only very slightly better than the non-probabilistic GNN approach.