New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to predict on new materials with saved pytorch file #63
Comments
I think the composition data object requires a dummy target column in the dataframe with the same name as the target the model was trained on. Can you test this and if its not that can you share the |
I added a dummy column for
The materials I'm interested in are not necessarily in Materials Project, so they don't have a |
Also,
|
class CompositionData(Dataset):
"""Dataset class for the Roost composition model."""
def __init__(
self,
df: pd.DataFrame,
task_dict: dict[str, str],
elem_embedding: str = "matscholar200",
inputs: str = "composition",
identifiers: Sequence[str] = ("material_id", "composition"),
):
"""Data class for Roost models.
Args:
df (pd.DataFrame): Pandas dataframe holding input and target values.
task_dict (dict[str, "regression" | "classification"]): Map from target names to task
type.
elem_embedding (str, optional): One of "matscholar200", "cgcnn92", "megnet16",
"onehot112" or path to a file with custom element embeddings.
Defaults to "matscholar200".
inputs (str, optional): df column name holding material compositions.
Defaults to "composition".
identifiers (list, optional): df columns for distinguishing data points. Will be
copied over into the model's output CSV. Defaults to ["material_id", "composition"].
""" I would think that provided the df has columns |
Okay I corrected a typo in the column names (used
|
Okay this error is to do with the fact that the data loader plays an important role as it batches the compositions to allow for the fact that different compositions can have different numbers of elements. I used the CLI args to the examples scripts to run all the experiments in the papers and so native scripting experience does maybe need some work and definitely more explanation. The first quirk here is that you need to make the test set into a subset: aviary/examples/roost-example.py Line 111 in e2bfd50
Then the easiest way will be to call the aviary/tests/test_roost_regression.py Lines 119 to 134 in e2bfd50
where the model_name and run_id are used to identify the checkpoint. If you want to script most of this yourself then the important thing to do is initialise a |
Ah I see. So I have it almost working I think. I added this to my script:
Now the error I get is that it can't find the checkpoint file (it points to |
It looks like it's working:
Although I'm wondering where the MAE and RMSE are coming from - are they on a validation set that the checkpoint file provides? |
They'll be meaningless here as using a dummy column for the targets when predicting |
got it - thanks for the help! |
I used
roost-example.py
and saved the trained model in a pytorch file (e.g.,roost.pt
). I have tried to load this file and predict as follows:and I get the following output:
Is it possible to add an example script to perform a prediction from a saved model?
Thank you
The text was updated successfully, but these errors were encountered: