Skip to content

Commit

Permalink
Merge pull request #349 from techwizrd/fix-name-error-in-qfedavg
Browse files Browse the repository at this point in the history
Import SGDSerialClientTrainer in qfedavg
  • Loading branch information
dunzeng committed Jan 31, 2024
2 parents 0ae5c7d + 06709b0 commit edc47f4
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions fedlab/contrib/algorithm/qfedavg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from .basic_server import SyncServerHandler
from .basic_client import SGDClientTrainer
from .basic_client import SGDSerialClientTrainer


##################
Expand Down Expand Up @@ -40,7 +41,7 @@ def uplink_package(self):
def setup_optim(self, epochs, batch_size, lr, q):
super().setup_optim(epochs, batch_size, lr)
self.q = q

def train(self, model_parameters, train_loader) -> None:
"""Client trains its local model on local dataset.
Args:
Expand Down Expand Up @@ -72,11 +73,12 @@ def train(self, model_parameters, train_loader) -> None:
ret_loss + 1e-10, self.q - 1) * grad.norm(
)**2 + 1.0 / self.lr * np.float_power(ret_loss + 1e-10, self.q)


class qFedAvgSerialClientTrainer(SGDSerialClientTrainer):
def setup_optim(self, epochs, batch_size, lr, q):
super().setup_optim(epochs, batch_size, lr)
self.q = q

def train(self, model_parameters, train_loader) -> None:
"""Client trains its local model on local dataset.
Args:
Expand Down

0 comments on commit edc47f4

Please sign in to comment.