In [None]:
%%capture
from torch import __version__ as TORCH_VERSION
!pip install pymatgen pybtex retrying  # install requirements to query sample data
!pip install torch-scatter -f https://data.pyg.org/whl/torch-{TORCH_VERSION}.html  # install torch scatter for aviary
!pip install -U git+https://github.com/CompRhys/aviary.git  # install aviary

In [None]:
import torch
import pandas as pd

from pymatgen.ext.optimade import OptimadeRester

from sklearn.model_selection import train_test_split as split

from aviary.utils import results_multitask, train_ensemble
from aviary.wren.utils import get_aflow_label_spglib, count_wyks

from aviary.roost.model import Roost
from aviary.roost.data import CompositionData, collate_batch as roost_cb

In [None]:
opt = OptimadeRester(["mp",])
results = opt.get_structures(elements=["N"], nelements=2)

In [None]:

df = pd.DataFrame(results)

df.rename(columns={"mp": "final_structure"}, inplace=True)
df["composition"] = df["final_structure"].apply(lambda x: x.composition.reduced_formula)
df["volume_per_atom"] = df["final_structure"].apply(lambda x: x.volume/len(x))
df["wyckoff"] = df["final_structure"].apply(get_aflow_label_spglib)

df = df[df["wyckoff"].apply(count_wyks) < 16]
df = df[df["final_structure"].apply(len) < 64]
df = df[df["volume_per_atom"] < 500]

df.index.name = "material_id"
df.reset_index(inplace=True)

In [None]:
resume = False
fine_tune = None
transfer = None

optim = "AdamW"
learning_rate = 3e-4
momentum = 0.9
weight_decay = 1e-6
batch_size = 128
workers = 0  
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

targets = ["volume_per_atom"]
tasks = ["regression"]
losses = ["L1"]
robust = True

data_seed = 42
test_size = 0.2
sample = 1

ensemble = 1
run_id = 1
epochs = 100
log = False

# NOTE setting workers to zero means that the data is loaded in the main
# process and enables caching

data_params = {
    "batch_size": batch_size,
    "num_workers": workers,
    "pin_memory": False,
    "shuffle": True,
}

setup_params = {
    "optim": optim,
    "learning_rate": learning_rate,
    "weight_decay": weight_decay,
    "momentum": momentum,
    "device": device,
}

restart_params = {
    "resume": resume,
    "fine_tune": fine_tune,
    "transfer": transfer,
}

task_dict = dict(zip(targets, tasks))
loss_dict = dict(zip(targets, losses))

In [None]:
torch.manual_seed(0)  # ensure reproducible results

elem_emb = "matscholar200"
model_name = "roost-reg-test"

data_params["collate_fn"] = roost_cb
data_params["shuffle"] = True 

dataset = CompositionData(
    df=df,
    elem_emb=elem_emb,
    task_dict=task_dict, 
)
n_targets = dataset.n_targets
elem_emb_len = dataset.elem_emb_len

train_idx = list(range(len(dataset)))

print(f"using {test_size} of training set as test set")
train_idx, test_idx = split(train_idx, random_state=data_seed, test_size=test_size)
test_set = torch.utils.data.Subset(dataset, test_idx)

print("No validation set used, using test set for evaluation purposes")
# NOTE that when using this option care must be taken not to
# peak at the test-set. The only valid model to use is the one
# obtained after the final epoch where the epoch count is
# decided in advance of the experiment.
val_set = test_set

train_set = torch.utils.data.Subset(dataset, train_idx[0::sample])

model_params = {
    "task_dict": task_dict,
    "robust": robust,
    "n_targets": n_targets,
    "elem_emb_len": elem_emb_len,
    "elem_fea_len": 64,
    "n_graph": 3,
    "elem_heads": 3,
    "elem_gate": [256],
    "elem_msg": [256],
    "cry_heads": 3,
    "cry_gate": [256],
    "cry_msg": [256],
    "trunk_hidden": [128, 128],
    "out_hidden": [64, 64],
}

train_ensemble(
    model_class=Roost,
    model_name=model_name,
    run_id=run_id,
    ensemble_folds=ensemble,
    epochs=epochs,
    train_set=train_set,
    val_set=val_set,
    log=log,
    data_params=data_params,
    setup_params=setup_params,
    restart_params=restart_params,
    model_params=model_params,
    loss_dict=loss_dict,
)

data_params["shuffle"] = False  # need fixed data order due to ensembling

roost_results_dict = results_multitask(
    model_class=Roost,
    model_name=model_name,
    run_id=run_id,
    ensemble_folds=ensemble,
    test_set=test_set,
    data_params=data_params,
    robust=robust,
    task_dict=task_dict,
    device=device,
    eval_type="checkpoint",
    save_results=False,
)
