Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/cgcnn-example.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def main( # noqa: C901
weight_decay=1e-6,
batch_size=128,
workers=0,
device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
device="cuda" if torch.cuda.is_available() else "cpu",
**kwargs,
):

Expand Down
2 changes: 1 addition & 1 deletion examples/roost-example.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def main( # noqa: C901
weight_decay=1e-6,
batch_size=128,
workers=0,
device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
device="cuda" if torch.cuda.is_available() else "cpu",
**kwargs,
):
if not len(targets) == len(tasks) == len(losses):
Expand Down
2 changes: 1 addition & 1 deletion examples/wren-example.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def main( # noqa: C901
weight_decay=1e-6,
batch_size=128,
workers=0,
device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
device="cuda" if torch.cuda.is_available() else "cpu",
**kwargs,
):

Expand Down
41 changes: 41 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os

import pytest
import torch
from matminer.datasets import load_dataset

from aviary.cgcnn.utils import get_cgcnn_input
from aviary.wren.utils import get_aflow_label_spglib

torch.manual_seed(0) # ensure reproducible results (applies to all tests)


@pytest.fixture(scope="session")
def df_matbench_phonons():
"""Return a pandas dataframe with the data from the Matbench phonons dataset."""

df = load_dataset("matbench_phonons")
df[["lattice", "sites"]] = [get_cgcnn_input(x) for x in df.structure]
df["material_id"] = [f"mb_phdos_{i}" for i in range(len(df))]
df["composition"] = [x.composition.formula.replace(" ", "") for x in df.structure]

df["phdos_clf"] = [1 if x > 450 else 0 for x in df["last phdos peak"]]

return df


@pytest.fixture(scope="session")
def df_matbench_phonons_wyckoff(df_matbench_phonons):
"""Getting Aflow labels is expensive so we split into a separate fixture to avoid
paying for it unless needed.
"""
df_matbench_phonons["wyckoff"] = [
get_aflow_label_spglib(x) for x in df_matbench_phonons.structure
]

return df_matbench_phonons


@pytest.fixture(scope="session")
def tests_dir():
return os.path.dirname(os.path.abspath(__file__))
31 changes: 4 additions & 27 deletions tests/test_cgcnn_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,15 @@

import numpy as np
import torch
from matminer.utils.io import load_dataframe_from_json
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.model_selection import train_test_split as split

from aviary.cgcnn.data import CrystalGraphData, collate_batch
from aviary.cgcnn.model import CrystalGraphConvNet
from aviary.cgcnn.utils import get_cgcnn_input
from aviary.utils import results_multitask, train_ensemble

torch.manual_seed(0) # ensure reproducible results


def test_cgcnn_clf():
data_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "data/matbench_phonons.json.gz"
)
def test_cgcnn_clf(df_matbench_phonons):
elem_emb = "cgcnn92"
targets = ["phdos_clf"]
tasks = ["classification"]
Expand All @@ -44,26 +37,14 @@ def test_cgcnn_clf():
weight_decay = 1e-6
batch_size = 128
workers = 0
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = "cuda" if torch.cuda.is_available() else "cpu"

task_dict = dict(zip(targets, tasks))
loss_dict = dict(zip(targets, losses))

assert os.path.exists(data_path), f"{data_path} does not exist!"

df = load_dataframe_from_json(data_path)
df["lattice"] = [None] * len(df)
df["sites"] = [None] * len(df)
df[["lattice", "sites"]] = df.apply(
lambda x: get_cgcnn_input(x.structure), axis=1, result_type="expand"
)
df["material_id"] = [f"mb_phdos_{i}" for i in range(len(df))]
df["composition"] = df.structure.apply(
lambda x: x.composition.formula.replace(" ", "")
dataset = CrystalGraphData(
df=df_matbench_phonons, elem_emb=elem_emb, task_dict=task_dict
)
df["phdos_clf"] = np.where((df["last phdos peak"] > 450), 1, 0)

dataset = CrystalGraphData(df=df, elem_emb=elem_emb, task_dict=task_dict)
n_targets = dataset.n_targets
elem_emb_len = dataset.elem_emb_len
nbr_fea_len = dataset.nbr_fea_dim
Expand Down Expand Up @@ -166,7 +147,3 @@ def test_cgcnn_clf():

assert ens_acc > 0.85
assert ens_roc_auc > 0.9


if __name__ == "__main__":
test_cgcnn_clf()
30 changes: 4 additions & 26 deletions tests/test_cgcnn_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,15 @@

import numpy as np
import torch
from matminer.utils.io import load_dataframe_from_json
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split as split

from aviary.cgcnn.data import CrystalGraphData, collate_batch
from aviary.cgcnn.model import CrystalGraphConvNet
from aviary.cgcnn.utils import get_cgcnn_input
from aviary.utils import results_multitask, train_ensemble

torch.manual_seed(0) # ensure reproducible results


def test_cgcnn_regression():
data_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "data/matbench_phonons.json.gz"
)
def test_cgcnn_regression(df_matbench_phonons):
elem_emb = "cgcnn92"
targets = ["last phdos peak"]
tasks = ["regression"]
Expand All @@ -44,25 +37,14 @@ def test_cgcnn_regression():
weight_decay = 1e-6
batch_size = 128
workers = 0
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = "cuda" if torch.cuda.is_available() else "cpu"

task_dict = dict(zip(targets, tasks))
loss_dict = dict(zip(targets, losses))

assert os.path.exists(data_path), f"{data_path} does not exist!"

df = load_dataframe_from_json(data_path)
df["lattice"] = [None] * len(df)
df["sites"] = [None] * len(df)
df[["lattice", "sites"]] = df.apply(
lambda x: get_cgcnn_input(x.structure), axis=1, result_type="expand"
)
df["material_id"] = [f"mb_phdos_{i}" for i in range(len(df))]
df["composition"] = df.structure.apply(
lambda x: x.composition.formula.replace(" ", "")
dataset = CrystalGraphData(
df=df_matbench_phonons, elem_emb=elem_emb, task_dict=task_dict
)

dataset = CrystalGraphData(df=df, elem_emb=elem_emb, task_dict=task_dict)
n_targets = dataset.n_targets
elem_emb_len = dataset.elem_emb_len
nbr_fea_len = dataset.nbr_fea_dim
Expand Down Expand Up @@ -164,7 +146,3 @@ def test_cgcnn_regression():
assert r2 > 0.7
assert mae < 150
assert rmse < 300


if __name__ == "__main__":
test_cgcnn_regression()
25 changes: 4 additions & 21 deletions tests/test_roost_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,15 @@

import numpy as np
import torch
from matminer.utils.io import load_dataframe_from_json
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.model_selection import train_test_split as split

from aviary.roost.data import CompositionData, collate_batch
from aviary.roost.model import Roost
from aviary.utils import results_multitask, train_ensemble

torch.manual_seed(0) # ensure reproducible results


def test_roost_clf():
data_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "data/matbench_phonons.json.gz"
)
def test_roost_clf(df_matbench_phonons):
elem_emb = "matscholar200"
targets = ["phdos_clf"]
tasks = ["classification"]
Expand All @@ -41,21 +35,14 @@ def test_roost_clf():
weight_decay = 1e-6
batch_size = 128
workers = 0
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = "cuda" if torch.cuda.is_available() else "cpu"

task_dict = dict(zip(targets, tasks))
loss_dict = dict(zip(targets, losses))

assert os.path.exists(data_path), f"{data_path} does not exist!"

df = load_dataframe_from_json(data_path)
df["material_id"] = [f"mb_phdos_{i}" for i in range(len(df))]
df["composition"] = df.structure.apply(
lambda x: x.composition.formula.replace(" ", "")
dataset = CompositionData(
df=df_matbench_phonons, elem_emb=elem_emb, task_dict=task_dict
)
df["phdos_clf"] = np.where((df["last phdos peak"] > 450), 1, 0)

dataset = CompositionData(df=df, elem_emb=elem_emb, task_dict=task_dict)
n_targets = dataset.n_targets
elem_emb_len = dataset.elem_emb_len

Expand Down Expand Up @@ -162,7 +149,3 @@ def test_roost_clf():

assert ens_acc > 0.9
assert ens_roc_auc > 0.9


if __name__ == "__main__":
test_roost_clf()
24 changes: 4 additions & 20 deletions tests/test_roost_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,15 @@

import numpy as np
import torch
from matminer.utils.io import load_dataframe_from_json
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split as split

from aviary.roost.data import CompositionData, collate_batch
from aviary.roost.model import Roost
from aviary.utils import results_multitask, train_ensemble

torch.manual_seed(0) # ensure reproducible results


def test_roost_regression():
data_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "data/matbench_phonons.json.gz"
)
def test_roost_regression(df_matbench_phonons):
elem_emb = "matscholar200"
targets = ["last phdos peak"]
tasks = ["regression"]
Expand All @@ -41,20 +35,14 @@ def test_roost_regression():
weight_decay = 1e-6
batch_size = 128
workers = 0
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = "cuda" if torch.cuda.is_available() else "cpu"

task_dict = dict(zip(targets, tasks))
loss_dict = dict(zip(targets, losses))

assert os.path.exists(data_path), f"{data_path} does not exist!"

df = load_dataframe_from_json(data_path)
df["material_id"] = [f"mb_phdos_{i}" for i in range(len(df))]
df["composition"] = df.structure.apply(
lambda x: x.composition.formula.replace(" ", "")
dataset = CompositionData(
df=df_matbench_phonons, elem_emb=elem_emb, task_dict=task_dict
)

dataset = CompositionData(df=df, elem_emb=elem_emb, task_dict=task_dict)
n_targets = dataset.n_targets
elem_emb_len = dataset.elem_emb_len

Expand Down Expand Up @@ -160,7 +148,3 @@ def test_roost_regression():
assert r2 > 0.7
assert mae < 150
assert rmse < 300


if __name__ == "__main__":
test_roost_regression()
30 changes: 6 additions & 24 deletions tests/test_wren_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,15 @@

import numpy as np
import torch
from matminer.utils.io import load_dataframe_from_json
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.model_selection import train_test_split as split

from aviary.utils import results_multitask, train_ensemble
from aviary.wren.data import WyckoffData, collate_batch
from aviary.wren.model import Wren
from aviary.wren.utils import get_aflow_label_spglib

torch.manual_seed(0) # ensure reproducible results


def test_wren_clf():
data_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "data/matbench_phonons.json.gz"
)
def test_wren_clf(df_matbench_phonons_wyckoff):
elem_emb = "matscholar200"
sym_emb = "bra-alg-off"
targets = ["phdos_clf"]
Expand All @@ -44,23 +37,16 @@ def test_wren_clf():
weight_decay = 1e-6
batch_size = 128
workers = 0
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = "cuda" if torch.cuda.is_available() else "cpu"

task_dict = dict(zip(targets, tasks))
loss_dict = dict(zip(targets, losses))

assert os.path.exists(data_path), f"{data_path} does not exist!"

df = load_dataframe_from_json(data_path)
df["wyckoff"] = df.structure.apply(get_aflow_label_spglib)
df["material_id"] = [f"mb_phdos_{i}" for i in range(len(df))]
df["composition"] = df.structure.apply(
lambda x: x.composition.formula.replace(" ", "")
)
df["phdos_clf"] = np.where((df["last phdos peak"] > 450), 1, 0)

dataset = WyckoffData(
df=df, elem_emb=elem_emb, sym_emb=sym_emb, task_dict=task_dict
df=df_matbench_phonons_wyckoff,
elem_emb=elem_emb,
sym_emb=sym_emb,
task_dict=task_dict,
)
n_targets = dataset.n_targets
elem_emb_len = dataset.elem_emb_len
Expand Down Expand Up @@ -171,7 +157,3 @@ def test_wren_clf():

assert ens_acc > 0.85
assert ens_roc_auc > 0.9


if __name__ == "__main__":
test_wren_clf()
Loading