<a href="https://colab.research.google.com/github/elliottower/ocp/blob/696DS-Spring21/docs/source/tutorials/colab_train_s2ef_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SchNet S2EF training example

The purpose of this notebook is to demonstrate some of the basics of the Open Catalyst Project's (OCP) codebase and data. In this example, we will train a schnet model for predicting the energy and forces of a given structure (S2EF task). First, ensure you have installed the OCP ocp repo and all the dependencies according to the [README](https://github.com/Open-Catalyst-Project/ocp/blob/master/README.md).

Disclaimer: This notebook is for tutorial purposes, it is unlikely it will be practical to train baseline models on our larger datasets using this format. As a next step, we recommend trying the command line examples. 

## Environment Setup

In [3]:
# Clone OCP repo and install PyTorch
!git clone https://github.com/Open-Catalyst-Project/ocp.git 
!pip3 install torch==1.6.0 torchvision==0.7.0 -f https://download.pytorch.org/whl/cu101/torch_stable.html

Cloning into 'ocp'...
remote: Enumerating objects: 89, done.[K
remote: Counting objects:   1% (1/89)[Kremote: Counting objects:   2% (2/89)[Kremote: Counting objects:   3% (3/89)[Kremote: Counting objects:   4% (4/89)[Kremote: Counting objects:   5% (5/89)[Kremote: Counting objects:   6% (6/89)[Kremote: Counting objects:   7% (7/89)[Kremote: Counting objects:   8% (8/89)[Kremote: Counting objects:  10% (9/89)[Kremote: Counting objects:  11% (10/89)[Kremote: Counting objects:  12% (11/89)[Kremote: Counting objects:  13% (12/89)[Kremote: Counting objects:  14% (13/89)[Kremote: Counting objects:  15% (14/89)[Kremote: Counting objects:  16% (15/89)[Kremote: Counting objects:  17% (16/89)[Kremote: Counting objects:  19% (17/89)[Kremote: Counting objects:  20% (18/89)[Kremote: Counting objects:  21% (19/89)[Kremote: Counting objects:  22% (20/89)[Kremote: Counting objects:  23% (21/89)[Kremote: Counting objects:  24% (22/89)[Kremote: Counting obje

In [5]:
# Install OCP dependencies
!pip3 install -r ./ocp/docs/source/tutorials/colab_requirements.txt

Collecting aiohttp==3.7.3
[?25l  Downloading https://files.pythonhosted.org/packages/ad/e6/d4b6235d776c9b33f853e603efede5aac5a34f71ca9d3877adb30492eb4e/aiohttp-3.7.3-cp36-cp36m-manylinux2014_x86_64.whl (1.3MB)
[K     |████████████████████████████████| 1.3MB 17.4MB/s 
[?25hCollecting aiohttp-cors==0.7.0
  Downloading https://files.pythonhosted.org/packages/13/e7/e436a0c0eb5127d8b491a9b83ecd2391c6ff7dcd5548dfaec2080a2340fd/aiohttp_cors-0.7.0-py3-none-any.whl
Collecting aioredis==1.3.1
[?25l  Downloading https://files.pythonhosted.org/packages/b0/64/1b1612d0a104f21f80eb4c6e1b6075f2e6aba8e228f46f229cfd3fdac859/aioredis-1.3.1-py3-none-any.whl (65kB)
[K     |████████████████████████████████| 71kB 11.7MB/s 
Collecting APScheduler==3.6.3
[?25l  Downloading https://files.pythonhosted.org/packages/f3/34/9ef20ed473c4fd2c3df54ef77a27ae3fc7500b16b192add4720cab8b2c09/APScheduler-3.6.3-py2.py3-none-any.whl (58kB)
[K     |████████████████████████████████| 61kB 9.6MB/s 
Collecting ase==3.19.1
[

In [6]:
# Install OCP module after we have installed other dependencies 
!python3 -m pip install --upgrade git+https://github.com/Open-Catalyst-Project/ocp.git

Collecting git+https://github.com/Open-Catalyst-Project/ocp.git
  Cloning https://github.com/Open-Catalyst-Project/ocp.git to /tmp/pip-req-build-s7wled6g
  Running command git clone -q https://github.com/Open-Catalyst-Project/ocp.git /tmp/pip-req-build-s7wled6g
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Building wheels for collected packages: ocp-models
  Building wheel for ocp-models (PEP 517) ... [?25l[?25hdone
  Created wheel for ocp-models: filename=ocp_models-0.0.1-py3-none-any.whl size=95312 sha256=03504c2207c81ee2c89eeaddafe274e97b531ef4c5a73b7b4828e76e33effad8
  Stored in directory: /tmp/pip-ephem-wheel-cache-f590spvr/wheels/c6/10/91/242ef2927f2768612c913b05f02c1f9b28d89de65c25aa0962
Successfully built ocp-models
Installing collected packages: ocp-models
Successfully installed ocp-models-0.0.1


## Imports

In [7]:
import torch
from ocpmodels import models
from ocpmodels.trainers import ForcesTrainer

In [7]:
# a simple sanity check that a GPU is available
if torch.cuda.is_available():
    print("True")
else:
    print("False")

True


## The essential steps for training an OCP model

1) Download data

2) Preprocess data (if necessary)

3) Define or load a configuration (config), which includes the following
   
   - task
   - model
   - optimizer
   - dataset
   - trainer

4) Train

5) Depending on the model/task there might be intermediate relaxation step

6) Predict

## Download and preprocess data

In [None]:
!python ./ocp/scripts/download_data.py --task s2ef --split "200k" --get-edges --num-workers 1 --ref-energy

--2021-02-17 01:05:04--  https://dl.fbaipublicfiles.com/opencatalystproject/data/s2ef_train_200K.tar
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 104.22.75.142, 172.67.9.4, 104.22.74.142, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|104.22.75.142|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 359720960 (343M) [application/x-tar]
Saving to: ‘s2ef_train_200K.tar’


2021-02-17 01:05:09 (73.1 MB/s) - ‘s2ef_train_200K.tar’ saved [359720960/359720960]

Extracting contents...
Uncompressing s2ef_train_200K/s2ef_train_200K: 100% 80/80 [00:40<00:00,  1.99it/s]
Preprocessing data into LMDBs: 100% 199992/200000 [38:03<00:00, 97.68it/s][0m

Please download the train S2EF 200K dataset from [here](https://github.com/Open-Catalyst-Project/ocp/blob/master/README.md#download-the-datasets). This example assumes the --get-edges flag to be present at preprocessing time.

In [None]:
# set the path to your local lmdb directory
train_src = "/ocp/data/s2ef/200k/train"

while True: pass

KeyboardInterrupt: ignored

## Define config

For this example, we will explicitly define the config; however, a set of default config files exists in the config folder of this repository. Default config yaml files can easily be loaded with the `build_config` util (found in `ocp/ocpmodels/common/utils.py`). Loading a yaml config is preferrable when launching jobs from the command line. We have included our best models' config files [here](https://github.com/Open-Catalyst-Project/ocp/tree/master/configs/s2ef).

**Task** 

In [None]:
task = {
    'dataset': 'trajectory_lmdb', # dataset used for the S2EF task
    'description': 'Regressing to energies and forces for DFT trajectories from OCP',
    'type': 'regression',
    'metric': 'mae',
    'labels': ['potential energy'],
    'grad_input': 'atomic forces',
    'train_on_free_atoms': True,
    'eval_on_free_atoms': True
}

**Model** - SchNet for this example

In [None]:
model = {
    'name': 'schnet',
    'hidden_channels': 1024, # if training is too slow for example purposes reduce the number of hidden channels
    'num_filters': 256,
    'num_interactions': 3,
    'num_gaussians': 200,
    'cutoff': 6.0
}

**Optimizer**

In [None]:
optimizer = {
    'batch_size': 16, # if hitting GPU memory issues, lower this
    'eval_batch_size': 8,
    'num_workers': 8,
    'lr_initial': 0.0001,
    'lr_gamma': 0.1,
    'lr_milestones': [15, 20],
    'warmup_epochs': 10,
    'warmup_factor': 0.2,
    'max_epochs': 1, # used for demonstration purposes
    'force_coefficient': 100,
}

**Dataset**

For simplicity, `train_src` is used for all the train/val/test sets. Feel free to update with the actual S2EF val and test sets, but it does require additional downloads and preprocessing. If you desire to normalize your targets, `normalize_labels` must be set to `True` and corresponding `mean` and `stds` need to be specified. These values have been precomputed for you and can be found in any of the [`base.yml`](https://github.com/Open-Catalyst-Project/ocp/blob/master/configs/s2ef/20M/base.yml#L5-L9) config files.

In [None]:
dataset = [
{'src': train_src, 'normalize_labels': False}, # train set 
{'src': train_src}, # val set (optional)
{'src': train_src} # test set (optional - writes predictions to disk)
]

NameError: ignored

**Trainer**

Use the `ForcesTrainer` for the S2EF and IS2RS tasks, and the `EnergyTrainer` for the IS2RE task 

In [None]:
trainer = ForcesTrainer(
    task=task,
    model=model,
    dataset=dataset,
    optimizer=optimizer,
    identifier="SchNet-example",
    run_dir="./", # directory to save results if is_debug=False. Prediction files are saved here so be careful not to override!
    is_debug=False, # if True, do not save checkpoint, logs, or results
    is_vis=False,
    print_every=10,
    seed=0, # random seed to use
    logger="tensorboard", # logger of choice (tensorboard and wandb supported)
    local_rank=0,
    amp=False, # use PyTorch Automatic Mixed Precision (faster training and less memory usage)
)

amp: false
cmd:
  checkpoint_dir: checkpoints/2021-01-06-16-49-04-SchNet-example
  identifier: SchNet-example
  logs_dir: logs/tensorboard/2021-01-06-16-49-04-SchNet-example
  print_every: 10
  results_dir: results/2021-01-06-16-49-04-SchNet-example
  seed: 0
  timestamp: 2021-01-06-16-49-04-SchNet-example
dataset:
  normalize_labels: false
  src: /home/jovyan/projects/ocp/data/s2ef/200k/train-demo
logger: tensorboard
logger_entity: null
logger_project: null
model: schnet
model_attributes:
  cutoff: 6.0
  hidden_channels: 1024
  num_filters: 256
  num_gaussians: 200
  num_interactions: 3
optim:
  batch_size: 16
  eval_batch_size: 8
  force_coefficient: 100
  lr_gamma: 0.1
  lr_initial: 0.0001
  lr_milestones:
  - 15
  - 20
  max_epochs: 1
  num_workers: 64
  warmup_epochs: 10
  warmup_factor: 0.2
task:
  dataset: trajectory_lmdb
  description: Regressing to energies and forces for DFT trajectories from OCP
  eval_on_free_atoms: true
  grad_input: atomic forces
  labels:
  - potential e

## Check the model

In [None]:
print(trainer.model)

OCPDataParallel(
  (module): SchNet(hidden_channels=1024, num_filters=256, num_interactions=3, num_gaussians=200, cutoff=6.0)
)


## Train

In [None]:
trainer.train()

forcesx_mae: 0.7450, forcesy_mae: 0.8024, forcesz_mae: 0.8352, forces_mae: 0.7942, forces_cos: 0.0523, forces_magnitude: 1.4755, energy_mae: 34.4331, energy_force_within_threshold: 0.0000, loss: 113.8546, epoch: 0.0003
forcesx_mae: 0.5871, forcesy_mae: 0.6545, forcesz_mae: 0.6278, forces_mae: 0.6232, forces_cos: 0.0263, forces_magnitude: 1.1090, energy_mae: 32.7792, energy_force_within_threshold: 0.0000, loss: 95.0260, epoch: 0.0035
forcesx_mae: 0.4987, forcesy_mae: 0.5432, forcesz_mae: 0.5586, forces_mae: 0.5335, forces_cos: 0.0030, forces_magnitude: 0.9548, energy_mae: 23.5220, energy_force_within_threshold: 0.0000, loss: 76.8902, epoch: 0.0067
forcesx_mae: 0.3893, forcesy_mae: 0.4586, forcesz_mae: 0.4432, forces_mae: 0.4303, forces_cos: -0.0035, forces_magnitude: 0.7461, energy_mae: 19.4026, energy_force_within_threshold: 0.0000, loss: 62.5090, epoch: 0.0099
forcesx_mae: 0.3424, forcesy_mae: 0.4016, forcesz_mae: 0.4249, forces_mae: 0.3896, forces_cos: 0.0140, forces_magnitude: 0.654

forcesx_mae: 0.1572, forcesy_mae: 0.1844, forcesz_mae: 0.1891, forces_mae: 0.1769, forces_cos: 0.0481, forces_magnitude: 0.2624, energy_mae: 3.4449, energy_force_within_threshold: 0.0000, loss: 21.1611, epoch: 0.1219
forcesx_mae: 0.1572, forcesy_mae: 0.1776, forcesz_mae: 0.1807, forces_mae: 0.1719, forces_cos: 0.0335, forces_magnitude: 0.2534, energy_mae: 3.0279, energy_force_within_threshold: 0.0000, loss: 20.1898, epoch: 0.1251
forcesx_mae: 0.1717, forcesy_mae: 0.1934, forcesz_mae: 0.2013, forces_mae: 0.1888, forces_cos: 0.0335, forces_magnitude: 0.2845, energy_mae: 3.6823, energy_force_within_threshold: 0.0000, loss: 22.6440, epoch: 0.1283
forcesx_mae: 0.1627, forcesy_mae: 0.1934, forcesz_mae: 0.2017, forces_mae: 0.1859, forces_cos: 0.0414, forces_magnitude: 0.2795, energy_mae: 3.6687, energy_force_within_threshold: 0.0000, loss: 22.2938, epoch: 0.1315
forcesx_mae: 0.1689, forcesy_mae: 0.2036, forcesz_mae: 0.1978, forces_mae: 0.1901, forces_cos: 0.0374, forces_magnitude: 0.2884, ene

forcesx_mae: 0.1680, forcesy_mae: 0.1768, forcesz_mae: 0.1716, forces_mae: 0.1721, forces_cos: 0.0381, forces_magnitude: 0.2594, energy_mae: 3.2181, energy_force_within_threshold: 0.0000, loss: 20.5549, epoch: 0.2435
forcesx_mae: 0.1275, forcesy_mae: 0.1469, forcesz_mae: 0.1566, forces_mae: 0.1437, forces_cos: 0.0293, forces_magnitude: 0.2094, energy_mae: 3.1564, energy_force_within_threshold: 0.0000, loss: 17.5119, epoch: 0.2467
forcesx_mae: 0.1334, forcesy_mae: 0.1614, forcesz_mae: 0.1614, forces_mae: 0.1521, forces_cos: 0.0446, forces_magnitude: 0.2214, energy_mae: 3.0059, energy_force_within_threshold: 0.0000, loss: 18.1482, epoch: 0.2499
forcesx_mae: 0.1482, forcesy_mae: 0.1669, forcesz_mae: 0.1801, forces_mae: 0.1650, forces_cos: 0.0295, forces_magnitude: 0.2427, energy_mae: 3.1646, energy_force_within_threshold: 0.0000, loss: 19.4738, epoch: 0.2531
forcesx_mae: 0.1390, forcesy_mae: 0.1679, forcesz_mae: 0.1772, forces_mae: 0.1614, forces_cos: 0.0364, forces_magnitude: 0.2357, ene

forcesx_mae: 0.1246, forcesy_mae: 0.1478, forcesz_mae: 0.1496, forces_mae: 0.1406, forces_cos: 0.0297, forces_magnitude: 0.2039, energy_mae: 3.5981, energy_force_within_threshold: 0.0000, loss: 17.6697, epoch: 0.3651
forcesx_mae: 0.1079, forcesy_mae: 0.1219, forcesz_mae: 0.1235, forces_mae: 0.1178, forces_cos: 0.0413, forces_magnitude: 0.1629, energy_mae: 3.3930, energy_force_within_threshold: 0.0000, loss: 15.1729, epoch: 0.3683
forcesx_mae: 0.1275, forcesy_mae: 0.1389, forcesz_mae: 0.1473, forces_mae: 0.1379, forces_cos: 0.0430, forces_magnitude: 0.1889, energy_mae: 3.6207, energy_force_within_threshold: 0.0000, loss: 17.3980, epoch: 0.3715
forcesx_mae: 0.1220, forcesy_mae: 0.1508, forcesz_mae: 0.1493, forces_mae: 0.1407, forces_cos: 0.0458, forces_magnitude: 0.1933, energy_mae: 3.4300, energy_force_within_threshold: 0.0000, loss: 17.3409, epoch: 0.3747
forcesx_mae: 0.1156, forcesy_mae: 0.1382, forcesz_mae: 0.1418, forces_mae: 0.1319, forces_cos: 0.0160, forces_magnitude: 0.1881, ene

forcesx_mae: 0.1069, forcesy_mae: 0.1163, forcesz_mae: 0.1190, forces_mae: 0.1141, forces_cos: 0.0367, forces_magnitude: 0.1587, energy_mae: 3.4257, energy_force_within_threshold: 0.0000, loss: 14.8809, epoch: 0.4867
forcesx_mae: 0.1068, forcesy_mae: 0.1461, forcesz_mae: 0.1417, forces_mae: 0.1316, forces_cos: 0.0375, forces_magnitude: 0.2006, energy_mae: 3.5625, energy_force_within_threshold: 0.0000, loss: 16.7261, epoch: 0.4899
forcesx_mae: 0.0971, forcesy_mae: 0.1192, forcesz_mae: 0.1329, forces_mae: 0.1164, forces_cos: 0.0240, forces_magnitude: 0.1689, energy_mae: 2.8136, energy_force_within_threshold: 0.0000, loss: 14.4461, epoch: 0.4931
forcesx_mae: 0.1010, forcesy_mae: 0.1302, forcesz_mae: 0.1387, forces_mae: 0.1233, forces_cos: 0.0503, forces_magnitude: 0.1742, energy_mae: 3.0475, energy_force_within_threshold: 0.0000, loss: 15.4902, epoch: 0.4963
forcesx_mae: 0.1098, forcesy_mae: 0.1340, forcesz_mae: 0.1413, forces_mae: 0.1284, forces_cos: 0.0480, forces_magnitude: 0.1877, ene

forcesx_mae: 0.1106, forcesy_mae: 0.1484, forcesz_mae: 0.1428, forces_mae: 0.1340, forces_cos: 0.0514, forces_magnitude: 0.1936, energy_mae: 3.0096, energy_force_within_threshold: 0.0000, loss: 16.3005, epoch: 0.6083
forcesx_mae: 0.0845, forcesy_mae: 0.1040, forcesz_mae: 0.1102, forces_mae: 0.0996, forces_cos: 0.0373, forces_magnitude: 0.1400, energy_mae: 3.1089, energy_force_within_threshold: 0.0000, loss: 13.0339, epoch: 0.6115
forcesx_mae: 0.1051, forcesy_mae: 0.1203, forcesz_mae: 0.1342, forces_mae: 0.1199, forces_cos: 0.0336, forces_magnitude: 0.1762, energy_mae: 3.4943, energy_force_within_threshold: 0.0000, loss: 15.4412, epoch: 0.6147
forcesx_mae: 0.0999, forcesy_mae: 0.1182, forcesz_mae: 0.1298, forces_mae: 0.1160, forces_cos: 0.0763, forces_magnitude: 0.1685, energy_mae: 3.3816, energy_force_within_threshold: 0.0000, loss: 14.9078, epoch: 0.6179
forcesx_mae: 0.0994, forcesy_mae: 0.1274, forcesz_mae: 0.1265, forces_mae: 0.1178, forces_cos: 0.0459, forces_magnitude: 0.1759, ene

forcesx_mae: 0.0885, forcesy_mae: 0.0954, forcesz_mae: 0.0963, forces_mae: 0.0934, forces_cos: 0.0425, forces_magnitude: 0.1254, energy_mae: 2.9841, energy_force_within_threshold: 0.0000, loss: 12.3427, epoch: 0.7299
forcesx_mae: 0.1132, forcesy_mae: 0.1279, forcesz_mae: 0.1598, forces_mae: 0.1336, forces_cos: 0.0574, forces_magnitude: 0.2056, energy_mae: 3.6703, energy_force_within_threshold: 0.0000, loss: 17.1168, epoch: 0.7331
forcesx_mae: 0.0969, forcesy_mae: 0.1117, forcesz_mae: 0.1281, forces_mae: 0.1123, forces_cos: 0.0467, forces_magnitude: 0.1594, energy_mae: 3.2045, energy_force_within_threshold: 0.0000, loss: 14.4537, epoch: 0.7363
forcesx_mae: 0.0985, forcesy_mae: 0.1093, forcesz_mae: 0.1228, forces_mae: 0.1102, forces_cos: 0.0337, forces_magnitude: 0.1560, energy_mae: 3.3397, energy_force_within_threshold: 0.0000, loss: 14.2750, epoch: 0.7395
forcesx_mae: 0.0891, forcesy_mae: 0.1040, forcesz_mae: 0.0997, forces_mae: 0.0976, forces_cos: 0.0513, forces_magnitude: 0.1336, ene

forcesx_mae: 0.0878, forcesy_mae: 0.1075, forcesz_mae: 0.1112, forces_mae: 0.1022, forces_cos: 0.0537, forces_magnitude: 0.1420, energy_mae: 2.7665, energy_force_within_threshold: 0.0000, loss: 13.0340, epoch: 0.8515
forcesx_mae: 0.0804, forcesy_mae: 0.1101, forcesz_mae: 0.1121, forces_mae: 0.1008, forces_cos: 0.0526, forces_magnitude: 0.1497, energy_mae: 3.0201, energy_force_within_threshold: 0.0000, loss: 13.0701, epoch: 0.8547
forcesx_mae: 0.1130, forcesy_mae: 0.1454, forcesz_mae: 0.1515, forces_mae: 0.1366, forces_cos: 0.0413, forces_magnitude: 0.2134, energy_mae: 3.7423, energy_force_within_threshold: 0.0000, loss: 17.4133, epoch: 0.8579
forcesx_mae: 0.1162, forcesy_mae: 0.1405, forcesz_mae: 0.1485, forces_mae: 0.1351, forces_cos: 0.0717, forces_magnitude: 0.2054, energy_mae: 3.0387, energy_force_within_threshold: 0.0000, loss: 16.6642, epoch: 0.8611
forcesx_mae: 0.0963, forcesy_mae: 0.1400, forcesz_mae: 0.1270, forces_mae: 0.1211, forces_cos: 0.0444, forces_magnitude: 0.1858, ene

forcesx_mae: 0.0903, forcesy_mae: 0.1047, forcesz_mae: 0.1131, forces_mae: 0.1027, forces_cos: 0.0740, forces_magnitude: 0.1434, energy_mae: 2.8340, energy_force_within_threshold: 0.0000, loss: 13.0621, epoch: 0.9731
forcesx_mae: 0.1015, forcesy_mae: 0.1113, forcesz_mae: 0.1251, forces_mae: 0.1126, forces_cos: 0.0519, forces_magnitude: 0.1649, energy_mae: 2.8163, energy_force_within_threshold: 0.0000, loss: 14.0856, epoch: 0.9763
forcesx_mae: 0.0894, forcesy_mae: 0.1152, forcesz_mae: 0.1167, forces_mae: 0.1071, forces_cos: 0.0416, forces_magnitude: 0.1568, energy_mae: 3.3290, energy_force_within_threshold: 0.0000, loss: 14.1662, epoch: 0.9795
forcesx_mae: 0.0756, forcesy_mae: 0.0945, forcesz_mae: 0.0963, forces_mae: 0.0888, forces_cos: 0.0476, forces_magnitude: 0.1240, energy_mae: 2.5540, energy_force_within_threshold: 0.0000, loss: 11.4540, epoch: 0.9827
forcesx_mae: 0.0855, forcesy_mae: 0.1059, forcesz_mae: 0.1267, forces_mae: 0.1061, forces_cos: 0.0665, forces_magnitude: 0.1618, ene

device 0: 100%|██████████| 6250/6250 [02:01<00:00, 51.41it/s]

forcesx_mae: 0.0894, forcesy_mae: 0.1088, forcesz_mae: 0.1151, forces_mae: 0.1044, forces_cos: 0.0610, forces_magnitude: 0.1540, energy_mae: 3.2270, energy_force_within_threshold: 0.0000, loss: 13.7160, epoch: 1.0000
### Predicting on test.



device 0: 100%|██████████| 6250/6250 [01:53<00:00, 55.20it/s]


Writing results to results/2021-01-06-16-49-04-SchNet-example/s2ef_predictions.npz


### Load Checkpoint
Once training has completed a `Trainer` class, by default, is loaded with the best checkpoint as determined by training or validation (if available) metrics. To load a `Trainer` class directly with a pretrained model, specify the `checkpoint_path` as defined by your previously trained model (`checkpoint_dir` in cell 9):

In [None]:
model = {
    'name': 'schnet',
    'hidden_channels': 1024, # if training is too slow for example purposes reduce the number of hidden channels
    'num_filters': 256,
    'num_interactions': 3,
    'num_gaussians': 200,
    'cutoff': 6.0
}

pretrained_trainer = ForcesTrainer(
    task=task,
    model=model,
    dataset=dataset,
    optimizer=optimizer,
    identifier="SchNet-example",
    run_dir="./", # directory to save results if is_debug=False. Prediction files are saved here so be careful not to override!
    is_debug=False, # if True, do not save checkpoint, logs, or results
    is_vis=False,
    print_every=10,
    seed=0, # random seed to use
    logger="tensorboard", # logger of choice (tensorboard and wandb supported)
    local_rank=0,
    amp=False, # use PyTorch Automatic Mixed Precision (faster training and less memory usage)
)

pretrained_trainer.load_pretrained(checkpoint_path="checkpoints/2021-01-06-16-49-04-SchNet-example/checkpoint.pt")

amp: false
cmd:
  checkpoint_dir: checkpoints/2021-01-06-17-23-12-SchNet-example
  identifier: SchNet-example
  logs_dir: logs/tensorboard/2021-01-06-17-23-12-SchNet-example
  print_every: 10
  results_dir: results/2021-01-06-17-23-12-SchNet-example
  seed: 0
  timestamp: 2021-01-06-17-23-12-SchNet-example
dataset:
  normalize_labels: false
  src: /home/jovyan/projects/ocp/data/s2ef/200k/train-demo
logger: tensorboard
logger_entity: null
logger_project: null
model: schnet
model_attributes:
  cutoff: 6.0
  hidden_channels: 1024
  num_filters: 256
  num_gaussians: 200
  num_interactions: 3
optim:
  batch_size: 16
  eval_batch_size: 8
  force_coefficient: 100
  lr_gamma: 0.1
  lr_initial: 0.0001
  lr_milestones:
  - 15
  - 20
  max_epochs: 1
  num_workers: 64
  warmup_epochs: 10
  warmup_factor: 0.2
task:
  dataset: trajectory_lmdb
  description: Regressing to energies and forces for DFT trajectories from OCP
  eval_on_free_atoms: true
  grad_input: atomic forces
  labels:
  - potential e

True

## Predict

If a test has been provided in your config, predictions are generated and written to disk automatically upon training completion. Otherwise, to make predictions on unseen data a `torch.utils.data` DataLoader object must be constructed. Here we reference our test set to make predictions on. Predictions are saved in `{results_file}.npz` in your `results_dir`.

In [None]:
# make predictions on the existing test_loader
predictions = pretrained_trainer.predict(pretrained_trainer.test_loader, results_file="s2ef_results", disable_tqdm=False)

### Predicting on test.


device 0: 100%|██████████| 6250/6250 [01:53<00:00, 55.28it/s]


Writing results to results/2021-01-06-17-23-12-SchNet-example/s2ef_s2ef_results.npz


In [None]:
energies = predictions["energy"]
forces = predictions["forces"]