In [None]:
%%capture
try:
    import google.colab  # noqa: F401

    IN_COLAB = True
except ImportError:
    IN_COLAB = False

if IN_COLAB:
    %%capture
    from torch import __version__ as TORCH_VERSION

    print(f"{TORCH_VERSION=}")

    !pip install -U git+https://github.com/CompRhys/aviary.git  # install aviary
    !wget -O taata.json.gz https://figshare.com/ndownloader/files/34423997

In [None]:
import gzip
import json

import pandas as pd
import torch
from pymatgen.analysis.prototypes import (
    count_wyckoff_positions,
    get_protostructure_label_from_spglib,
)
from pymatgen.core import Structure
from sklearn.model_selection import train_test_split as split
from torch.utils.data import DataLoader

from aviary.roost.data import CompositionData
from aviary.roost.data import collate_batch as roost_cb
from aviary.roost.model import Roost
from aviary.utils import results_multitask, train_ensemble

In [None]:
with gzip.open("taata.json.gz", "r") as fin:
    json_bytes = fin.read()

json_str = json_bytes.decode("utf-8")
data = json.loads(json_str)

df = pd.DataFrame(data["data"], columns=data["columns"])

df["final_structure"] = [Structure.from_dict(x) for x in df.final_structure]

df["composition"] = [x.composition.reduced_formula for x in df.final_structure]
df["volume_per_atom"] = [x.volume / len(x) for x in df.final_structure]
df["protostructure"] = df["final_structure"].map(get_protostructure_label_from_spglib)

df = df[df.protostructure.map(count_wyckoff_positions) < 16]
df["n_sites"] = df.final_structure.map(len)
df = df[df.n_sites < 64]
df = df[df.volume_per_atom < 500]

# NOTE for roost we keep only the lowest lying structures for each composition
df = df.sort_values(["composition", "E_vasp_per_atom"]).drop_duplicates(
    "composition", keep="first"
)

spglib: ssm_get_exact_positions failed.
spglib: get_bravais_exact_positions_and_lattice failed.
spglib: ssm_get_exact_positions failed.
spglib: get_bravais_exact_positions_and_lattice failed.
spglib: ssm_get_exact_positions failed.
spglib: get_bravais_exact_positions_and_lattice failed.
spglib: ssm_get_exact_positions failed.
spglib: get_bravais_exact_positions_and_lattice failed.
spglib: ssm_get_exact_positions failed.
spglib: get_bravais_exact_positions_and_lattice failed.
spglib: ssm_get_exact_positions failed.
spglib: get_bravais_exact_positions_and_lattice failed.
spglib: ssm_get_exact_positions failed.
spglib: get_bravais_exact_positions_and_lattice failed.
spglib: ssm_get_exact_positions failed.
spglib: get_bravais_exact_positions_and_lattice failed.
spglib: ssm_get_exact_positions failed.
spglib: get_bravais_exact_positions_and_lattice failed.
spglib: ssm_get_exact_positions failed.
spglib: get_bravais_exact_positions_and_lattice failed.
spglib: ssm_get_exact_positions failed.


In [None]:
resume = False
fine_tune = None
transfer = None

optim = "AdamW"
learning_rate = 3e-4
momentum = 0.9
weight_decay = 1e-6
batch_size = 128
workers = 0
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

targets = ["E_vasp_per_atom"]
tasks = ["regression"]
losses = ["L1"]
robust = True

data_seed = 42
test_size = 0.2
sample = 1

ensemble = 1
run_id = 1
epochs = 100
log = False

# NOTE setting workers to zero means that the data is loaded in the main
# process and enables caching

data_params = {
    "batch_size": batch_size,
    "num_workers": workers,
    "pin_memory": False,
    "shuffle": True,
}

setup_params = {
    "optim": optim,
    "learning_rate": learning_rate,
    "weight_decay": weight_decay,
    "momentum": momentum,
    "device": device,
}

restart_params = {
    "resume": resume,
    "fine_tune": fine_tune,
    "transfer": transfer,
}

task_dict = dict(zip(targets, tasks, strict=False))
loss_dict = dict(zip(targets, losses, strict=False))

In [None]:
torch.manual_seed(0)  # ensure reproducible results

elem_embedding = "matscholar200"
model_name = "roost-reg-test"

data_params["collate_fn"] = roost_cb
data_params["shuffle"] = True

dataset = CompositionData(
    df=df,
    task_dict=task_dict,
)
n_targets = dataset.n_targets

train_idx = list(range(len(dataset)))

print(f"using {test_size} of training set as test set")
train_idx, test_idx = split(train_idx, random_state=data_seed, test_size=test_size)
test_set = torch.utils.data.Subset(dataset, test_idx)

print("No validation set used, using test set for evaluation purposes")
# NOTE that when using this option care must be taken not to
# peak at the test-set. The only valid model to use is the one
# obtained after the final epoch where the epoch count is
# decided in advance of the experiment.
val_set = test_set

train_set = torch.utils.data.Subset(dataset, train_idx[0::sample])

train_loader = DataLoader(train_set, **data_params)
val_loader = DataLoader(
    val_set,
    **{**data_params, "batch_size": 16 * data_params["batch_size"], "shuffle": False},
)

model_params = {
    "task_dict": task_dict,
    "robust": robust,
    "n_targets": n_targets,
    "elem_embedding": elem_embedding,
    "elem_fea_len": 64,
    "n_graph": 3,
    "elem_heads": 3,
    "elem_gate": [256],
    "elem_msg": [256],
    "cry_heads": 3,
    "cry_gate": [256],
    "cry_msg": [256],
    "trunk_hidden": [128, 128],
    "out_hidden": [64, 64],
}

train_ensemble(
    model_class=Roost,
    model_name=model_name,
    run_id=run_id,
    ensemble_folds=ensemble,
    epochs=epochs,
    train_loader=train_loader,
    val_loader=val_loader,
    log=log,
    setup_params=setup_params,
    restart_params=restart_params,
    model_params=model_params,
    loss_dict=loss_dict,
)

using 0.2 of training set as test set
No validation set used, using test set for evaluation purposes
Total Number of Trainable Parameters: 973,777
Dummy MAE: 1.2757
Epoch: [0/99]
    train: E_vasp_per_atom N 10 MAE 1.27      Loss 1.12      RMSE 1.59     
 evaluate: E_vasp_per_atom N 1 MAE 1.29      Loss 1.13      RMSE 1.59     
Epoch: [1/99]
    train: E_vasp_per_atom N 10 MAE 1.25      Loss 1.11      RMSE 1.59     
 evaluate: E_vasp_per_atom N 1 MAE 1.25      Loss 1.10      RMSE 1.55     
Epoch: [2/99]
    train: E_vasp_per_atom N 10 MAE 1.17      Loss 1.03      RMSE 1.50     
 evaluate: E_vasp_per_atom N 1 MAE 0.98      Loss 0.87      RMSE 1.30     
Epoch: [3/99]
    train: E_vasp_per_atom N 10 MAE 0.85      Loss 0.74      RMSE 1.23     
 evaluate: E_vasp_per_atom N 1 MAE 0.62      Loss 0.49      RMSE 0.87     
Epoch: [4/99]
    train: E_vasp_per_atom N 10 MAE 0.59      Loss 0.40      RMSE 0.89     
 evaluate: E_vasp_per_atom N 1 MAE 0.44      Loss 0.13      RMSE 0.59     
Epoch: [5/

TypeError: results_multitask() got an unexpected keyword argument 'test_set'

In [None]:
test_loader = DataLoader(
    test_set,
    **{**data_params, "batch_size": 64 * data_params["batch_size"], "shuffle": False},
)

roost_results_dict = results_multitask(
    model_class=Roost,
    model_name=model_name,
    run_id=run_id,
    ensemble_folds=ensemble,
    test_loader=test_loader,
    robust=robust,
    task_dict=task_dict,
    device=device,
    eval_type="checkpoint",
    save_results=False,
)

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
------------Evaluate model on Test Set------------
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Evaluating Model

Task: target_name='E_vasp_per_atom' on test set
Model Performance Metrics:
R2 Score: 0.9494 
MAE: 0.2701
RMSE: 0.3576
