Skip to content

Commit

Permalink
Merge pull request #338 from hua-zi/patch-4
Browse files Browse the repository at this point in the history
Update model_maintainer.py
  • Loading branch information
dunzeng committed Nov 6, 2023
2 parents 9f4a354 + d8cb221 commit aff14a5
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion fedlab/core/model_maintainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,20 @@ def model_parameters(self) -> torch.Tensor:
"""Return serialized model parameters."""
return SerializationTool.serialize_model(self._model)

@property
def model_grads(self) -> torch.Tensor:
"""Return serialized model gradients(base on model.state_dict(), Shape is the same as model_parameters)."""
params = self._model.state_dict()
for name, p in self._model.named_parameters():
params[name].grad = p.grad
for key in params:
if params[key].grad is None:
params[key].grad = torch.zeros_like(params[key])
gradients = [param.grad.data.view(-1) for param in params.values()]
m_gradients = torch.cat(gradients)
m_gradients = m_gradients.cpu()
return m_gradients

@property
def model_gradients(self) -> torch.Tensor:
"""Return serialized model gradients."""
Expand Down Expand Up @@ -117,4 +131,4 @@ def set_model(self, parameters: torch.Tensor = None, id: int = None):
if id is None:
super().set_model(parameters)
else:
super().set_model(self.parameters[id])
super().set_model(self.parameters[id])

0 comments on commit aff14a5

Please sign in to comment.