<a href="https://colab.research.google.com/github/NMashalov/FederationLearning/blob/master/Fed.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Reserach on zero-gradient in federation learning

Code was highly inspired by implementation in paper:

https://arxiv.org/pdf/2304.07861.pdf


First we'll work without samples
$$
    f(x,\xi) ={\left(<x,\xi> - \sum_{i=1}^d\xi_i\right) ^2}+ \|x\|_\infty + \delta
$$

$x \in R^d$

$\xi_i \sim U[0,1]$

$\delta \sim \mathbf{N(0,\sigma^2)}$

Optimization set:

$$
    D = \{|x|=1, x > 0\}
$$


Recall, infinity norm is max among all coordinates

$$
    \|x\|_\infty = \max\limits_{i=1,n} x_i
$$

## Algorithm implmentation

In [11]:
import numpy as np

In [12]:
### HYPERPARAMS
# vector dim
D = 20
# number of machines
M = 32
# smoothing gamma
G = 0.001
# number of samples
N = 1000

In [15]:
# N x D matrix
X_train = np.random.uniform(size=(N,D))

In [18]:
def loss(x,xi, noise = 0):
    """
    Compute optimization function

    Argument:
        x: tensor [Machines number, Vector Dim]
        xi: tensor sample from []
        noise:
    """
    return (x @ xi - np.sum(xi,axis=1))  + np.max(x,dim=1) + noise * np.random.randn(D)

Gradient approximation

Two points l1:
$$
\nabla f_\gamma (x, ξ, e) = \frac{d}{2γ}
\left(f_{δ1}
(x + γe, ξ) − f_{δ2}
(x − γe, ξ)\right) sign(e)
$$


Two points l2:
$$
\nabla f_\gamma (x, ξ, e) = \frac{d}{2γ}
\left(f_{δ1}
(x + γe, ξ) − f_{δ2}
(x − γe, ξ)\right) e
$$

In [19]:
def sample_spherical(npoints, ndim=D):
    vec = np.random.randn(ndim, npoints)
    vec /= np.linalg.norm(vec, axis=0)
    return vec

In [21]:
import typing as tp

method_type = tp.Literal['l1','l2']

def calc_grad(x: np.array,xi: np.array, method_type: method_type = 'l2'):
    if method_type == 'l2':
        e = sample_spherical(M)
        grad = D / (2* G) * (loss(x + G * e,xi) - loss(x - G * e,xi)) * e
    elif method_type == 'l1':
        raise NotImplementedError('l1')
        grad = D / (2* G) * (loss(x + G * e,xi) - loss(x - G * e,xi)) * e
    else:
        raise NotImplementedError('No methods')

## Minibatch SGD

In [22]:
def sample_grad(x,batch_size):
    samples_idx = np.random.choice(N, batch_size, replace=True)
    X_sampled = X_train[samples_idx, :]
    return calc_grad(x,X_sampled)

In [24]:
def population_loss(x):
        """
        Compute the population loss for all training samples.

        Argument:
            weight: w
        """
        return loss(x, X_train)

In [25]:
import pandas as pd
T = 1000 # number of oracle calls
K = 10 # communication rounds

eta = 1e-3
beta = 2
alpha = 2
gamma = 0.1

local_batch = 10

w = np.random.randn(D)
w_ag = np.copy(w)

seq = pd.Series(name='loss')
for iter_cnt in range(0, T+1, K):
    w_md = (1/beta) * w + (1-(1/beta))*w_ag
    grad_md = sample_grad(w_md, M*K*local_batch)
    w_ag = w_md - eta * grad_md
    w = (1 - (1/alpha)) * w + (1/alpha) * \
        w_md - gamma * grad_md

  seq = pd.Series(name='loss')


ValueError: ignored

## FED_AVG

In [None]:
def broadcast_avg(pool):
    """
    Helper functions for FedAc and FedAvg, average and broadcast the weights.
    """
    avg = pool.mean(axis=0)
    pool = np.repeat(avg[np.newaxis, :], pool.shape[0], axis=0)
    return pool

In [None]:
import pandas as pd
def fedavg(eta, M, K, T,  record_intvl=512, print_intvl=8192, SEED=0):
        """
        Simulate Federated Averaging (FedAvg, a.k.a. Local-SGD, or Parallel SGD, etc.)

        Arguments:
            eta:    learning rate
            M:      number of workers
            K:      synchronization interval, (i.e., local steps)
            T:      total parallel runtime
            record_intvl:   compute the population loss every record_intvl steps.

        Return:
            A pandas.Series object of population loss evaluated.
        """
        # set of
        np.random.seed(SEED)
        # weights on nodes
        common_init_w = np.random.randn(D)
        #
        w_pool = np.repeat(common_init_w[np.newaxis, :], M, axis=0)

        seq = pd.Series(name='loss')
        for iter_cnt in range(T+1):
            if iter_cnt % K == 0:
                w_pool = broadcast_avg(w_pool)

                if iter_cnt % record_intvl == 0:
                    seq.at[iter_cnt] = loss(w_pool[0, :])

            w_pool -= eta * calc_grad
        return seq

In [None]:
calc_grad()

TypeError: ignored

## Smooth function