In [None]:
import pickle
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch_geometric as tg
import torch.nn.functional as F
from torch import nn, Tensor
from torch_geometric.nn import MessagePassing

# Open dataset

In [None]:
with open('../data/FourNodeGraph_data.pkl', 'rb') as f:
    data = pickle.load(f)

for key, value in data.items():
    if isinstance(value, torch.Tensor):
        print(key, value.size())
    elif isinstance(value, list):
        print(key, len(value))
    else:
        print(key, value)

data_tr = data['data_tr']
data_te = data['data_te']
x_m = data['x_m']
y_m = data['y_m']
x_std = data['x_std']
y_std = data['y_std']

In [None]:
data_tr[0].edge_index

In [None]:
# edge_attr = torch.FloatTensor([1, 1, -1, -1, -1, -1, 1, 1.0]).reshape(-1, 1)
edge_attr = torch.FloatTensor([1.0]*8).reshape(-1, 1)
# edge_attr

In [None]:
for graph in data_tr:
    graph.edge_attr = edge_attr
for graph in data_te:
    graph.edge_attr = edge_attr

In [None]:
data_tr[0]

In [None]:
train_loader = tg.loader.DataLoader(data_tr, batch_size=64)
test_loader = tg.loader.DataLoader(data_te, batch_size=100000)

# Choose device

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    print('cuda available')
    # mlflow.log_param('device', torch.cuda.get_device_name(device))
else:
    device = torch.device('cpu')
    print('cuda not available')
    # mlflow.log_param('device', 'cpu')

# Define Model 

In [None]:
class MyMessagePassingLayer(MessagePassing):
    """message passing layer that updates both node and edge embeddings.
    first, the node embedding is updated based on the previous node embedding and the message received from neighboring nodes. these messages are based on the embedding of the source node and the edge attributes.
    second, the edge embedding is updated based on the previous edge embedding and the node embeddings of the source and target node.

    parameters
    ----------
    messagepassing : [type]
        [description]
    """

    def __init__(self, node_in, edge_in, message_size, node_out, edge_out):

        """initialize layer

        parameters
        ----------
        node_in : int
            previous node embedding size
        edge_in: int
            previous edge embedding size
        message_size : int
            size of the message
        node_out : int
            node embedding size after updating
        edge_out: in
            edge embedding size after updating
        """
        super().__init__(aggr='add')
        self.mlp_message = torch.nn.Sequential(
                        torch.nn.Linear(2*node_in + edge_in, message_size),
                        torch.nn.LeakyReLU(),
                        torch.nn.Linear(message_size, message_size),
                        torch.nn.LeakyReLU())
        self.mlp_update = torch.nn.Sequential(
                        torch.nn.Linear(node_in + message_size,
                                        node_out),
                        torch.nn.LeakyReLU(),
                        torch.nn.Linear(node_out, node_out))
        self.mlp_edge = torch.nn.Sequential(
                        torch.nn.Linear(edge_in + 2*node_in, edge_out),
                        torch.nn.LeakyReLU(),
                        torch.nn.Linear(edge_out, edge_out))

        # print('__init__')
        # self.lin1 = torch.nn.linear(node_in, node_out)

    def forward(self, x, edge_index, edge_attr):
        """[summary]

        parameters
        ----------
        x : torch.tensor, shape [n, node_in]
            current node embedding for each of the n nodes
        edge_index : torch.tensor, shape [2, e]
            indices of all edges
        edge_attr : torch.tensor, shape [e, edge_in]
            edge_attributes of each of the e edges
        """
        # print('forward')
        return self.propagate(edge_index, x=x, edge_attr=edge_attr), self.edge_updater(edge_index, x=x, edge_attr=edge_attr)

    def edge_update(self, edge_index, x_i, x_j, edge_attr):
        # print('edge_updater')
        temp = torch.cat((edge_attr, x_i, x_j), dim=1)
        # print('edge_update, temp.shape:', temp.shape)

        return self.mlp_edge(temp)

    def message(self, x_i, x_j, edge_attr):
        """[summary]

        parameters
        ----------
        x_j : torch.tensor, shape [e, node_in]
            node embeddings of source nodes
        edge_attr : torch.tensor, shape [e, edge_in]
            [description]
        """
        # print('message')
        # print('x_j.shape, edge_attr.shape:', x_j.shape, edge_attr.shape)
        temp = torch.cat((x_i, x_j, edge_attr), dim=1)
        return self.mlp_message(temp)

    def update(self, aggr_out, x):
        # print('update')
        # print('x.shape, aggr_out.shape:', x.shape, aggr_out.shape)
        temp = torch.cat((x, aggr_out), dim=1)
        # temp = x
        return self.mlp_update(temp)


class MyGNN(torch.nn.Module):
    def __init__(self, use_edge_weight=False):
        super().__init__()

        self.conv1 = MyMessagePassingLayer(3, 1, 0, 16, 16)
        # self.batchnorm_x = torch.nn.BatchNorm1d(16)
        # self.batchnorm_e = torch.nn.BatchNorm1d(16)
        self.conv2 = MyMessagePassingLayer(16, 16, 16, 16, 16)
        # self.batchnorm2_x = torch.nn.BatchNorm1d(16)
        # self.batchnorm2_e = torch.nn.BatchNorm1d(16)
        self.conv3 = MyMessagePassingLayer(16, 16, 16, 1, 0)

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr

        # convolutional part
        x, edge_attr = self.conv1(x, edge_index, edge_attr=edge_attr)
        # x, edge_attr = self.batchnorm_x(x), self.batchnorm_e(edge_attr)
        x, edge_attr = F.leaky_relu(x), F.leaky_relu(edge_attr)
        x, edge_attr = self.conv2(x, edge_index, edge_attr=edge_attr)
        # x, edge_attr = self.batchnorm2_x(x), self.batchnorm2_e(edge_attr)
        x, edge_attr = F.leaky_relu(x), F.leaky_relu(edge_attr)
        x, _ = self.conv3(x, edge_index, edge_attr=edge_attr)

        return x

In [None]:
class Flow(nn.Module):
    def __init__(self):
        # dim: dimension of data sample
        # cond: dimension of the thing we're conditioning on (the bet in this case)
        # h: hidden layer size
        super().__init__()
        super(Flow, self).__init__()
        self.GNN = MyGNN()

    def forward(self, t, x_t, batch) -> Tensor:
        batch2 = batch.clone()
        node_attr = batch2.x
        # print('t.shape', t.shape)
        # print('x_t.shape', x_t.shape)
        # print('node_attr.shape', node_attr.shape)
        batch2.x = torch.cat((node_attr, t, x_t), dim=-1)
        # print('batch2.x.shape', batch2.x.shape)

        node_attr = self.GNN(batch2)
        # print('modified node_attr.shape', node_attr.shape)

        return node_attr

    def step(self, x_t: Tensor, t_start: Tensor, t_end: Tensor, batch) -> Tensor:
        # t: float, current time
        # t_end: float, end time
        # x_t: shape [batch size, 2], current position
        # cond: shape [batch size, 2], initial position

        t_start = t_start.view(1, 1).expand(x_t.shape[0], 1)

        # print('x_t.shape', x_t.shape)
        # print('t_start.shape', t_start.shape)
        # print('t_end.shape', t_end.shape)
        # print('cond.shape', cond.shape)


        return (x_t + (t_end - t_start)
                * self(
                    t=t_start + (t_end - t_start) / 2,
                    x_t= x_t + self(x_t=x_t, t=t_start, batch=batch) * (t_end - t_start) / 2,
                    batch=batch
                       )
                )

In [None]:
flow = Flow().to(device)
print(flow)

n_params = sum(p.numel() for p in flow.parameters())
print('Total nr of parameters:', n_params)

# Train model

In [None]:
def MSELoss_allTargets(pred, target, batch, return_indices=False):
    """MSE loss applied on all possible targets, then taking the minimum per graph

    Parameters
    ----------
    pred : torch tensor, [N, f]
        prediction per node, N = total nr of nodes, f = nr of predicted features
    target : torch tensor, [N, f, a]
        all possible targets, N = total nr of nodes, f = nr of predicted features, a = nr of possible alternatives
    batch : torch tensor, [N, ]
        graph that each node belongs to
    return_indices : bool, optional
        whether to return the indices of the alternatives giving the minimum loss, by default False

    Returns
    -------
    torch tensor, [1,]
        MSE loss, using the closest alternative for each graph
    """
    # sum square error per node per alternative (resulting shape: [N, a])
    SE = torch.sum((pred.unsqueeze(-1) - target)**2, dim=1)
    # sum square error per graph per alternative (resulting shape: [G, a], with G the nr of graphs)
    SE = tg.nn.global_add_pool(SE, batch=batch)
    # take the minimum error per graph
    SE_min, inds = torch.min(SE, axis=-1)

    # take mean over entire batch
    MSE = torch.sum(SE_min)/(pred.numel())

    if return_indices:
        return MSE, inds
    else:
        return MSE

In [None]:
symmetric_matching = True

optimizer = torch.optim.Adam(flow.parameters(), 1e-3)
loss_fn = nn.MSELoss()

MSE_loss = []
for i in range(1000):
    if i%100 == 0:
        print(i)
    loss_temp = []
    for i, batch in enumerate(train_loader):
        batch = batch.to(device)

        cond = batch.x  # initial position (condition)
        x_0 = cond + torch.randn_like(batch.y[..., 0]).to(device)  # initial Gaussian noise

        if symmetric_matching:  # find the closest target
            mse, inds = MSELoss_allTargets(x_0, batch.y, batch.batch, return_indices=True)
            inds = inds[batch.batch]
            x_1 = torch.take_along_dim(batch.y, inds.reshape(-1, 1, 1), dim=-1).squeeze(-1)
        else:  # use the first option as target
            x_1 = batch.y[..., 0]  # final position (target)

        n_batches = torch.max(batch.batch) + 1

        t = torch.rand(n_batches, 1).to(device)[batch.batch]

        x_t = (1 - t) * x_0 + t * x_1
        dx_t = x_1 - x_0

        optimizer.zero_grad()
        loss = loss_fn(flow(t=t, x_t=x_t, batch=batch), dx_t)
        loss.backward()
        optimizer.step()

        loss_temp.append(loss.item())

    if i == 200:
        optimizer = torch.optim.Adam(flow.parameters(), 1e-5)
    MSE_loss.append(np.mean(loss_temp))

# Plot loss

In [None]:
plt.plot(MSE_loss)
plt.yscale('log')
plt.legend()

# Sampling

In [None]:
# %matplotlib qt
n_steps = 8

fig, axes = plt.subplots(1, n_steps//2 + 1, figsize=(20, 4), sharex=True, sharey=True)

time_steps = torch.linspace(0, 1.0, n_steps + 1).to(device)

for ax in axes:
    ax.set_aspect('equal')
    # ax.axline((0,0), (1,1), c='tab:red')
    # ax.axline((0,0), (-1,1), c='tab:red')

for j in range(10):  # 10 predictions per data point
    for batch in test_loader:
        batch = batch.to(device)
        cond = batch.x  # initial position (condition)
        x = cond + torch.randn_like(cond).to(device)  # initial Gaussian noise

        real = batch.y[..., 0].cpu().detach()*y_std+y_m  # final position (target)

        pred = x.cpu().detach()*y_std+y_m

        axes[0].scatter(real, pred, s=1, c='tab:blue')
        axes[0].set_title(f't = {time_steps[0]:.2f}')
        axes[0].legend()

        for i in range(n_steps):
            x = flow.step(x_t=x, t_start=time_steps[i], t_end=time_steps[i + 1], batch=batch)

            if i % 2 == 1:
                pred = x.cpu().detach()*y_std+y_m
                real = batch.y[..., 0].cpu().detach()*y_std+y_m  # final position (target)
                axes[i//2 + 1].scatter(real, pred, s=1, c='tab:blue')
                axes[i//2 + 1].set_title(f't = {time_steps[i + 1]:.2f}')

plt.tight_layout()
plt.show()

# Bifurcation plots

In [None]:
# %%
# Plot how ground truth bifurcates
fig, ax = plt.subplots()
fig.patch.set_facecolor("None")

flow.eval()
for batch in test_loader:
    for ind in range(100):
        pos1 = (batch.x[[4*ind+0, 4*ind+2],0]*x_std+x_m).cpu().detach().numpy()
        pos2a = (batch.y[[4*ind+0, 4*ind+2],0,0]*y_std+y_m).cpu().detach().numpy()
        pos2b = (batch.y[[4*ind+1, 4*ind+3],0,0]*y_std+y_m).cpu().detach().numpy()

        ax.scatter(*pos1,  c='tab:blue',   s=10, label='Initial')
        ax.scatter(*pos2a, c='tab:orange', s=10, label='Final, node 0 & 2')
        ax.scatter(*pos2b, c='tab:red',    s=10, label='Final, node 1 & 3')

        ax.annotate('', xy=pos2a, xytext=pos1,
                    arrowprops=dict(arrowstyle='->', facecolor='black'),
                    )
        ax.annotate('', xy=pos2b, xytext=pos1,
                    arrowprops=dict(arrowstyle='->', facecolor='black'),
                    )

    handles, labels = ax.get_legend_handles_labels()
    plt.legend(handles[:3], labels[:3])
    ax.set_xlabel('Embedding node 0 and 1')
    ax.set_ylabel('Embedding node 2 and 3')
    ax.set_aspect('equal')
    ax.set_title('Ground truth bifurcation')

    plt.show()



In [None]:
# Plot how result bifurcates (not conditioned on output)
fig, ax = plt.subplots()
fig.patch.set_facecolor("None")

flow.eval()
for batch in test_loader:
    batch = batch.clone().to(device=device)

    cond = batch.x  # initial position (condition)
    x = cond + torch.randn_like(cond).to(device)  # initial Gaussian noise

    for i in range(n_steps):
        x = flow.step(x_t=x, t_start=time_steps[i], t_end=time_steps[i + 1], batch=batch)

    pred = x

    for ind in range(100):
        pos1 = (batch.x[[4*ind+0, 4*ind+2]]*x_std+x_m).cpu().detach().numpy()
        pos2a = (pred[[4*ind+0, 4*ind+2]]*y_std+y_m).cpu().detach().numpy()
        pos2b = (pred[[4*ind+1, 4*ind+3]]*y_std+y_m).cpu().detach().numpy()

        ax.scatter(*pos1,  c='tab:blue',   s=10, label='Initial')
        ax.scatter(*pos2a, c='tab:orange', s=10, label='Final, node 0 & 2')
        ax.scatter(*pos2b, c='tab:red',    s=10, label='Final, node 1 & 3')

        ax.annotate('', xy=pos2a, xytext=pos1,
                    arrowprops=dict(arrowstyle='->', facecolor='black'),
                    )
        ax.annotate('', xy=pos2b, xytext=pos1,
                    arrowprops=dict(arrowstyle='->', facecolor='black'),
                    )

    handles, labels = ax.get_legend_handles_labels()
    plt.legend(handles[:3], labels[:3])
    ax.set_xlabel('Embedding node 0 and 1')
    ax.set_ylabel('Embedding node 2 and 3')
    ax.set_aspect('equal')
    ax.set_title('Predicted bifurcation')

    plt.show()

# Calculate Wasserstein distance
And also the 'relative distance' (to do: think of better name?) distance to option 1/(distance to option 1 + distance to option 2)

In [None]:
test_loader = tg.loader.DataLoader(data_te, batch_size=100000)

In [None]:
from scipy.stats import wasserstein_distance_nd
flow.eval()
rel_dist = []
for batch in test_loader:
    batch = batch.clone().to(device)

    preds = np.zeros((len(batch.x), 100))  # create 100 predictions per node

    for i in range(100):  # make 100 predictions for each data point
        batch = batch.clone().to(device=device)

        cond = batch.x  # initial position (condition)
        x = cond + torch.randn_like(cond).to(device)  # initial Gaussian noise

        for j in range(n_steps):
            x = flow.step(x_t=x, t_start=time_steps[j], t_end=time_steps[j + 1], batch=batch)

        preds[..., i] = x.squeeze(-1).cpu().detach().numpy()*y_std+y_m


    preds = preds.reshape(-1, 4, 100)
    reals = batch.y.cpu().detach().numpy().reshape(-1, 4, 2)*y_std + y_m

    print('preds.shape:', preds.shape)
    print('reals.shape:', reals.shape)
    wd = []
    for pred, real in zip(preds, reals):  # iterate over all data points (as far as I know, this cannot be batched)
        wd.append(wasserstein_distance_nd(pred.T, real.T))

        diff = pred[..., np.newaxis] - real[..., np.newaxis, :]
        dist = np.linalg.norm(diff, axis=0)
        print(diff.shape)
        print(dist.shape)
        rel_dist_temp = dist[..., 0] / (dist[..., 0] + dist[..., 1])
        print(rel_dist_temp.shape)
        rel_dist.extend(rel_dist_temp.flatten().tolist())


print(np.mean(wd))

In [None]:
# histogram of rel_dist
plt.rcParams.update({'font.size': 16})
plt.figure(figsize=(5,5))
plt.hist(rel_dist,
         bins=np.linspace(0, 1, 40), density=True)
plt.xlabel(r'$\frac{\text{Distance to option 1}}{\text{Distance to option 1 + Distance to option 2}}$', fontsize=18)
# change label font size x label
plt.axvline(0.0, color='red', linestyle='--')
plt.axvline(1.0, color='red', linestyle='--')
plt.ylabel('Density')

# Check symmetry

In [None]:
test_loader = tg.loader.DataLoader(data_te, batch_size=1)

In [None]:
n_steps = 32
time_steps = torch.linspace(0, 1.0, n_steps + 1).to(device)
for batch in test_loader:
    batch = batch.clone().to(device)
    print('batch.x', batch.x.shape)
    print('batch.y', batch.y.shape)

    preds = np.zeros((4, 10000))  # create lots of predictions per node

    for i in range(10000):  # make lots of predictions for each data point
        batch = batch.clone().to(device=device)

        cond = batch.x[:4]  # initial position (condition)
        x = cond + torch.randn_like(cond).to(device)  # initial Gaussian noise

        for j in range(n_steps):
            x = flow.step(x_t=x, t_start=time_steps[j], t_end=time_steps[j + 1], batch=batch)

        preds[..., i] = x.squeeze(-1).cpu().detach().numpy()*y_std+y_m

    preds_1graph = preds.reshape(4, -1)
    reals_1graph = batch.y[:4].cpu().detach().numpy().reshape(4, 2)*y_std + y_m

    print('preds_1graph.shape:', preds_1graph.shape)
    print('reals_1graph.shape:', reals_1graph.shape)

    break

In [None]:
diff = preds_1graph[..., np.newaxis] - reals_1graph[..., np.newaxis, :]
dist = np.linalg.norm(diff, axis=0)
print(diff.shape)
print(dist.shape)

In [None]:
plt.hist(dist[:, 0], bins=np.linspace(-5, 25, 100), alpha=0.5, label='Distance to option 1', density=True)
plt.hist(dist[:, 1], bins=np.linspace(-5, 25, 100), alpha=0.5, label='Distance to option 2', density=True)
plt.xlabel('Distance')
plt.ylabel('Density')
plt.legend()

In [None]:

plt.rcParams.update({'font.size': 16})
plt.figure(figsize=(5,5))
plt.hist(dist[:, 0]/(dist[:,0] + dist[:, 1]),
         bins=np.linspace(0, 1, 40), density=True)
plt.xlabel(r'$\frac{\text{Distance to option 1}}{\text{Distance to option 1 + Distance to option 2}}$', fontsize=18)
# change label font size x label
plt.axvline(0.0, color='red', linestyle='--')
plt.axvline(1.0, color='red', linestyle='--')
plt.ylabel('Density')

In [None]:
np.sum(dist[:, 0] < dist[:, 1])

In [None]:
np.sum(dist[:, 0] > dist[:, 1])

In [None]:
np.sum(dist[:, 0] == dist[:, 1])

In [None]:
bools = (dist[:, 0] < dist[:, 1])+(dist[:, 0] > dist[:, 1])
print(np.sum(bools))
dist[bools]

In [None]:
dist.shape

In [None]:
np.isnan(dist).any()

In [None]:
plt.scatter(dist[:, 0], dist[:, 1], s=1)
plt.xlabel('Distance to option 1')
plt.ylabel('Distance to option 2')
plt.gca().set_aspect('equal')