You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
When model contains BN layer, the param bn.num_batches_tracked would be convert to int by grpc. But the trainer.update can't handle this situation well.
A dummy solution:
defupdate(self, model_parameters):
''' Called by the FL client to update the model parameters Arguments: model_parameters (dict): PyTorch Module object's state_dict. '''forkeyinmodel_parameters:
ifisinstance(model_parameters[key], list):
model_parameters[key] =torch.FloatTensor(
model_parameters[key])
elifisinstance(model_parameters[key], int):
model_parameters[key] =torch.tensor(model_parameters[key], dtype=torch.long)
print(key, model_parameters[key])
elifisinstance(model_parameters[key], float):
model_parameters[key] =torch.tensor(model_parameters[key], dtype=torch.float)
self.ctx.model.load_state_dict(self._param_filter(model_parameters),
strict=False)
or can we solve it before sending the model_param?
The text was updated successfully, but these errors were encountered:
There is another type conversion bug in aggregator.
If the value is int, like 8, FloatTensor(8), would convert 8 to a random tensor with shape 8.
Should we handle these two situation in the message buffer instead of in aggregator?
rayrayraykk
changed the title
AttributeError in distributed mode
AttributeError in distributed mode -- (Avoid type conversion outside worker)
Jul 7, 2022
Describe the bug
When model contains BN layer, the param
bn.num_batches_tracked
would be convert toint
bygrpc
. But thetrainer.update
can't handle this situation well.A dummy solution:
or can we solve it before sending the model_param?
The text was updated successfully, but these errors were encountered: