# Model Training

In this tutorial we are going to train a model from scratch on a molecular dataset from the MD17 collection.
Start by creating a project folder and downloading the dataset.

## Acquiring a dataset

You can obtain the benzene dataset with DFT labels either by running the following command or manually from this [link](http://www.quantum-machine.org/gdml/data/xyz/benzene2018_dft.zip). Apax uses ASE to read in datasets, so make sure to convert your own data into an ASE readable format (extxyz, traj etc). Be carefull the downloaded dataset has to be modified like in the `apax.untils.dataset.mop_md17` function in order to be readable.

In [None]:
from pathlib import Path
from apax.utils.datasets import download_md17_benzene_DFT, mod_md17

data_path = Path("project")

file_path = download_md17_benzene_DFT(data_path)
file_path = mod_md17(file_path)



## Configuration files

Next, we require a configuration file that specifies the model and training parameters.
In order to get users quickly up and running, our command line interface provides an easy way to generate input templates.
The provided templates come in in two levels of verbosity: minimal and full.
In the following we are going to use a minimal input file. To see a complete list and explanation of all parameters, consult the documentation page LINK.
For more information on the CLI,  simply run `apax -h`.

In [None]:
!apax -h

The following command create a minimal configuration file in the working directory.

In [None]:
!apax template train

Open the resulting `config.yaml` file in an editor of your choice and make sure to fill in the data path field with the name of the data set you just downloaded.
For the purposes of this tutorial we will train on 1000 data points and validate the model on 200 more during the training. Further, the units of the labels have to be specified. Random splitting is done by apax but it is also possible to input a pre-splitted training and validation dataset

The filled in configuration file should look similar to this one.

```yaml
epoch: 1000
data:
    data_path: md17.extexyz
    epochs: 1000
    n_train: 1000
    energy_unit: kcal/mol
    pos_unit: Ang
    ....
```

It also can be modefied with the utils function `mod_config` provided by Apax.


In [None]:
from apax.utils.helpers import mod_config
import yaml


config_path = Path("config.yaml")

config_updates = {
    "n_epochs": 10,
    "data": {
        "experiment": "benzene_dft_cli",
        "directory": "project/models",
        "data_path": str(file_path),
        "energy_unit": "kcal/mol",
        "pos_unit": "Ang",
    }
}
config_dict = mod_config(config_path, config_updates)

with open("config.yaml", "w") as conf:
    yaml.dump(config_dict, conf, default_flow_style=False)



In order to check whether the a configuration file is valid, we provide the `validate` command. This is especially convenient when submitting training runs on a compute cluster.


In [None]:
!apax validate train config.yaml

Configuration files are validated using Pydantic and the errors provided by the `validate` command give precise instructions on how to fix the input file.
For example, changing `epochs` to `-1000`, validate will give the following feedback to the user:

In [None]:
config_updates = {
    "n_epochs": -1000,
}
config_dict = mod_config(config_path, config_updates)

with open("error_config.yaml", "w") as conf:
    yaml.dump(config_dict, conf, default_flow_style=False)

In [None]:
!apax validate train error_config.yaml

## Training

Model training can be started by running

In [None]:
!apax train config.yaml



During training, apax displays a progress bar to keep track of the validation loss.
This progress bar is optional however and can be turned off in the config. LINK
The default configuration writes training metrics to a CSV file, but TensorBoard is also supported.
One can specify which to use by adding the following section to the input file:

```yaml
callbacks:
    - CSV
```

If training is interrupted for any reason, re-running the above `train` command will resume training from the latest checkpoint.

Furthermore, an Apax trianing can easily be started within a scriped.

In [None]:
from apax.train.run import run

config_path = Path("config.yaml")

config_updates = {
    "n_epochs": 100,
    "data": {
        "experiment": "benzene_dft_script",
        "directory": "project/models",
        "data_path": str(file_path),
        "energy_unit": "kcal/mol",
        "pos_unit": "Ang",
    }
}

config_dict = mod_config(config_path, config_updates)

run(config_dict)

In [None]:
import csv
import matplotlib.pyplot as plt
import numpy as np


path = "project/models/benzene_dft_script/log.csv"

keys = ["energy_mae", "forces_mse", "forces_mae", "loss"]
data_dict = {}

with open(path, 'r') as file:
    reader = csv.reader(file)

    # Extract the headers (keys) from the first row
    headers = next(reader)

    # Initialize empty lists for each key
    for header in headers:
        data_dict[header] = []

    # Read the rest of the rows and append values to the corresponding key
    for row in reader:
        for idx, value in enumerate(row):
            key = headers[idx]
            data_dict[key].append(float(value))

for key in keys:
    fig, axes = plt.subplots(1, 2, sharey=True, sharex=True, figsize=(18, 4))

    val = np.array(data_dict[f"val_{key}"])
    train = np.array(data_dict[f"train_{key}"])
    epoch = np.array(data_dict["epoch"])

    axes[0].plot(epoch, val)
    axes[1].plot(epoch, train)

    axes[0].set_ylabel(f"val_{key}")
    axes[0].set_xlabel(r"epoch")
    axes[1].set_ylabel(f"train_{key}")
    axes[1].set_xlabel(r"epoch")
    fig.show()

## Evaluation

After the training is completed and we are satisfied with our choice of hyperparameters and vadliation loss, we can evaluate the model on the test set.
We provide a separate command for test set evaluation:


In [None]:
!apax evaluate config_minimal.yaml


TODO pretty print results to the terminal

Congratulations, you have successfully trained and evaluated your first apax model!

## A Closer Look At Training Parameters

TODO

To remove all the created files and clean up yor working directory run

In [None]:
!rm -r project config.yaml error_config.yaml