In [None]:
#imports
import lightning as L
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import wandb
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import WandbLogger

from {{cookiecutter.repo_name}}.models.flow_matching import FlowMatching, FlowMatchingCFG
from {{cookiecutter.repo_name}}.data.MNIST_datamodule import MNISTDataModule
from {{cookiecutter.repo_name}}.data.moons_datamodule import MoonsDataModule
from {{cookiecutter.repo_name}}.networks.mlp import MLP
from {{cookiecutter.repo_name}}.networks.unet import UNet
from {{cookiecutter.repo_name}}.modules.schedulers import LinearScheduler
from {{cookiecutter.repo_name}}.modules.samplers import GaussianSampler
from {{cookiecutter.repo_name}}.modules.solvers import EulerSolver
from {{cookiecutter.repo_name}}.utils import show_imgs


In [None]:
moons_dm = MoonsDataModule()

# --- Train Flow Matching Model ---
print("Training Flow Matching Model")
flow_model_moons = MoonsNet()
flow_model = FlowMatching(
    model=flow_model_moons,
    alpha_beta_scheduler=LinearScheduler(data_dim=2),
    sampler=GaussianSampler(target_shape=(2,)),
)
flow_trainer = Trainer(
    max_epochs=7,
    accelerator="auto",
    devices="auto",
    log_every_n_steps=50,
)
flow_trainer.fit(flow_model, moons_dm)
generated_samples = flow_model.generate_samples(steps=50, labels=[0, 1])
print(generated_samples)
print("Flow Matching Model training complete.")

In [None]:
moons_dm = MoonsDataModule()

# --- Train Flow Matching Model ---
print("Training Flow Matching Model")
flow_model_moons = MoonsNet()
flow_model_cfg = FlowMatchingCFG(
    model=flow_model_moons,
    num_classes=2,
    alpha_beta_scheduler=LinearScheduler(data_dim=2),
    sampler=GaussianSampler(target_shape=(2,)),
)
flow_trainer = Trainer(
    max_epochs=17,
    accelerator="auto",
    devices="auto",
    log_every_n_steps=50,
)
flow_trainer.fit(flow_model_cfg, moons_dm)
generated_samples = flow_model_cfg.generate_samples(steps=50, labels=[0, 1], guidance_scale=3.0)
print(generated_samples)
print("Flow Matching Model training complete.")

In [None]:
mnist_dm = MNISTDataModule(num_workers=1)

# --- Train Flow Matching Model ---
print("Training Flow Matching Model")
flow_model_unet = UNet()
flow_model = FlowMatching(flow_model_unet)
flow_trainer = Trainer(max_epochs=2, accelerator="auto", devices="auto", log_every_n_steps=50)
flow_trainer.fit(flow_model, mnist_dm)
generated_samples = flow_model.generate_samples(labels=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], steps=50)
show_imgs(generated_samples)
print("Flow Matching Model training complete.")

In [None]:
mnist_dm = MNISTDataModule(num_workers=1)

# --- Train Flow Matching Model ---
print("Training Flow Matching Model")
flow_model_unet = UNet()
flow_model_cfg = FlowMatchingCFG(flow_model_unet)
flow_trainer = Trainer(max_epochs=2, accelerator="auto", devices="auto", log_every_n_steps=50)
flow_trainer.fit(flow_model_cfg, mnist_dm)
generated_samples = flow_model_cfg.generate_samples(
    labels=[
        0,
        1,
        2,
        3,
        4,
        5,
        6,
        7,
        8,
        9,
        10,
        10,
    ],
    steps=50,
)
show_imgs(generated_samples)
print("Flow Matching Model training complete.")