In [15]:
import torch
from torchcfm.optimal_transport import OTPlanSampler
import time

import matplotlib.pyplot as plt
import numpy as np
import ot as pot
import torch
import torchdyn
from torchdyn.core import NeuralODE
from torchcfm.conditional_flow_matching import *
from torchcfm.models.models import *
from torchcfm.utils import *


In [16]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [17]:
x = []
distance = 0.4
sample_multiplier = 10
# 16
# for i in np.arange(0, 8, distance):
#     for j in np.arange(0, 2, distance):
#         x.append([j, i])
for i in range(16 * sample_multiplier):
    x.append([np.random.uniform(0, 2), np.random.uniform(0, 8)])

# 4
# for i in np.arange(0, 2, distance):
#     for j in np.arange(2, 6, distance):
#         x.append([j, i])
for i in range(4 * sample_multiplier):
    x.append([np.random.uniform(2, 4), np.random.uniform(0, 2)])


# 16
# for i in np.arange(0, 8, distance):
#     for j in np.arange(10, 12, distance):
#         x.append([j, i])
for i in range(16 * sample_multiplier):
    x.append([np.random.uniform(10, 12), np.random.uniform(0, 8)])

# 8
# for i in np.arange(6, 8, distance):
#     for j in np.arange(12, 16, distance):
#         x.append([j, i])
for i in range(8 * sample_multiplier):
    x.append([np.random.uniform(12, 16), np.random.uniform(6, 8)])

# 8
# for i in np.arange(3, 5, distance):
#     for j in np.arange(12, 16, distance):
#         x.append([j, i])
for i in range(8 * sample_multiplier):
    x.append([np.random.uniform(12, 16), np.random.uniform(3, 5)])

# 8
# for i in np.arange(0, 2, distance):
#     for j in np.arange(12, 16, distance):
#         x.append([j, i])
for i in range(8 * sample_multiplier):
    x.append([np.random.uniform(12, 16), np.random.uniform(0, 2)])

x = torch.tensor(x)
x = x.to(device)

In [18]:
def sample_conditional_pt(x0, x1, t, sigma):
    """
    Draw a sample from the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].

    Parameters
    ----------
    x0 : Tensor, shape (bs, *dim)
        represents the source minibatch
    x1 : Tensor, shape (bs, *dim)
        represents the target minibatch
    t : FloatTensor, shape (bs)

    Returns
    -------
    xt : Tensor, shape (bs, *dim)

    References
    ----------
    [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
    """
    t = t.reshape(-1, *([1] * (x0.dim() - 1)))
    mu_t = t * x1 + (1 - t) * x0
    epsilon = torch.randn_like(x0)
    return mu_t + sigma * epsilon

In [19]:
def compute_conditional_vector_field(x0, x1):
    """
    Compute the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1].

    Parameters
    ----------
    x0 : Tensor, shape (bs, *dim)
        represents the source minibatch
    x1 : Tensor, shape (bs, *dim)
        represents the target minibatch

    Returns
    -------
    ut : conditional vector field ut(x1|x0) = x1 - x0

    References
    ----------
    [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
    """
    return x1 - x0

In [20]:
def MMD(x, y):
        gamma = 2
        xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t())
        rx = (xx.diag().unsqueeze(0).expand_as(xx))
        ry = (yy.diag().unsqueeze(0).expand_as(yy))

        dxx = rx.t() + rx - 2. * xx # Used for A in (1)
        dyy = ry.t() + ry - 2. * yy # Used for B in (1)
        rxx = rx[0].repeat(y.shape[0], 1)
        ryy = ry[0].repeat(x.shape[0], 1) 
        dxy = rxx.t() + ryy - 2. * zz # Used for C in (1)

        XX, YY, XY = (torch.zeros(xx.shape).to(device),
                      torch.zeros(yy.shape).to(device),
                      torch.zeros(zz.shape).to(device))
        XX += 1/(1 + dxx/gamma**2)
        YY += 1/(1 + dyy/gamma**2)
        XY += 1/(1 + dxy/gamma**2)
        return XX.mean() + YY.mean() - 2*XY.mean()

In [21]:
f = open("results.txt", "w")
for i in [128]:
    for j in [8]:
        loss_arr = []
        f.write("OTFlow layers: " + str(j) + " width: " + str(i) + "\n")
        ot_sampler = OTPlanSampler(method="exact")
        sigma = 1e-8
        dim = 2
        batch_size = 256
        if j == 2:
            model = torch.nn.Sequential(
                torch.nn.Linear(dim + 1, i),
                torch.nn.ReLU(),
                torch.nn.Linear(i, dim),
            )
        if j == 4:
            model = torch.nn.Sequential(
                torch.nn.Linear(dim + 1, i),
                torch.nn.ReLU(),
                torch.nn.Linear(i, i),
                torch.nn.ReLU(),
                torch.nn.Linear(i, i),
                torch.nn.ReLU(),
                torch.nn.Linear(i, dim),
            )
        if j == 8:
            model = torch.nn.Sequential(
                torch.nn.Linear(dim + 1, i),
                torch.nn.ReLU(),
                torch.nn.Linear(i, i),
                torch.nn.ReLU(),
                torch.nn.Linear(i, i),
                torch.nn.ReLU(),
                torch.nn.Linear(i, i),
                torch.nn.ReLU(),
                torch.nn.Linear(i, i),
                torch.nn.ReLU(),
                torch.nn.Linear(i, i),
                torch.nn.ReLU(),
                torch.nn.Linear(i, i),
                torch.nn.ReLU(),
                torch.nn.Linear(i, dim),
            )
        #_______________
        num_dims = 2
        node = NeuralODE(
                torch_wrapper(model), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4
            )
        with torch.no_grad():
            traj = node.trajectory(
                torch.tensor(np.random.normal([0] * num_dims, 1, size=(x.shape[0], num_dims))).float().to(device),
                t_span=torch.linspace(0, 1, 100),
            )
        f.write("MMD before training: " + str(MMD(traj[-1, :, :], x).item()) + "\n")
        #_______________
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
        FM = ConditionalFlowMatcher(sigma=sigma)

        dataloader = torch.utils.data.DataLoader(x, batch_size=int(x.shape[0]), shuffle=True)

        start = time.time()
        for k in range(5000):
            for x_batch in dataloader:
                optimizer.zero_grad()

                x1 = x_batch.to(device)
                x0 = torch.tensor(np.random.normal([0] * num_dims, 1, size=(len(x1), num_dims))).to(device)


                t = torch.rand(x0.shape[0]).type_as(x0)
                xt = sample_conditional_pt(x0, x1, t, sigma=sigma).float()
                ut = compute_conditional_vector_field(x0, x1)

                vt = model(torch.cat([xt, t[:, None]], dim=-1).float())
                loss = torch.mean((vt - ut) ** 2)

                loss.backward()
                optimizer.step()
                loss_arr.append(loss.item())

            if (k + 1) % 1000 == 0:
                end = time.time()
                print(f"{k+1}: loss {loss.item():0.3f} time {(end - start):0.2f}")
                start = end

        np.savetxt("l"+str(j) + "w" + str(i) + ".txt", loss_arr, fmt='%f')
        node = NeuralODE(
            torch_wrapper(model), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4
        )
        with torch.no_grad():
            traj = node.trajectory(
                torch.tensor(np.random.normal([0] * num_dims, 1, size=(x.shape[0], num_dims))).float().to(device),
                t_span=torch.linspace(0, 1, 100),
            )
        f.write("MMD after training: " + str(MMD(traj[-1, :, :], x).item()) + "\n")
        f.write("\n")
        node = NeuralODE(
            torch_wrapper(model), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4
        )
        with torch.no_grad():
            traj = node.trajectory(
                torch.tensor(np.random.normal([0] * num_dims, 1, size=(10000, num_dims))).float().to(device),
                t_span=torch.linspace(0, 1, 100),
            )
        traj = traj.cpu().numpy()
        n = 1000
        col_red = '#c61826'
        col_dark_red = '#590d08'
        col_blue = '#01024d'
        plt.figure()
        plt.scatter(traj[:, :n, 0], traj[:, :n, 1], s=0.1, alpha=0.1, c=col_dark_red)
        plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=5, alpha=0.8, c=col_blue)
        plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=5, alpha=1, c=col_red)
        plt.xlabel("X")
        plt.ylabel("Y")
        plt.ylim(-5, 10)
        plt.xlim(-5, 18)
        plt.savefig("ICfm_plot_LE_gen_l" + str(j) + "w" + str(i) + ".png")
        plt.close()

        plt.figure()
        plt.scatter(x[:, 0], x[:, 1], alpha = 1, s=25, label="Original Data", c = col_blue)
        plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=25, alpha=1, label="Generated Data", c=col_red)
        plt.xlabel("X")
        plt.ylabel("Y")
        plt.ylim(-5, 10)
        plt.xlim(-5, 18)
        plt.legend()
        plt.savefig("ICfm_plot_LE_gen_smp_l" + str(j) + "w" + str(i) + ".png")
        plt.close()

        # plt.legend()
        # plt.figure()
        # plt.semilogy(loss_arr)
        # plt.xlabel("Epochs")
        # plt.ylabel("Log Log Likelihood Loss")
        # plt.ylim(1e-8, 450)
        # plt.savefig("semilog_loss_plot_l" + str(j) + "w" + str(i) + ".png")
        plt.close()
        plt.figure()
        plt.plot(loss_arr)
        plt.xlabel("Epochs")
        plt.ylabel("Flow Matching Loss")
        plt.ylim(0, 20)
        plt.savefig("ICfm_LE_loss_plot_l" + str(j) + "w" + str(i) + ".png")
        plt.close()

f.close()

1000: loss 5.726 time 15.49
2000: loss 5.195 time 22.07
3000: loss 5.817 time 16.51
4000: loss 6.632 time 16.67
5000: loss 5.395 time 17.04
