In [None]:
%%capture
from torch import __version__ as TORCH_VERSION
!pip install matminer # install requirements to query sample data
!pip install torch-scatter -f https://data.pyg.org/whl/torch-{TORCH_VERSION}.html  # install torch scatter for aviary
!pip install -U git+https://github.com/CompRhys/aviary.git  # install aviary
!wget -O taata.json.gz https://figshare.com/ndownloader/files/34423997

In [None]:
import torch
import pandas as pd

from sklearn.model_selection import train_test_split as split
from pymatgen.core import Structure

from aviary.utils import results_multitask, train_ensemble
from aviary.wren.utils import get_aflow_label_from_spglib, count_wyckoff_positions

from aviary.wren.model import Wren
from aviary.wren.data import WyckoffData, collate_batch as wren_cb


In [None]:
df = pd.read_json("taata.json.gz")
df["final_structure"] = [Structure.from_dict(x) for x in df.final_structure]

df["composition"] = [x.composition.reduced_formula for x in df.final_structure]
df["volume_per_atom"] = [x.volume/len(x) for x in df.final_structure]
df["wyckoff"] = df["final_structure"].map(get_aflow_label_from_spglib)

df = df[df.wyckoff.map(count_wyckoff_positions) < 16]
df["n_sites"] = df.final_structure.map(len)
df = df[df.n_sites < 64]
df = df[df.volume_per_atom < 500]

In [None]:
resume = False
fine_tune = None
transfer = None

optim = "AdamW"
learning_rate = 3e-4
momentum = 0.9
weight_decay = 1e-6
batch_size = 128
workers = 0  
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

targets = ["E_vasp_per_atom"]
tasks = ["regression"]
losses = ["L1"]
robust = True

data_seed = 42
test_size = 0.2
sample = 1

ensemble = 1
run_id = 1
epochs = 100
log = False

# NOTE setting workers to zero means that the data is loaded in the main
# process and enables caching

data_params = {
    "batch_size": batch_size,
    "num_workers": workers,
    "pin_memory": False,
    "shuffle": True,
}

setup_params = {
    "optim": optim,
    "learning_rate": learning_rate,
    "weight_decay": weight_decay,
    "momentum": momentum,
    "device": device,
}

restart_params = {
    "resume": resume,
    "fine_tune": fine_tune,
    "transfer": transfer,
}

task_dict = dict(zip(targets, tasks))
loss_dict = dict(zip(targets, losses))

In [None]:
torch.manual_seed(0)  # ensure reproducible results

elem_embedding = "matscholar200"
sym_emb = "bra-alg-off"
model_name = "wren-reg-test"

data_params["collate_fn"] = wren_cb
data_params["shuffle"] = True 

dataset = WyckoffData(
    df=df,
    elem_embedding=elem_embedding,
    sym_emb=sym_emb,
    task_dict=task_dict
)
n_targets = dataset.n_targets
elem_emb_len = dataset.elem_emb_len
sym_emb_len = dataset.sym_emb_len

train_idx = list(range(len(dataset)))

print(f"using {test_size} of training set as test set")
train_idx, test_idx = split(train_idx, random_state=data_seed, test_size=test_size)
test_set = torch.utils.data.Subset(dataset, test_idx)

print("No validation set used, using test set for evaluation purposes")
# NOTE that when using this option care must be taken not to
# peak at the test-set. The only valid model to use is the one
# obtained after the final epoch where the epoch count is
# decided in advance of the experiment.
val_set = test_set

train_set = torch.utils.data.Subset(dataset, train_idx[0::sample])

model_params = {
    "task_dict": task_dict,
    "robust": robust,
    "n_targets": n_targets,
    "elem_emb_len": elem_emb_len,
    "elem_fea_len": 32,
    "sym_emb_len": sym_emb_len,
    "sym_fea_len": 32,
    "n_graph": 3,
    "elem_heads": 1,
    "elem_gate": [256],
    "elem_msg": [256],
    "cry_heads": 1,
    "cry_gate": [256],
    "cry_msg": [256],
    "trunk_hidden": [128, 128],
    "out_hidden": [64, 64],
}

train_ensemble(
    model_class=Wren,
    model_name=model_name,
    run_id=run_id,
    ensemble_folds=ensemble,
    epochs=epochs,
    train_set=train_set,
    val_set=val_set,
    log=log,
    data_params=data_params,
    setup_params=setup_params,
    restart_params=restart_params,
    model_params=model_params,
    loss_dict=loss_dict,
)

data_params["shuffle"] = False  # need fixed data order due to ensembling

roost_results_dict = results_multitask(
    model_class=Wren,
    model_name=model_name,
    run_id=run_id,
    ensemble_folds=ensemble,
    test_set=test_set,
    data_params=data_params,
    robust=robust,
    task_dict=task_dict,
    device=device,
    eval_type="checkpoint",
    save_results=False,
)
