In [None]:
import copy
import os
import time
import warnings
import numpy as np

import torch
from torch.optim.optimizer import Optimizer
import torchvision

In [None]:
class FedProxLocalSolver(Optimizer):
    
    def __init__(self, optimizer, mu=0):
        if mu < 0.0:
            raise ValueError(f'Invalid mu value: {mu}')
        self.optimizer = optimizer
        self.mu = mu
        self.param_groups = self.optimizer.param_groups
        self.state = self.optimizer.state
    
    def _update(self, group):
        for p in group['params']:
            if p.grad is None:
                continue
            state = self.state[p]
            p.data.add_(state['proximal'], alpha=-group['lr'])
        
    def step(self, closure=None):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                state = self.state[p]
                if 'initial_weights' not in state:
                    state['initial_weights'] = torch.clone(p.data).detach()
                state['proximal'] = self.mu * (p.data - state['initial_weights'])

        loss = self.optimizer.step(closure=closure)
        for group in self.param_groups:
            self._update(group)
        return loss