In [11]:
import math
from sklearn.datasets import make_moons, make_circles, make_classification

def generate_data(dataset_type, n_samples=1000, noise=0.1):
    if dataset_type == 'moons':
        X, y = make_moons(n_samples=n_samples, noise=noise)
    elif dataset_type == 'circles':
        X, y = make_circles(n_samples=n_samples, noise=noise, factor=0.5)
        # normalize
        X = (X - X.mean(axis=0)) / X.std(axis=0)
    elif dataset_type == 'classification':
        X, y = make_classification(n_samples=n_samples, n_features=2, n_informative=2,
                                   n_redundant=0, n_clusters_per_class=1)
    elif dataset_type == '5gaussians':
        centers = [[math.cos(2 * math.pi * i / 5), math.sin(2 * math.pi * i / 5)] for i in range(5)]
        X, y = make_classification(n_samples=n_samples, n_features=2, n_informative=2,
                                   n_redundant=0, n_clusters_per_class=1, n_classes=5,
                                   class_sep=2.0, centers=centers)
    else:
        raise ValueError("Unknown dataset type")
    return X, y

In [16]:
import torch 
from cfm.modules import SimpleFlowModel
from cfm.utils import Trainer

X, _ = generate_data('moons', n_samples=10000, noise=0.05)
dataset = torch.utils.data.TensorDataset(torch.tensor(X, dtype=torch.float32))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)

flow_model = SimpleFlowModel(input_dim=2, time_dim=8, hidden_dim=32)

optimizer = torch.optim.Adam(flow_model.parameters(), lr=1e-3)
trainer = Trainer(flow_model, dataloader, n_epochs=200, sigma=0.005, sample_from_coupling=None, optimizer=optimizer)
trainer.train(from_random_gaussian=True)

torch.save(flow_model.state_dict(), 'flow_model.pth')

Epoch [200/200], Loss: 0.9855: 100%|██████████| 200/200 [00:15<00:00, 12.90it/s]


In [17]:
from cfm.utils import FlowModelPipeline
model = SimpleFlowModel(input_dim=2, time_dim=8, hidden_dim=32)
pipeline = FlowModelPipeline.from_pretrained(model, state_dict_path='./flow_model.pth', device='cuda')

seeds = torch.randn(100, 2)
samples = pipeline.sample(seeds, n_steps=100)

In [18]:
pipeline.generate_animation(seeds, n_steps=100)