# Train MENT-Flow

In [None]:
import hydra
import numpy as np
import torch
from omegaconf import DictConfig
from omegaconf import OmegaConf
from pprint import pprint

import mentflow as mf

import setup

In [None]:
with hydra.initialize(version_base=None, config_path="../../config"):
    cfg = hydra.compose(
        config_name="rec_2d_linear_flow.yaml", 
        overrides=[
            "dist.name=swissroll",
            "meas.noise_scale=0.15",
            "meas.noise_type=uniform",
            "train.epochs=20",
            "train.iters=200",
            "train.dmax=0.001",
            "train.penalty_step=5.0",
            "train.penalty_scale=1.1",
            "seed=21",
            "plot.line_kind=line",
            "gen.transforms=3",
            "device=mps",
        ]
    )
    print(OmegaConf.to_yaml(cfg))

In [None]:
transforms, diagnostics, measurements = setup.generate_training_data(
    cfg,
    make_dist=setup.make_dist,
    make_diagnostics=setup.make_diagnostics,
    make_transforms=setup.make_transforms,
)

model = setup.setup_mentflow_model(
    cfg,
    transforms=transforms,
    diagnostics=diagnostics,
    measurements=measurements,
)

setup.train_mentflow_model(
    cfg,
    model=model,
    setup_plot=setup.setup_plot,
    setup_eval=setup.setup_eval,
    output_dir=None,
    notebook=True,
)