In [2]:
import numpy as np
import cvxpy as cp
import pandas as pd
import matplotlib.pyplot as plt

from data_generation import *

## Create dataset

In [18]:
n = 100 # number of data points
m = 1 # number of features
sigma = 0.1
lamb = 2
rho = 1
X, y, w = gen_synthetic_normal(n,m,sigma)

In [19]:
l = lambda X: np.linalg.norm(X@np.linalg.pinv(X)@y-y)**2

In [20]:
print("Shape of X: ", X.shape)
print("Shape of y: ", y.shape)
print("Shape of w: ", w.shape)

Shape of X:  (100, 1)
Shape of y:  (100, 1)
Shape of w:  (1, 1)


## Nominal Problem

In [21]:
def solve_nominal(X, y, lamb=1):
    beta = cp.Variable((m,1))
    log_likelihood = cp.sum(
        cp.multiply(y, X @ beta) - cp.logistic(X @ beta)
    )
    prob = cp.Problem(cp.Maximize(log_likelihood/n - lamb * cp.norm(beta, 2)))
    prob.solve()
    return prob.value, beta.value

In [22]:
loss, beta = solve_nominal(X,y,lamb)
print(loss,beta)

-0.6025550647646747 [[-0.87384067]]


## Robust Problem

https://arxiv.org/pdf/1412.6572.pdf

Check page 4

In [32]:
def solve_robust(X, y, lamb=1,rho=0.1):
    w = cp.Variable((m,1))
    b = cp.Variable(1)
    robust_obj = cp.Minimize(cp.sum(cp.logistic(cp.multiply(y,(rho*cp.norm(w,1) - X@w - b)))) + lamb*cp.square(cp.norm(w,2)))
    prob = cp.Problem(robust_obj)
    prob.solve()
    return prob.value, w.value, b.value

In [33]:
loss, beta, b = solve_robust(X,y,lamb)
print(loss,beta)

DCPError: Problem does not follow DCP rules. Specifically:
The objective is not DCP. Its following subexpressions are not:
[[-0.29915037]
 [-3.88090367]
 [ 2.21373123]
 [ 1.91348189]
 [ 1.71221572]
 [-0.34340797]
 [-2.16399919]
 [ 1.19601845]
 [ 0.38424361]
 [ 0.94441981]
 [ 0.38662849]
 [ 1.64405136]
 [-2.35750371]
 [ 0.85223018]
 [-0.8813186 ]
 [-3.11178457]
 [ 2.34438367]
 [ 2.66965757]
 [-0.83027786]
 [-0.07548711]
 [-0.97654628]
 [ 2.01732266]
 [-4.32094321]
 [ 0.12198457]
 [ 1.5034777 ]
 [ 1.4699799 ]
 [-2.09012213]
 [-1.07340079]
 [ 0.83851692]
 [ 0.39560752]
 [ 2.80485942]
 [-1.63495372]
 [ 4.10657855]
 [-2.02143501]
 [ 4.71456365]
 [ 0.78319156]
 [-0.72605406]
 [-0.45402683]
 [-3.99127321]
 [-0.90769092]
 [ 0.92540298]
 [ 1.32844285]
 [-1.73518033]
 [ 2.29269536]
 [-1.61345056]
 [ 0.3926185 ]
 [ 2.73133287]
 [ 0.1425875 ]
 [ 3.29551527]
 [-1.86458579]
 [-1.33757215]
 [ 1.21742339]
 [ 2.06991996]
 [ 1.86518299]
 [-0.61606043]
 [-2.04913978]
 [ 1.85184179]
 [ 1.69480115]
 [-3.19647627]
 [ 1.76825296]
 [-0.06980137]
 [ 2.62347867]
 [ 1.39616799]
 [-4.00433153]
 [ 2.74778029]
 [-2.02849984]
 [ 2.44622651]
 [-1.75086303]
 [ 0.48720284]
 [ 1.77268097]
 [ 2.78888863]
 [-0.30880558]
 [-1.86559788]
 [ 2.29071982]
 [-0.92588954]
 [-0.8113622 ]
 [-2.7371287 ]
 [-0.79650234]
 [-1.57241516]
 [-2.43740765]
 [ 3.82346203]
 [-2.16705678]
 [ 4.8969447 ]
 [ 0.86174864]
 [-3.53585172]
 [-1.2084533 ]
 [ 0.73481681]
 [ 1.10305569]
 [ 3.98994054]
 [-2.01651812]
 [ 0.37825405]
 [-1.69559279]
 [ 1.82040655]
 [-1.61812817]
 [ 0.08310035]
 [-2.07356335]
 [ 0.52988968]
 [ 1.88253079]
 [-2.09126317]
 [ 1.67889014]] * (0.1 * max(norm1(var975), None, False) + -[[ 0.20603968]
 [ 2.01892835]
 [-1.11130681]
 [-1.00710757]
 [-0.85139398]
 [ 0.16472667]
 [ 1.10318597]
 [-0.62458127]
 [-0.13081594]
 [-0.61827081]
 [-0.23474798]
 [-0.91553023]
 [ 1.2826022 ]
 [-0.43612095]
 [ 0.44927125]
 [ 1.6248419 ]
 [-1.27454666]
 [-1.37130946]
 [ 0.42323111]
 [ 0.04303981]
 [ 0.56570885]
 [-1.10779488]
 [ 2.23547534]
 [-0.02753405]
 [-0.73276236]
 [-0.72145065]
 [ 1.10739202]
 [ 0.61594692]
 [-0.45806213]
 [-0.12521539]
 [-1.50437771]
 [ 0.9316808 ]
 [-2.22755829]
 [ 1.09409874]
 [-2.47159045]
 [-0.45584438]
 [ 0.34310501]
 [ 0.34348118]
 [ 2.04670413]
 [ 0.40334141]
 [-0.45555482]
 [-0.66446964]
 [ 0.88666044]
 [-1.2244264 ]
 [ 0.76380164]
 [-0.23926575]
 [-1.41955496]
 [-0.04828039]
 [-1.7607628 ]
 [ 0.87643474]
 [ 0.71853816]
 [-0.58541504]
 [-1.06838396]
 [-1.04446132]
 [ 0.21599517]
 [ 0.95346135]
 [-0.98519728]
 [-0.88075558]
 [ 1.62281362]
 [-0.94962861]
 [ 0.04056106]
 [-1.35965292]
 [-0.79907362]
 [ 2.17638096]
 [-1.5176625 ]
 [ 1.05645584]
 [-1.35439805]
 [ 1.0029065 ]
 [-0.1590541 ]
 [-0.9190567 ]
 [-1.51212241]
 [ 0.13394375]
 [ 0.95617007]
 [-1.12557792]
 [ 0.44319797]
 [ 0.31675511]
 [ 1.4048846 ]
 [ 0.41809488]
 [ 0.8835787 ]
 [ 1.26839713]
 [-2.08634676]
 [ 1.13335825]
 [-2.61183518]
 [-0.49718427]
 [ 1.78326157]
 [ 0.61088537]
 [-0.35751906]
 [-0.514878  ]
 [-2.04134801]
 [ 1.07116191]
 [-0.22383691]
 [ 0.91320931]
 [-0.87376021]
 [ 0.87176283]
 [-0.06535133]
 [ 1.11495375]
 [-0.3611592 ]
 [-0.97643284]
 [ 1.12676017]
 [-0.81535156]] * var975 + -var976)