Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to predict on new materials with saved pytorch file #63

Closed
sarah-allec opened this issue Dec 7, 2022 · 10 comments
Closed

How to predict on new materials with saved pytorch file #63

sarah-allec opened this issue Dec 7, 2022 · 10 comments

Comments

@sarah-allec
Copy link

I used roost-example.py and saved the trained model in a pytorch file (e.g., roost.pt). I have tried to load this file and predict as follows:

targets=["E_f"]
tasks=["regression"]
task_dict = dict(zip(targets, tasks))
df = pd.read_csv('candidate_compositions.csv')
X = CompositionData(df, elem_embedding = "matscholar200", task_dict = task_dict)

model = torch.load('models/roost.pt')
y_pred = model.predict(X)

and I get the following output:

Traceback (most recent call last):
  File "~/opt/anaconda3/envs/aviary/lib/python3.7/site-packages/pandas/core/indexes/base.py", line 3361, in get_loc
    return self._engine.get_loc(casted_key)
  File "pandas/_libs/index.pyx", line 76, in pandas._libs.index.IndexEngine.get_loc
  File "pandas/_libs/index.pyx", line 108, in pandas._libs.index.IndexEngine.get_loc
  File "pandas/_libs/hashtable_class_helper.pxi", line 5198, in pandas._libs.hashtable.PyObjectHashTable.get_item
  File "pandas/_libs/hashtable_class_helper.pxi", line 5206, in pandas._libs.hashtable.PyObjectHashTable.get_item
KeyError: 'E_f'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "roost-predict.py", line 12, in <module>
    y_pred = model.predict(X)
  File "~/opt/anaconda3/envs/aviary/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "~/opt/anaconda3/envs/aviary/lib/python3.7/site-packages/aviary/core.py", line 357, in predict
    data_loader, disable=True if not verbose else None
  File "~/opt/anaconda3/envs/aviary/lib/python3.7/site-packages/tqdm/std.py", line 1173, in __iter__
    for obj in iterable:
  File "~/opt/anaconda3/envs/aviary/lib/python3.7/site-packages/aviary/roost/data.py", line 126, in __getitem__
    targets.append(Tensor([row[target]]))
  File "~/opt/anaconda3/envs/aviary/lib/python3.7/site-packages/pandas/core/series.py", line 942, in __getitem__
    return self._get_value(key)
  File "~/opt/anaconda3/envs/aviary/lib/python3.7/site-packages/pandas/core/series.py", line 1051, in _get_value
    loc = self.index.get_loc(label)
  File "~/opt/anaconda3/envs/aviary/lib/python3.7/site-packages/pandas/core/indexes/base.py", line 3363, in get_loc
    raise KeyError(key) from err
KeyError: 'E_f'

Is it possible to add an example script to perform a prediction from a saved model?

Thank you

@CompRhys
Copy link
Owner

CompRhys commented Dec 8, 2022

I think the composition data object requires a dummy target column in the dataframe with the same name as the target the model was trained on. Can you test this and if its not that can you share the df.columns and the model.model_params dict.

@sarah-allec
Copy link
Author

I added a dummy column for E_f and I now get the following:

  File "roost-predict.py", line 17, in <module>
    y_pred = model.predict(X)
  File "~/opt/anaconda3/envs/aviary/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "~/opt/anaconda3/envs/aviary/lib/python3.7/site-packages/aviary/core.py", line 357, in predict
    data_loader, disable=True if not verbose else None
  File "~/opt/anaconda3/envs/aviary/lib/python3.7/site-packages/tqdm/std.py", line 1173, in __iter__
    for obj in iterable:
  File "~/opt/anaconda3/envs/aviary/lib/python3.7/site-packages/aviary/roost/data.py", line 91, in __getitem__
    material_ids = row[self.identifiers].to_list()
  File "~/opt/anaconda3/envs/aviary/lib/python3.7/site-packages/pandas/core/series.py", line 966, in __getitem__
    return self._get_with(key)
  File "~/opt/anaconda3/envs/aviary/lib/python3.7/site-packages/pandas/core/series.py", line 1006, in _get_with
    return self.loc[key]
  File "~/opt/anaconda3/envs/aviary/lib/python3.7/site-packages/pandas/core/indexing.py", line 931, in __getitem__
    return self._getitem_axis(maybe_callable, axis=axis)
  File "~/opt/anaconda3/envs/aviary/lib/python3.7/site-packages/pandas/core/indexing.py", line 1153, in _getitem_axis
    return self._getitem_iterable(key, axis=axis)
  File "~/opt/anaconda3/envs/aviary/lib/python3.7/site-packages/pandas/core/indexing.py", line 1093, in _getitem_iterable
    keyarr, indexer = self._get_listlike_indexer(key, axis)
  File "~/opt/anaconda3/envs/aviary/lib/python3.7/site-packages/pandas/core/indexing.py", line 1314, in _get_listlike_indexer
    self._validate_read_indexer(keyarr, indexer, axis)
  File "~/opt/anaconda3/envs/aviary/lib/python3.7/site-packages/pandas/core/indexing.py", line 1377, in _validate_read_indexer
    raise KeyError(f"{not_found} not in index")
KeyError: "['material_id'] not in index"

The materials I'm interested in are not necessarily in Materials Project, so they don't have a material_id - I tried to add a dummy column but that gives the same error.

@sarah-allec
Copy link
Author

Also, df.columns = ['composition', 'E_f'] and model.model_params is below:

{'task_dict': {'E_f': 'regression'}, 'robust': True, 'n_targets': [1], 'out_hidden': [256, 128, 64], 
'trunk_hidden': [1024, 512], 'elem_emb_len': 200, 'elem_fea_len': 64, 'n_graph': 3, 'elem_heads': 3, 
'elem_gate': [256], 'elem_msg': [256], 'cry_heads': 3, 'cry_gate': [256], 'cry_msg': [256]}

@CompRhys
Copy link
Owner

CompRhys commented Dec 8, 2022

class CompositionData(Dataset):
    """Dataset class for the Roost composition model."""


    def __init__(
        self,
        df: pd.DataFrame,
        task_dict: dict[str, str],
        elem_embedding: str = "matscholar200",
        inputs: str = "composition",
        identifiers: Sequence[str] = ("material_id", "composition"),
    ):
        """Data class for Roost models.

        Args:
            df (pd.DataFrame): Pandas dataframe holding input and target values.
            task_dict (dict[str, "regression" | "classification"]): Map from target names to task
                type.
            elem_embedding (str, optional): One of "matscholar200", "cgcnn92", "megnet16",
                "onehot112" or path to a file with custom element embeddings.
                Defaults to "matscholar200".
            inputs (str, optional): df column name holding material compositions.
                Defaults to "composition".
            identifiers (list, optional): df columns for distinguishing data points. Will be
                copied over into the model's output CSV. Defaults to ["material_id", "composition"].
        """

I would think that provided the df has columns ["material_id", "composition", "E_f"] it should work. There's nothing specific to MP about the material_id it's just an id for keeping track of materials. Perhaps this confusion could be avoided by not having defaults for the inputs and identifiers

@sarah-allec
Copy link
Author

Okay I corrected a typo in the column names (used material_ids instead of material_id), but I now get the following error:

Traceback (most recent call last):
  File "roost-predict.py", line 18, in <module>
    y_pred = model.predict(X)
  File "~/opt/anaconda3/envs/aviary/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "~/opt/anaconda3/envs/aviary/lib/python3.7/site-packages/aviary/core.py", line 359, in predict
    preds = self(*inputs)  # forward pass to get model preds
  File "~/opt/anaconda3/envs/aviary/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
TypeError: forward() missing 1 required positional argument: 'cry_elem_idx'

@CompRhys
Copy link
Owner

CompRhys commented Dec 8, 2022

Okay this error is to do with the fact that the data loader plays an important role as it batches the compositions to allow for the fact that different compositions can have different numbers of elements. I used the CLI args to the examples scripts to run all the experiments in the papers and so native scripting experience does maybe need some work and definitely more explanation.

The first quirk here is that you need to make the test set into a subset:

test_set = torch.utils.data.Subset(test_set, range(len(test_set)))

Then the easiest way will be to call the results_multitask utility function for the checkpoint saved by the model:

results_dict = results_multitask(
model_class=Roost,
model_name=model_name,
run_id=run_id,
ensemble_folds=ensemble,
test_set=test_set,
data_params=data_params,
robust=robust,
task_dict=task_dict,
device=device,
eval_type="checkpoint",
save_results=False,
)
preds = results_dict[target_name]["preds"]
targets = results_dict[target_name]["targets"]

where the model_name and run_id are used to identify the checkpoint.

If you want to script most of this yourself then the important thing to do is initialise a Torch DataLoader with the Subset of the DataSet and the collate_batch found in aviary.roost.data. Then iterate over the DataLoader calling predict on each of the batches in turn. You can use fairly large batch sizes for prediction >512 so this iteration is fairly quick even on large prediction sets.

@sarah-allec
Copy link
Author

sarah-allec commented Dec 8, 2022

Ah I see. So I have it almost working I think. I added this to my script:

test_set = torch.utils.data.Subset(X, range(len(X)))
model_name="roost"
ensemble=1
run_id=1
device=None
robust = True
batch_size = batch_size=128
workers = 0
data_params = {
    "batch_size": batch_size,
    "num_workers": workers,
    "pin_memory": False,
    "shuffle": True,
    "collate_fn": collate_batch,
}
if device is None:
    device = "cuda" if torch.cuda.is_available() else "cpu"

results_dict = results_multitask(
     model_class=Roost,
     model_name=model_name,
     run_id=run_id,
     ensemble_folds=ensemble,
     test_set=test_set,
     data_params=data_params,
     robust=robust,
     task_dict=task_dict,
     device=device,
     eval_type="checkpoint",
     save_results=False,
)

preds = results_dict[target_name]["preds"]
targets = results_dict[target_name]["targets"]

Now the error I get is that it can't find the checkpoint file (it points to ~/opt/anaconda3/envs/aviary/lib/python3.7/site-packages/models/roost/ but the file is actually in ~/opt/anaconda3/envs/aviary/lib/python3.7/site-packages/models/roost_s-0_t-1. This should be an easy fix - I'll work on it and comment when I am done.

@sarah-allec
Copy link
Author

It looks like it's working:

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
------------Evaluate model on Test Set------------
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Testing on 320 samples
Evaluating Model

Task: 'E_f' on test set
Model Performance Metrics:
R2 Score: 0.0000 
MAE: 2.5439
RMSE: 2.5981

Although I'm wondering where the MAE and RMSE are coming from - are they on a validation set that the checkpoint file provides?

@CompRhys
Copy link
Owner

CompRhys commented Dec 8, 2022

They'll be meaningless here as using a dummy column for the targets when predicting

@sarah-allec
Copy link
Author

got it - thanks for the help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants