In [3]:
%pylab inline
%reload_ext autoreload
%autoreload 2
import numpy as np
import pickle
import scipy
from methods import BFGS
from oracles import create_log_reg_oracle
from tqdm import tqdm_notebook

Populating the interactive namespace from numpy and matplotlib


In [4]:
np.random.seed(42)

In [5]:
dataset_path = './Datasets/a2a.txt'
data_name = 'a2a'

In [6]:
# regularization parameter
lmb = 1e-3

In [7]:
# number of nodes, size of local data, and number of weights
N = 2265    
n = 15         
m = 151             
d = 123 

In [8]:
# data reading
b = np.zeros((N,))   
A = np.zeros((N, d))

f = open(dataset_path, 'r')
for i, line in enumerate(f):
    if i < N:
        line = line.split()
        for c in line:
            if c == '+1':
                b[i] = 1
            elif c == '-1':
                b[i] = -1
            elif c == '\n':
                continue
            else:
                c = c.split(':')
                A[i][int(c[0]) - 1] = float(c[1]) 
                
f.close()

In [9]:
# define shift
shift = np.ones(d)*0.1

In [10]:
# read optimum
x_opt = np.loadtxt('a2a_optimum.txt')

In [11]:
x = x_opt + shift

In [12]:
# function which returns Hessian of f_{ij} at x
def Hessian(x, i, j): 
    l = i*m+j
    alpha = b[l]**2*np.exp(-b[l]*A[l].dot(x))/(1+np.exp(-b[l]*A[l].dot(x)))**2
    ans = alpha*A[l].reshape((d,1)).dot(A[l].reshape(1,d))
    return ans

# function which returns Hessian of P at x
def full_Hessian(x):
    B = np.zeros((d,d))
    for i in range(n):
        for j in range(m):
            B += 1/N*Hessian(x, i, j)
    return B

In [13]:
# initial approxiamtion of Hessian
H_0 = (full_Hessian(x)+lmb*np.eye(d))

In [14]:
# create logistic regression problem 
oracle = create_log_reg_oracle(A, b, lmb)

In [None]:
method = BFGS(oracle,x, tolerance=1e-15, stopping_criteria=None,\
                  line_search_options={'method': 'Constant'})
method.run(max_iter=100, max_time=10)

In [16]:
# save results of BFGS
np.savetxt('{}_bfgs_lmb={}.txt'.format(data_name, lmb),method.hist['func'], fmt='%4.16f',delimiter='\n')