# Overview
In this notebook, we try to replicate 2 of the paper experiments: optimizing the logistic regression loss with

1. Stochastic Local Gradient Descent VS Stochastic Scaffnew

2. Local Gradient Descent (Deterministic) Vs Scaffnew

# Imports and Data splitting 
Please take a look at the 'exp_setup' and 'optimization_utilities' files for the definition of most functions 

Let's download an crucial package first

In [None]:
# ! pip install git+https://github.com/konstmish/opt_methods.git

In [None]:
import os
import urllib
import math
import random

import numpy as np
from sklearn import datasets
from typing import Sequence, Union, List, Tuple
from numpy.linalg import norm
from tqdm import tqdm

np.random.seed(69)
random.seed(69)

In [None]:
from exp_setup import split_into_batches, download_dataset
# download the data
DATA, LABELS = download_dataset()

# the number of devices used across the notebook
NUM_DEVICES = 16
DEVICES_DATA, DEVICES_LABELS = split_into_batches(DATA, y=LABELS, even_split=False, batch_size=int(math.ceil(DATA.shape[0] /NUM_DEVICES)))

In [None]:
# calculate the minimum number of samples in each device 
MIN_DEVICE_SIZE = min([d_data.shape[0] for d_data in DEVICES_DATA])
# set the batch size
BATCH_SIZE = MIN_DEVICE_SIZE // 4
print(BATCH_SIZE)

In [None]:
# ESTIMAT THE SMOOTHNESS CONSTANT
from exp_setup import L_estimation
PROBLEM_L = L_estimation(DEVICES_DATA, DEVICES_LABELS)
# SET THE REGULARIZATION PARAMETERS AS IN THE PAPER
PROBLEM_LAMBDA = PROBLEM_L * 10 ** -4
print(PROBLEM_L, PROBLEM_LAMBDA)
# set the LEARNING RATE: 
LEARNING_RATE = 1 / (PROBLEM_L)

In [None]:
from functools import partial
from exp_setup import lr_loss, stochastic_lr_gradient, stochastic_lr_loss, lr_gradient

DETERMINISTIC_FUNTION =partial(lr_loss, lam=PROBLEM_LAMBDA)
DETERMINISTIC_GRADIENT_FUNCTION = partial(lr_gradient, lam=PROBLEM_LAMBDA)
STOCHASTIC_FUNCTION = partial(stochastic_lr_loss, lam=PROBLEM_LAMBDA, batch_size=BATCH_SIZE)
STOCHASTIC_GRADIENT_FUNCTION = partial(stochastic_lr_gradient, lam=PROBLEM_LAMBDA, batch_size=BATCH_SIZE)

In [None]:
import optimization_utilities as opt
# create the set up: x_0 and seed
def set_up(seed: int = 69) -> np.ndarray:
        # changing the seed mainly changes the starting point
    random.seed(seed)
    np.random.seed(seed)
    w_0 = np.random.randn(DATA.shape[1], 1)
    return w_0

# find the minimum value of the function
TRUE_MIN = opt.find_x_true(DATA, LABELS, lam=PROBLEM_LAMBDA) 
print(TRUE_MIN)

# Stochastic Case

## Stochastic Local Gradient Descent

In [None]:
import matplotlib.pyplot as plt
def plot_iterations(criterions: List[float],
                    start_index: int = 0, 
                    end_index: int = -1,
                    plot_label: str = None,
                    x_label: str = None,
                    y_label: str = None,
                    show:bool = True,
                    ):
    
    end_index = (end_index + len(criterions)) % len(criterions)

    if plot_label is None:
        plt.plot(list(range(start_index, end_index)), criterions[start_index:end_index])
    else:
        plt.plot(list(range(start_index, end_index)), criterions[start_index:end_index], label=str(plot_label))
    

    plt.xlabel('iteration' if x_label is None else x_label)
    plt.ylabel('criterion (log_{10} scale)' if x_label is None else y_label)
    
    if show:
        plt.legend()
        plt.show()

In [None]:
K =  10 ** 5
NUM_LOCAL_STEPS = 200
COMMUNICATION_ROUNDS = K // NUM_LOCAL_STEPS
CRITERION = lambda x: np.mean([lr_loss(d_data, d_label, x, PROBLEM_LAMBDA) for d_data, d_label in zip(DEVICES_DATA, DEVICES_LABELS)]) - TRUE_MIN
X0 = set_up()

INITIAL_VALUE = CRITERION(X0)

In [None]:
import optimization_utilities as opt
import importlib
importlib.reload(opt)

local_gd_xpoints, local_gd_criterions = opt.localGD(
                                                num_local_steps=NUM_LOCAL_STEPS,
                                                device_data=DEVICES_DATA, 
                                                device_labels=DEVICES_LABELS,
                                                function=DETERMINISTIC_FUNTION,
                                                gradient_function=STOCHASTIC_GRADIENT_FUNCTION,            
                                                x_0=X0,
                                                mode=CRITERION, 
                                                gamma_k=lambda _ :1 / LEARNING_RATE, 
                                                K=K
                                                )
local_gd_criterions = [INITIAL_VALUE] + local_gd_criterions
# convert the criterion value1s to log 'scale'
local_gd_log_criterions = [np.log10(c) for c in local_gd_criterions]

In [None]:
plt.figure(figsize=(14, 10))
plot_iterations(criterions=local_gd_log_criterions, 
                x_label='communication rounds', 
                y_label='f(x) - f(*): log 10', 
                show=False)
plt.yticks(np.linspace(np.min(local_gd_log_criterions), np.max(local_gd_log_criterions), num=20))
plt.title(f"Stochastic Local GD with {NUM_DEVICES} devices and {NUM_LOCAL_STEPS} local steps")
plt.show()

## Stochastic ProxSkip

In [None]:
# the optimial prox skip probability
PROX_SKIP_PROBABILITY = np.sqrt(PROBLEM_LAMBDA / LEARNING_RATE)
# add some extra 
PROX_SKIP_K =  K ** 2
print(PROX_SKIP_PROBABILITY)

In [None]:
import optimization_utilities as opt
import importlib
importlib.reload(opt)
prox_xpoints, prox_criterions = opt.proxSkipFL(
            devices_data=DEVICES_DATA, 
            devices_labels=DEVICES_LABELS,
            function=DETERMINISTIC_FUNTION,
            gradient_function=STOCHASTIC_GRADIENT_FUNCTION,
            skip_probability=PROX_SKIP_PROBABILITY, 
            communication_rounds=COMMUNICATION_ROUNDS,
            x_0=set_up(), 
            max_iterations=PROX_SKIP_K, 
            gamma_k=lambda _ : 1 / LEARNING_RATE,
            mode=CRITERION,
            report_by_prox=50,
            )

# add the initial value 
prox_criterions = [INITIAL_VALUE] + prox_criterions
prox_log_criterions = [np.log10(max(c, 10 ** -8)) for c in prox_criterions]

In [None]:
local_gd_log_criterions = [np.log10(INITIAL_VALUE)] + local_gd_log_criterions
prox_log_criterions = [np.log10(INITIAL_VALUE)] + prox_log_criterions

In [None]:
plt.figure(figsize=(14, 10))

plot_iterations(criterions=local_gd_log_criterions, 
                x_label='communication rounds', 
                y_label='f(x) - f(*): log 10', 
                plot_label='local GD',
                show=False)

plot_iterations(criterions=prox_log_criterions, 
                x_label='communication rounds', 
                y_label='f(x) - f(*): log 10', 
                plot_label='Prox Skip',
                show=False)

plt.yticks(np.linspace(
                        start=min(np.min(local_gd_log_criterions), np.min(prox_log_criterions)), 
                        stop=max(np.max(local_gd_log_criterions), np.max(prox_log_criterions)),num=20
                                 ))
plt.legend()
plt.title(f'S Prox Skip, p:{round(PROX_SKIP_PROBABILITY, 5)}, SLGD: {NUM_DEVICES} devices, {NUM_LOCAL_STEPS} local steps')
plt.show()

# Determinsistic Case

In [None]:
K =  10 ** 5
NUM_LOCAL_STEPS = 30
COMMUNICATION_ROUNDS = K // NUM_LOCAL_STEPS

In [None]:
X0 = set_up()
INITIAL_VALUE = CRITERION(X0)
INITIAL_VALUE

In [None]:
local_gd_xpoints, local_gd_criterions = opt.localGD(
                                                num_local_steps=NUM_LOCAL_STEPS,
                                                device_data=DEVICES_DATA, 
                                                device_labels=DEVICES_LABELS,
                                                function=DETERMINISTIC_FUNTION,
                                                gradient_function=DETERMINISTIC_GRADIENT_FUNCTION,            
                                                x_0=X0,
                                                mode=CRITERION, 
                                                gamma_k=lambda _ :1 / 4 * PROBLEM_L, 
                                                K=K
                                                )
local_gd_criterions = [INITIAL_VALUE] + local_gd_criterions 
local_gd_log_criterions = [np.log10(c) for c in local_gd_criterions] 

In [None]:
PROX_SKIP_PROBABILITY = np.sqrt(PROBLEM_LAMBDA / (4 * PROBLEM_L))
PROX_SKIP_K =  K ** 2
PROX_SKIP_PROBABILITY

In [None]:
prox_xpoints, prox_criterions = opt.proxSkipFL(
            devices_data=DEVICES_DATA, 
            devices_labels=DEVICES_LABELS,
            function=DETERMINISTIC_FUNTION,
            gradient_function=DETERMINISTIC_GRADIENT_FUNCTION,
            skip_probability=PROX_SKIP_PROBABILITY, 
            communication_rounds=COMMUNICATION_ROUNDS,
            x_0=set_up(), 
            max_iterations=PROX_SKIP_K, 
            gamma_k=lambda _ : 1 / 4 * PROBLEM_L,
            mode=CRITERION,
            report_by_prox=50
            )

prox_criterions = [INITIAL_VALUE] + prox_criterions
prox_log_criterions = [np.log10(c) for c in prox_criterions]

In [None]:
plt.figure(figsize=(14, 10))

plot_iterations(criterions=local_gd_log_criterions, 
                x_label='communication rounds', 
                y_label='f(x) - f(*): log 10', 
                plot_label='local GD',
                show=False)

plot_iterations(criterions=prox_log_criterions, 
                x_label='communication rounds', 
                y_label='f(x) - f(*): log 10', 
                plot_label='Prox Skip',
                show=False)

plt.yticks(np.linspace(
                        start=min(np.min(local_gd_log_criterions), np.min(prox_log_criterions)), 
                        stop=max(np.max(local_gd_log_criterions), np.max(prox_log_criterions)),num=20
                                 ))
plt.legend()
plt.title(f'Prox Skip, p:{round(PROX_SKIP_PROBABILITY, 5)}, LGD: {NUM_DEVICES} devices, {NUM_LOCAL_STEPS} local steps')
plt.show()