In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import pickle

# Import dataset with bifurcation

In [None]:
with open('../data/ThreeRoads_data.pkl', 'rb') as f:
    data = pickle.load(f)

# print info about all entries
for key, value in data.items():
    if isinstance(value, torch.Tensor):
        print(key, value.size())
    else:
        print(key, value)

data_tr = data['data_tr']
data_te = data['data_te']
x_m = data['x_m']
y_m = data['y_m']
x_std = data['x_std']
y_std = data['y_std']

In [None]:
train_loader = torch.utils.data.DataLoader(data_tr, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(data_te, batch_size=10000, shuffle=False)

# Define loss functions

In [None]:
def KL_loss(logsig_i, mu_i, logsig_f=torch.tensor(0.0), mu_f=torch.tensor(0.0)):
    temp = 2*(logsig_f-logsig_i) - 1 + torch.exp(logsig_i)**2/torch.exp(logsig_f)**2 + (mu_f - mu_i)**2/torch.exp(logsig_f)**2
    temp = 0.5*torch.sum(temp, axis=-1)
    return torch.mean(temp)

def MSELoss_allTargets(pred, target, return_indices=False):
    """MSE loss applied on all possible targets, then taking the minimum per data point, i.e., using the closest target.

    Parameters
    ----------
    pred : torch tensor, [N, f]
        prediction per node, N = batch size, f = nr of predicted features
    target : torch tensor, [N, f, a]
        all possible targets, N = batch size, f = nr of predicted features, a = nr of possible alternatives
    return_indices : bool, optional
        whether to return the indices of the alternatives giving the minimum loss, by default False

    Returns
    -------
    MSE : torch tensor, [1,]
        MSE loss, using the closest alternative for each graph
    indices : torch tensor, [N,]
        index indicating which of the alternatives was closest
    """
    # sum square error of both nodes per alternative (resulting shape: [N, a])
    SE = torch.sum((pred.unsqueeze(-1) - target)**2, dim=1)
    # take the minimum square error per graph
    SE_min, inds = torch.min(SE, axis=-1)

    # take mean over entire batch
    MSE = torch.sum(SE_min)/(pred.numel())

    if return_indices:
        return MSE, inds
    else:
        return MSE

# Choose device

In [None]:
device = torch.device('cpu')  # for debugging

In [None]:
# if torch.cuda.is_available():
#     device = torch.device('cuda')
#     print('cuda available')
# else:
#     device = torch.device('cpu')
#     print('cuda not available')

# Define Model
As a base model, for the VAE, we use the same model as for the non-probabilistic model in ThreeRoads_non-probabilistic.ipynb, with one difference; the number of input and output features per node are now variable.

In [None]:
class SetModel(torch.nn.Module):
    def __init__(self, size_in, size_out):
        super(SetModel, self).__init__()
        self.mininet1 = torch.nn.Sequential(
            torch.nn.Linear(size_in, 16),
            # torch.nn.LeakyReLU(),
            # torch.nn.Linear(16, 16),
            # torch.nn.LeakyReLU(),
            # torch.nn.Linear(16, 16),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(16, 16)
        )

        self.mininet2 = torch.nn.Sequential(
            torch.nn.Linear(2*size_in, 16),
            # torch.nn.LeakyReLU(),
            # torch.nn.Linear(16, 16),
            # torch.nn.LeakyReLU(),
            # torch.nn.Linear(16, 16),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(16, 16)
        )

        self.mininet3 = torch.nn.Sequential(
            torch.nn.Linear(32, 16),
            # torch.nn.LeakyReLU(),
            # torch.nn.Linear(16, 16),
            # torch.nn.LeakyReLU(),
            # torch.nn.Linear(16, 16),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(16, size_out)
        )

    def forward(self, x):
        # x.shape = [batch size, 2, size_in]
        # output shape = [batch size, 2, size_out]
        # use mini-net 1 on input 1 and input 2 separately
        x1 = self.mininet1(x[:, 0])
        x2 = self.mininet1(x[:, 1])

        # use mini-net 2 on the the pair 1,2 and the pair 2,1 of inputs
        x12 = self.mininet2(x.flatten(start_dim=1))
        x21 = self.mininet2(x[:, [1, 0]].flatten(start_dim=1))

        # make prediction for each node
        x1 = self.mininet3(torch.cat([x1, x21], dim=1))
        x2 = self.mininet3(torch.cat([x2, x12], dim=1))

        # concatenate the two results
        x = torch.stack([x1, x2], dim=1)

        return x

In [None]:
class ProbabilisticSetModel(torch.nn.Module):
    def __init__(self, embedding_dim, latent_dim):
        super().__init__()

        self.setmodel1 = SetModel(1, embedding_dim)
        self.setmodel2 = SetModel(2, embedding_dim)
        self.mlp1 = torch.nn.Sequential(
                        torch.nn.Linear(embedding_dim, 16),
                        torch.nn.LeakyReLU(),
                        torch.nn.Linear(16, 2*latent_dim)
                        )
        self.mlp2 = torch.nn.Sequential(
                        torch.nn.Linear(embedding_dim*2, 16),
                        torch.nn.LeakyReLU(),
                        torch.nn.Linear(16, 2*latent_dim)
                        )
        self.setmodel3 = SetModel(embedding_dim+latent_dim, 1)

        self.embedding_dim = embedding_dim
        self.latent_dim = latent_dim

    def forward(self, x, y, cond_on_final=False):
        # x shape: [batch size, 2]
        # y shape: [batch size, 2]
        bs = x.shape[0]

        if x.ndim != 2 or x.shape[-1] != 2:
            raise ValueError(f'x should have shape [batch size, 2], currently {x.shape}')
        if cond_on_final:
            if y.ndim != 2 or y.shape[-1] != 2:
                raise ValueError(f'y should have shape [batch size, 2], currently {y.shape}')
            if x.shape[0] != y.shape[0]:
                raise ValueError(f'x and y should have the same batch size. Current shapes: x.shape={x.shape}, y.shape={y.shape}')

        x = x.unsqueeze(dim=-1)  # new shapes: [batch size, 2, 1] (indicating 1 feature)
        if cond_on_final:
            y = y.unsqueeze(dim=-1)

        # Apply set model to initial to create an embedding
        x2 = self.setmodel1(x)  # shape [batch size, 2, 16]
        # reshape to: [batch size*2, 16], so each node can be treated as a separate datapoint
        x2 = x2.flatten(end_dim=1)
        if cond_on_final:
            x3 = self.setmodel2(torch.cat([x,y], dim=-1)) # shape [batch size, 2, 16]
            x3 = x3.flatten(end_dim=1)

        # Create logsig, mu from initial
        logsig_mu_i = self.mlp1(x2).reshape(-1, 2, self.latent_dim)
        # shape [batch size*2, 2 predictions (log sigma and mu), latent_dim]
        logsig_i, mu_i = logsig_mu_i[:, 0], logsig_mu_i[:, 1]
        # shape [batch size*2, latent_dim]
        if cond_on_final:
            # Create logsig, mu from initial and final
            x4 = torch.cat((x2, x3), dim=-1)
            logsig_mu_f = self.mlp2(x4).reshape(-1, 2, self.latent_dim)
            logsig_f, mu_f = logsig_mu_f[:, 0], logsig_mu_f[:, 1]

        # Sample (1 sample per node)
        eps = torch.randn_like(logsig_i)
        z_i = eps*torch.exp(logsig_i) + mu_i
        z_i = z_i.reshape(-1, 2, self.latent_dim)
        if cond_on_final:
            eps = torch.randn_like(logsig_f)
            z_f = eps*torch.exp(logsig_f) + mu_f
            z_f = z_f.reshape(-1, 2, self.latent_dim)

        x2 = x2.reshape(-1, 2, self.embedding_dim)
        if cond_on_final:
            x = torch.cat((x2, z_f), dim=-1)
        else:
            x = torch.cat((x2, z_i), dim=-1)

        x = self.setmodel3(x)

        if cond_on_final:
            return logsig_i.reshape(-1, 2, self.latent_dim), mu_i.reshape(-1, 2, self.latent_dim), logsig_f.reshape(-1, 2, self.latent_dim), mu_f.reshape(-1, 2, self.latent_dim), x.squeeze(-1)
        else:
            return (x.squeeze(-1), )





In [None]:
# use latent_dim=1 here to be able to visualize the latent space.
# using a higher value might lead to better results (although from what I see, they seem to be bad anyways)
model = ProbabilisticSetModel(embedding_dim=16, latent_dim=1).to(device)

for batch in train_loader:
    x = batch[..., 0].to(device)
    y = batch[..., 1].to(device)
    for asdf in model(x, y, cond_on_final=True):
        print(asdf.shape)

    break

In [None]:
print(model)

In [None]:
n_params = sum(p.numel() for p in model.parameters())
print('Total nr of parameters:', n_params)

In [None]:
model.eval()
for batch in test_loader:
    x = batch[..., 0].to(device)
    y = batch[..., 1].to(device)
    pred = model(x, y)[-1]
    print('Shapes:', pred.shape, y.shape)
    y123 = batch[..., 1:].to(device)
    print('Shape y123:', y123.shape)
    test_loss, inds = MSELoss_allTargets(pred, y123, return_indices=True)
    print('test RMSE', torch.sqrt(test_loss).item())

    ground_truth = np.take_along_axis(y123.cpu().numpy(),
                                    inds.reshape(-1, 1, 1).cpu().numpy(), axis=-1)
    plt.scatter(ground_truth, pred.cpu().detach().numpy(), s=1)
    # plt.gca().set_aspect('equal')
    plt.xlabel('real')
    plt.ylabel('predicted')
    plt.title('Real vs predicted before training\n(to check initialization of network)')

    break

# Train

In [None]:
lr=1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

losses = {
            # on training data, while conditioning on output:
            'MSE_train': [],
            'KL_train': [],
            'loss_train': [], # total loss

            # on validation data, without conditioning on output:
            'MSE_val': [],
            'MSE_val_allTargets': [],  # MSE compared to closest target

            # on validation data, but still conditioned on output:
            'KL_val2': [],
            'MSE_val2': []
        }

KL_increase_range = 50  # nr of epochs over which we increase the weight of the KL loss
for epoch in range(500):
# for epoch in range(50):
    print(f'epoch {epoch}')

    # ====================== 1) TRAIN ======================
    KL_weight = np.clip(epoch/KL_increase_range, a_min=0.0, a_max=1.0)/2
    print(' \t KL_weight:', KL_weight)
    losses_temp = {key: [] for key in losses}

    model.train()
    for batch in train_loader:
        x = batch[..., 0].to(device)
        y = batch[..., 1].to(device)
        optimizer.zero_grad()

        out = model(x, y, cond_on_final=True)
        pred = out[-1]
        if torch.isnan(pred).any():
            raise ValueError('nan value in train prediction')

        MSE_loss = torch.nn.MSELoss()(pred, y)
        KL_loss1 = KL_loss(*out[:-1])

        loss = (1.0-KL_weight)*MSE_loss + KL_weight*KL_loss1
        loss.backward()

        optimizer.step()

        losses_temp['MSE_train'].append(MSE_loss.item())
        losses_temp['KL_train'].append(KL_loss1.item())
        losses_temp['loss_train'].append(loss.item())

    for loss in ['MSE_train', 'KL_train', 'loss_train']:
        losses[loss].append(np.mean(losses_temp[loss]))
        print(f'\t{loss:10} {losses[loss][-1]:8.4}')

    # =============== 2) EVALUATE WITHOUT CONDITIONING ON OUTPUT ===============
    # In evaluation mode, so not conditioned on output. Calculating KL loss is therefore not possible.
    model.eval()
    for batch in test_loader:
        x = batch[..., 0].to(device)
        y = batch[..., 1].to(device)
        out = model(x, y, cond_on_final=False)
        pred = out[-1]
        if torch.isnan(pred).any():
            raise ValueError('nan value in train prediction')

        MSE_loss = torch.nn.MSELoss()(pred, y)
        y123 = batch[..., 1:].to(device)
        MSE_loss_alltarg, inds = MSELoss_allTargets(pred, y123, return_indices=True)

        losses_temp['MSE_val'].append(MSE_loss.item())
        losses_temp['MSE_val_allTargets'].append(MSE_loss_alltarg.item())

    for loss in ['MSE_val', 'MSE_val_allTargets']:
        losses[loss].append(np.mean(losses_temp[loss]))
        print(f'\t{loss:10} {losses[loss][-1]:8.4}')

    # ================ 3) EVALUATE WITH CONDITIONING ON OUTPUT ================
    model.eval()
    for batch in test_loader:
        x = batch[..., 0].to(device)
        y = batch[..., 1].to(device)
        out = model(x, y, cond_on_final=True)
        pred = out[-1]
        if torch.isnan(pred).any():
            raise ValueError('nan value in train prediction')

        MSE_loss = torch.nn.MSELoss()(pred, y)
        KL_loss1 = KL_loss(*out[:-1])

        losses_temp['MSE_val2'].append(MSE_loss.item())
        losses_temp['KL_val2'].append(KL_loss1.item())

    for loss in ['MSE_val2', 'KL_val2']:
        losses[loss].append(np.mean(losses_temp[loss]))
        print(f'\t{loss:10} {losses[loss][-1]:8.4}')

    # ======== Lowering learning rate, early stopping =========
    # Stop training if both MSE and KL loss did not decrease enough on validation data
    if (
        epoch > KL_increase_range
        and
        (np.mean(losses["MSE_val_allTargets"][-5:])
             > 0.995*np.mean(losses["MSE_val_allTargets"][-10:-5]))
        and (np.mean(losses["KL_val2"][-5:])
             > 0.995*np.mean(losses["KL_val2"][-10:-5]))
    ):

        print('======= !! Validation loss did not decrease enough =======')
        print('Stopping training')
        break


# Plot losses

In [None]:
plt.figure()

for key, value in losses.items():
    plt.plot(value, label=key)
plt.legend()
plt.yscale('log')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.show()


# Compare input, output, predicted output

In [None]:
model.eval()
for batch in test_loader:
    batch = batch[:6]
    x = batch[..., 0].to(device)
    y = batch[..., 1].to(device)
    pred = model(x, y)[-1]

    x = x.cpu().detach().numpy()*x_std + x_m
    y123 = batch[..., 1:].cpu().detach().numpy()*y_std + y_m
    pred = pred.cpu().detach().numpy()*y_std + y_m

    # print input
    print('============== INPUT ==============')
    for asdf in x:
        print(*asdf, sep=', ')

    print('============== POSSIBLE OUTPUTS ==============')
    for asdf in y123:  #.reshape(-1, 2, 3):
        print(*asdf.T, sep=', ')

    print('============== PREDICTIONS ==============')
    for asdf in pred:
        print(*asdf.T, sep=', ')

# Plot results

The plot below shows how 10 results for the first 10 data points. You can see that unlike the non-probabilistic model in ThreeRoads_non-probabilistic.ipynb, now if we run the model ten times on the same input, we get 10 different results. We also see that these results form clusters around the three possible outputs, which is what we want, but we also see that this clusters are not very sharp; there are quite some results that are inbetween the possible options. This not what we want.

In [None]:
model.eval()
preds_all = []
for batch in test_loader:

    for _ in range(50): # create 50 predictions for each of the 10 first data points
        x = batch[..., 0].to(device)
        y1 = batch[..., 1].to(device)
        out = model(x, y1, cond_on_final=False)
        y_pred = out[-1].cpu().detach().numpy()*y_std + y_m
        x = x.cpu().detach().numpy()*x_std + x_m

        preds_all.extend(y_pred[:10])

In [None]:
# Plot the bifurcations in the first 10 cases
fig, ax = plt.subplots(figsize=(3,5), dpi=300)
fig.patch.set_facecolor("None")

model.eval()
for batch in test_loader:
    x = batch[..., 0].numpy()*x_std + x_m
    y1 = batch[..., 1].numpy()*y_std + y_m
    y2 = batch[..., 2].numpy()*y_std + y_m
    y3 = batch[..., 3].numpy()*y_std + y_m

    for i in range(10): # iterate over 10 first data points
        ax.scatter(*x[i], c='tab:blue', s=50, label='_'*i + 'Initial')
        ax.scatter(*y1[i], c='tab:orange', s=10, label='_'*i + 'Option 1')
        ax.scatter(*y2[i], c='tab:green', s=10, label='_'*i + 'Option 2')
        ax.scatter(*y3[i], c='tab:red', s=10, label='_'*i + 'Option 3')

        ax.annotate('', xy=y1[i], xytext=x[i],
                    arrowprops=dict(arrowstyle='->', facecolor='black'),
                    )
        ax.annotate('', xy=y2[i], xytext=x[i],
                    arrowprops=dict(arrowstyle='->', facecolor='black'),
                    )
        ax.annotate('', xy=y3[i], xytext=x[i],
                    arrowprops=dict(arrowstyle='->', facecolor='black'),
                    )

preds_all = np.array(preds_all)
ax.scatter(preds_all[:, 0], preds_all[:, 1], c='magenta', s=50, marker='x', label='Prediction', alpha=0.3)

handles, labels = ax.get_legend_handles_labels()
plt.legend(handles[:5], labels[:5])
ax.set_xlabel('Person 1')
ax.set_ylabel('Person 2')
ax.set_aspect('equal')

plt.show()

The plot below shows what happens when you condition on one of the target outputs, which is what we do during training. It shows that most of the time, there is a cluster around option 1, which is the one that is being conditioned upon.

In [None]:
# Plot the bifurcations in the first 10 cases
fig, ax = plt.subplots(figsize=(5,8))
fig.patch.set_facecolor("None")

model.eval()
for batch in test_loader:
    x = batch[..., 0].numpy()*x_std + x_m
    y1 = batch[..., 1].numpy()*y_std + y_m
    y2 = batch[..., 2].numpy()*y_std + y_m
    y3 = batch[..., 3].numpy()*y_std + y_m

    for i in range(10):
        ax.scatter(*x[i], c='tab:blue', s=50, label='_'*i + 'Initial')
        ax.scatter(*y1[i], c='tab:orange', s=10, label='_'*i + 'Option 1')
        ax.scatter(*y2[i], c='tab:green', s=10, label='_'*i + 'Option 2')
        ax.scatter(*y3[i], c='tab:red', s=10, label='_'*i + 'Option 3')

        ax.annotate('', xy=y1[i], xytext=x[i],
                    arrowprops=dict(arrowstyle='->', facecolor='black'),
                    )
        ax.annotate('', xy=y2[i], xytext=x[i],
                    arrowprops=dict(arrowstyle='->', facecolor='black'),
                    )
        ax.annotate('', xy=y3[i], xytext=x[i],
                    arrowprops=dict(arrowstyle='->', facecolor='black'),
                    )

    for _ in range(20):
        x = batch[..., 0].to(device)
        y1 = batch[..., 1].to(device)
        out = model(x, y1, cond_on_final=True)
        y_pred = out[-1].cpu().detach().numpy()*y_std + y_m
        x = x.cpu().detach().numpy()*x_std + x_m

        for i in range(10):

            ax.scatter(*y_pred[i], c='magenta', s=50, marker='x', label='Prediction', alpha=0.3)

            # ax.annotate('', xy=y_pred[i], xytext=x[i],
            #             arrowprops=dict(arrowstyle='->', facecolor='gray', edgecolor='magenta', alpha=0.3),
            #                 )

handles, labels = ax.get_legend_handles_labels()
plt.legend(handles[:5], labels[:5])
ax.set_xlabel('Node 1')
ax.set_ylabel('Node 2')
ax.set_aspect('equal')

plt.show()

# Plot real vs predicted

The plots below show the real value per node versus the value predicted by the model. However, if we compare the predicted value simply with option 1 (the values that would be trained on), then the performance looks bad if the model happens to predict one of the other two valid solutions, even though this is actually exactly the behavior we want. Therefore, we pick the closest valid output option and use that as our ground truth.

For the second plot we again condition on option 1, just as we do during training.
It looks like the performance in both cases is similar, which is good news since it means that the model does not rely on the given output.

In [None]:
plt.figure()
model.eval()
for batch in test_loader:
    x = batch[..., 0].to(device)
    y123 = batch[..., 1:].to(device)
    batch = batch.clone()
    batch.to(device=device)
    pred = model(x, None)[-1]

    _, inds = MSELoss_allTargets(pred, y123, return_indices=True)

ground_truth = np.take_along_axis(y123.cpu().detach().numpy(),
                                  inds.cpu().numpy().reshape(-1,1,1), axis=-1)
plt.scatter(ground_truth*y_std+y_m, pred.cpu().detach().numpy()*y_std+y_m, s=1)
plt.axline([0,0], [1,1], c='tab:orange')
plt.gca().set_aspect('equal')
plt.xlabel('real')
plt.ylabel('predicted')
plt.title('Real (best match) vs predicted')
plt.show()

plt.figure()
model.eval()
for batch in test_loader:
    x = batch[..., 0].to(device)
    y = batch[..., 1].to(device)
    batch = batch.clone()
    batch.to(device=device)
    pred = model(x, y, cond_on_final=True)[-1]

y = y.cpu().detach().numpy()
plt.scatter(y*y_std+y_m, pred.cpu().detach().numpy()*y_std+y_m, s=1)
plt.axline([0,0], [1,1], c='tab:orange')
plt.gca().set_aspect('equal')
plt.xlabel('real')
plt.ylabel('predicted')
plt.title('Real vs predicted, conditioned on output')
plt.show()

# Contourplot of latent space
Needs a latent_dim of 1, so there is one latent dimension per node.

In [None]:
if not model.latent_dim == 1:
    raise ValueError('Cannot visualize latent space if latent_dim is not 1')


In [None]:
import matplotlib.patches as mpatch

In [None]:
# %matplotlib qt
N = 300
v1, v2 = -40.0, 40.0
d = v2 - v1
options = [[v1-d/2, v2-d/2], [v1-d/2, v2+d/2], [v1+d/2, v2+d/2]]
options = np.tile(np.array(options).T, reps=(N, 1, 1))
alpha = 1.0
with torch.no_grad():
    model.eval()
    x_new = (torch.FloatTensor([v1,v2]*N).reshape(-1, 2, 1)-x_m)/x_std
    x_new = x_new.to(device)
    y = (torch.from_numpy(options).float().reshape(-1, 2, 3)-y_m)/y_std
    y = y.to(device)

    # apply first part of model to get latent space mu and sigma
    x2 = model.setmodel1(x_new)
    x2 = x2.flatten(end_dim=1)
    logsig_mu_i = model.mlp1(x2).reshape(-1, 2, model.latent_dim)
    logsig_i, mu_i = logsig_mu_i[:, 0], logsig_mu_i[:, 1]
    logsig_f = []
    mu_f = []
    for i in range(3):
        y_temp = y[..., [i]]
        asdf = torch.cat([x_new,y_temp], dim=-1)
        x3 = model.setmodel2(asdf)
        x3 = x3.flatten(end_dim=1)
        x4 = torch.cat((x2, x3), dim=-1)
        logsig_mu_f = model.mlp2(x4).reshape(-1, 2, model.latent_dim)
        logsig_f.append(logsig_mu_f[:, 0])
        mu_f.append(logsig_mu_f[:, 1])

    # sample
    eps = torch.randn_like(logsig_i)
    z_i = eps*torch.exp(logsig_i) + mu_i
    z_f = []
    for i in range(3):
        eps = torch.randn_like(logsig_f[i])
        temp = eps*torch.exp(logsig_f[i]) + mu_f[i]
        temp = temp.reshape(-1, 2, model.latent_dim)
        z_f.append(temp)

    # turn sample into prediction
    x2 = x2.reshape(-1, 2, model.embedding_dim)
    x_new = torch.cat((x2, z_i.reshape(-1, 2, model.latent_dim)), dim=-1)
    x_new = model.setmodel3(x_new)
    pred = x_new.squeeze(-1).cpu().detach().numpy()*y_std + y_m
    preds = []
    for i in range(3):
        x_new = torch.cat((x2, z_f[i]), dim=-1)
        x_new = model.setmodel3(x_new)
        preds.append(x_new.squeeze(-1).cpu().detach().numpy()*y_std + y_m)

    mu_i = mu_i.reshape(-1, 2)
    logsig_i = logsig_i.reshape(-1, 2)
    z_i = z_i.reshape(-1, 2)
    for i in range(3):
        logsig_f[i] = logsig_f[i].reshape(-1, 2)
        mu_f[i] = mu_f[i].reshape(-1, 2)

    fig = plt.figure(figsize=(10,10))
    # draw ellipse patch centered at mu_i with axes given by np.exp(logsig_i)
    ellipse = mpatch.Ellipse(mu_i[0],
                            width=2*np.exp(logsig_i[0,0].cpu().detach().numpy()),
                            height=2*np.exp(logsig_i[0,1].cpu().detach().numpy()),
                            fc='None', edgecolor='black')
    plt.gca().add_patch(ellipse)
    colors = ['tab:orange', 'tab:blue', 'tab:red']
    for i in range(3):
        ellipse = mpatch.Ellipse(mu_f[i][0],
                                width=2*np.exp(logsig_f[i][0,0].cpu().detach().numpy()),
                                height=2*np.exp(logsig_f[i][0,1].cpu().detach().numpy()),
                                fc='None', edgecolor=colors[i])
        plt.gca().add_patch(ellipse)

    # plot samples
    plt.scatter(z_i[:, 0].cpu().detach().numpy(), z_i[:, 1].cpu().detach().numpy(), s=1, label='latent space samples', color='black', alpha=alpha)
    for i in range(3):
        plt.scatter(z_f[i][:, 0].cpu().detach().numpy(), z_f[i][:, 1].cpu().detach().numpy(), s=1, label=f'latent space samples, cond. on option {i}', alpha=alpha, color=colors[i])

    # plot means
    plt.scatter(mu_i[:, 0].cpu().detach().numpy(), mu_i[:, 1].cpu().detach().numpy(), s=50, label='latent space means', c='black', marker='x')
    for i in range(3):
        plt.scatter(mu_f[i][:, 0].cpu().detach().numpy(), mu_f[i][:, 1].cpu().detach().numpy(), s=50, marker='x', label=f'latent space means, cond. on option {i}', color=colors[i])

    # contourplot of z1, z2 vs pred
    minz = min(mu_i[0,0].cpu().item() - 5*np.exp(logsig_i[0,0].cpu().item()),
               mu_i[0,1].cpu().item() - 5*np.exp(logsig_i[0,1].cpu().item())
    )
    maxz = max(mu_i[0,0].cpu().item() + 5*np.exp(logsig_i[0,0].cpu().item()),
               mu_i[0,1].cpu().item() + 5*np.exp(logsig_i[0,1].cpu().item())
    )
    plt.gca().set_xlim([minz,  maxz])
    plt.gca().set_ylim([minz, maxz])

    Z1, Z2 = np.meshgrid(
        np.linspace(minz, maxz, 100),
        np.linspace(minz, maxz, 100))
    Z1 = torch.tensor(Z1).float().to(device).reshape(-1, 1, 1)
    Z2 = torch.tensor(Z2).float().to(device).reshape(-1, 1, 1)
    Z = torch.cat((Z1, Z2), dim=1)

    # turn sample into prediction
    x_new = (torch.FloatTensor([v1,v2]*100*100).reshape(-1, 2, 1)-x_m)/x_std
    x_new = x_new.to(device)
    x2 = model.setmodel1(x_new)
    x2 = x2.reshape(-1, 2, model.embedding_dim)
    x_new = torch.cat((x2, Z), dim=-1)
    x_new = model.setmodel3(x_new)
    pred_grid = x_new.squeeze(-1)
    _, inds = MSELoss_allTargets(pred_grid, y[[0]], return_indices=True)


    cnt = plt.contourf(Z1.cpu().detach().numpy().reshape(100, 100),
                Z2.cpu().detach().numpy().reshape(100, 100),
                inds.cpu().numpy().reshape(100, 100),
                levels=[-0.5, 0.5, 1.5, 2.5],
                cmap='viridis', alpha=0.5, zorder=-10)
    plt.colorbar(label='Closest option to prediction',
        ticks=[0, 1, 2])
    plt.legend()
    plt.gca().set_aspect('equal')
    plt.xlabel('$z_1$')
    plt.ylabel('$z_2$')
    plt.gca().set_title('VAE latent space. Ellipses indicate standard deviation.')
plt.show()

In [None]:
# Plot the resulting predictions for the samples shown in the latent space plot above
fig, ax = plt.subplots(figsize=(5,8))
fig.patch.set_facecolor("None")

x = [v1, v2]
y1 = options[0,:,0]
y2 = options[0,:,1]
y3 = options[0,:,2]

ax.scatter(*x, c='tab:green', s=50, label='Initial')

for i, y_temp in enumerate([y1, y2, y3]):
    ax.scatter(*y_temp, c=colors[i], s=10, label=f'Option {i+1}')
# ax.scatter(*y1, c='tab:orange', s=10, label='Option 1')
# ax.scatter(*y2, c='tab:green', s=10, label='Option 2')
# ax.scatter(*y3, c='tab:red', s=10, label='Option 3')

ax.annotate('', xy=y1, xytext=x,
            arrowprops=dict(arrowstyle='->', facecolor='black'),
            )
ax.annotate('', xy=y2, xytext=x,
            arrowprops=dict(arrowstyle='->', facecolor='black'),
            )
ax.annotate('', xy=y3, xytext=x,
            arrowprops=dict(arrowstyle='->', facecolor='black'),
            )

for i, pred_temp in enumerate(preds):
    ax.scatter(*pred_temp.T, s=50, marker='x', label=f'Prediction, conditioned on option {i+1}', alpha=0.2, color=colors[i])

ax.scatter(*pred.T, s=50, marker='x', label='Prediction, no conditioning', alpha=0.2, color='black')

handles, labels = ax.get_legend_handles_labels()
plt.legend(handles, labels)
ax.set_xlabel('Node 1')
ax.set_ylabel('Node 2')
ax.set_aspect('equal')

plt.show()

# Calculate Wasserstein distance

In [None]:
from scipy.stats import wasserstein_distance_nd

model.eval()
for batch in test_loader:
    batch = batch.clone().to(device)

    preds = np.zeros((len(batch), 2, 100))  # create 100 predictions per data point

    for i in range(100):  # make 100 predictions for each data point
        pred = model(batch[..., 0], batch[..., 1], cond_on_final=False)[-1]
        preds[..., i] = pred.cpu().detach().numpy()*y_std+y_m


    reals = batch[..., 1:].cpu().detach().numpy()*y_std + y_m

    print(preds.shape)
    print(reals.shape)

    wd = []
    for pred, real in zip(preds, reals):  # iterate over all data points (as far as I know, this cannot be batched)
        wd.append(wasserstein_distance_nd(pred.T, real.T))

print('Mean Wasserstein distance between real and predicted:', np.mean(wd))