In [1]:
using Random
using ScikitLearn 
@sk_import svm:SVR
using ScikitLearn.GridSearch: GridSearchCV
using ScikitLearn: fit!, predict
using ScikitLearn
@sk_import metrics:mean_squared_error
import Distributions: Uniform 
using Distributions
using LinearAlgebra
using MLDataUtils
using PyCall

In [2]:
using Plots

In [129]:
using CSV
using DataFrames

In [3]:
function random_simplex(dim, M)
    t = rand(1:M, dim-1)
    t = append!(t,M)
    t_sort = sort(t)
    p = [t_sort[i+1]- t_sort[i] for i in 1:(dim-1)]
    p = append!(p,t_sort[1])
    shuffle!(p)/M
end
   

function opt_solution_oracle(prob_vec)
    p_plus = prob_vec[1]
    p_nega = prob_vec[2]
    out = min(log(p_plus/p_nega), 10)
    out = max(out, -10)
end

function softmax(x)
#     max = maximum(x)
#     e_x = collect(exp(elem -max) for elem in x)
#     e_x/sum(e_x)
    exp.(x)./sum(exp.(x))
end

function cost_func(w,z)
    log(1+exp(-z*w))
end

function grad_cost_func(w,z)
    -z/(1+ exp(z*w))
end

grad_cost_func (generic function with 1 method)

# Generate samples and fit the approximate oracle

In [4]:
k =2 # dim of prob
d = 1 # dim of solution
n1 = 1000
p_samples = collect(random_simplex(2, n1) for i in 1:n1)
w_samples = collect(opt_solution_oracle(p) for p in p_samples);

idx = shuffle(1:n1)
idx1 = idx[1:Int(0.8*n1)]
idx2 = idx[Int(0.8*n1)+1:end]
p_train = p_samples[idx1]
w_train = w_samples[idx1]
p_test = p_samples[idx2]
w_test = w_samples[idx2];

In [6]:
py"""
from sklearn.svm import SVR
import numpy as np
import numdifftools as nd
#from sklearn.grid_search import GridSearchCV
from sklearn.model_selection import learning_curve,GridSearchCV
from sklearn.metrics import mean_squared_error
def fit_sol_map_SVR(df_p_train, df_w_train, df_p_test, df_w_test):
#     X = df_p_train.values
    X = df_p_train
#     y1 = df_w_train['w1'].values
    y1 = df_w_train
    svr = GridSearchCV(SVR(kernel='rbf'), cv=5,
                   param_grid={"C": [1e-1, 1e0, 1e1, 1e2],
                               "gamma": np.logspace(-2, 2, 10)})
    svr_model1 = svr.fit(X, y1)
    w1_pred = svr.predict(df_p_test)
    print('w1 Mean squared error:', mean_squared_error(df_w_test, w1_pred))
    return svr_model1

def approx_sol_oracle(prob_test, model1, grad=True):
    dim_prob = 2
    fun1 = lambda x: model1.predict(x.reshape(1, dim_prob)).reshape(-1,1)
    w1 = fun1(prob_test)
    w = np.array([w1]).reshape(-1,1)
    if grad == True:
        grad1 = nd.Gradient(fun1)(np.array([[list(prob_test)]]))
        gradw = np.array([grad1])
        return w, gradw
    else:
        return w
def approx_sol_oracle2(prob_test, model1, grad=False):
    dim_prob = 2
    fun1 = lambda x: model1.predict(x.reshape(1, dim_prob)).reshape(-1,1)
    w1 = fun1(prob_test)
    w = np.array([w1]).reshape(-1,1)
    if grad == True:
        grad1 = nd.Gradient(fun1)(np.array([[list(prob_test)]]))
        gradw = np.array([grad1])
        return w, gradw
    else:
        return w
"""

In [7]:
svr1 = py"fit_sol_map_SVR"(p_train, w_train, p_test, w_test)

PyObject GridSearchCV(cv=5, estimator=SVR(),
             param_grid={'C': [0.1, 1.0, 10.0, 100.0],
                         'gamma': array([1.00000000e-02, 2.78255940e-02, 7.74263683e-02, 2.15443469e-01,
       5.99484250e-01, 1.66810054e+00, 4.64158883e+00, 1.29154967e+01,
       3.59381366e+01, 1.00000000e+02])})

## true value of B

In [8]:
Random.seed!(123)
# B_true = rand(0:5,2,6)
B_true = [4  1  0  3  2  5
 2  1  4  4  2  1]

2×6 Array{Int64,2}:
 4  1  0  3  2  5
 2  1  4  4  2  1

# useful functions for nonparametric prescriptive benchmarks 

In [9]:
function grad_prob(B,x)
    temp = 1/(exp(dot(B[1,:]-B[2,:],x)) + exp(dot(B[2,:]-B[1,:],x)))
    [[temp*x, -temp*x], [-temp*x, temp*x]]
end

grad_prob (generic function with 1 method)

In [10]:
function get_batch_index(sample_size, batch_size)
    rand(1:sample_size, batch_size)
end

function get_knn_count(x, x_train,z_train, k)
    num_x = size(x_train)[1]
    dist_list = collect(norm(x - x_train[i,:]) for i in 1:num_x )  
        idx = sortperm(dist_list)[1:k]
        vec = z_train[idx]
        n_plus = length(vec[vec.>0])
        n_minus = length(vec[vec.<0])
        return n_plus, n_minus
end

    function get_kernel_count(x,x_train,z_train, h)
        num_x = size(x_train)[1]
        weight = collect(exp(-(norm(x - x_train[i,:])/h)^2) for i in 1:num_x)
        n_plus = 0
        n_minus = 0
        for i in 1:num_x
            if z_train[i] >0
                n_plus += weight[i]
            else
                n_minus += weight[i]
            end
        end
        return n_plus, n_minus
end 
    

get_kernel_count (generic function with 1 method)

In [11]:
py"""
from sklearn.datasets import load_iris
from sklearn.model_selection import cross_val_score
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
import numpy as np
def get_count_tree(x,x_train, z_train, dep =3 ):
    clf = DecisionTreeClassifier(random_state=0, max_depth = dep)
    clf.fit(x_train,z_train)
    train_nodes = clf.apply(x_train)
    node = clf.apply([x])
    index = np.where(train_nodes == node[0])
    n_plus = 0
    n_minus = 0
#     print(index)
    for ind in index[0]:
        if z_train[ind] > 0:
            n_plus += 1
#             print("plus")
        elif z_train[ind] < 0:
            n_minus += 1
#             print("minus")
    return n_plus, n_minus
# def get_count_rf(x,x_train, z_train, dep =2 ):
#     clf = RandomForestClassifier(random_state=0, max_depth = dep)
#     clf.fit(x_train,z_train)
#     train_nodes = clf.apply(x_train)
#     node = clf.apply([x])
#     index = np.where(train_nodes == node[0])
#     n_plus = 0
#     n_minus = 0
# #     print(index)
#     for ind in index[0]:
#         if z_train[ind] > 0:
#             n_plus += 1
# #             print("plus")
#         elif z_train[ind] < 0:
#             n_minus += 1
# #             print("minus")
#     return n_plus, n_minus
"""

# Generate samples

In [12]:
p = 5
M = 1
num_x0 = 5000
num_t = 2000
num = num_x0 + num_t

7000

In [34]:
x_samples = M*rand(Uniform(0,1), (num,p)); #x without intercept term

In [35]:
x_samples1 = zeros((num,p+1)); #x with intercept term

In [36]:
z_samples = zeros(num)
for i in 1:num
    x_t = append!(x_samples[i,:],1)
    x_samples1[i,:] = x_t
    p_true = softmax(B_true*x_t)
#     println(p_true)
    bern = 2*rand(Bernoulli(p_true[1]),1)[1] - 1
    z_samples[i]= bern    
end   

[0.963205675946752, 0.03679432405324794]
[0.8793990421319782, 0.12060095786802182]
[0.9023120357392462, 0.09768796426075367]
[0.9238910648139094, 0.07610893518609065]
[0.5850753793492612, 0.41492462065073865]
[0.8529636540177036, 0.1470363459822964]
[0.9355307426737441, 0.06446925732625584]
[0.9724880003837815, 0.027511999616218495]
[0.9873513897876895, 0.012648610212310427]
[0.9197850152569323, 0.08021498474306765]
[0.932767118570597, 0.06723288142940291]
[0.9908263327062099, 0.009173667293790178]
[0.9243900535289932, 0.07560994647100687]
[0.9536312001603112, 0.04636879983968877]
[0.8840204736176492, 0.11597952638235076]
[0.9753740561503856, 0.024625943849614394]
[0.9906049814293711, 0.009395018570628862]
[0.8869030051189486, 0.11309699488105142]
[0.9707968956222536, 0.02920310437774635]
[0.9565522622660054, 0.043447737733994504]
[0.9766194363146324, 0.023380563685367537]
[0.7848293520539229, 0.21517064794607718]
[0.9632181266668639, 0.03678187333313596]
[0.985873799257475, 0.01412620

[0.9380014417116342, 0.06199855828836582]
[0.9305712541685319, 0.06942874583146807]
[0.9707452015063625, 0.029254798493637487]
[0.7933430288479881, 0.20665697115201187]
[0.8554429463757933, 0.14455705362420662]
[0.8709194731893233, 0.12908052681067678]
[0.9472795482646091, 0.05272045173539085]
[0.7394358337626493, 0.2605641662373506]
[0.9823048274793331, 0.01769517252066686]
[0.9471240051391381, 0.05287599486086197]
[0.9603705732698197, 0.03962942673018028]
[0.9906181941608074, 0.009381805839192472]
[0.9142214693448597, 0.08577853065514032]
[0.8927447148491977, 0.10725528515080236]
[0.9602973325495556, 0.03970266745044431]
[0.9064318275272757, 0.09356817247272428]
[0.8889191090817871, 0.11108089091821302]
[0.8936534345272364, 0.10634656547276347]
[0.653347040825315, 0.34665295917468497]
[0.7372613793720044, 0.2627386206279957]
[0.9821171675037631, 0.01788283249623679]
[0.922230697860641, 0.07776930213935898]
[0.9831720954164905, 0.01682790458350941]
[0.8403326518194985, 0.1596673481805

[0.6913420358870417, 0.3086579641129583]
[0.819717083893624, 0.1802829161063761]
[0.9738531295300811, 0.026146870469919065]
[0.9933697450562963, 0.0066302549437037665]
[0.9537971566851623, 0.046202843314837795]
[0.9610625774978088, 0.03893742250219119]
[0.979848290272058, 0.02015170972794201]
[0.9550827267381486, 0.04491727326185133]
[0.9934220865296525, 0.006577913470347547]
[0.7436138171964346, 0.2563861828035653]
[0.9411132564891852, 0.05888674351081483]
[0.8378964986828425, 0.16210350131715756]
[0.9432092129382283, 0.05679078706177167]
[0.9466705763346877, 0.0533294236653123]
[0.8252872347408368, 0.1747127652591633]
[0.9903897235446367, 0.009610276455363217]
[0.9704826827220059, 0.029517317277994105]
[0.9199840372642091, 0.08001596273579083]
[0.9223007025172498, 0.07769929748275019]
[0.7860257230299466, 0.21397427697005353]
[0.9595875973899078, 0.04041240261009221]
[0.8972926210157527, 0.10270737898424731]
[0.8756222725024451, 0.12437772749755488]
[0.8930466951234614, 0.10695330487

[0.5055469410094424, 0.49445305899055764]
[0.8774500436332553, 0.12254995636674464]
[0.9389351494902719, 0.06106485050972816]
[0.9750757542634743, 0.024924245736525758]
[0.865421369983906, 0.13457863001609405]
[0.9494348318248641, 0.050565168175135955]
[0.9935047308521976, 0.0064952691478023855]
[0.9644489405190176, 0.03555105948098241]
[0.979722273424001, 0.02027772657599897]
[0.9760212529847585, 0.02397874701524144]
[0.45473420949066, 0.5452657905093401]
[0.8500706995267131, 0.1499293004732869]
[0.9759505081178872, 0.02404949188211286]
[0.9627051228816959, 0.037294877118304114]
[0.9785013368382227, 0.02149866316177732]
[0.7433551518359761, 0.2566448481640239]
[0.9890493853155865, 0.010950614684413572]
[0.956725622595656, 0.043274377404344117]
[0.7359912475306561, 0.26400875246934374]
[0.8496757032686619, 0.15032429673133801]
[0.9155982975661773, 0.08440170243382272]
[0.9155206335276359, 0.08447936647236398]
[0.8096416377548945, 0.1903583622451055]
[0.9769321492088455, 0.0230678507911

[0.9582857028723457, 0.041714297127654364]
[0.949254576857085, 0.05074542314291501]
[0.9443413619541718, 0.05565863804582819]
[0.9235076714358783, 0.07649232856412172]
[0.9676072988699144, 0.032392701130085605]
[0.8975806773342496, 0.1024193226657504]
[0.787311513461025, 0.21268848653897499]
[0.988172319054557, 0.011827680945442943]
[0.6567136448312194, 0.3432863551687806]
[0.9791095001426963, 0.020890499857303752]
[0.9638810691928418, 0.03611893080715815]
[0.9786143045025006, 0.021385695497499295]
[0.929008121670104, 0.07099187832989592]
[0.7050154961265704, 0.29498450387342956]
[0.6842043630906465, 0.31579563690935347]
[0.7591703463137794, 0.24082965368622047]
[0.832009686183138, 0.16799031381686216]
[0.9609920678096775, 0.03900793219032255]
[0.9610895943457001, 0.03891040565429997]
[0.7010550485606593, 0.29894495143934086]
[0.9489899398777557, 0.05101006012224428]
[0.9409935476726446, 0.059006452327355303]
[0.9722949550346227, 0.027705044965377287]
[0.8778836898983685, 0.12211631010

[0.9965070401954623, 0.003492959804537689]
[0.3871349250332766, 0.6128650749667234]
[0.8178587583573835, 0.18214124164261639]
[0.9841849785973261, 0.01581502140267385]
[0.568244527444423, 0.431755472555577]
[0.5912429740047457, 0.4087570259952544]
[0.987598319279964, 0.012401680720036026]
[0.8910434709090275, 0.10895652909097255]
[0.9709852227101509, 0.029014777289849077]
[0.9068503811009617, 0.0931496188990382]
[0.9121245399561274, 0.08787546004387255]
[0.9800366246450454, 0.01996337535495471]
[0.9462630925464177, 0.05373690745358237]
[0.9860181914506196, 0.013981808549380319]
[0.7662369390513013, 0.23376306094869873]
[0.9563526038017456, 0.04364739619825445]
[0.9873188526690472, 0.01268114733095273]
[0.8575022366205384, 0.1424977633794617]
[0.59704966848596, 0.40295033151403997]
[0.9346983761852399, 0.06530162381476007]
[0.9703673657358749, 0.02963263426412499]
[0.9406034380535221, 0.05939656194647789]
[0.9391733130559516, 0.06082668694404841]
[0.9833381064426607, 0.01666189355733934

[0.5415870249579497, 0.45841297504205025]
[0.5995015109666898, 0.4004984890333103]
[0.9959012712010266, 0.004098728798973406]
[0.9785413565438785, 0.021458643456121432]
[0.9782928979987804, 0.021707102001219544]
[0.9189885567040651, 0.08101144329593482]
[0.9240040891045076, 0.07599591089549239]
[0.9402709042380017, 0.059729095761998204]
[0.5833005091212866, 0.4166994908787134]
[0.798984554476858, 0.20101544552314202]
[0.9543194066679602, 0.0456805933320398]
[0.9933348059405799, 0.00666519405942004]
[0.954929733843954, 0.045070266156046104]
[0.954941608377339, 0.045058391622661016]
[0.9216287139047418, 0.07837128609525819]
[0.9884280919146436, 0.011571908085356313]
[0.9104162657791723, 0.08958373422082765]
[0.9728278949541094, 0.02717210504589068]
[0.9649304002100849, 0.03506959978991516]
[0.9823463360339947, 0.017653663966005432]
[0.80591559283811, 0.19408440716188993]
[0.9440046140600005, 0.055995385939999566]
[0.9547588273107313, 0.04524117268926881]
[0.973226128909, 0.02677387109100

[0.8900749701590008, 0.10992502984099914]
[0.5836811203964055, 0.41631887960359437]
[0.8973470162499859, 0.10265298375001396]
[0.42191259989398666, 0.5780874001060134]
[0.9706042591008992, 0.029395740899100707]
[0.4317567800048574, 0.5682432199951427]
[0.7887677069141573, 0.2112322930858428]
[0.8645843940309071, 0.1354156059690928]
[0.9088636200755823, 0.09113637992441773]
[0.7949223646014355, 0.20507763539856455]
[0.5915197391765259, 0.4084802608234742]
[0.8948194561324997, 0.10518054386750023]
[0.9678784265520047, 0.032121573447995246]
[0.9459155246309123, 0.0540844753690877]
[0.9253420981043193, 0.07465790189568057]
[0.9756480192324977, 0.024351980767502373]
[0.79772586770697, 0.20227413229303015]
[0.8486473345964082, 0.15135266540359174]
[0.7111552519568092, 0.2888447480431909]
[0.9779598356413274, 0.022040164358672704]
[0.8794381290896608, 0.12056187091033921]
[0.879973444466523, 0.1200265555334769]
[0.7015500011075803, 0.29844999889241974]
[0.9796879282974343, 0.02031207170256567

[0.906279817785396, 0.09372018221460408]
[0.9264657416621015, 0.07353425833789849]
[0.3791684839327926, 0.6208315160672074]
[0.9636224976443791, 0.03637750235562078]
[0.9785971842809106, 0.021402815719089394]
[0.9774682037906371, 0.022531796209362877]
[0.5081977580337486, 0.49180224196625144]
[0.8616261833258737, 0.13837381667412638]
[0.9804507825112303, 0.019549217488769786]
[0.9928768410936786, 0.007123158906321378]
[0.8802053695092373, 0.11979463049076258]
[0.8510106123024879, 0.14898938769751227]
[0.9844042659093395, 0.015595734090660496]
[0.8848340456029983, 0.11516595439700163]
[0.8642880177984955, 0.1357119822015045]
[0.9505383015825957, 0.049461698417404346]
[0.95813789576518, 0.041862104234819994]
[0.8776123617771167, 0.1223876382228834]
[0.9587284754158015, 0.041271524584198555]
[0.7110403022062515, 0.28895969779374847]
[0.6437979144710749, 0.356202085528925]
[0.7701895858938991, 0.22981041410610084]
[0.8622319358902689, 0.13776806410973105]
[0.9727323513950952, 0.02726764860

[0.8209037224470886, 0.17909627755291135]
[0.8987836906152789, 0.10121630938472116]
[0.9145038576608827, 0.08549614233911736]
[0.9107466039980447, 0.0892533960019553]
[0.5590988079981432, 0.44090119200185696]
[0.9869287175026519, 0.013071282497348086]
[0.4785452274965969, 0.5214547725034031]
[0.970953673220075, 0.029046326779924993]
[0.700332843553863, 0.29966715644613695]
[0.9938493112323359, 0.006150688767664042]
[0.3691985170461755, 0.6308014829538244]
[0.965250139808363, 0.034749860191637054]
[0.9718093705949996, 0.02819062940500033]
[0.8919611359634847, 0.10803886403651526]
[0.9450231830520973, 0.05497681694790269]
[0.9862742475218779, 0.013725752478122147]
[0.5513072929782644, 0.4486927070217356]
[0.9897791728708508, 0.010220827129149113]
[0.9585933548844762, 0.041406645115523895]
[0.9688264208968675, 0.031173579103132495]
[0.9827608027481008, 0.017239197251899183]
[0.9642791272722315, 0.03572087272776839]
[0.6120594382419334, 0.3879405617580665]
[0.7655181522564818, 0.2344818477

[0.7825536778007509, 0.21744632219924906]
[0.8189972746751447, 0.18100272532485528]
[0.9801350865339201, 0.019864913466079822]
[0.7636255566916926, 0.23637444330830748]
[0.9725291433708595, 0.02747085662914057]
[0.8821499423669817, 0.11785005763301828]
[0.9956970130085785, 0.004302986991421599]
[0.8529359172739964, 0.14706408272600352]
[0.8804944549997596, 0.11950554500024031]
[0.4405638741754802, 0.5594361258245198]
[0.8918717388802164, 0.1081282611197836]
[0.9639950946372313, 0.03600490536276863]
[0.48019582771629293, 0.5198041722837071]
[0.9347213020999816, 0.06527869790001838]
[0.8159657789702368, 0.1840342210297631]
[0.4610939828778203, 0.5389060171221797]
[0.9544256046991644, 0.04557439530083562]
[0.9232143921451319, 0.07678560785486807]
[0.9204688463671771, 0.07953115363282304]
[0.9669408637123408, 0.03305913628765929]
[0.8855775164716649, 0.11442248352833508]
[0.9157542465355121, 0.08424575346448793]
[0.9659314614975844, 0.034068538502415646]
[0.9828190284832179, 0.017180971516

[0.9791557502310125, 0.02084424976898745]
[0.9072634323296777, 0.09273656767032241]
[0.986697307284742, 0.013302692715258007]
[0.906353240534405, 0.09364675946559507]
[0.9398695357777791, 0.060130464222220924]
[0.9343963002619801, 0.06560369973802001]
[0.9475770287498383, 0.052422971250161744]
[0.7962984596415034, 0.2037015403584966]
[0.9221928865256415, 0.07780711347435844]
[0.9825739015293663, 0.017426098470633773]
[0.7891402240290598, 0.21085977597094016]
[0.5278594458005657, 0.47214055419943424]
[0.7597098355114474, 0.24029016448855256]
[0.9601231444807082, 0.03987685551929173]
[0.8950937780709277, 0.1049062219290723]
[0.7428553619682954, 0.25714463803170456]
[0.9760553546879599, 0.023944645312040123]
[0.9319233051161734, 0.06807669488382666]
[0.9212019410339567, 0.07879805896604324]
[0.9684878289148193, 0.031512171085180694]
[0.9735846661397888, 0.026415333860211106]
[0.7962945398681828, 0.20370546013181717]
[0.8379280842419211, 0.16207191575807894]
[0.9735706796019925, 0.02642932

[0.43076283758853906, 0.5692371624114608]
[0.963174471567186, 0.03682552843281404]
[0.9900213625766113, 0.009978637423388724]
[0.8796446862915391, 0.1203553137084608]
[0.9590424046118984, 0.04095759538810153]
[0.8646953832939227, 0.13530461670607727]
[0.8283255992850561, 0.1716744007149439]
[0.866196020948362, 0.13380397905163793]
[0.5677300198017374, 0.4322699801982625]
[0.6824465310912153, 0.31755346890878466]
[0.9693845303623235, 0.030615469637676525]
[0.9343042691080281, 0.06569573089197177]
[0.45274241474188903, 0.547257585258111]
[0.7922470163855214, 0.20775298361447864]
[0.7228364868733914, 0.27716351312660853]
[0.9037416646137038, 0.09625833538629634]
[0.9724940751102833, 0.027505924889716715]
[0.9458092438294647, 0.05419075617053533]
[0.9759120450165554, 0.024087954983444612]
[0.8682628424757036, 0.13173715752429635]
[0.9400917356763798, 0.059908264323620154]
[0.7071813308551682, 0.29281866914483184]
[0.6974658394242536, 0.30253416057574645]
[0.8410203923574945, 0.158979607642

[0.9474956949077178, 0.0525043050922822]
[0.7470805430748936, 0.2529194569251064]
[0.8907130113749957, 0.10928698862500416]
[0.847573478054571, 0.15242652194542902]
[0.9359749491078783, 0.06402505089212181]
[0.7642440970585804, 0.2357559029414196]
[0.9469197476846314, 0.05308025231536861]
[0.961858008688674, 0.03814199131132594]
[0.9909031602061247, 0.009096839793875378]
[0.8616915956094104, 0.13830840439058956]
[0.7803203434720433, 0.21967965652795673]
[0.9380274803912042, 0.061972519608795865]
[0.9839106929480392, 0.016089307051960745]
[0.7501642316184142, 0.24983576838158575]
[0.8533920927484845, 0.14660790725151554]
[0.9832414590474994, 0.016758540952500622]
[0.9041315676819623, 0.09586843231803768]
[0.9750129528390084, 0.024987047160991576]
[0.8693361170640054, 0.1306638829359947]
[0.8185199148214221, 0.18148008517857792]
[0.9107096557538179, 0.08929034424618205]
[0.9932330176889064, 0.006766982311093649]
[0.7781715540590282, 0.22182844594097176]
[0.9507640757976541, 0.04923592420

[0.9213264885400141, 0.07867351145998598]
[0.9855079348868735, 0.01449206511312654]
[0.9833633236678434, 0.016636676332156666]
[0.7496820703723891, 0.25031792962761085]
[0.9770672811119386, 0.02293271888806134]
[0.9031192655384828, 0.09688073446151726]
[0.97944823048383, 0.02055176951616991]
[0.5769206945293331, 0.4230793054706668]
[0.8226429749587302, 0.17735702504126968]
[0.5782307592726655, 0.42176924072733446]
[0.5431004693723349, 0.456899530627665]
[0.9920140510113892, 0.007985948988610759]
[0.9376633072031993, 0.062336692796800694]
[0.9284441463769357, 0.07155585362306435]
[0.9788084552149227, 0.021191544785077322]
[0.8739734965761918, 0.12602650342380822]
[0.9913367167683335, 0.00866328323166652]
[0.9706687963050247, 0.029331203694975258]
[0.9851681397411624, 0.014831860258837707]
[0.9474811242412469, 0.05251887575875324]
[0.9215140437815419, 0.0784859562184581]
[0.9678400229952252, 0.032159977004774666]
[0.986191065703438, 0.013808934296562112]
[0.9911135581656523, 0.0088864418

[0.9752925320847771, 0.02470746791522289]
[0.986781919606994, 0.013218080393005918]
[0.9465202644309969, 0.05347973556900305]
[0.6546918905854069, 0.3453081094145932]
[0.9800319205163694, 0.01996807948363057]
[0.8650315796507486, 0.13496842034925127]
[0.8972707070768504, 0.10272929292314951]
[0.9734904207466204, 0.026509579253379583]
[0.9738555020795819, 0.0261444979204181]
[0.9451918643967664, 0.05480813560323362]
[0.9784015846198715, 0.021598415380128515]
[0.9748282283215279, 0.025171771678472177]
[0.9763636123688948, 0.02363638763110521]
[0.9199103135503417, 0.08008968644965833]
[0.7693982097512757, 0.23060179024872435]
[0.9136856644350535, 0.08631433556494644]
[0.9645708615562351, 0.03542913844376495]
[0.9562087185313909, 0.04379128146860912]
[0.8229765993503709, 0.17702340064962907]
[0.9694122765511148, 0.03058772344888521]
[0.6820123157536906, 0.31798768424630947]
[0.873367756537674, 0.1266322434623261]
[0.9933507416245103, 0.006649258375489724]
[0.7357345019022891, 0.26426549809

[0.9789709791913069, 0.021029020808693004]
[0.9828188216575837, 0.01718117834241632]
[0.6795364415137285, 0.3204635584862716]
[0.9963808706677704, 0.003619129332229639]
[0.7794587467958167, 0.2205412532041834]
[0.9288461902789761, 0.07115380972102395]
[0.5562845235370278, 0.44371547646297216]
[0.9744869525632401, 0.025513047436759757]
[0.979713901886822, 0.020286098113178035]
[0.9133937128813917, 0.08660628711860824]
[0.8273796324861583, 0.17262036751384166]
[0.8589902126695467, 0.14100978733045327]
[0.9630864279653748, 0.03691357203462509]
[0.8557697831034505, 0.1442302168965494]
[0.9912328040086582, 0.008767195991341766]
[0.856737542674411, 0.1432624573255889]
[0.9605694071638211, 0.03943059283617884]
[0.9784940523710065, 0.021505947628993403]
[0.6778077245724196, 0.32219227542758044]
[0.5754134815837444, 0.42458651841625555]
[0.9849648273726122, 0.015035172627387862]
[0.9815089199748853, 0.0184910800251147]
[0.8722404274039066, 0.1277595725960934]
[0.7329366842018352, 0.267063315798

[0.9296084418885631, 0.07039155811143687]
[0.9862258859169349, 0.01377411408306512]
[0.9692377804912381, 0.030762219508762033]
[0.9950029936165552, 0.004997006383444767]
[0.6728385930108026, 0.32716140698919727]
[0.98288267580261, 0.017117324197390052]
[0.9127736686956995, 0.08722633130430053]
[0.9577529331504944, 0.042247066849505646]
[0.7160565563372011, 0.28394344366279894]
[0.9451124907522193, 0.05488750924778076]
[0.91444612536981, 0.08555387463019005]
[0.900085285329577, 0.09991471467042297]
[0.7937098999615027, 0.2062901000384974]
[0.9849192428000513, 0.015080757199948601]
[0.6355614340053007, 0.36443856599469937]
[0.7819298082926538, 0.21807019170734623]
[0.94297306779425, 0.05702693220574997]
[0.9320913389637704, 0.06790866103622954]
[0.961912344960155, 0.03808765503984487]
[0.8311562374490146, 0.16884376255098543]
[0.9447897870517638, 0.05521021294823621]
[0.9883388885997217, 0.011661111400278349]
[0.9865614142890665, 0.013438585710933544]
[0.9663659439420961, 0.0336340560579

[0.9243576281732034, 0.07564237182679653]
[0.9595888512422481, 0.04041114875775184]
[0.9746415111370071, 0.025358488862992877]
[0.9733405413833889, 0.026659458616611093]
[0.9533254608047499, 0.04667453919525019]
[0.894966438335317, 0.10503356166468296]
[0.9813138541964126, 0.018686145803587444]
[0.9628582715847961, 0.03714172841520391]
[0.9442706208459767, 0.05572937915402337]
[0.9903373423719819, 0.009662657628018026]
[0.9843284421419528, 0.015671557858047257]
[0.6210032523939201, 0.37899674760608]
[0.8974765000539667, 0.10252349994603321]
[0.9261413222732057, 0.07385867772679428]
[0.9385515136006216, 0.06144848639937839]
[0.9819090350253113, 0.018090964974688713]
[0.9672432904446852, 0.032756709555314796]
[0.9822948390793262, 0.01770516092067377]
[0.835156344004832, 0.16484365599516806]
[0.9713596813286204, 0.028640318671379585]
[0.9336206633954492, 0.06637933660455078]
[0.973080882566801, 0.026919117433199028]
[0.7118266602583962, 0.28817333974160375]
[0.9899217618983256, 0.01007823

[0.9665391413614071, 0.033460858638592855]
[0.6992988040395494, 0.3007011959604507]
[0.8181684359332777, 0.18183156406672227]
[0.827634078547576, 0.172365921452424]
[0.992925690986619, 0.00707430901338104]
[0.5089688703843526, 0.4910311296156475]
[0.8999413714617412, 0.10005862853825875]
[0.9775062143203868, 0.022493785679613146]
[0.7985478227464979, 0.20145217725350203]
[0.9249420696036551, 0.07505793039634485]
[0.8143984570344749, 0.18560154296552517]
[0.7927765194409492, 0.20722348055905082]
[0.9814609298950704, 0.018539070104929615]
[0.7506764679430017, 0.24932353205699836]
[0.9265166986206939, 0.07348330137930606]
[0.925172841150823, 0.07482715884917707]
[0.9698407155025928, 0.030159284497407252]
[0.6491402494213206, 0.35085975057867935]
[0.9734644160429824, 0.026535583957017536]
[0.9893546671475184, 0.010645332852481627]
[0.7628617850634782, 0.23713821493652185]
[0.9600924867044868, 0.0399075132955133]
[0.8590682197995325, 0.14093178020046754]
[0.9404325638097257, 0.0595674361902

[0.9407759963884448, 0.0592240036115552]
[0.8775791613643883, 0.12242083863561173]
[0.965955965172362, 0.03404403482763803]
[0.9217326352427128, 0.07826736475728717]
[0.8614896846599509, 0.138510315340049]
[0.9665338050752162, 0.033466194924783824]
[0.6032140248861916, 0.39678597511380836]
[0.7795475304995654, 0.22045246950043468]
[0.8371990242597858, 0.16280097574021427]
[0.871665065648807, 0.12833493435119311]
[0.9694300459571211, 0.030569954042878923]
[0.94493940413323, 0.05506059586676996]
[0.9628367387622825, 0.03716326123771742]
[0.806792894843509, 0.19320710515649092]
[0.9077953333730674, 0.09220466662693262]
[0.9269823593571623, 0.07301764064283779]
[0.9253792400416851, 0.07462075995831484]
[0.9135063141697944, 0.0864936858302056]
[0.8126725945167799, 0.1873274054832202]
[0.9867473666888554, 0.0132526333111446]
[0.8863798173092313, 0.11362018269076861]
[0.7185348428995636, 0.28146515710043635]
[0.9760889201390083, 0.0239110798609918]
[0.7360008115842502, 0.2639991884157497]
[0.

[0.8387291110857921, 0.161270888914208]
[0.3751672446613185, 0.6248327553386815]
[0.8379278849498152, 0.16207211505018476]
[0.9490759011676386, 0.05092409883236148]
[0.9028948464261971, 0.09710515357380284]
[0.9819529084411442, 0.018047091558855744]
[0.7165211947371078, 0.28347880526289226]
[0.8284059690631206, 0.17159403093687936]
[0.8050184162228305, 0.1949815837771694]
[0.9721994997932196, 0.027800500206780477]
[0.9390380678807181, 0.06096193211928195]
[0.9559099063744337, 0.04409009362556627]
[0.7258436197446758, 0.27415638025532424]
[0.9738356441278901, 0.02616435587210999]
[0.9775977570622342, 0.022402242937765816]
[0.8244217484090991, 0.17557825159090096]
[0.927242644489772, 0.07275735551022812]
[0.9927070550352366, 0.007292944964763429]
[0.9603715778798673, 0.03962842212013281]
[0.6123891211550163, 0.38761087884498363]
[0.8759990465464058, 0.12400095345359427]
[0.8536387281402312, 0.14636127185976872]
[0.9412244793500055, 0.05877552064999449]
[0.9954303294595754, 0.004569670540

[0.8708718781907359, 0.12912812180926417]
[0.88093602126412, 0.11906397873588008]
[0.8351299654673073, 0.16487003453269267]
[0.8984210959179221, 0.10157890408207798]
[0.8859403570509901, 0.11405964294900998]
[0.9905561906699591, 0.009443809330040942]
[0.8049869258225825, 0.19501307417741748]
[0.8933903489640952, 0.1066096510359047]
[0.968026413292618, 0.03197358670738203]
[0.9896869601511497, 0.010313039848850458]
[0.7148701119094839, 0.2851298880905161]
[0.8343325769711973, 0.1656674230288027]
[0.692941664235759, 0.30705833576424113]
[0.9307279824233483, 0.06927201757665165]
[0.6817716307198657, 0.31822836928013437]
[0.9069490118923532, 0.0930509881076467]
[0.9573207947893727, 0.04267920521062731]
[0.9454384990914058, 0.05456150090859429]
[0.8976265001186667, 0.10237349988133324]
[0.9350649134439861, 0.06493508655601389]
[0.9250475275579552, 0.07495247244204481]
[0.9693365712828346, 0.030663428717165445]
[0.8439458664151199, 0.15605413358488007]
[0.8123223161314291, 0.1876776838685708

[0.8291217928667579, 0.17087820713324212]
[0.8883156855915276, 0.11168431440847237]
[0.9574528699929495, 0.04254713000705044]
[0.9715461187300648, 0.028453881269935227]
[0.9879386339943783, 0.012061366005621615]
[0.9650437433178616, 0.03495625668213844]
[0.8924785652921442, 0.10752143470785587]
[0.8088185847052629, 0.191181415294737]
[0.9294022126412753, 0.07059778735872464]
[0.601583288127858, 0.39841671187214206]
[0.9463456156572535, 0.05365438434274655]
[0.8962107647982613, 0.10378923520173855]
[0.7442102248980408, 0.2557897751019592]
[0.80634660761545, 0.19365339238454995]
[0.8261102118630663, 0.17388978813693365]
[0.9889504855664649, 0.01104951443353499]
[0.9910228063260434, 0.008977193673956592]
[0.9810532171467783, 0.0189467828532218]
[0.9853046142636586, 0.014695385736341523]
[0.9795882212813392, 0.02041177871866078]
[0.9688449084488192, 0.031155091551180854]
[0.8964889974354845, 0.10351100256451558]
[0.8457306216527508, 0.15426937834724921]
[0.4560835481141437, 0.5439164518858

[0.9384362690991259, 0.061563730900874046]
[0.8744481925292937, 0.12555180747070632]
[0.9756511413894465, 0.02434885861055348]
[0.9724458637468522, 0.027554136253147815]
[0.9780769722371793, 0.02192302776282074]
[0.9670917771813209, 0.03290822281867907]
[0.7447012859287807, 0.2552987140712194]
[0.810699389145552, 0.18930061085444805]
[0.8296791682157153, 0.17032083178428475]
[0.9733894763143609, 0.026610523685639174]
[0.9061543231893056, 0.0938456768106945]
[0.9742557924687699, 0.025744207531230072]
[0.7680433410323316, 0.2319566589676684]
[0.829552374289768, 0.17044762571023203]
[0.858118319055993, 0.14188168094400694]
[0.9086749345424158, 0.09132506545758423]
[0.8841773075412808, 0.1158226924587193]
[0.7966639330488907, 0.20333606695110928]
[0.5918567096363204, 0.4081432903636795]
[0.9088008650542845, 0.0911991349457156]
[0.9279071770255125, 0.07209282297448752]
[0.975118318032807, 0.024881681967193113]
[0.8124366190183175, 0.1875633809816824]
[0.9861711212618304, 0.01382887873816974

[0.9833814349339564, 0.01661856506604371]
[0.8935215700159344, 0.10647842998406556]
[0.9356720626923927, 0.06432793730760723]
[0.9317025133744529, 0.06829748662554709]
[0.6827331455488026, 0.3172668544511974]
[0.6974990921830353, 0.3025009078169647]
[0.9746001612245067, 0.025399838775493247]
[0.8745179048400991, 0.1254820951599009]
[0.8664748812506003, 0.13352511874939962]
[0.7996202456782417, 0.2003797543217583]
[0.9664058121440066, 0.03359418785599343]
[0.9824850632582538, 0.017514936741746215]
[0.8986651030147403, 0.10133489698525967]
[0.8450688497540345, 0.15493115024596546]
[0.9903595637149589, 0.009640436285041154]
[0.8583325309449384, 0.14166746905506156]
[0.9942636394208344, 0.0057363605791655925]
[0.834978362457382, 0.16502163754261798]
[0.8501993708193156, 0.14980062918068435]
[0.9730413597367784, 0.026958640263221572]
[0.9134302118104768, 0.0865697881895233]
[0.9583861255351834, 0.0416138744648166]
[0.5911538323751103, 0.4088461676248897]
[0.6504931316600732, 0.3495068683399

[0.8848068424571591, 0.11519315754284089]
[0.5724568677975458, 0.42754313220245427]
[0.9194251096663718, 0.08057489033362827]
[0.9713591284346891, 0.028640871565311]
[0.9765149587188654, 0.023485041281134576]
[0.9594585171509509, 0.040541482849049175]
[0.9950112852400105, 0.004988714759989487]
[0.9542722672355937, 0.04572773276440627]
[0.8647917921397911, 0.13520820786020893]
[0.9924489202538437, 0.007551079746156361]
[0.9506271367968685, 0.04937286320313158]
[0.7508592756360029, 0.24914072436399706]
[0.9900948609122979, 0.009905139087702113]
[0.9840187346460149, 0.01598126535398512]
[0.949124724412126, 0.050875275587874]
[0.8505471955627453, 0.14945280443725475]
[0.34525559618577245, 0.6547444038142275]
[0.7899334943968799, 0.21006650560312004]
[0.8311014588037559, 0.16889854119624412]
[0.9641037181685684, 0.03589628183143163]
[0.9307766646126124, 0.06922333538738758]
[0.9934729742307631, 0.006527025769237062]
[0.931567388635423, 0.06843261136457703]
[0.649899423617587, 0.350100576382

# cvxpy layer benchmark
Tried using pycall for the cvxpy layer benchmark, did not work. Same code works in python. 
Dump csv files of training and test sets. 

In [87]:
#PyCall.pyversion

v"3.7.7"

In [128]:
py"""
import cvxpy as cp
import torch 
from cvxpylayers.torch import CvxpyLayer
import numpy as np
import torch.nn as nn
from torch.nn import Module
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.nn.parameter import Parameter

p = 5


def train(iters, Xtrain, ztrain):
    class Net(nn.Module):
        def __init__(self):
          super(Net, self).__init__()
          # First fully connected layer
          self.fc1 = nn.Linear(p+1, 32)
          # Second fully connected layer that outputs 
          self.fc2 = nn.Linear(32, 1)
        def forward(self, x):
            x = nn.LeakyReLU(0.1)(self.fc1(x))
            x = nn.LeakyReLU(0.1)(self.fc2(x))
            return x


    dim0 = ztrain.shape[0]
    ztrain = ztrain.reshape((dim0,1))
    # problem
    w_cvxpy = cp.Variable(1)
    z_cvxpy = cp.Parameter(1)
    objective = cp.logistic(1+ cp.exp(-cp.multiply(z_cvxpy,w_cvxpy))) + 0.001*cp.norm(w_cvxpy)**2 
    problem = cp.Problem(cp.Minimize(objective), [w_cvxpy <= 10, w_cvxpy >= -10 ])
    assert problem.is_dpp()
    net = Net()
    cvxlayer = CvxpyLayer(problem, [z_cvxpy], [w_cvxpy])
    results = []

    optimizer = optim.SGD(net.parameters(), lr=0.01)  
    torch.autograd.set_detect_anomaly(True)
    for i in range(iters):
        out1 = net(Xtrain)
        out, = cvxlayer(out1)
#        loss = torch.mean(torch.log(torch.add(torch.exp(torch.neg(torch.multiply(out,ztrain))), 1)))    
#        optimizer.zero_grad()   # zero the gradient buffers
#        loss.backward()
#        optimizer.step()    # Does the update
#        results.append(loss.item())
#        print("(iter %d) loss: %g " % (i, results[-1]))
    return results, net, cvxlayer

def predict(net, cvxlayer, Xtest, ztest):
    z_pred = net(Xtest)
    w_test, = cvxlayer(z_pred)
    return w_test

def main_cvxlayer(Xtrain, ztrain, Xtest, ztest):
    Xtrain = Xtrain.astype(np.float32)
    ztrain = ztrain.astype(np.float32)
    Xtest = Xtest.astype(np.float32)
    ztest = ztest.astype(np.float32)
    Xtrain, Xtest, ztrain, ztest = map(
        torch.from_numpy, [Xtrain, Xtest, ztrain, ztest])
    results, net, cvxlayer = train(50, Xtrain, ztrain)
#    w_test = predict(net, cvxlayer, Xtest, ztest)
    return net
    

"""

In [116]:
# num_x = 50
# idx3 = rand(1:num_x0, num_x)
# x_train = x_train01[idx3,:]
# z_train = z_train0[idx3,:]
# w_test = py"main_cvxlayer"(x_test1, z_train, x_test1, z_test)
# print(w_test)
# for i in 1:num_t
#     z = z_test[i]
#     w = w_test[i]
#     cost_t = cost_func(w,z)
#     cost_test1 += cost_t
# end  

In [121]:
# Pkg.add("DataFrames")

# Main

In [332]:
idx = shuffle(1:num_x0);

In [333]:
# specify num = training size:
#num = 50
num = 1000

1000

In [334]:

idx1 = idx[1:num]
idx2 = idx[num+1:end];

In [335]:
x_train = x_samples[idx1,:]
x_train1 = x_samples1[idx1, :]
z_train = z_samples[idx1]
x_test = x_samples[idx2,:]
x_test1 = x_samples1[idx2, :]
z_test = z_samples[idx2];

In [336]:
# dump csvs
CSV.write("x_train.csv", Tables.table(x_train1), append = false)

CSV.write("z_train.csv", Tables.table(reshape(z_train, length(z_train), 1)), append = false)

CSV.write("x_test.csv", Tables.table(x_test1), append = false)

CSV.write("z_test.csv", Tables.table(reshape(z_test, length(z_test), 1)), append = false)

"z_test.csv"

# test set performance

## true B

In [337]:
cost_test0 = 0
for i in 1:num_t
    x = append!(x_test[i, :],1)
    z = z_test[i]
    p = softmax(B_true*x)
    w = opt_solution_oracle(p)
    cost_t = cost_func(w[1],z[1])
    cost_test0 += cost_t
end 
cost_test0 = cost_test0/num_t

0.3113324503626359

## Our method

In [488]:
# B = zeros((2,6))
alpha = 0.01
# alpha = 0.1
total_ite = 10 #1000
batch_size = 128
cost_list = zeros(total_ite)
for ite in 1:total_ite
    println("ite ", ite)
    #     println("old B ", B)
    grad = zeros((2,6))
    cost = 0
    b_index = get_batch_index(num_x, batch_size)
    for indx in b_index
        x = append!(x_train[indx, :],1)
    #         print(x)
    #         print(z)
        z = z_train[indx]
        p = softmax(B*x)
        result = py"approx_sol_oracle"(p,svr1)
        w = result[1]
        grad_w = result[2]
    #         println(grad_w)
        grad_c = grad_cost_func(w[1],z[1])
    #         println(grad_c)
        grad_t = grad_c*grad_w
        grad_t =  grad_prob(B,x)[1]*grad_t[1] + grad_prob(B,x)[2]*grad_t[2]
        grad_t = hcat(grad_t[1] , grad_t[2])
        grad += transpose(grad_t)
    end
    grad = grad/batch_size
    #     println("grad ",grad)
    for i in 1:num_x
        x = append!(x_train[i, :],1)
        z = z_train[i]
        p = softmax(B*x)
        result = py"approx_sol_oracle2"(p,svr1)
        w = result[1]
        cost_t = cost_func(w[1],z[1])
        cost += cost_t
    end
    cost = cost/num_x
    println("cost ", cost)
    cost_list[ite] = cost
    B = B - alpha* grad
    #     println("update B ", B)
end

ite 1
cost 0.3494719725879051
ite 2
cost 0.3494153039175525
ite 3
cost 0.34942431410689845
ite 4
cost 0.34946729531846166
ite 5
cost 0.3494968924949367
ite 6
cost 0.3494664351724854
ite 7
cost 0.349490921619074
ite 8
cost 0.34946321466084207
ite 9
cost 0.3494726591163422
ite 10
cost 0.349451665814623


In [489]:
cost_test1 = 0
for i in 1:num_t
    x = append!(x_test[i, :],1)
    z = z_test[i]
    p = softmax(B*x)
    w = opt_solution_oracle(p)
    cost_t = cost_func(w[1],z[1])
    cost_test1 += cost_t
end
cost_test1 = cost_test1/num_t

0.3415587212065811

## cvxpy layer

In [493]:
# load cvxpylayer decisions from csv
w_nn = CSV.read("w_test.csv", datarow=1, DataFrame)
w_nn = convert(Matrix, w_nn);

In [494]:
cost_test2 = 0
for i in 1:num_t
#     print(i)
    z = z_test[i]
    w = w_nn[i]
    cost_t = cost_func(w,z)
    cost_test2 += cost_t
end
cost_test2 = cost_test2/num_t

0.42301068894008304

## nonparametric prescriptive methods

 #### KNN

In [365]:
cost_test3 = 0
for i in 1:num_t
    n_p, n_n = get_kernel_count(x_test[i,:], x_train, z_train, 800)
    p = [n_p/(n_p+n_n), n_n/(n_p+n_n)]
    w = opt_solution_oracle(p)
    z = z_test[i]
    cost_t = cost_func(w[1],z[1])
    cost_test3 += cost_t
end 
cost_test3 = cost_test3/num_t

0.38752768359608064

In [366]:
cost_test4 = 0
for i in 1:num_t
    n_p, n_n = get_kernel_count(x_test[i,:], x_train, z_train, 1)
    p = [n_p/(n_p+n_n), n_n/(n_p+n_n)]
    w = opt_solution_oracle(p)
    z = z_test[i]
    cost_t = cost_func(w[1],z[1])
    cost_test4 += cost_t
end 
cost_test4 = cost_test4/num_t

0.36834323743953507

In [370]:
cost_test5 = 0
for i in 1:num_t
    n_p, n_n = get_kernel_count(x_test[i,:], x_train, z_train, 5)
    p = [n_p/(n_p+n_n), n_n/(n_p+n_n)]
    w = opt_solution_oracle(p)
    z = z_test[i]
    cost_t = cost_func(w[1],z[1])
    cost_test5 += cost_t
end
cost_test5 = cost_test5/num_t

0.38663706822789146

 #### Kernel

In [371]:
cost_test6 = 0
for i in 1:num_t
    n_p, n_n = get_kernel_count(x_test[i,:], x_train, z_train, 0.1)
    p = [n_p/(n_p+n_n), n_n/(n_p+n_n)]
    w = opt_solution_oracle(p)
    z = z_test[i]
    cost_t = cost_func(w[1],z[1])
    cost_test6 += cost_t
end 
cost_test6 = cost_test6/num_t

0.6197008834415776

In [372]:
cost_test6 = 0
for i in 1:num_t
    n_p, n_n = get_kernel_count(x_test[i,:], x_train, z_train, 100)
    p = [n_p/(n_p+n_n), n_n/(n_p+n_n)]
    w = opt_solution_oracle(p)
    z = z_test[i]
    cost_t = cost_func(w[1],z[1])
    cost_test6 += cost_t
end 
cost_test6 = cost_test6/num_t

0.3875254782893816

In [376]:
cost_test6 = 0
for i in 1:num_t
    n_p, n_n = get_kernel_count(x_test[i,:], x_train, z_train, 5)
    p = [n_p/(n_p+n_n), n_n/(n_p+n_n)]
    w = opt_solution_oracle(p)
    z = z_test[i]
    cost_t = cost_func(w[1],z[1])
    cost_test6 += cost_t
end 
cost_test6 = cost_test6/num_t

0.38663706822789146

#### tree

In [377]:
cost_test7 = 0
for i in 1:num_t
    n_p, n_n = py"get_count_tree"(x_test[i,:], x_train, z_train, dep =2)
    p = [n_p/(n_p+n_n), n_n/(n_p+n_n)]
    w = opt_solution_oracle(p)
    z = z_test[i]
    cost_t = cost_func(w[1],z[1])
    cost_test7 += cost_t
end 
cost_test7 = cost_test7/num_t

0.3309671714133481

In [380]:
cost_test8 = 0
for i in 1:num_t
    n_p, n_n = py"get_count_tree"(x_test[i,:], x_train, z_train, dep =3)
    p = [n_p/(n_p+n_n), n_n/(n_p+n_n)]
    w = opt_solution_oracle(p)
    z = z_test[i]
    cost_t = cost_func(w[1],z[1])
    cost_test8 += cost_t
end 
cost_test8 = cost_test8/num_t

0.3686909393733938

## SAA

In [381]:
cost_test9 = 0
for i in 1:num_t
    n_p, n_n = get_kernel_count(x_test[i,:], x_train, z_train, num_x)
    p = [n_p/(n_p+n_n), n_n/(n_p+n_n)]
    w = opt_solution_oracle(p)
    z = z_test[i]
    cost_t = cost_func(w[1],z[1])
    cost_test9 += cost_t
end 
cost_test9 = cost_test9/num_t

0.387518757765629

# Results

| Algorithm   | N=50        | N=100         | N=1000       | 
| :---        |    :----:   |          ---: |         ---: |
| Optimal     | 0.310       | 0.305         |  0.311       |
| Our method  | 0.343       | 0.329         |  0.341       |
| KNN         | 0.395       | 0.349         |  0.368       |
| kernel      | 0.396       | 0.367         |  0.387       |
| tree        | 1.257       | 0.349         |  0.331       |
| cvxlayer    | 0.436       | 0.406         |  0.423       |
| SAA         | 0.396       | 0.368         |  0.388       |