Skip to content

Commit

Permalink
Update load weights
Browse files Browse the repository at this point in the history
  • Loading branch information
tomoe committed Jun 6, 2023
1 parent 44fe5b3 commit fe09427
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions multiml/task/pytorch/pytorch_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,14 @@ def load_model(self):

# partial weights
if ':' in model_path:
model_path, partial = model_path.split(':')
model_path, *partial = model_path.split(':')
model_dict = self.ml.model.state_dict()
state_dict = torch.load(model_path)
new_state_dict = {}
for key, value in state_dict.items():
if partial in key:
new_state_dict[key] = value
for ipartial in partial:
if ipartial in key:
new_state_dict[key] = value
model_dict.update(new_state_dict)
self.ml.model.load_state_dict(model_dict)
else:
Expand Down

0 comments on commit fe09427

Please sign in to comment.