# Train MENT

In [None]:
import hydra
import numpy as np
import proplot as pplt
import torch
from omegaconf import DictConfig
from omegaconf import OmegaConf
from pprint import pprint

import mentflow as mf

import setup

In [None]:
with hydra.initialize(version_base=None, config_path="../config"):
    cfg = hydra.compose(
        config_name="rec_nd_2d_ment.yaml",
        overrides=[
            "d=3",
            "seed=21",
            "dist.name=rings",
            "meas.optics=corner",
            "meas.bins=20",

            "model.mode=sample",
            "model.samp.method=grid",
            "model.samp.res=50",
            "model.samp.noise=1.0",
            "model.verbose=true",
            
            "train.omega=1.0",
            "train.batch_size=1000000",
            
            "eval.dist=none",
        ],
    )
    print(OmegaConf.to_yaml(cfg))

In [None]:
transforms, diagnostics, measurements = setup.generate_training_data(
    cfg,
    make_dist=setup.make_dist,
    make_diagnostics=setup.make_diagnostics,
    make_transforms=setup.make_transforms,
)

model = setup.setup_ment_model(
    cfg,
    transforms=transforms,
    diagnostics=diagnostics,
    measurements=measurements,
)

setup.train_ment_model(
    cfg,
    model=model,
    setup_plot=setup.setup_plot,
    setup_eval=setup.setup_eval,
    output_dir=None,
    notebook=True,
)