In [1]:
%load_ext autoreload
%autoreload 2
import molpot as mpot
import torch
from ignite.engine.events import Events
from ignite.metrics import MeanAbsoluteError, MetricUsage

import logging
logging.basicConfig(level=logging.INFO)

ModuleNotFoundError: No module named 'molpot.logging'

In [2]:
def define_training_process(node=64, depth=4, nbasis=10):
        
        config = mpot.Config()
        config.set_device("cuda")

        pinet = mpot.nnp.PiNet2(
            depth=depth,
            basis_fn=mpot.nnp.radial.GaussianRBF(nbasis, 4.5),
            cutoff_fn=mpot.nnp.cutoff.CosineCutoff(4.5),
            pi_nodes=[node, node],
            ii_nodes=[node, node, node, node],
            pp_nodes=[node, node, node, node],
            activation=torch.nn.Tanh(),
        )

        e_readout = mpot.nnp.base.Batchwise(
            n_neurons=[node, node, 1],
            in_key=("pinet", "p1"),
            out_key="energy",
            reduce="sum",
        )
        f_readout = mpot.nnp.base.PairForce(in_key=("pinet", "p1"), out_key="forces")
        potential = mpot.potential.PotentialSeq(pinet, e_readout, f_readout)

        optimizer = torch.optim.Adam(potential.parameters(), lr=1e-4)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.994)
        loss_fn = mpot.Constraint()
        loss_fn.add("energy", torch.nn.MSELoss(), "energy", "energy", 1.0)
        loss_fn.add("forces", torch.nn.MSELoss(), "forces", "forces", 10.0)
        lwscheduler = mpot.engine.loss.ExponentialLW(
            loss_fn.get_constraint("forces"), gamma=0.99
        )

        trainer = mpot.PotentialTrainer(
            model=potential,
            optimizer=optimizer,
            loss_fn=loss_fn,
            device=config.device,
        )
        trainer.compile()
        trainer.add_lr_scheduler(scheduler)
        trainer.add_lw_scheduler(lwscheduler)
        # trainer.add_checkpoint("ckpt")

        train_metric_usage = MetricUsage(
            started=Events.ITERATION_STARTED(every=100),
            iteration_completed=Events.ITERATION_COMPLETED,
            completed=Events.ITERATION_COMPLETED(every=100),
        )
        eval_metric_usage = MetricUsage(
            started=Events.EPOCH_STARTED,
            iteration_completed=Events.ITERATION_COMPLETED,
            completed=Events.EPOCH_COMPLETED,
        )

        trainer.set_metric_usage(
            trainer=train_metric_usage, evaluator=eval_metric_usage
        )

        trainer.add_metric(
            "e_mae",
            lambda: MeanAbsoluteError(
                output_transform=lambda x: (
                    x["predicts", "energy"],
                    x["labels", "energy"],
                ),
                device=config.device,
            ),
        )
        trainer.add_metric(
            "f_mae",
            lambda: MeanAbsoluteError(
                output_transform=lambda x: (
                    x["predicts", "forces"],
                    x["labels", "forces"],
                ),
                device=config.device,
            ),
        )

        # trainer.attach_tensorboard(log_dir="tblog")
        return trainer


In [None]:
def define_dataloader(ds_path, molecule):

    rmd17_ds = mpot.dataset.rMD17(ds_path, molecule=molecule)
    rmd17_ds.prepare(total=1000, preprocess=[mpot.process.NeighborList(cutoff=5.0)])

    train_ds, eval_ds = torch.utils.data.random_split(rmd17_ds, [0.95, 0.05])
    train_dl = mpot.DataLoader(train_ds, batch_size=10)
    eval_dl = mpot.DataLoader(eval_ds, batch_size=1)
    return train_dl, eval_dl

In [None]:
class TrainModel(mpot.App):

    def cmd_run(self, ds_path: str, molecule: str, max_steps: int):

        train_dl, eval_dl = define_dataloader(ds_path, molecule)
        trainer = define_training_process(train_dl, eval_dl)
        trainer.run(train_data=train_dl, max_steps=max_steps, eval_data=eval_dl)

    def cmd_run_with_config(self, config_path):
        config = self.load_config(config_path)
        train_dl, eval_dl = define_dataloader(config['ds_path'], config['molecule'])
        trainer = define_training_process(train_dl, eval_dl)
        trainer.run(train_data=train_dl, max_steps=config['max_steps'], eval_data=eval_dl)

In [None]:
app = TrainModel()
app.cmd_run("/workspaces/train_pot/data/rmd17", "aspirin", "1000")

In [None]:
import optuna

state = None


class TuningModel(mpot.App):

    def cmd_run(self, ds_path: str, molecule: str, n_trials: int):
        study_name = f"tuning_{molecule}"
        study = optuna.create_study(
            study_name=study_name, storage="sqlite:///tuning.db", load_if_exists=True
        )

        train_dl, eval_dl = define_dataloader(
            ds_path,
            molecule,
        )

        def objective(trial):
            global state
            node = trial.suggest_int("node", 16, 64, step=16)
            trainer = define_training_process(node=node)
            state = trainer.run(train_data=train_dl, max_steps=101, eval_data=None)
            return state.metrics["e_mae"]

        study.optimize(objective, n_trials=n_trials)
        return study


app = TuningModel()
study = app.cmd_run("/workspaces/train_pot/data/rmd17", "aspirin", 5)

In [None]:
optuna.visualization.plot_optimization_history(study)

In [None]:
import burr

class BatchTrainModel(mpot.App):
    ...
