# Train atomic and molecular systems
This notebook can be used to train the atomic and molecular systems H, He and H2.
If you prefer to use a Python script, use `train_atomic_molecular.py` instead.

In [None]:
import torch

In [None]:
import sys
sys.path.append('../../')

In [None]:
import continuum
from trainer import Trainer

### Choose system and make hparams
Each system has its own hparams. Please uncomment only the hparams settings for the system you would like to train.

In [None]:
# # hparams for Hydrogen atom (H)
# hparams = deepcopy(continuum.DEFAULT_HPARAMS)
# hparams.net = 'DriftResNet'
# hparams.potential = 'h_potential'
# hparams.number_of_particles = 1
# hparams.D = 3
# hparams.H = 256
# hparams.lr = 1e-2

## hparams for Helium atom (He)
hparams = continuum.DEFAULT_HPARAMS
hparams.net = 'PairDriftHelium'
hparams.potential = 'he_potential'
hparams.number_of_particles = 2
hparams.D = 3
hparams.lr = 1e-3

# # hparams for Hydrogen molecule (H2)
# hparams = deepcopy(continuum.DEFAULT_HPARAMS)
# hparams.net = 'PairDriftH2'
# hparams.potential = 'h2_param'
# hparams.R = 1.401  # set to 2.8 for wide H2 molecule
# hparams.number_of_particles = 2
# hparams.D = 3
# hparams.H = 64
# hparams.lr = 5e-4

# # hparams for wide Hydrogen molecule (H2)
# hparams = deepcopy(continuum.DEFAULT_HPARAMS)
# hparams.net = 'PairDriftH2'
# hparams.potential = 'h2_param'
# hparams.R = 2.8  # set to 2.8 for wide H2 molecule
# hparams.number_of_particles = 2
# hparams.D = 3
# hparams.H = 64
# hparams.lr = 1e-3

print(hparams)

## Train model

In [None]:
model = continuum.Model(hparams)

In [None]:
trainer = Trainer(name='He', gpus=[3], max_epochs=100)

In [None]:
trainer.fit(model)

### Resume training

If you would like to continue training one of our trained models, uncomment and execute the following code blocks.

In [None]:
# checkpoint_path = 'results/H2_wide/version_0/_ckpt_epoch_91.ckpt'
# model = continuum.Model.load_from_checkpoint(checkpoint_path)
# trainer = Trainer(name='H2_wide', gpus=[3],
#                   version=0, resume_from_checkpoint=checkpoint_path,
#                   max_epochs=42)

In [None]:
# trainer.fit(model)