In [None]:
from BearRobot.Net.basic_net.mlp import MLP
import torch
from flow_matching.path.scheduler import CondOTScheduler
from flow_matching.path import AffineProbPath
from flow_matching.solver import Solver, ODESolver
from flow_matching.utils import ModelWrapper

import matplotlib.pyplot as plt

steps = int(0.5e+4)

In [None]:
def inf_train_gen(batch_size: int = 200, device: str = "cpu"):
    x1 = torch.rand(batch_size, device=device) * 4 - 2
    x2_ = torch.rand(batch_size, device=device) - torch.randint(high=2, size=(batch_size, ), device=device) * 2
    x2 = x2_ + (torch.floor(x1) % 2)

    data = torch.cat([x1[:, None], x2[:, None]], dim=1) / 0.45
    
    return data.float()

data = inf_train_gen(batch_size=1000)
print(data.shape)

def inf_train_gen_2(batch_size: int = 200, device: str = "cpu"):
    x1 = torch.rand(batch_size, device=device) * 4 - 2
    x2_ = torch.rand(batch_size, device=device) - torch.randint(high=2, size=(batch_size, ), device=device) * 2
    x2 = x2_ + (torch.floor(x1) % 2)

    data =  torch.cat([x1[:, None] * 0.5, x2[:, None]], dim=1) / 0.45
    
    return data.float()

data = inf_train_gen_2(batch_size=1000)
print(data.shape)

# Original

In [None]:
lr=1e-3
steps = int(1e+4)
bs = 4000
show_freq = int(500)

vf = MLP(input_size=3, hidden_sizes=[512, 512], output_size=2).to('cuda')

optim = torch.optim.Adam(params=vf.parameters(), lr=lr)

loss_fn = torch.nn.MSELoss()

loss_list = []
for i in range(steps):
    k = 10
    x1 = inf_train_gen(batch_size=bs, device='cuda')
    # x0 = torch.randn((bs, 2)).to(x1.device)
    x0 = inf_train_gen_2(batch_size=bs, device='cuda')
    noise = torch.randn_like(x0).to(x0.device) * 1
    x0 = noise
    # x0 = x0.reshape(bs, k, 2)
    # x1, x0 = find_closest_a_b_pairs(x1, x0)
    # x0 = select_min_mse(x0, x1)
    
    t = torch.rand((x1.shape[0], 1)).to(x1.device)
    
    xt = (1 - t) * x0 + t * x1
    input = torch.cat([xt, t], dim=1)
    vf_pre = vf(input)
    vf_gt = x1 - x0
    
    optim.zero_grad()
    # weight = torch.sqrt((x1 - x0) ** 2)
    # loss = torch.mean((vf_pre - vf_gt)) # ** 2 * weight) # + torch.mean(torch.linalg.norm(vf_pre, dim=1)) * 0.5
    loss = loss_fn(vf_gt, vf_pre)
    loss.backward()
    optim.step()
    
    if (i + 1) % show_freq == 0:
        loss = loss.detach().cpu().numpy().item()
        loss_list.append(loss)
        print(f"{i} step, loss={loss}")


# sample and draw
# flow_matching
from matplotlib import cm

class WrappedModel(ModelWrapper):
    def forward(self, x: torch.Tensor, t: torch.Tensor, **extras):
        t = t.reshape(1, 1).repeat(x.shape[0], 1)
        input = torch.cat([x, t], dim=1)
        return self.model(input)

wrapped_vf = WrappedModel(vf)

# step size for ode solver
step_size = 0.05

norm = cm.colors.Normalize(vmax=50, vmin=0)

batch_size = 500000  # batch size
T = torch.linspace(0,1,10)  # sample times
T = T.to(device='cuda')

# x_init = torch.randn((batch_size, 2), dtype=torch.float32, device='cuda')
x_init = inf_train_gen_2(batch_size).to('cuda')
noise = torch.randn_like(x_init).to(x_init.device)
x_init = noise

solver = ODESolver(velocity_model=wrapped_vf)  # create an ODESolver class
sol = solver.sample(time_grid=T, x_init=x_init, method='midpoint', step_size=step_size, return_intermediates=False)  # sample from the model


sol = sol.cpu().numpy()
T = T.cpu()

fig, axs = plt.subplots(1, 10,figsize=(20,20))

for i in range(10):
    H= axs[i].hist2d(sol[i,:,0], sol[i,:,1], 300, range=((-5,5), (-5,5)))
    
    cmin = 0.0
    cmax = torch.quantile(torch.from_numpy(H[0]), 0.99).item()
    
    norm = cm.colors.Normalize(vmax=cmax, vmin=cmin)
    
    _ = axs[i].hist2d(sol[i,:,0], sol[i,:,1], 300, range=((-5,5), (-5,5)), norm=norm)
    
    axs[i].set_aspect('equal')
    axs[i].axis('off')
    axs[i].set_title('t= %.2f' % (T[i]))
    
plt.tight_layout()
plt.show()

In [6]:
sol.shape

(500000, 2)

# OT Original

In [None]:
lr=1e-3
steps = int(1e+4)
bs = 4000
show_freq = int(500)

vf = MLP(input_size=3, hidden_sizes=[512, 512], output_size=2).to('cuda')

optim = torch.optim.Adam(params=vf.parameters(), lr=lr)

loss_fn = torch.nn.MSELoss()

loss_list = []
for i in range(steps):
    k = 10
    x1 = inf_train_gen(batch_size=bs, device='cuda')
    # x0 = torch.randn((bs, 2)).to(x1.device)
    x0 = inf_train_gen_2(batch_size=bs, device='cuda')
    # noise = torch.randn_like(x0).to(x0.device) * 1
    # x0 = x0.reshape(bs, k, 2)
    # x1, x0 = find_closest_a_b_pairs(x1, x0)
    # x0 = select_min_mse(x0, x1)
    
    t = torch.rand((x1.shape[0], 1)).to(x1.device)
    
    xt = (1 - t) * x0 + t * x1
    input = torch.cat([xt, t], dim=1)
    vf_pre = vf(input)
    vf_gt = x1 - x0
    
    optim.zero_grad()
    # weight = torch.sqrt((x1 - x0) ** 2)
    # loss = torch.mean((vf_pre - vf_gt)) # ** 2 * weight) # + torch.mean(torch.linalg.norm(vf_pre, dim=1)) * 0.5
    loss = loss_fn(vf_gt, vf_pre)
    loss.backward()
    optim.step()
    
    if (i + 1) % show_freq == 0:
        loss = loss.detach().cpu().numpy().item()
        loss_list.append(loss)
        print(f"{i} step, loss={loss}")

# sample and draw
# flow_matching
from matplotlib import cm

class WrappedModel(ModelWrapper):
    def forward(self, x: torch.Tensor, t: torch.Tensor, **extras):
        t = t.reshape(1, 1).repeat(x.shape[0], 1)
        input = torch.cat([x, t], dim=1)
        return self.model(input)

wrapped_vf = WrappedModel(vf)

# step size for ode solver
step_size = 0.05

norm = cm.colors.Normalize(vmax=50, vmin=0)

batch_size = 500000  # batch size
T = torch.linspace(0,1,10)  # sample times
T = T.to(device='cuda')

# x_init = torch.randn((batch_size, 2), dtype=torch.float32, device='cuda')
x_init = inf_train_gen_2(batch_size).to('cuda')
# noise = torch.randn_like(x_init).to(x_init.device)
# x_init = noise

solver = ODESolver(velocity_model=wrapped_vf)  # create an ODESolver class
sol = solver.sample(time_grid=T, x_init=x_init, method='midpoint', step_size=step_size, return_intermediates=True)  # sample from the model


sol = sol.cpu().numpy()
T = T.cpu()

fig, axs = plt.subplots(1, 10,figsize=(20,20))

for i in range(10):
    H= axs[i].hist2d(sol[i,:,0], sol[i,:,1], 300, range=((-5,5), (-5,5)))
    
    cmin = 0.0
    cmax = torch.quantile(torch.from_numpy(H[0]), 0.99).item()
    
    norm = cm.colors.Normalize(vmax=cmax, vmin=cmin)
    
    _ = axs[i].hist2d(sol[i,:,0], sol[i,:,1], 300, range=((-5,5), (-5,5)), norm=norm)
    
    axs[i].set_aspect('equal')
    axs[i].axis('off')
    axs[i].set_title('t= %.2f' % (T[i]))
    
plt.tight_layout()
plt.show()

# OT minibatch

In [None]:
from torchcfm.optimal_transport import OTPlanSampler
from sklearn.datasets import make_moons, make_circles

ot_sampler = OTPlanSampler(method="exact")
steps=15000
lr=1e-3
bs = 4001
show_freq = int(500)
source_data_type = 'gaussian'
target_data_type = '8square'
noise_scale = 0.
ot_minibatch = False

vf = MLP(input_size=3, hidden_sizes=[512, 512], output_size=2).to('cuda')

optim = torch.optim.Adam(params=vf.parameters(), lr=lr)

loss_fn = torch.nn.MSELoss()

loss_list = []
for i in range(steps):
    # x1 data generation
    if target_data_type == '8square':
        x1 = inf_train_gen(batch_size=bs, device='cuda')
    elif target_data_type == 'circle':
        x1_numpy, _ = make_circles(n_samples=bs, shuffle=True, noise=0, factor=0.2)
        x1 = torch.from_numpy(x1_numpy).to('cuda').to(torch.float32)
    else:
        raise NotImplementedError
    
    
    # x0 data generation
    if source_data_type == 's8square':
        x0 = inf_train_gen_2(batch_size=bs, device='cuda')  # squeezed 8 squre
        noise = torch.randn_like(x0).to(x0.device) * noise_scale
        x0 += noise
    elif source_data_type == 'gaussian':
        x0 = torch.randn_like(x1).to(x1.device)  # gaussian
    elif source_data_type == '2moons':
        x0, _ = make_moons(n_samples=bs, noise=3)  # 2 moons
        x0 = torch.from_numpy(x0).to(x1.device).to(x1.dtype)
    else:
        raise NotImplementedError
        
    if ot_minibatch:
        x0, x1 = ot_sampler.sample_plan(x0, x1) 
    
    t = torch.rand((x1.shape[0], 1)).to(x1.device)
    
    noise_new = torch.randn_like(x1).to(x1.device)
    xt = (1 - t) * x0 + t * x1
    input = torch.cat([xt, t], dim=1)
    vf_pre = vf(input)
    vf_gt = x1 - x0
    
    optim.zero_grad()
    loss = loss_fn(vf_gt, vf_pre)
    loss.backward()
    optim.step()
    
    if (i + 1) % show_freq == 0:
        loss = loss.detach().cpu().numpy().item()
        loss_list.append(loss)
        print(f"{i} step, loss={loss}")

In [None]:
# flow_matching
from matplotlib import cm

class WrappedModel(ModelWrapper):
    def forward(self, x: torch.Tensor, t: torch.Tensor, **extras):
        t = t.reshape(1, 1).repeat(x.shape[0], 1)
        input = torch.cat([x, t], dim=1)
        return self.model(input)

wrapped_vf = WrappedModel(vf)

# step size for ode solver
step_size = 1/10

norm = cm.colors.Normalize(vmax=50, vmin=0)

batch_size = 500000  # batch size
T = torch.linspace(0,1,10)  # sample times
T = T.to(device='cuda')


if source_data_type == 's8square':
    x_init = inf_train_gen_2(batch_size, device='cuda')  # squeezed 8 squre
    noise = torch.randn_like(x_init).to(x_init.device) * noise_scale
    x_init += noise
elif source_data_type == 'gaussian':
    x_init = torch.randn((batch_size, 2), dtype=torch.float32, device='cuda')
elif source_data_type == '2moons':
    x_init, _ = make_moons(n_samples=batch_size, noise=3)  # 2 moons
    x_init = torch.from_numpy(x_init).to('cuda').to(torch.float32)
else:
    raise NotImplementedError

solver = ODESolver(velocity_model=wrapped_vf)  # create an ODESolver class
sol = solver.sample(time_grid=T, x_init=x_init, method='midpoint', step_size=step_size, return_intermediates=True)  # sample from the model


sol = sol.cpu().numpy()
sol_all_data = sol.reshape(-1, 2)
sol_all_data_1 = sol_all_data[:, 0].max()
sol_all_data_2 = sol_all_data[:, 1].max()
T = T.cpu()

fig, axs = plt.subplots(1, 10,figsize=(20,20))

for i in range(10):
    H= axs[i].hist2d(sol[i,:,0], sol[i,:,1], 300, range=((-sol_all_data_1,sol_all_data_1), (-sol_all_data_2,sol_all_data_2)))
    
    cmin = 0.0
    cmax = torch.quantile(torch.from_numpy(H[0]), 0.99).item()
    
    norm = cm.colors.Normalize(vmax=cmax, vmin=cmin)
    
    _ = axs[i].hist2d(sol[i,:,0], sol[i,:,1], 300, range=((-sol_all_data_1,sol_all_data_1), (-sol_all_data_2,sol_all_data_2)), norm=norm)
    
    axs[i].set_aspect('equal')
    axs[i].axis('off')
    axs[i].set_title('t= %.2f' % (T[i]))
    
plt.tight_layout()
plt.show()