In [122]:
import torch
import numpy as np
from FaroLR import loss_robustness, loss_fairness
from utils import get_data

In [123]:
def loss_fairness(X, y, w, tp=False):
	if not tp:
		u_down = torch.sum(1.0 - X[:, 0])
		u_up = torch.sum(torch.sigmoid(torch.matmul(X, w)) * (1.0 - X[:, 0]))
		v_down = torch.sum(X[:, 0])
		v_up = torch.sum(torch.sigmoid(torch.matmul(X, w)) * X[:, 0])
		loss = torch.square((u_up / u_down) - (v_up / v_down))
	else:
		y_flat=y.reshape(-1)
		u_down = torch.sum((1.0 - X[:, 0]) * y_flat)
		u_up = torch.sum(torch.sigmoid(torch.matmul(X, w)) * (1.0 - X[:, 0]) * y_flat)
		v_down = torch.sum(X[:, 0] * y_flat)
		v_up = torch.sum(torch.sigmoid(torch.matmul(X, w)) * X[:, 0] * y_flat)
		loss = torch.square((u_up / u_down) - (v_up / v_down))
	return loss

def loss_robustness(X, y, w):
	gradx = (y.reshape(-1) - torch.sigmoid(torch.matmul(X, w))).reshape(-1, 1) * w
	loss = torch.mean(torch.sum(torch.square(gradx), axis=1))
	return loss

In [124]:
X,y=get_data(data='adult',attr='race')
X=np.hstack([X,np.ones(X.shape[0]).reshape(-1,1)])
s=X[:,0]
X=torch.tensor(X,dtype=torch.float32)
y=torch.tensor(y,dtype=torch.float32)

In [125]:
def grad_robustness(X,y,w):
    w=torch.tensor(w,dtype=torch.float32,requires_grad=True)
    loss=loss_robustness(X,y,w)
    loss.backward()
    return np.array(w.grad.reshape(-1).tolist())

def grad_fairness(X,y,w):
    w=torch.tensor(w,dtype=torch.float32,requires_grad=True)
    loss=loss_fairness(X,y,w,tp=True)
    loss.backward()
    return np.array(w.grad.reshape(-1).tolist())

sigmoid=lambda x: 1/(1+np.exp(-x))

In [126]:
w=np.random.uniform(-1,1,X.shape[1]).tolist()

In [127]:
grad_torch=grad_fairness(X,y,w)

In [143]:
def theoritical_grad_fairness(X,y,w):
    X=X.numpy()
    y=y.numpy()
    w=np.array(w)
    res=np.zeros(len(w))
    
    u=np.sum(1-X[:,0])
    v=np.sum(X[:,0]*y)
        r=sigmoid(np.dot(X[i],w))
        
    return 2*(a_up/a_down-b_up/b_down)*(c_up/c_down-d_up/d_down)

In [144]:
grad_plain=theoritical_grad_fairness(X,y,w)

In [145]:
grad_torch/grad_plain

  """Entry point for launching an IPython kernel.


array([ 2.52041075e+12,             inf, -9.23733328e+11, -2.28742414e+11,
        2.11669620e+13,             inf,             inf,  6.61444467e+12,
       -6.04816671e+11,             inf,             inf,  2.98893233e+12,
       -2.24429552e+12, -2.65152692e+14, -1.48360738e+12, -1.92051479e+12,
                   inf,             inf,            -inf,            -inf,
       -1.73003506e+13,            -inf, -7.40134119e+12,  6.15334871e+11,
        3.16581641e+13,             inf,             inf,             inf,
                  -inf,             inf, -1.50644471e+12,  4.21003868e+14,
       -2.11459165e+12,             inf,  4.21275972e+12, -4.46347608e+11,
                   inf, -2.92419011e+11, -8.52873265e+14,  1.01687814e+12,
                  -inf,             inf, -6.58309015e+12,            -inf,
       -8.69889003e+12,  3.25264857e+13,            -inf,  1.31499966e+13,
       -7.20972022e+10,             inf,  2.22532285e+12, -1.09890450e+12,
        2.65044859e+12,  

In [131]:
# ** Fairness gradient is verified to be correct **

In [132]:
grad_torch=grad_robustness(X,y,w)

In [133]:
def theoritical_grad_robustness(X,y,w):
    X=X.numpy()
    y=y.numpy()
    w=np.array(w)
    
    res=np.zeros(len(w))
    for i in range(0,X.shape[0]):
        a=w*((y[i]-sigmoid(np.dot(w,X[i])))**2)
        b=X[i]*np.dot(w,w)*(y[i]-sigmoid(np.dot(w,X[i])))*sigmoid(np.dot(w,X[i]))*(1.0-sigmoid(np.dot(w,X[i])))
        res+=(a-b)
    return 2/X.shape[0]*res

In [134]:
grad_plain=theoritical_grad_robustness(X,y,w)

In [135]:
grad_torch

array([ 2.29230022,  0.25449067,  1.68026114,  0.41608003, -0.30080062,
        0.21090929,  1.0546999 ,  0.37598765,  0.06875969,  1.27546751,
        0.43849078, -0.33980227, -0.12757343, -0.11775145,  0.04216666,
        0.10916863,  0.22497816,  0.11470972, -0.37913513, -0.18816382,
        0.49170554, -0.12648989, -0.42071754, -0.2798219 , -0.44988957,
        0.21373925,  0.14134668,  0.50585967, -0.39607912,  1.3355689 ,
        0.34252587, -0.37392655, -0.96160495,  0.15492979,  1.91574132,
       -0.02537192,  0.33506823,  0.06648839,  0.37875181,  0.23121132,
       -0.06555521,  0.46059886, -0.37420535, -0.28880155,  0.9889493 ,
       -0.23111458, -0.16850817,  0.37374538,  0.01639301,  0.45510548,
       -0.12649496, -0.49972391,  1.20528448, -0.24065359,  1.26488352,
        0.85468745, -0.542539  , -0.21195427,  0.44957766,  0.21083719,
       -0.42729807, -0.46677196,  0.40189144, -0.1875948 ,  0.08493423,
       -0.06048225, -0.03060159,  0.12595607, -0.16555053, -0.40

In [136]:
grad_plain

array([ 2.2923016 ,  0.25449109,  1.68026082,  0.41608008, -0.30080065,
        0.2109093 ,  1.05469989,  0.37598765,  0.06875969,  1.275468  ,
        0.43849081, -0.33980226, -0.12757342, -0.11775145,  0.04216666,
        0.10916868,  0.22497817,  0.11470971, -0.37913517, -0.18816384,
        0.49170553, -0.12648989, -0.42071753, -0.27982194, -0.44988959,
        0.21373885,  0.14134671,  0.50585971, -0.39607913,  1.33556935,
        0.34252563, -0.37392653, -0.9616051 ,  0.15492979,  1.91574139,
       -0.02537191,  0.33506819,  0.06648863,  0.37875181,  0.23121119,
       -0.0655552 ,  0.46059887, -0.3742054 , -0.28880148,  0.98894926,
       -0.23111461, -0.16850817,  0.3737454 ,  0.01639311,  0.45510548,
       -0.12649495, -0.49972399,  1.20528415, -0.2406536 ,  1.26488377,
        0.85468764, -0.54253904, -0.21195427,  0.44957761,  0.2108372 ,
       -0.42729805, -0.46677195,  0.40189143, -0.1875948 ,  0.08493424,
       -0.06048225, -0.03060159,  0.1259561 , -0.16555054, -0.40

In [137]:
grad_torch/grad_plain

array([0.9999994 , 0.99999837, 1.00000019, 0.99999988, 0.99999992,
       0.99999996, 1.00000001, 1.        , 1.0000001 , 0.99999962,
       0.99999992, 1.00000001, 1.00000006, 1.00000003, 0.99999987,
       0.9999995 , 0.99999995, 1.00000006, 0.99999991, 0.99999987,
       1.00000001, 1.00000003, 1.00000003, 0.99999987, 0.99999996,
       1.00000185, 0.99999976, 0.99999992, 0.99999997, 0.99999966,
       1.00000069, 1.00000005, 0.99999984, 0.99999998, 0.99999997,
       1.00000062, 1.0000001 , 0.9999963 , 1.00000001, 1.00000058,
       1.00000017, 0.99999998, 0.99999986, 1.00000024, 1.00000004,
       0.99999989, 1.00000002, 0.99999995, 0.99999353, 1.00000001,
       1.00000011, 0.99999984, 1.00000027, 0.99999995, 0.9999998 ,
       0.99999978, 0.99999993, 0.99999997, 1.00000011, 0.99999991,
       1.00000004, 1.00000002, 1.00000003, 0.99999999, 0.99999993,
       0.99999995, 0.99999993, 0.99999982, 0.99999993, 0.99999998,
       1.        , 0.99999997, 1.00000003, 0.99999994, 0.99999

In [None]:
# ** Robustness gradient is verified to be correct **