diff --git a/examples/reaction_prediction/rexgen_direct/utils.py b/examples/reaction_prediction/rexgen_direct/utils.py index d2a9de61..338f13ed 100644 --- a/examples/reaction_prediction/rexgen_direct/utils.py +++ b/examples/reaction_prediction/rexgen_direct/utils.py @@ -528,7 +528,7 @@ def prepare_reaction_center(args, reaction_center_config): n_layers=reaction_center_config['n_layers'], n_tasks=reaction_center_config['n_tasks']) reaction_center_model.load_state_dict( - torch.load(args['center_model_path'])['model_state_dict']) + torch.load(args['center_model_path'], map_location=torch.device('cpu'))['model_state_dict']) reaction_center_model = reaction_center_model.to(args['device']) reaction_center_model.eval()