In [None]:
%%capture
from torch import __version__ as TORCH_VERSION

print(f"{TORCH_VERSION=}")

!pip install -U git+https://github.com/CompRhys/aviary.git  # install aviary
!wget -O taata.json.gz https://figshare.com/ndownloader/files/34423997

In [None]:
import gzip
import json

import pandas as pd
import torch
from pymatgen.analysis.prototypes import (
    count_wyckoff_positions,
    get_protostructure_label_from_spglib,
)
from pymatgen.core import Structure
from sklearn.model_selection import train_test_split as split

from aviary.utils import results_multitask, train_ensemble
from aviary.wrenformer.data import collate_batch as wrenformer_cb
from aviary.wrenformer.data import df_to_in_mem_dataloader
from aviary.wrenformer.model import Wrenformer

In [None]:
with gzip.open("taata.json.gz", "r") as fin:
    json_bytes = fin.read()

json_str = json_bytes.decode("utf-8")
data = json.loads(json_str)

df = pd.DataFrame(data["data"], columns=data["columns"])

df["final_structure"] = [Structure.from_dict(x) for x in df.final_structure]

df["composition"] = [x.composition.reduced_formula for x in df.final_structure]
df["volume_per_atom"] = [x.volume / len(x) for x in df.final_structure]
df["wyckoff"] = df["final_structure"].map(get_protostructure_label_from_spglib)

df = df[df.wyckoff.map(count_wyckoff_positions) < 16]
df["n_sites"] = df.final_structure.map(len)
df = df[df.n_sites < 64]
df = df[df.volume_per_atom < 500]

spglib: ssm_get_exact_positions failed.
spglib: get_bravais_exact_positions_and_lattice failed.
spglib: ssm_get_exact_positions failed.
spglib: get_bravais_exact_positions_and_lattice failed.
spglib: ssm_get_exact_positions failed.
spglib: get_bravais_exact_positions_and_lattice failed.
spglib: ssm_get_exact_positions failed.
spglib: get_bravais_exact_positions_and_lattice failed.
spglib: ssm_get_exact_positions failed.
spglib: get_bravais_exact_positions_and_lattice failed.
spglib: ssm_get_exact_positions failed.
spglib: get_bravais_exact_positions_and_lattice failed.
spglib: ssm_get_exact_positions failed.
spglib: get_bravais_exact_positions_and_lattice failed.
spglib: ssm_get_exact_positions failed.
spglib: get_bravais_exact_positions_and_lattice failed.
spglib: ssm_get_exact_positions failed.
spglib: get_bravais_exact_positions_and_lattice failed.
spglib: ssm_get_exact_positions failed.
spglib: get_bravais_exact_positions_and_lattice failed.
spglib: ssm_get_exact_positions failed.


In [None]:
resume = False
fine_tune = None
transfer = None

optim = "AdamW"
learning_rate = 3e-4
momentum = 0.9
weight_decay = 1e-6
batch_size = 128
workers = 0
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

targets = ["E_vasp_per_atom"]
tasks = ["regression"]
losses = ["L1"]
robust = True

data_seed = 42
test_size = 0.2
sample = 1

ensemble = 1
run_id = 1
epochs = 3
log = False

# NOTE setting workers to zero means that the data is loaded in the main
# process and enables caching

data_params = {
    "batch_size": batch_size,
    "num_workers": workers,
    "pin_memory": False,
    "shuffle": True,
}

setup_params = {
    "optim": optim,
    "learning_rate": learning_rate,
    "weight_decay": weight_decay,
    "momentum": momentum,
    "device": device,
}

restart_params = {
    "resume": resume,
    "fine_tune": fine_tune,
    "transfer": transfer,
}

task_dict = dict(zip(targets, tasks, strict=False))
loss_dict = dict(zip(targets, losses, strict=False))

In [None]:
torch.manual_seed(0)  # ensure reproducible results

input_col = "wyckoff"
embedding_type = "wyckoff"
model_name = "wrenformer-reg-test"

data_params["collate_fn"] = wrenformer_cb
data_params["shuffle"] = True

print(f"using {test_size} of training set as test set")
train_df, test_df = split(df, random_state=data_seed, test_size=test_size)

print("No validation set used, using test set for evaluation purposes")
# NOTE that when using this option care must be taken not to
# peak at the test-set. The only valid model to use is the one
# obtained after the final epoch where the epoch count is
# decided in advance of the experiment.
val_df = test_df

data_loader_kwargs = dict(
    id_col="material_id",  # TODO this should take a list of columns
    input_col=input_col,
    target_col=targets[0],  # TODO this should take a list of columns
    embedding_type=embedding_type,
    device=device,
)

train_loader = df_to_in_mem_dataloader(
    train_df,
    batch_size=batch_size,
    shuffle=True,
    **data_loader_kwargs,
)

val_loader = df_to_in_mem_dataloader(
    test_df,
    batch_size=batch_size * 16,
    shuffle=False,
    **data_loader_kwargs,
)

n_targets = [
    1 if task_type == "regression" else train_df[target_col].max() + 1
    for target_col, task_type in task_dict.items()
]

model_params = {
    "task_dict": task_dict,
    "robust": robust,
    "n_targets": n_targets,
    "n_features": train_loader.tensors[0][0].shape[-1],
    "d_model": 128,
    "n_attn_layers": 6,
    "n_attn_heads": 4,
    "trunk_hidden": (1024, 512),
    "out_hidden": (256, 128, 64),
    "embedding_aggregations": ("mean",),
}

train_ensemble(
    model_class=Wrenformer,
    model_name=model_name,
    run_id=run_id,
    ensemble_folds=ensemble,
    epochs=epochs,
    train_loader=train_loader,
    val_loader=val_loader,
    log=log,
    setup_params=setup_params,
    restart_params=restart_params,
    model_params=model_params,
    loss_dict=loss_dict,
)

test_loader = df_to_in_mem_dataloader(
    test_df,
    batch_size=batch_size * 64,
    shuffle=False,
    **data_loader_kwargs,
)

roost_results_dict = results_multitask(
    model_class=Wrenformer,
    model_name=model_name,
    run_id=run_id,
    ensemble_folds=ensemble,
    test_loader=test_loader,
    robust=robust,
    task_dict=task_dict,
    device=device,
    eval_type="checkpoint",
    save_results=False,
)

using 0.2 of training set as test set
No validation set used, using test set for evaluation purposes


TypeError: train_ensemble() got an unexpected keyword argument 'train_loader'