From fe09427a2f373988e31377d85c0db00ce4736735 Mon Sep 17 00:00:00 2001 From: tomoe Date: Tue, 6 Jun 2023 15:53:52 +0900 Subject: [PATCH] Update load weights --- multiml/task/pytorch/pytorch_base.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/multiml/task/pytorch/pytorch_base.py b/multiml/task/pytorch/pytorch_base.py index 4f512f5..ec669af 100644 --- a/multiml/task/pytorch/pytorch_base.py +++ b/multiml/task/pytorch/pytorch_base.py @@ -241,13 +241,14 @@ def load_model(self): # partial weights if ':' in model_path: - model_path, partial = model_path.split(':') + model_path, *partial = model_path.split(':') model_dict = self.ml.model.state_dict() state_dict = torch.load(model_path) new_state_dict = {} for key, value in state_dict.items(): - if partial in key: - new_state_dict[key] = value + for ipartial in partial: + if ipartial in key: + new_state_dict[key] = value model_dict.update(new_state_dict) self.ml.model.load_state_dict(model_dict) else: