# Model Validation / Testing on OE62 from Checkpoint

In [None]:
import torch
import torch_geometric
import logging
from pathlib import Path
from tqdm import tqdm
import os

from ocpmodels import models
from ocpmodels.common import logger
from ocpmodels.common.utils import setup_logging, load_config
from ocpmodels.datasets import LmdbDataset
from ocpmodels.common.registry import registry
from ocpmodels.trainers import EnergyTrainer, ForcesTrainer


setup_logging()

%load_ext autoreload
%autoreload 2

### Define model variant by choosing a config file
For each model, the following variants exist: **baseline**, **variant with Ewald message passing**, **increased cutoff** and **increased embedding size**.

Configs to choose from: 
- schnet_oe62_baseline.yml
- schnet_oe62_ewald.yml
- schnet_oe62_cutoff.yml
- schnet_oe62_embeddings.yml
----------------------------
- painn_oe62_baseline.yml
- painn_oe62_ewald.yml
- painn_oe62_cutoff.yml
- painn_oe62_embeddings.yml
----------------------------
- dpp_oe62_baseline.yml
- dpp_oe62_ewald.yml
- dpp_oe62_cutoff.yml
- dpp_oe62_embeddings.yml
----------------------------
- gemnet_oe62_baseline.yml
- gemnet_oe62_ewald.yml
- gemnet_oe62_cutoff.yml
- gemnet_oe62_embeddings.yml
----------------------------

In [None]:
config_dir = "configs_runtime_oe62"
#-----------Put your model variant here-----------
config_path = os.path.join(config_dir, "schnet_oc20_baseline.yml")

### Parse config file and initialize `EnergyTrainer` object for OE62

In [None]:
torch.cuda.empty_cache()
conf = load_config(config_path)[0]
task = conf["fixed"]["task"]
model = conf["fixed"]["model"]
optimizer = conf["fixed"]["optimizer"]
name = conf["fixed"]["name"]
logger = conf["fixed"]["logger"]
dataset = conf["fixed"]["dataset"]
trainer = EnergyTrainer(
    task=task,
    model=model,
    dataset=dataset,
    optimizer=optimizer,
    identifier=name,
    run_dir="./",
    is_debug=True,  # if True, do not save checkpoint, logs, or results
    print_every=5000,
    seed=0,  # random seed to use
    logger=logger,  # logger of choice (tensorboard and wandb supported)
    local_rank=0,
    amp=False,  # whether to use PyTorch Automatic Mixed Precision
)

### Load checkpoint file

After training your model (using the provided `seml` commands, or the training notebook from this repository, paste the path to your checkpoint file below.

In [None]:
checkpoint_dir = "checkpoints"
checkpoint_path = os.path.join(
    checkpoint_dir,
    "[your_checkpoint_dir]",
    "best_checkpoint.pt")
trainer.load_checkpoint(checkpoint_path=checkpoint_path)

### Validate or test model
Replace the argument below by `split="val"` to use the OE62 validation split instead.

In [None]:
metrics = trainer.validate(split="test")
results = {key: val["metric"] for key, val in metrics.items()}
print(f"Results for configuration {name}: {results}")

# Model Validation on OC20 from Checkpoint

On OC20, only validation can be done locally. To generate results on the test set, follow the instructions on https://github.com/Open-Catalyst-Project/ocp to obtain files for submission on eval.ai.

### Define model variant by choosing a config file
For each model, the following variants exist: **baseline**, **variant with Ewald message passing**, **increased cutoff**.

Configs to choose from: 
- schnet_oc20_baseline.yml
- schnet_oc20_ewald.yml
- schnet_oc20_cutoff.yml
----------------------------
- painn_oc20_baseline.yml
- painn_oc20_ewald.yml
- painn_oc20_cutoff.yml
----------------------------
- dpp_oc20_baseline.yml
- dpp_oc20_ewald.yml
- dpp_oc20_cutoff.yml
----------------------------
- gemnet_oc20_baseline.yml
- gemnet_oc20_ewald.yml
- gemnet_oc20_cutoff.yml
----------------------------

In [None]:
config_dir = "configs_oc20"
#-----------Put your model variant here-----------
config_path = os.path.join(config_dir,"schnet_oc20_baseline.yml")

### Parse config file and initialize `ForcesTrainer` object for OC20

In [None]:
for split in ["id", "ood_ads", "ood_cat", "ood_both"]:
    torch.cuda.empty_cache()
    conf = load_config(config_path)[0]
    task = conf["fixed"]["task"]
    model = conf["fixed"]["model"]
    optimizer = conf["fixed"]["optimizer"]
    name = conf["fixed"]["name"]
    logger = conf["fixed"]["logger"]
    # Replace dataset_train by dataset_id, dataset_ood_ads, dataset_ood_both or
    # dataset_ood_cat to validate only on a particular subsplit (note that the
    # validation set subsampling option is currently just available on the
    # combination of all four splits, specified by putting dataset_train below)
    dataset = conf["fixed"][f"dataset_{split}"]
    trainer = ForcesTrainer(
        task=task,
        model=model,
        dataset=dataset,
        optimizer=optimizer,
        identifier=f"{name}_{split}_test1",
        run_dir="./",
        is_debug=False,  # if True, do not save checkpoint, logs, or results
        print_every=5000,
        seed=0,  # random seed to use
        logger=logger,  # logger of choice (tensorboard and wandb supported)
        local_rank=0,
        amp=False,  # whether to use PyTorch Automatic Mixed Precision
    )
    
    trainer.train_dataset.close_db()
    if trainer.config.get("val_dataset", False):
        trainer.val_dataset.close_db()
    if trainer.config.get("test_dataset", False):
        trainer.test_dataset.close_db()

### Load checkpoint file

After training your model by using the provided `seml` commands, paste the path to your checkpoint file below.

In [None]:
checkpoint_dir = "checkpoints"
checkpoint_path = os.path.join(
    checkpoint_dir,
    "[your_checkpoint_dir]",
    "best_checkpoint.pt")
trainer.load_checkpoint(checkpoint_path=checkpoint_path)

### Validate model
The setting below validates on a 1% subsample of all validation structures, drawn evenly from all four splits.
Replace the argument below by `split="val"` to use the full validation set (all four splits) instead (as this takes 100x as long, we recommend doing it overnight)

In [None]:
metrics = trainer.validate(split="val_sub") #put split="val" for full validation set
results = {key: val["metric"] for key, val in metrics.items()}
print(f"Results for configuration {name}: {results}")