GLOBAL MODEL TESTING #1417
-
Hi everyone, I have a problem regarding global model testing. I have completed a training and what I would like to do is to use the obtained global model and its weights to do a simple test by submitting an image to it in order to classify it. To do this I have identified admin@nvidia.com/transfer/<JOB_ID>/workspace/app_server/FL_global_model.pt as the global model, however I am not sure if this is it. If not please let me know where I can find the global model following the job download via admin. So in Python I have declared the model used in training, which is different from the default model (SimpleNetwork used in hello-pt-tb), however upon running the command model.load_state_dict(torch.load(path)), where path refers to admin@nvidia.com/transfer/<JOB_ID>/workspace/app_server/FL_global_model.pt, I get the error Missing key(s) in state_dict: ...... which is usually an error related to model and weights not matching. How can I solve it? Thank you in advance for your support. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
@holgerroth can you help with this ? |
Beta Was this translation helpful? Give feedback.
-
The model is saved in a dict depending on the persistor you used. You might need to access it with model.load_state_dict(torch.load(path_to_model)["model"]) as the standard PT persitor saves additional meta information together with the model weights. For details see here |
Beta Was this translation helpful? Give feedback.
The model is saved in a dict depending on the persistor you used. You might need to access it with model.load_state_dict(torch.load(path_to_model)["model"]) as the standard PT persitor saves additional meta information together with the model weights. For details see here
NVFlare/nvflare/app_opt/pt/model_persistence_format_manager.py
Line 31 in 94d3695