In [5]:
# Implements an optimization proceedure 
# that does not require computing gradients.
# It is based on a gradient free inverse function
# approximation algorithm for univariate functions.
# I test it here on the prime number classification
# problem. For all numbers up to 31. The dataset
# here is pretty small but the model is 100,000 
# parameters or so. I get convergence to about 
# 10^{-6} error in about 25,000 function evaluations
# Results vary from run to run

import numpy as np
import copy
X= [];
np.random.seed()

X.append([0,0,0,0,1])
X.append([0,0,0,1,0])
X.append([0,0,0,1,1])
X.append([0,0,1,0,0])
X.append([0,0,1,0,1])
X.append([0,0,1,1,0])
X.append([0,0,1,1,1])
X.append([0,1,0,0,0])
X.append([0,1,0,0,1])
X.append([0,1,0,1,0])
X.append([0,1,0,1,1])
X.append([0,1,1,0,0])
X.append([0,1,1,0,1])
X.append([0,1,1,1,0])
X.append([0,1,1,1,1])
X.append([1,0,0,0,0])
X.append([1,0,0,0,1])
X.append([1,0,0,1,0])
X.append([1,0,0,1,1])
X.append([1,0,1,0,0])
X.append([1,0,1,0,1])
X.append([1,0,1,1,0])
X.append([1,0,1,1,1])
X.append([1,1,0,0,0])
X.append([1,1,0,0,1])
X.append([1,1,0,1,0])
X.append([1,1,0,1,1])
X.append([1,1,1,0,0])
X.append([1,1,1,0,1])
X.append([1,1,1,1,0])
X.append([1,1,1,1,1])

X=np.array(X)
Y=np.array([1,1,1,0,1,0,1,0,0,0,1,0,1,0,0,0,1,0,1,0,0,0,1,0,0,0,0,0,1,0,1])
W1 = (np.random.random((100,5))-.5)
W2 = (np.random.random((100,100))-.5)
W4 = (np.random.random((100,100))-.5)
W3 = np.random.random((1,100))-.5

def printar(x):
    theS = "[";
    for i in range(0,np.shape(x)[0]):
        theS=theS+"[";
        for j in range(0,np.shape(x)[1]):
            theS=theS+str(x[i,j])+",";
        theS=theS+"]\n";
    theS =theS+"]";
    print(theS)
        

def relu(x):
    return x*(x>=0)

def forwardpass(X,Y,W1,W2,W3,W4):
    ans = 0
    for i in range(0,np.shape(X)[0]):
        h1 = relu(np.dot(W1,X[i,:]));
        h2 = relu(np.dot(W2,h1));
        #h3 = relu(np.dot(W4,h2));
        #h4 = np.dot(W3,h3);
        h4 = np.dot(W3,h2);
        ans += (h4[0]-Y[i])**2
    return ans;
    
import bisect

def nnFunct(x):
    global X
    global Y
    global W1
    global W2
    global W3
    global W4
    global dW1
    global dW2
    global dW3
    global dW4
    return forwardpass(X,Y,W1+x*dW1,W2+x*dW2,W3+x*dW3,W4+dW4*x)

# This function implements a gradient 
# free optimization proceedure for a function
# Of one variable. forward is a function
# pointer, and inverse is an array that 
# stores visited points that parameterizes
# our inverse function approximation
# target is the value we are trying to reach
# with the univariate function, step=0 means
# we will try to go directly to the target, step>0
# means we will take a step of size step toward
# the target. Inverse is assumend to be sorted in 
# order of distance to the target.

def tempF(x):
    return (np.exp(x)-1)**2;

def tempF2(x):
    return np.sin(x);
    
    
def gradFreeOpt(forward, inverse, target=0, step=0.01, randomError=.005):
    # Gets the two closes points to the target and interpolates
    # the inverse function by these points with a line. Gets estimate
    # of input to get target output
    
    # Get current best:
    current = inverse[0,1]
    # get next target
    if step > 0 and step < np.abs(target-current):
        if target < current:
            nextT = -step + current
        if target >=current:
            nextT = current + step
    else:
        nextT = target
    # Compute g inverse estimation
    bob = np.array([[inverse[0,0]**2,inverse[0,0],1],[inverse[1,0]**2,inverse[1,0],1],[inverse[2,0]**2,inverse[2,0],1]])
    sally=np.linalg.solve(bob,inverse[:,1]);
    # If nextT is reachable, since our inverse function approximation is
    # a quadratic function, nextT will be reached at 2 different locations
    # we choose the one closest to the center of our three reference points
    quot = sally[1]**2-(4*sally[0]*(sally[2]-nextT));
    if quot > 0:
        pos = (-sally[1]+np.sqrt(quot))/(2*sally[0])
        neg = (-sally[1]-np.sqrt(quot))/(2*sally[0])
        if np.abs(pos-np.mean(inverse[:,0]))<np.abs(neg-np.mean(inverse[:,0])):
            est =pos;
        else:
            est = neg;
    else:
        # If not reachable set the estimation to the 
        # critical point
        est = -sally[1]/(2*sally[0])
    #slope=(inverse[1,0]-inverse[0,0])/(inverse[1,1]-inverse[0,1])
    #est = slope*(nextT-inverse[0,1])+inverse[0,0]
    
    #print("estimation: "+str(est));
    r=(np.random.randn(1)-.5)*randomError;
    prop = forward(est+r);
    #print("True forward: "+str(prop));
    if np.abs(target -prop) < np.abs(inverse[0,1]-target):
        inverse[1,:]=inverse[0,:];
        inverse[0,:]=[est+r,prop];
    elif np.abs(target-prop) > np.abs(inverse[0,1]-target) and np.abs(target-prop) < np.abs(inverse[1,1]-target):
        inverse[2,:]=inverse[1,:]
        inverse[1,:]=[est+r,prop]
    else :
        inverse[2,:]=[est+r,prop];
    #printar(inverse)
    return inverse



print(str(tempF(-1)));
print(str(tempF(1)));

dW1 = (np.random.random(np.shape(W1))-.5)
dW2 = (np.random.random(np.shape(W2))-.5)
dW3 = np.random.random(np.shape(W3))-.5
dW4 = (np.random.random(np.shape(W4))-.5)
inverse=np.array([[.05,nnFunct(.05)],[-.01,nnFunct(-.01)],[-.05,nnFunct(-.05)]])
printar(inverse)
inverse=inverse[inverse[:,1].argsort()]
printar(inverse)



functevals=0
while np.abs(inverse[0,1]-0) >= .00001 and functevals<100000:
    dW1 = (np.random.random(np.shape(W1))-.5)
    dW2 = (np.random.random(np.shape(W2))-.5)
    dW3 = np.random.random(np.shape(W3))-.5
    dW4 = (np.random.random(np.shape(W4))-.5)
    best2=inverse[0,1];
    abB=np.abs(best2)
    if abB >.05:
        abB=.05
    inverse=np.array([[abB,nnFunct(abB)],[-abB/5,nnFunct(-abB/5)],[-abB,nnFunct(-abB)]])
    inverse=inverse[inverse[:,1].argsort()]
    functevals += 3;
    step=0;
    #counter=0;
    for i in range(0,10):
        best=inverse[0,1]
        inverse = gradFreeOpt(nnFunct,inverse,0,step);
        #if inverse[0,1]==best:
            #counter += 1
        if np.abs(inverse[0,1]) <= .000005:
            break;
        functevals += 1
        #if counter == 15:
            #break
    if np.abs(best2-0)>np.abs(inverse[0,1]-0):
        W1 = W1+inverse[0,0]*dW1;
        W3 = W3+inverse[0,0]*dW3;
        W2 = W2+inverse[0,0]*dW2;
        W4= W4+inverse[0,0]*dW4;
    print(inverse[0,1])
print(functevals)

0.39957640089372803
2.9524924420125593
[[0.05,58.03888736930223,]
[-0.01,61.64284511041864,]
[-0.05,64.84965169784702,]
]
[[0.05,58.03888736930223,]
[-0.01,61.64284511041864,]
[-0.05,64.84965169784702,]
]
35.028593259355326
32.48341465358538
31.860308879036857
30.753712373603705
29.953520473602847
21.79008840478658
18.637262762167545
16.407956106109022
16.183945506008616
15.99163622784181
15.67353442509105
10.527956779670982
10.52139913139614
10.472031333412279
10.165796815176254
10.092899850128454
9.598665681404952
9.262749246510287
9.220285414051723
8.934446559248686
8.933931259830345
8.935668875064819
8.77556710227494
8.780479191803943
8.729415552344047
8.729065520382399
8.715620000016113
8.110097236881407
7.996290191448007
7.87711641736777
7.862832169690015
7.7713276205077
7.662808714105268
7.51042960634516
7.087206165552187
7.069780155161563
7.051022618447171
6.520509916503318
6.492601330432821
6.480532107823184
6.414652891979841
6.419012108594375
6.301137211700336
6.3011343592641

0.22126246184258963
0.2179998256393532
0.21288071683952325
0.21040731549468014
0.20918219426922624
0.20915224707528532
0.20914721025141852
0.2055756668858696
0.20609119796379866
0.20723526864069625
0.21200959156948757
0.20559199245461107
0.20730665449972493
0.20421156911925809
0.20145483261381614
0.20065695595533953
0.1995048716924603
0.1997924996273045
0.19733066729939583
0.1947824567848397
0.19453179373347357
0.19654212006911506
0.194367265455075
0.1955005153816833
0.19313383217634725
0.1912343938837581
0.18812641267004687
0.18811020933967315
0.18883472400997015
0.18782287590718846
0.1876819494602392
0.18721791241421895
0.18715517379921334
0.18678704628852366
0.18657909966607072
0.184433408706895
0.18231351678455826
0.1793476724889289
0.17834404242676444
0.1749578864208448
0.17508409092838395
0.17437904906471155
0.17732917707115953
0.17402696781603613
0.1740863430954539
0.17281319247167476
0.17279999325425907
0.17429228486707163
0.1706043786871031
0.17054047773182618
0.17000731180651

0.02759884485514907
0.02719869864899134
0.02712100855079431
0.028643956503771977
0.026986493479526426
0.028245306604586448
0.026787922492471763
0.02811505124247273
0.030638681600388448
0.026793225153899716
0.025419447709260835
0.02549458462529955
0.02541280210046165
0.02519465306706838
0.024822426766544517
0.0244710695589748
0.024566255155074698
0.024682750693232475
0.025057469299583423
0.024751625921393
0.024639737351136174
0.0246344482828863
0.02483313389795293
0.0247102199227226
0.024004637832885423
0.02335134629306428
0.02315170105549788
0.023220484741434396
0.023084650196452308
0.023084967676677264
0.023071508632757434
0.023166929076414447
0.022415812384087467
0.022446515146612087
0.026490136986232107
0.022612073812066192
0.022425245868127032
0.02267808062288936
0.023133257101524664
0.022448710443397225
0.022181359163652346
0.021950777916424524
0.021945864299583766
0.022118080085353257
0.02194312904127051
0.022465861943274802
0.02316136358401069
0.021889433428521555
0.021943526666

0.0056653061303930385
0.005728601118009668
0.005851481732000555
0.005688585211345585
0.005517436144975099
0.005522422223251351
0.00545687370000197
0.005603965495269876
0.006471824074725693
0.00554203431116024
0.005408422578330049
0.005251665827335387
0.005490882239520305
0.005253997838964704
0.005364774616411077
0.005458190157461671
0.00525550736267174
0.005254572283120579
0.006783385170212679
0.005250222555262438
0.005249111033304064
0.005664898118212942
0.00524787387075123
0.005311869301414573
0.005618284463018236
0.005548031314310736
0.005627628660561754
0.005452673734755375
0.005629451775134247
0.005939319843462021
0.0055503467257760454
0.005546161924792952
0.00554483111944898
0.005737959705992299
0.0058041385117300085
0.005524204515781681
0.005755615572738276
0.0070559817394744385
0.006269727594421063
0.006205667911378904
0.006120887597659961
0.005995515212521039
0.006460310717462349
0.00599970716459008
0.006112298242661812
0.006446933977962348
0.005918914775094647
0.0059159305243

0.0013964649951740292
0.0013817777165410648
0.0013961680332023974
0.0014270156429293597
0.0014552035624043839
0.0014683202874966994
0.0013929493145898632
0.001431158495037001
0.0014175884637831395
0.0014088833602167701
0.0013604519347494716
0.0014441279894965829
0.001367895045279635
0.0013615319663586368
0.0013722402032268547
0.0013608057395663544
0.001404077641373499
0.0013737412126807354
0.0013304591157410105
0.0013301530933237553
0.0014204901530186164
0.0013665394980183604
0.0014190225775268298
0.0013206860928033453
0.0012960262414966628
0.0016108090668668115
0.001351947067346952
0.0013779666779854519
0.0013697145923959393
0.0015648465978297637
0.0013060350682468802
0.0012927435482116547
0.0013071412533236244
0.0013066287867801907
0.001292445237631954
0.0013608429521155515
0.0013081524687610054
0.001384142780706244
0.001305319008471903
0.0012941041882956706
0.001350003453220034
0.0013434891385500512
0.0013585857188107636
0.0012951899738513327
0.001327813650059767
0.00136440075237257

0.0004617408476990757
0.0004581263615302923
0.0004604437118762976
0.00045528849057416894
0.00046337716983695657
0.00045536359155540804
0.0004544482418995428
0.0004534485656598834
0.00048295763150127494
0.0004543039554682909
0.00044944555952163305
0.00044803114580413525
0.0004535865470932954
0.00045208326444838657
0.0004557152218808298
0.0004450880993414506
0.00045599495613377647
0.00044449743032984513
0.00044833587895989457
0.0004385988344617192
0.00043852728964168127
0.00043786416981921733
0.00043412699419083465
0.0004388237149148066
0.00043813252996027427
0.0004435536438181748
0.00043926748400005176
0.00044441354809329525
0.00043781877495073474
0.0004429611853744368
0.00045065765568627094
0.0004487911032384015
0.0004391111000041631
0.00044567675030268776
0.0004416775148881085
0.00044532551895474696
0.0004298939962272203
0.00041838287180506804
0.0004218948914493578
0.0004177966633989862
0.00042776545229769033
0.0004178440461715843
0.0004056816033038554
0.00040493540093152164
0.0004044

9.270469308784245e-05
9.238078186627896e-05
9.084899652150398e-05
9.04623521096434e-05
9.032372681944244e-05
9.032293411266127e-05
9.024093880333175e-05
9.063780767338108e-05
9.059378076338147e-05
9.04872704457627e-05
9.092079002292881e-05
9.071023669563587e-05
8.961675847131812e-05
9.120569197306078e-05
8.977154276203996e-05
8.991940431935255e-05
8.969322281904259e-05
8.911455503818515e-05
8.711584951457237e-05
8.652445272701794e-05
8.360799584498207e-05
8.215514805860146e-05
8.179977278086558e-05
8.215031508442355e-05
8.1159875347175e-05
8.095611462233275e-05
8.131837725886976e-05
7.939945951582299e-05
7.631279560137025e-05
7.648914386178294e-05
7.62184040774932e-05
7.548180481042039e-05
7.547954655043895e-05
7.621449026455053e-05
7.549678383513276e-05
7.510860315625344e-05
7.3878722526308e-05
7.417565324307733e-05
7.402989475075085e-05
7.397212995533583e-05
7.32794710350159e-05
7.345039301587478e-05
7.357975555427182e-05
7.298668521534134e-05
7.331502653685548e-05
7.275821701454667e

1.451911239433563e-05
1.4309419367990487e-05
1.4238415823527766e-05
1.407297468418514e-05
1.3938413961534601e-05
1.3827812328572506e-05
1.3820173420427917e-05
1.353173252686637e-05
1.3540522149160065e-05
1.3544315226033449e-05
1.352700916160948e-05
1.351362158055872e-05
1.343162939714908e-05
1.3459574512932132e-05
1.3362660436634655e-05
1.3232768617497692e-05
1.3245710866842829e-05
1.3220510418455371e-05
1.3218415019227136e-05
1.323236062214121e-05
1.3222587087948978e-05
1.3176670435268143e-05
1.2892710821437263e-05
1.2761618424860772e-05
1.2752846410841335e-05
1.2759881044124365e-05
1.2762463118216024e-05
1.2517666329668748e-05
1.2520645966052006e-05
1.252095770623911e-05
1.2475951620562902e-05
1.2486637442950802e-05
1.2466095463045816e-05
1.2370042582791077e-05
1.2276748737787461e-05
1.2302236252348739e-05
1.2123124235046474e-05
1.197513909223501e-05
1.198281278375539e-05
1.1934081542628438e-05
1.194829587570028e-05
1.1916655091920455e-05
1.1928168424065908e-05
1.1908336029833877e-05