In [None]:
import pickle
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch import nn, Tensor

# Open Dataset

In [None]:
with open('../data/ThreeRoads_data.pkl', 'rb') as f:
    data = pickle.load(f)

# print info about all entries
for key, value in data.items():
    if isinstance(value, torch.Tensor):
        print(key, value.size())
    else:
        print(key, value)

data_tr = data['data_tr']
data_te = data['data_te']
x_m = data['x_m']
y_m = data['y_m']
x_std = data['x_std']
y_std = data['y_std']

In [None]:
train_loader = torch.utils.data.DataLoader(data_tr, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(data_te, batch_size=10000, shuffle=False)

# Choose device

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    print('cuda available')
    # mlflow.log_param('device', torch.cuda.get_device_name(device))
else:
    device = torch.device('cpu')
    print('cuda not available')
    # mlflow.log_param('device', 'cpu')

# Define Model 

In [None]:
class Flow(nn.Module):
    def __init__(self, dim: int = 1, cond: int = 1, h: int = 16):
        # dim: dimension of data sample
        # cond: dimension of the thing we're conditioning on (the bet in this case)
        # h: hidden layer size
        super().__init__()
        super(Flow, self).__init__()
        self.mininet1 = torch.nn.Sequential(
            torch.nn.Linear(3, h),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(h, h),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(h, h),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(h, h)
        )

        self.mininet2 = torch.nn.Sequential(
            torch.nn.Linear(5, h),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(h, h),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(h, h),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(h, h)
        )

        self.mininet3 = torch.nn.Sequential(
            torch.nn.Linear(2*h, h),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(h, h),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(h, h),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(h, 1)
        )

    def forward(self, t: Tensor, x_t: Tensor, cond: Tensor) -> Tensor:
        # cond = initial position: shape (batch_size, 2)
        # x_t = current position: shape (batch_size, 2)
        # t = current time: shape (batch_size, 1)

        x1 = torch.cat((t, x_t[:, 0:1], cond[:, 0:1]), 1)  # shape [batch size, 3]
        x2 = torch.cat((t, x_t[:, 1:2], cond[:, 1:2]), 1)

        # use mini-net 1 on input 1 and input 2 separately
        x1 = self.mininet1(x1)  # shape [batch size, h]
        x2 = self.mininet1(x2)  # shape [batch size, h]

        # print(t.shape, x_t.shape, cond.shape)

        # use mini-net 2 on the the pair 1,2 and the pair 2,1 of inputs
        x12 = torch.cat((t, x_t, cond), 1)
        x21 = torch.cat((t, x_t[:, [1, 0]], cond[:, [1, 0]]), 1)
        x12 = self.mininet2(x12)  # shape [batch size, h]
        x21 = self.mininet2(x21)  # shape [batch size, h]

        # make prediction for each node
        x1 = self.mininet3(torch.cat([x1, x21], dim=1))
        x2 = self.mininet3(torch.cat([x2, x12], dim=1))

        # concatenate the two results
        x = torch.cat([x1, x2], dim=1)

        return x  # shape [batch size, 2]

    def step(self, x_t: Tensor, t_start: Tensor, t_end: Tensor, cond: Tensor) -> Tensor:
        # t: float, current time
        # t_end: float, end time
        # x_t: shape [batch size, 2], current position
        # cond: shape [batch size, 2], initial position

        t_start = t_start.view(1, 1).expand(x_t.shape[0], 1)

        # print('x_t.shape', x_t.shape)
        # print('t_start.shape', t_start.shape)
        # print('t_end.shape', t_end.shape)
        # print('cond.shape', cond.shape)


        return (x_t + (t_end - t_start)
                * self(
                    t=t_start + (t_end - t_start) / 2,
                    x_t= x_t + self(x_t=x_t, t=t_start, cond=cond) * (t_end - t_start) / 2,
                    cond=cond
                       )
                )

In [None]:
flow = Flow().to(device)
print(flow)

n_params = sum(p.numel() for p in flow.parameters())
print('Total nr of parameters:', n_params)

In [None]:

optimizer = torch.optim.Adam(flow.parameters(), 1e-3)
loss_fn = nn.MSELoss()

MSE_loss = []
for i in range(1000):
    if i%100 == 0:
        print(i)
    loss_temp = []
    for i, batch in enumerate(train_loader):
        batch = batch.to(device)
        x_1 = batch[..., 1]  # final position (target)
        cond = batch[..., 0]  # initial position (condition)
        x_0 = cond + 0.5*torch.randn_like(x_1).to(device)  # initial Gaussian noise
        t = torch.rand(len(x_1), 1).to(device)

        x_t = (1 - t) * x_0 + t * x_1
        dx_t = x_1 - x_0

        optimizer.zero_grad()
        loss = loss_fn(flow(t=t, x_t=x_t, cond=cond), dx_t)
        loss.backward()
        optimizer.step()

        loss_temp.append(loss.item())

    if i == 200:
        optimizer = torch.optim.Adam(flow.parameters(), 1e-5)
    MSE_loss.append(np.mean(loss_temp))

In [None]:
plt.plot(MSE_loss)
plt.yscale('log')
plt.legend()

# Sampling

In [None]:
# %matplotlib qt
n_steps = 8

fig, axes = plt.subplots(1, n_steps//2 + 1, figsize=(20, 4), sharex=True, sharey=True)

time_steps = torch.linspace(0, 1.0, n_steps + 1).to(device)

for ax in axes:
    ax.set_aspect('equal')
    # ax.axline((0,0), (1,1), c='tab:red')
    # ax.axline((0,0), (-1,1), c='tab:red')

for j in range(10):  # 10 predictions per data point
    for batch in test_loader:
        batch = batch.to(device)
        cond = batch[..., 0]  # conditioning on the bet
        x = cond + 0.5*torch.randn_like(batch[..., 1]).to(device)  # initial gaussian noise
        bet = cond.cpu()*x_std+x_m
        pred = x.cpu().detach()*y_std+y_m
        real = batch[..., 1].cpu()*y_std+y_m
        axes[0].scatter(bet, pred, s=1, label='_'*j+'predicted', c='tab:orange')
        axes[0].scatter(bet, real, s=1, label='_'*j+'real', c='tab:blue')
        axes[0].set_title(f't = {time_steps[0]:.2f}')
        axes[0].legend()
        axes[0].set_xlim(-100.0, 100.0)
        axes[0].set_ylim(-100.0, 100.0)

        for i in range(n_steps):
            x = flow.step(x_t=x, t_start=time_steps[i], t_end=time_steps[i + 1], cond=cond)

            if i % 2 == 1:
                bet = cond.cpu()*x_std+x_m
                pred = x.cpu().detach()*y_std+y_m
                real = batch[..., 1].cpu()*y_std+y_m
                axes[i//2 + 1].scatter(bet, real, s=1, label='real', c='tab:blue')
                axes[i//2 + 1].scatter(bet, pred, s=1, label='predicted', c='tab:orange')
                                    # , s=1, c='tab:blue', alpha=0.5)
                axes[i//2 + 1].set_title(f't = {time_steps[i + 1]:.2f}')

plt.tight_layout()
plt.show()

In [None]:
n_steps = 64
time_steps = torch.linspace(0, 1.0, n_steps + 1).to(device)

flow.eval()
preds_all = []
for batch in test_loader:
    batch = batch.to(device)

    for _ in range(50):

        x = batch[..., 0].to(device)
        x = cond + 0.5*torch.randn_like(batch[..., 1]).to(device)  # initial gaussian noise
        cond = batch[..., 0]  # conditioning on the bet
        bet = cond.cpu()*x_std+x_m
        pred = x.cpu().detach()*y_std+y_m
        real = batch[..., 1].cpu()*y_std+y_m

        for i in range(n_steps):
            x = flow.step(x_t=x, t_start=time_steps[i], t_end=time_steps[i + 1], cond=cond)

        y_pred = x.cpu().detach().numpy()*y_std + y_m
        preds_all.extend(y_pred[:10])

In [None]:
# Plot the bifurcations in the first 10 cases
# plt.rcParams.update({'font.size': 16})
fig, ax = plt.subplots(figsize=(3,5), dpi=300)
fig.patch.set_facecolor("None")

for batch in test_loader:
    x = batch[..., 0].numpy()*x_std + x_m
    y1 = batch[..., 1].numpy()*y_std + y_m
    y2 = batch[..., 2].numpy()*y_std + y_m
    y3 = batch[..., 3].numpy()*y_std + y_m

    for i in range(10): # iterate over 10 first data points
        ax.scatter(*x[i], c='tab:blue', s=50, label='_'*i + 'Initial')
        ax.scatter(*y1[i], c='tab:orange', s=10, label='_'*i + 'Option 1')
        ax.scatter(*y2[i], c='tab:green', s=10, label='_'*i + 'Option 2')
        ax.scatter(*y3[i], c='tab:red', s=10, label='_'*i + 'Option 3')

        ax.annotate('', xy=y1[i], xytext=x[i],
                    arrowprops=dict(arrowstyle='->', facecolor='black'),
                    )
        ax.annotate('', xy=y2[i], xytext=x[i],
                    arrowprops=dict(arrowstyle='->', facecolor='black'),
                    )
        ax.annotate('', xy=y3[i], xytext=x[i],
                    arrowprops=dict(arrowstyle='->', facecolor='black'),
                    )

preds_all = np.array(preds_all)
ax.scatter(preds_all[:, 0], preds_all[:, 1], c='magenta', s=50, marker='x', label='Prediction', alpha=0.3)

handles, labels = ax.get_legend_handles_labels()
plt.legend(handles[:5], labels[:5])
ax.set_xlabel('Person 1')
ax.set_ylabel('Person 2')
ax.set_aspect('equal')

plt.show()

In [None]:
y_pred.shape

# Calculate Wasserstein distance

In [None]:
from scipy.stats import wasserstein_distance_nd

flow.eval()
for batch in test_loader:
    batch = batch.clone().to(device)

    preds = np.zeros((len(batch), 2, 100))  # create 100 predictions per data point

    for i in range(100):  # make 100 predictions for each data point

        x = batch[..., 0].to(device)
        x = cond + 0.5*torch.randn_like(batch[..., 1]).to(device)  # initial gaussian noise
        cond = batch[..., 0]  # conditioning on the bet
        bet = cond.cpu()*x_std+x_m
        pred = x.cpu().detach()*y_std+y_m
        real = batch[..., 1].cpu()*y_std+y_m

        for j in range(n_steps):
            x = flow.step(x_t=x, t_start=time_steps[j], t_end=time_steps[j + 1], cond=cond)

        preds[..., i] = x.cpu().detach().numpy()*y_std+y_m


    reals = batch[..., 1:].cpu().detach().numpy()*y_std + y_m

    print(preds.shape)
    print(reals.shape)

    wd = []
    for pred, real in zip(preds, reals):  # iterate over all data points (as far as I know, this cannot be batched)
        wd.append(wasserstein_distance_nd(pred.T, real.T))

print('Mean Wasserstein distance between real and predicted:', np.mean(wd))