In [1]:
%reload_ext autoreload
%autoreload 2
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import bernoulli
from oracles import LogReg
from methods import Standard_Newton,\
Basic_method, PositiveCase_method, GeneralCase_method
from utils import function_plot_builder, bits_plot_builder, bits_plot_builder1
from methods import CubicMaxNewton, CubicMaxNewtonP
from methods import DINGO, GeneralCase_methodP, PositiveCase_methodP
from easydict import EasyDict
from methods import diana, adiana, dcgd
from utils import loss_logistic, grad

In [2]:
np.random.seed(42)

In [3]:
dataset_path = './Datasets/a2a.txt'
data_name = 'a2a'

In [4]:
# regularization parameter
lmb = 1e-3

In [5]:
# number of nodes, size of local data, and number of weights
N = 2265    
n = 15         
m = 151             
d = 123 

In [6]:
# data reading
b = np.zeros((N,))   
A = np.zeros((N, d))

f = open(dataset_path, 'r')
for i, line in enumerate(f):
    if i < N:
        line = line.split()
        for c in line:
            if c == '+1':
                b[i] = 1
            elif c == '-1':
                b[i] = -1
            elif c == '\n':
                continue
            else:
                c = c.split(':')
                A[i][int(c[0]) - 1] = float(c[1]) 
                
f.close()

In [7]:
# create logistic regression problem
logreg = LogReg(A=A, b=b, reg_coef=lmb, n=n, m=m, d=d)

In [None]:
# find optimal solution using Newton's method
SN = Standard_Newton(logreg)
SN.find_optimum(np.zeros(d), n_steps=20)
x_opt = logreg.get_optimum()

In [9]:
# define shift
shift = np.ones(d)*0.1

## Gradient type methods

Note that we implemented gradient type methods which return total number of bits for one node. It means that if you want to obtain the total number of bits sent by all nodes you should multiply it by

- $2n$ for ADIANA
- $n$ for DIANA
- $n$ for DCGD

In [10]:
class args(EasyDict):
    def __init__(self, data_name, T, node, L, lamda, eta=0.05, alpha=0.5, theta_1=0.25, theta_2=0.5, gamma=0.5,
                 beta=0.95,
                 prob=1,
                 comp_method='no_comp',
                 r=None, s=None, plotn=100, ID=1, s_level=10):
        super().__init__()
        self.ID = ID
        self.data_name = data_name
        self.T = T
        self.plotn = plotn
        self.node = node
        self.eta = eta
        self.L = L
        self.lamda = lamda
        self.alpha = alpha
        self.theta_1 = theta_1
        self.theta_2 = theta_2
        self.gamma = gamma
        self.beta = beta
        self.prob = prob
        self.comp_method = comp_method
        self.r = r
        self.s = s
        self.s_level = s_level

##### ADIANA

In [11]:
# find estimation of L

H = np.dot(A.T,A)/N
temp = np.linalg.eigvalsh(H)
L = np.abs(temp[-1])/4

In [12]:
# define parameters of methods

max_iter = 10000 
arg = args(data_name, max_iter, n, L, lmb)
arg.r = d/4
arg.s = np.sqrt(d)
arg.plotn = 200
comp_methods = [
     'rand_sparse'
]
arg.comp_method = comp_methods[0]

In [None]:
x = x_opt + shift
loss_adiana, com_bits_adiana = adiana(A, b, x, arg, f_opt=loss_logistic(A, b, x_opt, arg), tol=1e-5)

##### DIANA

In [14]:
# find estimation of L

H = np.dot(A.T,A)/N
temp = np.linalg.eigvalsh(H)
L = np.abs(temp[-1])/4

In [15]:
# define parameters of methods

max_iter = 10 
arg = args(data_name, max_iter, n, L, lmb)
arg.r = d/4
arg.s = np.sqrt(d)
arg.plotn = 200
comp_methods = [
     'natural_comp'
]
arg.comp_method = comp_methods[0]

In [None]:
x = x_opt + shift
loss_diana, com_bits_diana = diana(A, b, x, arg, f_opt=loss_logistic(A, b, x_opt, arg), tol=1e-2)

##### DCGD

In [17]:
# find estimation of L

H = np.dot(A.T,A)/N
temp = np.linalg.eigvalsh(H)
L = np.abs(temp[-1])/4

In [18]:
# define parameters of methods

max_iter = 10 
arg = args(data_name, max_iter, n, L, lmb)
arg.r = d/4
arg.s = np.sqrt(d)
arg.plotn = 200
comp_methods = [
     'rand_dithering'
]
arg.comp_method = comp_methods[0]

In [None]:
x = x_opt + shift
loss_dcgd, com_bits_dcgd = dcgd(A, b, x, arg, f_opt=loss_logistic(A, b, x_opt, arg), tol=1e-2)