## Introduction

The goal of this lab is to implement a simple Conditional Flow Matching (CFM).
It uses the same toy data as the GAN example.

We start by showing that no deep learning is required (in low dimension), by implementing it using scikit-learn and more precisely a plain k nearest neighbors approach to regression.

Then, we write a pytorch implementation (using neural networks and stochastic gradient descent).


## Toy data generation

In [None]:
# styling the notebook solutions
from IPython.core.display import HTML
def css_styling():
    return HTML("""<style>.solution { background: black !important; overflow-y: hidden; &:not(:hover) {height: 1em !important; &::before { content: "solution"; color: red; } } &:hover { background: gray !important; transition: 1s linear 1s; } }</style>""")
css_styling()

In [None]:
# %%
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import sklearn
import scipy
from torch.utils.data import DataLoader
import torch.autograd as autograd

In [None]:
# Create a synthetic dataset
nb_samples = 1000
radius = 1
nz = .1
# generate the data
X_train = torch.zeros((nb_samples, 2))
r = radius + nz*torch.randn(nb_samples)
theta = torch.rand(nb_samples)*2*torch.pi
X_train[:, 0] = r*torch.cos(theta)
X_train[:, 1] = r*torch.sin(theta)

# clean up to avoid using variables below, by mistake
del radius, nz, r, theta

In [None]:
# Visualize the data
plt.figure(figsize=(3, 3))
plt.scatter(X_train[:, 0], X_train[:, 1], s=20, alpha=0.8, edgecolor='k', marker='o', label='original samples')
plt.grid(alpha=0.5)
plt.legend(loc='best')
plt.tight_layout()
plt.show()
# %%

## Conditional flow matching using kNN regression

We choose as conditioning variable a pair made of a source point (gaussian noise) and target point (the dataset of interest), namely $z = (x_0, x_1)$

In [None]:
#### Simple version with sklearn
nb_fit = 100000

x_train_numpy = X_train.numpy()
sampled_indices = np.random.randint(0, nb_samples, nb_fit)

# TODO IMPLEMENT FLOW MATCHING
# - get nb_fit 2D "source" points -> x0
# - get nb_fit 2D "target" points -> x1
# - get t time instants uniformly in [0,1]
# - compute x (the linear interpolation between x0 and x1 at time t)
x0 = ........ # N,2
x1 = ........ #  N,2
t = ........ # N,1  (or N)
x = ........ # N,2

assert x.shape == (nb_fit, 2)

# conditional velocity field
u = x1 - x0 # N,2

u_in = np.concatenate([x, t], axis=-1) # N,3
u_out = u # N,2
velocity_estimator = sklearn.neighbors.KNeighborsRegressor(n_neighbors=100) # question: try n_neighbors=1 , what do you get below? how can you explain it?
velocity_estimator.fit(u_in, u_out)

# clean up to avoid using variables below, by mistake
del x0, x1, t, x, u, u_in, u_out, x_train_numpy

<pre class="solution">
x0 = np.random.randn(nb_fit, 2) # N,2
x1 = x_train_numpy[sampled_indices] #  N,2
t = np.random.uniform(0, 1, (nb_fit, 1)) # N,1  (or N)
x = x0 * (1-t) + x1 * t # N,2
</pre>

In [None]:
# Visualize the flow

def view_field(v, at_t):
    l1x = np.linspace(-1.2, 1.2, 20)
    l1y = np.linspace(-1.2, 1.2, 20)
    x,y = np.meshgrid(l1x, l1y)
    x = x.reshape((-1, 1))
    y = y.reshape((-1, 1))
    u_in = np.concatenate([x, y, np.full_like(x, at_t)], axis=-1)
    u_out = v.predict(u_in)
    u_out = u_out.reshape((l1x.size, l1y.size, 2))
    fig = plt.figure(figsize=(5,5))
    ax = fig.subplots()
    ax.set_title(f"t={at_t}")
    q = ax.quiver(x, y, u_out[:,:,0], u_out[:,:,1], -(u_out**2).sum(axis=-1)**0.5)

In [None]:
view_field(velocity_estimator, .9)

In [None]:
for i in [0, .1, .3, .5, .7, .9]:
    view_field(velocity_estimator, i)


In [None]:
#### visualize the generation of samples
def view_generation(v, n=1000, dt=0.01, plot_at=[0, .4, .5, .8, .9, 1]):
    plot_at = plot_at[:] # copy the list
    x = np.random.randn(n, 2)
    t = 0.0
    while t < plot_at[-1]:
        x += dt * v.predict(np.concatenate([x, np.full((n,1), t)], axis=-1))
        t += dt
        if t > plot_at[0]:
            plot_at = plot_at[1:]
            plt.figure(figsize=(3, 3))
            plt.scatter(x[:, 0], x[:, 1], s=20, alpha=0.2)
            plt.grid(alpha=0.5)
            plt.tight_layout()
            plt.show()
            if len(plot_at) == 0:
                break

view_generation(velocity_estimator)




## Challenge: do the same with torch

(the main structure is given as a guide)

In [None]:
# Create a torch model, (x,y,t) -> (vx, vy)
# you can use an MLP for instance

In [None]:
class Velocity(nn.Module):
    def __init__(self):
        super(Velocity, self).__init__()
        # TODO BUILD AN MLP MODEL (e.g. 3->10->20->20−>2, with ELU activations
        self.model = None

    def forward(self, z):
        return torch.randn(z.shape[0], 2)         # TODO REMOVE ONCE THE MODEL IS DEFINED
        return self.model(z)


<pre class="solution">
self.model = nn.Sequential(
    nn.Linear(3, 10), nn.ELU(),
    nn.Linear(10, 20), nn.ELU(),
    nn.Linear(20, 20), nn.ELU(),
    nn.Linear(20, 2),
)  
</pre>

In [None]:
# Create a model and an optimizer (only one, not like GAN that have 2 networks), and data loader

In [None]:
# taking inspiration from GANs

In [None]:
velocity = Velocity()

n_epochs = 100
batch_size = 32  # size of the batch

##### TODO init
optimizer = None
dataloader = None


<pre class="solution">
optimizer = torch.optim.Adam(velocity.parameters(), lr=0.01)
dataloader = DataLoader(X_train, batch_size, shuffle=True)
</pre>

In [None]:
# Make your training loop
# At each step,
# - draw a minibatch of z, that is a minibatch of pairs (x0,x1), as before, but with a minibatch
# - draw a minibatch of t values between 0 and 1
# - compute x as the linear interpolation between x0 and x1 (at time t)
# - compute u = x1 - x0
# - compute a l2 loss on the prediction ||network([x, t]) - u||²

In [None]:
# taking inspiration from GANs and from above

In [None]:
for epoch in range(n_epochs):
    total_train_loss = 0
    for i, x in enumerate(dataloader):
        
        x1 = x.type(torch.float32)  # real data

        optimizer.zero_grad()

        # TODO
        # ...
        u_cond = None

        u_pred = None

        
        loss = torch.sum((u_pred - u_cond)**2)
        total_train_loss += loss.item()

        loss.backward()
        optimizer.step()

    if (epoch+1) % 10 == 0:
        print(epoch+1, total_train_loss / nb_samples)


<pre class="solution">
x0 = torch.randn_like(x1).type(torch.float32)
t = torch.rand(x1.shape[0], 1).type(torch.float32)
x = x0 * (1-t) + x1 * t
u_cond = x1 - x0
u_pred = velocity(torch.concatenate([x, t], axis=-1))
</pre>

In [None]:
# Question:
#
# given that the unit of the loss is in squared spatial unit per unit of time,
# i.e. 2 means, roughly, that over the integration from t=0 to t=1, the error is of roughly sqrt(2) = 1.4 space unit (bigger than our box)
#
# Is the loss good?
# Is this coherent with whether it works (below)? how can we explain it does?

In [None]:
# wrapping your model to reuse the above visualizations

In [None]:
class Torch2Sklearn():
    def __init__(self, model):
        self.model = model
    def predict(self, x):
        x = torch.Tensor(x)
        with torch.no_grad():
            o = self.model(x).numpy()
        return o

In [None]:
wrapped = Torch2Sklearn(velocity)

In [None]:
view_field(wrapped, 0.1)
view_field(wrapped, 0.9)

In [None]:
view_generation(wrapped)