Adapted from flow_matching/examples/standalone_flow_matching.ipynb from the Facebook Flow Matching github repository (https://github.com/facebookresearch/flow_matching/blob/main/examples/standalone_flow_matching.ipynb).

Attempts to turn a standard Gaussian distribution into a distribution with two sharp peaks at -X and +X, where X is the amount you bet.



In [None]:
import torch
from torch import nn, Tensor
import pickle
import numpy as np
import matplotlib.pyplot as plt


In [None]:
# open training data
with open('../data/HeadsOrTails_data.pkl', 'rb') as f:
    data = pickle.load(f)

data_tr = data['data_tr']
data_te = data['data_te']
x_std = data['x_std']
x_m = data['x_m']
y_std = data['y_std']
y_m = data['y_m']

In [None]:
class Flow(nn.Module):
    def __init__(self, dim: int = 1, cond: int = 1, h: int = 64):
        # dim: dimension of data sample
        # cond: dimension of the thing we're conditioning on (the bet in this case)
        # h: hidden layer size
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(cond + dim + 1, h), nn.ELU(),
            nn.Linear(h, h), nn.ELU(),
            nn.Linear(h, h), nn.ELU(),
            nn.Linear(h, dim))

    def forward(self, t: Tensor, x_t: Tensor, cond: Tensor) -> Tensor:
        return self.net(torch.cat((t, x_t, cond), -1))

    def step(self, x_t: Tensor, t_start: Tensor, t_end: Tensor, cond: Tensor) -> Tensor:
        t_start = t_start.view(1, 1).expand(x_t.shape[0], 1)

        return (x_t + (t_end - t_start)
                * self(
                    t=t_start + (t_end - t_start) / 2,
                    x_t= x_t + self(x_t=x_t, t=t_start, cond=cond) * (t_end - t_start) / 2,
                    cond=cond
                       )
                )

# Choose device

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    print('cuda available')
    # mlflow.log_param('device', torch.cuda.get_device_name(device))
else:
    device = torch.device('cpu')
    print('cuda not available')
    # mlflow.log_param('device', 'cpu')

# Training

In [None]:
seed = 0
torch.manual_seed(seed)

In [None]:
train_loader = torch.utils.data.DataLoader(data_tr, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(data_te, batch_size=100000, shuffle=True)

In [None]:
flow = Flow().to(device)

optimizer = torch.optim.Adam(flow.parameters(), 1e-3)
loss_fn = nn.MSELoss()

MSE_loss = []
for i in range(10000):
    if i%100 == 0:
        print(i)
    loss_temp = []
    for i, batch in enumerate(train_loader):
        batch = batch.to(device)
        x_1 = batch[:, [1]]
        cond = batch[:, [0]]
        x_0 = torch.randn_like(x_1).to(device)
        t = torch.rand(len(x_1), 1).to(device)

        x_t = (1 - t) * x_0 + t * x_1
        dx_t = x_1 - x_0

        optimizer.zero_grad()
        loss = loss_fn(flow(t=t, x_t=x_t, cond=cond), dx_t)
        loss.backward()
        optimizer.step()

        loss_temp.append(loss.item())

    if i == 1000:
        optimizer = torch.optim.Adam(flow.parameters(), 1e-5)
    MSE_loss.append(np.mean(loss_temp))

# Plot losses

In [None]:
plt.figure()

plt.plot(MSE_loss)
# plt.legend()
plt.yscale('log')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.show()


In [None]:
# plot smoothened loss
plt.figure()
plt.plot(np.convolve(MSE_loss, np.ones(100)/100, mode='valid'))
plt.yscale('log')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.show()

# Sampling

In [None]:
%matplotlib qt
n_steps = 128

fig, axes = plt.subplots(1, n_steps//2 + 1, figsize=(20, 4), sharex=True, sharey=True)

time_steps = torch.linspace(0, 1.0, n_steps + 1).to(device)

for ax in axes:
    ax.set_aspect('equal')
    # ax.axline((0,0), (1,1), c='tab:red')
    # ax.axline((0,0), (-1,1), c='tab:red')

for j in range(10):  # 10 predictions per data point
    for batch in test_loader:
        batch = batch.to(device)
        x = torch.randn_like(batch[:, [1]]).to(device)  # initial gaussian noise
        cond = batch[:, [0]]  # conditioning on the bet
        bet = cond.cpu()*x_std+x_m
        pred = x.cpu().detach()[:, 0]*y_std+y_m
        real = batch[:, 1].cpu()*y_std+y_m
        axes[0].scatter(bet, pred, s=1, label='_'*j+'predicted', c='tab:orange')
        axes[0].scatter(bet, real, s=1, label='_'*j+'real', c='tab:blue')
        axes[0].set_title(f't = {time_steps[0]:.2f}')
        axes[0].legend()
        axes[0].set_xlim(-100.0, 100.0)
        axes[0].set_ylim(-100.0, 100.0)

        for i in range(n_steps):
            x = flow.step(x_t=x, t_start=time_steps[i], t_end=time_steps[i + 1], cond=cond)

            if i % 2 == 1:
                bet = cond.cpu()*x_std+x_m
                pred = x.cpu().detach()[:, 0]*y_std+y_m
                real = batch[:, 1].cpu()*y_std+y_m
                axes[i//2 + 1].scatter(bet, real, s=1, label='real', c='tab:blue')
                axes[i//2 + 1].scatter(bet, pred, s=1, label='predicted', c='tab:orange')
                                    # , s=1, c='tab:blue', alpha=0.5)
                axes[i//2 + 1].set_title(f't = {time_steps[i + 1]:.2f}')

plt.tight_layout()
plt.show()

In [None]:
# result at t=1 only
fig, ax = plt.subplots(figsize=(4,3), dpi=200)
for batch in test_loader:
    batch = batch.to(device)
    cond = batch[:, [0]]  # conditioning on the bet
    bet = cond.cpu()*x_std+x_m
    real = batch[:, 1].cpu()*y_std+y_m
    for j in range(10):  # 10 predictions per data point
        x = torch.randn_like(batch[:, [1]]).to(device)  # initial gaussian noise
        for i in range(n_steps):
            x = flow.step(x_t=x, t_start=time_steps[i], t_end=time_steps[i + 1], cond=cond)

        pred = x.cpu().detach()[:, 0]*y_std+y_m
        ax.scatter(bet, pred, s=1, label='_'*j+'predicted', c='tab:orange')
        ax.scatter(bet, real, s=1, label='_'*j+'true', c='tab:blue')
ax.legend()
ax.set_ylabel('winnings (€)')
ax.set_xlabel('bet (€)')
ax.set_aspect('equal')
fig.subplots_adjust(left=0.2, bottom=0.2)

# Histogram for bet = €50

In [None]:
# Evolution of the histograms
n_steps = 32
time_steps = torch.linspace(0, 1.0, n_steps + 1).to(device)
print(len(time_steps))
inds_to_plot = np.linspace(0, n_steps, 5).astype(int)
print(inds_to_plot)

fig, axes = plt.subplots(1, len(inds_to_plot), figsize=(15, 3), sharex=True, sharey=True)

N = 10000
x = torch.randn(N, 1).to(device)  # initial gaussian noise
cond = (torch.ones((N, 1), dtype=torch.float32, device=device)*50-data['x_m'])/data['x_std']  # conditioning on the bet

bins = np.linspace(-150, 150, 50)
axes[0].hist(x.cpu().detach()[:, 0]*data['y_std']+data['y_m'],
             bins=bins, density=True)
axes[0].set_title(f't = {time_steps[0]:.2f}')
axes[0].set_ylabel('probability density')
axes[0].set_xlabel('predicted winnings (€)')

fig.suptitle(f'Samples from the flow, bet = 50€, {n_steps} steps')

for ax in axes:
    # ax.set_yscale('log')
    ax.axvline(50, color='r', linestyle='dashed')
    ax.axvline(-50, color='r', linestyle='dashed')
    # ax.vlines([50, -50], 0, 100, colors='r', linestyles='dashed', label='possible ground truth winnings')

to_plot = 1
for i in range(1,n_steps+1):
    print(i, to_plot)
    x = flow.step(x_t=x, t_start=time_steps[i-1], t_end=time_steps[i], cond=cond)

    if i == inds_to_plot[to_plot]:
        axes[to_plot].hist(
                            x.cpu().detach()[:, 0]*data['y_std']+data['y_m'],
                            bins=bins, density=True,)
        axes[to_plot].set_title(f't = {time_steps[i]:.2f}')
        axes[to_plot].set_xlabel('predicted winnings (€)')
        to_plot += 1

fig.subplots_adjust(top=0.8, bottom=0.2)

In [None]:
%matplotlib inline
# Result only
plt.rcParams.update({'font.size': 16})
fig, ax = plt.subplots(1, figsize=(5, 5))

bins = np.linspace(-75, 75, 40)
preds = (x.cpu().detach()[:, 0]*data['y_std']+data['y_m']).numpy()
ax.hist(preds, bins=bins, density=True)
# ax.set_title(f'Final distribution for bet of 50€')
ax.set_xlabel('predicted winnings (€)')
ax.set_ylabel('Density')
ax.axvline(50, color='r', linestyle='dashed')
ax.axvline(-50, color='r', linestyle='dashed')
# fig.suptitle(f'Samples from the flow, bet = 50€, {n_steps} steps')
fig.subplots_adjust(top=0.8, left=0.15)
plt.show()

In [None]:
np.sum(preds < 0)

In [None]:
np.sum(preds > 0)

# Calculate Wasserstein distance

In [None]:
from scipy.stats import wasserstein_distance

n_steps = 128
time_steps = torch.linspace(0, 1.0, n_steps + 1).to(device)

flow.eval()
for batch in test_loader:
    batch = batch.clone().to(device)

    preds = np.zeros((len(batch), 100))  # create 100 predictions per data point

    cond = batch[:, [0]]  # conditioning on the bet
    for i in range(100):  # make 100 predictions for each data point
        x = torch.randn_like(batch[:, [1]]).to(device).clone()  # initial gaussian noise
        for j in range(n_steps):
            x = flow.step(x_t=x, t_start=time_steps[j], t_end=time_steps[j + 1], cond=cond)
        preds[:, i] = x.clone().cpu().detach().numpy()[:, 0]*y_std+y_m

    reals = batch[:, 1:].cpu().detach().numpy()*y_std + y_m

    print(preds.shape)
    print(reals.shape)

    wd = []
    for pred, real in zip(preds, reals):  # iterate over all data points (as far as I know, this cannot be batched)
        wd.append(wasserstein_distance(pred, real))

print('Mean Wasserstein distance between real and predicted:', np.mean(wd))