In [22]:
import torch
from torchcfm.optimal_transport import OTPlanSampler
import time

import matplotlib.pyplot as plt
import numpy as np
import ot as pot
import torch
import torchdyn
from torchdyn.core import NeuralODE
from torchcfm.conditional_flow_matching import *
from torchcfm.models.models import *
from torchcfm.utils import *


In [23]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [24]:
x = []
distance = 0.4
sample_multiplier = 10
# 16
# for i in np.arange(0, 8, distance):
#     for j in np.arange(0, 2, distance):
#         x.append([j, i])
for i in range(16 * sample_multiplier):
    x.append([np.random.uniform(2, 6), np.random.uniform(2, 6)])

x = torch.tensor(x)
x = x.to(device)

In [25]:
def sample_conditional_pt(x0, x1, t, sigma):
    """
    Draw a sample from the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].

    Parameters
    ----------
    x0 : Tensor, shape (bs, *dim)
        represents the source minibatch
    x1 : Tensor, shape (bs, *dim)
        represents the target minibatch
    t : FloatTensor, shape (bs)

    Returns
    -------
    xt : Tensor, shape (bs, *dim)

    References
    ----------
    [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
    """
    t = t.reshape(-1, *([1] * (x0.dim() - 1)))
    mu_t = t * x1 + (1 - t) * x0
    epsilon = torch.randn_like(x0)
    return mu_t + sigma * epsilon

In [26]:
def compute_conditional_vector_field(x0, x1):
    """
    Compute the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1].

    Parameters
    ----------
    x0 : Tensor, shape (bs, *dim)
        represents the source minibatch
    x1 : Tensor, shape (bs, *dim)
        represents the target minibatch

    Returns
    -------
    ut : conditional vector field ut(x1|x0) = x1 - x0

    References
    ----------
    [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
    """
    return x1 - x0

In [27]:
def MMD(x, y):
        gamma = 2
        xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t())
        rx = (xx.diag().unsqueeze(0).expand_as(xx))
        ry = (yy.diag().unsqueeze(0).expand_as(yy))

        dxx = rx.t() + rx - 2. * xx # Used for A in (1)
        dyy = ry.t() + ry - 2. * yy # Used for B in (1)
        rxx = rx[0].repeat(y.shape[0], 1)
        ryy = ry[0].repeat(x.shape[0], 1) 
        dxy = rxx.t() + ryy - 2. * zz # Used for C in (1)

        XX, YY, XY = (torch.zeros(xx.shape).to(device),
                      torch.zeros(yy.shape).to(device),
                      torch.zeros(zz.shape).to(device))
        XX += 1/(1 + dxx/gamma**2)
        YY += 1/(1 + dyy/gamma**2)
        XY += 1/(1 + dxy/gamma**2)
        return XX.mean() + YY.mean() - 2*XY.mean()

In [28]:
f = open("results.txt", "w")
for i in [8, 16, 24, 32]:
    for j in [2, 4]:
        loss_arr = []
        f.write("OTFlow layers: " + str(j) + " width: " + str(i) + "\n")
        ot_sampler = OTPlanSampler(method="exact")
        sigma = 1e-4
        dim = 2
        batch_size = 256
        if j == 2:
            model = torch.nn.Sequential(
                torch.nn.Linear(dim + 1, i),
                torch.nn.ReLU(),
                torch.nn.Linear(i, dim),
            )
        if j == 4:
            model = torch.nn.Sequential(
                torch.nn.Linear(dim + 1, i),
                torch.nn.ReLU(),
                torch.nn.Linear(i, i),
                torch.nn.ReLU(),
                torch.nn.Linear(i, i),
                torch.nn.ReLU(),
                torch.nn.Linear(i, dim),
            )
        #_______________
        num_dims = 2
        node = NeuralODE(
                torch_wrapper(model), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4
            )
        with torch.no_grad():
            traj = node.trajectory(
                torch.tensor(np.random.normal([0] * num_dims, 1, size=(x.shape[0], num_dims))).float().to(device),
                t_span=torch.linspace(0, 1, 100),
            )
        f.write("MMD before training: " + str(MMD(traj[-1, :, :], x).item()) + "\n")
        #_______________
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
        FM = ConditionalFlowMatcher(sigma=sigma)

        dataloader = torch.utils.data.DataLoader(x, batch_size=int(x.shape[0]), shuffle=True)

        start = time.time()
        for k in range(2000):
            for x_batch in dataloader:
                optimizer.zero_grad()

                x1 = x_batch.to(device)
                x0 = torch.tensor(np.random.normal([0] * num_dims, 1, size=(len(x1), num_dims))).to(device)


                x0, x1 = ot_sampler.sample_plan(x0.float(), x1.float())

                t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)
                t = t.float().to(device)
                xt = xt.float().to(device)
                ut = ut.float().to(device)
                vt = model(torch.cat([xt, t[:, None]], dim=-1).float())
                loss = torch.mean((vt - ut) ** 2)

                loss.backward()
                optimizer.step()
                loss_arr.append(loss.item())

            if (k + 1) % 1000 == 0:
                end = time.time()
                print(f"{k+1}: loss {loss.item():0.3f} time {(end - start):0.2f}")
                start = end

        np.savetxt("l"+str(j) + "w" + str(i) + ".txt", loss_arr, fmt='%f')
        node = NeuralODE(
            torch_wrapper(model), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4
        )
        with torch.no_grad():
            traj = node.trajectory(
                torch.tensor(np.random.normal([0] * num_dims, 1, size=(x.shape[0], num_dims))).float().to(device),
                t_span=torch.linspace(0, 1, 100),
            )
        f.write("MMD after training: " + str(MMD(traj[-1, :, :], x).item()) + "\n")
        f.write("\n")
        node = NeuralODE(
            torch_wrapper(model), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4
        )
        with torch.no_grad():
            traj = node.trajectory(
                torch.tensor(np.random.normal([0] * num_dims, 1, size=(10000, num_dims))).float().to(device),
                t_span=torch.linspace(0, 1, 100),
            )
        traj = traj.cpu().numpy()

        col_red = '#c61826'
        col_dark_red = '#590d08'
        col_blue = '#01024d'
        n = 1000
        plt.figure()
        plt.scatter(traj[:, :n, 0], traj[:, :n, 1], s=0.1, alpha=0.1, c=col_dark_red)
        plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=5, alpha=0.8, c=col_blue)
        plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=5, alpha=1, c=col_red)
        plt.xlabel("X")
        plt.ylabel("Y")
        plt.ylim(-5, 10)
        plt.xlim(-5, 10)
        plt.savefig("cfm_plot_Square_gen_l" + str(j) + "w" + str(i) + ".png")
        plt.close()

        plt.figure()
        plt.scatter(x[:, 0], x[:, 1], alpha = 1, s=25, label="Original Data", c = col_blue)
        plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=25, alpha=1, label="Generated Data", c=col_red)
        import matplotlib.patches as patches
        square = patches.Rectangle((2, 2), 4, 4, linewidth=1, edgecolor=col_dark_red, facecolor='none')

        plt.gca().add_patch(square)
        plt.ylim(0, 8)
        plt.xlim(0, 8)
        plt.legend()
        plt.savefig("cfm_plot_Square_gen_smp_l" + str(j) + "w" + str(i) + ".png")
        plt.close()

        # plt.legend()
        # plt.figure()
        # plt.semilogy(loss_arr)
        # plt.xlabel("Epochs")
        # plt.ylabel("Log Log Likelihood Loss")
        # plt.ylim(1e-8, 450)
        # plt.savefig("semilog_loss_plot_l" + str(j) + "w" + str(i) + ".png")
        plt.close()
        plt.figure()
        plt.plot(loss_arr)
        plt.xlabel("Epochs")
        plt.ylabel("Flow Matching Loss")
        plt.ylim(0, 10)
        plt.savefig("cfm_Square_loss_plot_l" + str(j) + "w" + str(i) + ".png")
        plt.close()

f.close()

1000: loss 0.085 time 7.69
2000: loss 0.099 time 7.03
1000: loss 0.087 time 6.12
2000: loss 0.072 time 6.15
1000: loss 0.120 time 6.18
2000: loss 0.101 time 5.72
1000: loss 0.090 time 6.75
2000: loss 0.067 time 9.09
1000: loss 0.090 time 8.61
2000: loss 0.073 time 6.94
1000: loss 0.077 time 14.79
2000: loss 0.066 time 9.32
1000: loss 0.078 time 6.78
2000: loss 0.067 time 5.03
1000: loss 0.091 time 7.29
2000: loss 0.075 time 7.66
