https://arxiv.org/abs/2505.04486

In [None]:
# # # download and update library
# !git clone https://github.com/Kemsekov/kemsekov_torch
# !cd kemsekov_torch && git pull

In [None]:
from matplotlib import pyplot as plt
import torch

size = 10000
X = torch.randint(0,3,(size,))*2
Y = torch.cos(X)+torch.randn_like(1.0*X)*0.1
X = torch.sin(X)+torch.randn_like(1.0*X)*0.1
domain1 = torch.stack([X,Y],-1)/2

n=2
X = torch.linspace(-torch.pi,torch.pi,size)
Y = torch.cos(X)
X = torch.sin(X)
domain2 = torch.stack([X,Y],-1)
domain2+=torch.randn_like(domain2)*0.01

domain1 = torch.concat([domain1,domain2])*4

domain2 = torch.randn_like(domain1)

domain1,domain2 = domain2,domain1

# plt.scatter(*domain1.chunk(2,-1),label='domain1')
# plt.scatter(*domain2.chunk(2,-1),label='domain2')
# plt.tight_layout()
# plt.legend()

In [None]:
import random
import numpy as np
from kemsekov_torch.train import split_dataset

class PairedDataset(torch.utils.data.Dataset):
    def __init__(self,domain1,domain2,seed=1):
        self.ind = np.array(range(len(domain1)))
        np.random.seed(seed)
        np.random.shuffle(self.ind)
        
        self.d1 = domain1
        self.d2 = domain2
    def __getitem__(self, index):
        ind1=random.randint(0,len(self.d1)-1)
        
        d1,d2 = self.d1[ind1],self.d2[index]
        return d1,d2
    def __len__(self):
        return len(self.d1)
dataset = PairedDataset(domain1,domain2)
train_dataset,test_dataset,train_loader, test_loader = split_dataset(
    dataset,
    test_size=0.05,
    batch_size=64,
    num_workers=16
)

In [None]:
import torch.nn as nn

class FmModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(5,256), #x,y + time + f_emb
            # nn.LayerNorm(256),
            nn.LeakyReLU(0.2),
            nn.Linear(256,256),
            # nn.LayerNorm(256),
            nn.LeakyReLU(0.2),
            nn.Linear(256,2),
        )
    def forward(self,x, f : torch.Tensor, t : torch.Tensor):
        if t.dim()==1:
            t = t[:,None]
        xt = torch.concat([x,t,f],-1)
        return self.model(xt)

class VaeModel(nn.Module):
    def __init__(self,input_dim=2,embedding_dim=32):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Linear(input_dim,256),
            nn.LayerNorm(256),
            nn.GELU(),
            
            nn.Linear(256,256),
            nn.LayerNorm(256),
            nn.GELU(),
            
            nn.Linear(256,2*embedding_dim),
        )
        self.decode = nn.Sequential(
            nn.Linear(embedding_dim,256),
            nn.LayerNorm(256),
            nn.GELU(),
            
            nn.Linear(256,256),
            nn.LayerNorm(256),
            nn.GELU(),
            
            nn.Linear(256,input_dim),
        )
    def encode(self,x):
        mu,logvar = self.enc(x).chunk(2,-1)
        return mu,logvar
    
    def sample(self,mu,logvar,variance_scale : float = 1.0):
        z = torch.randn_like(mu)*logvar.exp()*variance_scale+mu
        return self.decode(z)
    
    def forward(self,x):
        mu,logvar = self.encode(x)
        s = self.sample(mu,logvar)
        return s,mu,logvar

In [None]:
import torch.nn.functional as F
from kemsekov_torch.train import train
from kemsekov_torch.metrics import r2_score
from kemsekov_torch.common_modules import kl_divergence


beta = 1

def compute_loss_and_metric(model,batch):
    d1,d2 = batch
    rec,mu,var = model(d2)
    kl=kl_divergence(mu,var,-1)
    
    loss = F.mse_loss(d2,rec)+beta*kl
    return loss,{
        'r2':r2_score(d2,rec),
        'kl':kl
    }

epochs=10
vae = VaeModel(2,2)
optim = torch.optim.AdamW(vae.parameters(),1e-3)
sch = torch.optim.lr_scheduler.CosineAnnealingLR(optim,len(train_loader)*epochs)

# vae = train(
#     vae,
#     train_loader,
#     test_loader,
#     compute_loss_and_metric,
#     'runs/flow-matching-vae',
#     save_on_metric_improve=['r2'],
#     num_epochs=epochs,
#     optimizer=optim,
#     scheduler=sch,
#     accelerate_args={
#         'mixed_precision':'bf16'
#     },
#     ema_args={
#         'beta':0.999,
#         'power':1
#     }
# )

In [None]:
from kemsekov_torch.train import load_last_checkpoint,load_best_checkpoint
vae = load_best_checkpoint(vae,"runs/flow-matching-vae").cpu().eval()

for d1,d2 in test_loader:
    break

variance_scale = 1
with torch.no_grad():
    mu,var = vae.encode(d2)
    sample = vae.sample(mu,var,variance_scale)
    decode = vae.decode(mu)
    generate = vae.decode(torch.randn_like(mu))
    
plt.figure(figsize=(12,3))
plt.subplot(1,4,1)
plt.scatter(*d2.chunk(2,-1),label='input')
plt.title('input')
plt.subplot(1,4,2)
plt.scatter(*decode.chunk(2,-1),label='decode')
plt.title('decode')
plt.subplot(1,4,3)
plt.scatter(*sample.chunk(2,-1),label='sample')
plt.title('sample')
plt.subplot(1,4,4)
plt.scatter(*generate.chunk(2,-1))
plt.title('generate')
plt.tight_layout()

In [None]:
from kemsekov_torch.train import train
from kemsekov_torch.metrics import r2_score
from kemsekov_torch.flow_matching import FlowMatching

def get_f(d2,sample_scale=1):
    with torch.no_grad():
        vae.to(d2.device)
        mu,logvar = vae.encode(d2)
        f = torch.randn_like(mu)*logvar.exp()*sample_scale+mu
        return f

def mse(pred,target,scale=1):
    return (pred-target).pow(2).mul(scale).mean()

num_epochs=20
# scaler = torch.sin(torch.tensor([1]))
# fm = FlowMatching(lambda x: torch.sin(x)/scaler.to(x.device)/2+1)
fm = FlowMatching()

contrast_lambda = 0.1
def compute_loss_and_metric(model,batch):
    d1,d2 = batch
    d1 = torch.randn_like(d1)
    f = get_f(d2,0.5)
    
    def run_model(x,t):
        return model(x,f*0,t)
    
    pred_dir,true_dir,contrast,t = fm.contrastive_flow_matching_pair(run_model,d1,d2)

    loss = mse(pred_dir,true_dir)-contrast_lambda*mse(contrast,pred_dir)
    return loss,{
        'r2':r2_score(pred_dir,true_dir),
    }

epochs=20
m = FmModel()
optim = torch.optim.AdamW(m.parameters(),1e-2)
sch = torch.optim.lr_scheduler.CosineAnnealingLR(optim,len(train_loader)*epochs)

train(
    m,
    train_loader,
    test_loader,
    compute_loss_and_metric,
    'runs/flow-matching',
    save_on_metric_improve=['r2'],
    checkpoints_count=1,
    num_epochs=epochs,
    optimizer=optim,
    scheduler=sch,
    accelerate_args={
        'mixed_precision':'bf16'
    },
    ema_args={
        'beta':0.99,
        'power':1
    }
)

In [None]:
from kemsekov_torch.train import load_last_checkpoint
m = load_last_checkpoint(m,"runs/flow-matching").eval().cpu()

d1s = []
d2s = []
count = 0
for d1,d2 in train_loader:
    d1s.append(d1)
    d2s.append(d2)
    count+=1
    if count>4: break
d1 = torch.concat(d1s,0)
d2 = torch.concat(d2s,0)

In [None]:
f = get_f(d2,1/2)
def run_model(x,t):
    return m(x,f*0,t)

steps=64
churn_scale=0.00
d2_pred,paths = fm.sample(run_model,d1,steps,churn_scale=churn_scale,return_intermediates=True)
# d1_pred,paths2 = fm.sample(m,d2,steps,churn_scale=churn_scale,inverse=True,return_intermediates=True)

# plt.scatter(*d1.chunk(2,-1),label='d1')
plt.scatter(*d2.chunk(2,-1),label='d2')
plt.scatter(*d2_pred.chunk(2,-1),label='d2 pred')
# plt.scatter(*d1_pred.chunk(2,-1),label='d1 pred')
plt.legend()
plt.tight_layout()
plt.show()


from random import shuffle
plt.figure(figsize=(8,8))
paths_stack = torch.stack(paths,0).transpose(0,1)

plt.scatter(*paths_stack[:,0,:].chunk(2,-1),c='blue',label='start')
plt.scatter(*paths_stack[:,-1,:].chunk(2,-1),c='orange',label='end')

for path in paths_stack:
    start = path[0]
    end = path[-1]

    if random.randint(0,10)==0:
        plt.plot(*path.chunk(2,-1),c="gray")
plt.legend()