Here, we try a conventional, non-probabilistic approach on the Three Roads Problem, and show how it fails.

In [None]:
import pickle
import torch
import numpy as np
import matplotlib.pyplot as plt

# Open Dataset

In [None]:
with open('../data/ThreeRoads_data.pkl', 'rb') as f:
    data = pickle.load(f)

# print info about all entries
for key, value in data.items():
    if isinstance(value, torch.Tensor):
        print(key, value.size())
    else:
        print(key, value)

data_tr = data['data_tr']
data_te = data['data_te']
x_m = data['x_m']
y_m = data['y_m']
x_std = data['x_std']
y_std = data['y_std']

In [None]:
train_loader = torch.utils.data.DataLoader(data_tr, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(data_te, batch_size=10000, shuffle=False)

# Choose device

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    print('cuda available')
    # mlflow.log_param('device', torch.cuda.get_device_name(device))
else:
    device = torch.device('cpu')
    print('cuda not available')
    # mlflow.log_param('device', 'cpu')

# Define Model 

In [None]:
class SetModel(torch.nn.Module):
    def __init__(self):
        super(SetModel, self).__init__()
        self.mininet1 = torch.nn.Sequential(
            torch.nn.Linear(1, 10),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(10, 10),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(10, 10),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(10, 10)
        )

        self.mininet2 = torch.nn.Sequential(
            torch.nn.Linear(2, 10),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(10, 10),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(10, 10),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(10, 10)
        )

        self.mininet3 = torch.nn.Sequential(
            torch.nn.Linear(20, 10),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(10, 10),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(10, 10),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(10, 1)
        )

    def forward(self, x):
        # use mini-net 1 on input 1 and input 2 separately
        x1 = self.mininet1(x[:, 0:1])
        x2 = self.mininet1(x[:, 1:2])

        # use mini-net 2 on the the pair 1,2 and the pair 2,1 of inputs
        x12 = self.mininet2(x)
        x21 = self.mininet2(x[:, [1, 0]])

        # make prediction for each node
        x1 = self.mininet3(torch.cat([x1, x21], dim=1))
        x2 = self.mininet3(torch.cat([x2, x12], dim=1))

        # concatenate the two results
        x = torch.cat([x1, x2], dim=1)

        return x

model = SetModel().to(device)

In [None]:
# test permutation equivariance
for batch in train_loader:
    x = batch[..., 0].to(device)
    y = batch[..., 1].to(device)
    y_pred = model(x)
    print(y_pred[:10])

    # flip inputs (the two nodes swap values)
    x = x[:, [1,0]]
    y = y[:, [1,0]]
    y_pred = model(x)
    print(y_pred[:10])
    break

# Train Model

In [None]:
lr=1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

MSE_train = []
MSE_val = []

for epoch in range(1000):

    # ====================== 1) TRAIN ======================
    model.train()
    print(f'epoch {epoch}', end=' ')
    MSE_train_temp = []
    for batch in train_loader:
        x = batch[..., 0].to(device)
        y = batch[..., 1].to(device)
        optimizer.zero_grad()
        pred = model(x)
        if torch.isnan(pred).any():
            raise ValueError('nan value in train prediction')

        MSE_loss = torch.nn.MSELoss()(pred, y)

        MSE_loss.backward()
        optimizer.step()

        MSE_train_temp.append(MSE_loss.item())

    MSE_train.append(np.mean(MSE_train_temp))
    print(f'train {MSE_train[-1]:8.4}, ', end='')

    # ====================== 2) EVALUATE ======================
    model.eval()
    MSE_val_temp = []
    for batch in test_loader:
        x = batch[..., 0].to(device)
        y = batch[..., 1].to(device)
        pred = model(x)

        if torch.isnan(pred).any():
            raise ValueError('nan value in val prediction')

        MSE_loss = torch.nn.MSELoss()(pred, y)
        MSE_val_temp.append(MSE_loss.item())

    MSE_val.append(np.mean(MSE_val_temp))
    print(f'val {MSE_val_temp[-1]:8.4}')

    # ======== Early stopping =========
    if np.mean(MSE_val[-5:]) > 0.995*np.mean(MSE_val[-10:-5]):
        print('======= !! Validation loss did not decrease enough =======')
        print('Stopping training')
        break

In [None]:
plt.plot(MSE_train, label='training')
plt.plot(MSE_val, label='validation')
plt.yscale('log')
plt.legend()

The training curve above shows that in the beginning, the network learns to make a prediction in the right range, but after that it does not improve any more.
Below, we plot the predicted outputs together with the 3 possible outputs to compare them.

In [None]:
model.eval()
preds_all = []
for batch in test_loader:
    x = batch[..., 0].to(device)
    y_pred = model(x).cpu().detach().numpy()*y_std + y_m

    preds_all.extend(y_pred[:10])


In [None]:
preds_all

In [None]:
# Plot the bifurcations in the first 10 cases
fig, ax = plt.subplots(figsize=(3,5), dpi=300)
fig.patch.set_facecolor("None")

model.eval()
for batch in test_loader:
    x = batch[..., 0].to(device)
    y1 = batch[..., 1].numpy()*y_std + y_m
    y2 = batch[..., 2].numpy()*y_std + y_m
    y3 = batch[..., 3].numpy()*y_std + y_m
    x = x.cpu().detach().numpy()*x_std + x_m

    for i in range(10):
        ax.scatter(*x[i], c='tab:blue', s=50, label='_'*i + 'Initial')
        ax.scatter(*y1[i], c='tab:orange', s=10, label='_'*i + 'Option 1')
        ax.scatter(*y2[i], c='tab:green', s=10, label='_'*i + 'Option 2')
        ax.scatter(*y3[i], c='tab:red', s=10, label='_'*i + 'Option 3')
        # ax.scatter(*y_pred[i], c='magenta', s=50, marker='x', label='Prediction')

        ax.annotate('', xy=y1[i], xytext=x[i],
                    arrowprops=dict(arrowstyle='->', facecolor='black'),
                    )
        ax.annotate('', xy=y2[i], xytext=x[i],
                    arrowprops=dict(arrowstyle='->', facecolor='black'),
                    )
        ax.annotate('', xy=y3[i], xytext=x[i],
                    arrowprops=dict(arrowstyle='->', facecolor='black'),
                    )
        # ax.annotate('', xy=y_pred[i], xytext=x[i],
        #             arrowprops=dict(arrowstyle='->', facecolor='gray', edgecolor='magenta'),
        #             )

preds_all = np.array(preds_all)
ax.scatter(preds_all[:, 0], preds_all[:, 1], c='magenta', s=50, marker='x', label='Prediction')

handles, labels = ax.get_legend_handles_labels()
plt.legend(handles[:5], labels[:5])
ax.set_xlabel('Person 1')
ax.set_ylabel('Person 2')
ax.set_aspect('equal')

plt.show()

The plot above shows that each time the model predicts approximately the average of the 3 options.

# Calculate Wasserstein distance

In [None]:
from scipy.stats import wasserstein_distance_nd

model.eval()
for batch in test_loader:
    batch = batch.clone().to(device)

    preds = np.zeros((len(batch), 2, 100))  # create 100 predictions per data point

    for i in range(100):  # make 100 predictions for each data point
        pred = model(batch[..., 0])
        preds[..., i] = pred.cpu().detach().numpy()*y_std+y_m


    reals = batch[..., 1:].cpu().detach().numpy()*y_std + y_m

    print(preds.shape)
    print(reals.shape)

    wd = []
    for pred, real in zip(preds, reals):  # iterate over all data points (as far as I know, this cannot be batched)
        wd.append(wasserstein_distance_nd(pred.T, real.T))

print('Mean Wasserstein distance between real and predicted:', np.mean(wd))