# Train NN

In [None]:
import hydra
import numpy as np
import torch
from omegaconf import DictConfig
from omegaconf import OmegaConf

import mentflow as mf

import setup

In [None]:
with hydra.initialize(version_base=None, config_path="../config"):
    cfg = hydra.compose(
        config_name="rec_nd_1d_nn.yaml",
        overrides=[
            "d=6",
            "device=mps",
            "dist.name=rings",
            "meas.num=25",
            "train.batch_size=10000",
            "train.iters=200",
            "train.lr_patience=100",
            "train.penalty=500.0",
            "model.entest=cov",
            "seed=21",
        ],
    )
    print(OmegaConf.to_yaml(cfg))

In [None]:
transforms, diagnostics, measurements = setup.generate_training_data(
    cfg,
    make_dist=setup.make_dist,
    make_diagnostics=setup.make_diagnostics,
    make_transforms=setup.make_transforms,
)

model = setup.setup_mentflow_model(
    cfg,
    transforms=transforms,
    diagnostics=diagnostics,
    measurements=measurements,
)

setup.train_mentflow_model(
    cfg,
    model=model,
    setup_plot=setup.setup_plot,
    setup_eval=setup.setup_eval,
    output_dir=None,
    notebook=True,
)