In [686]:
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score

def c_generate(x_range):
    bin=len(x_range)
    C=np.random.random((bin,bin))
    for i in range(bin):
        for j in range(bin):
            C[i,j]=abs(x_range[i]-x_range[j])
    return C

def normialise(tem_dist):
    return [tem_dist[i]/sum(tem_dist) for i in range(len(tem_dist))]

def tmp_generator(gamma_dict,num,q_dict,q_num,L):
    bin=gamma_dict[0].shape[0]
    if q_num<=0:
        q=np.matrix(np.ones((bin,bin)))
    else:
        q=q_dict[q_num]
    tmp_gamma=np.zeros((bin,bin))
    tmp_q=np.zeros((bin,bin))
    for i in range(bin):
        for j in range(bin):
            if gamma_dict[num-L].item(i,j) != 0:
                tmp_gamma[i,j]=q.item(i,j)*gamma_dict[num-1].item(i,j)*gamma_dict[num-L-1].item(i,j)/gamma_dict[num-L].item(i,j)
                tmp_q[i,j]=q.item(i,j)*gamma_dict[num-L-1].item(i,j)/gamma_dict[num-L].item(i,j)
            else:
                print('zero division')
                # tmp_gamma[i,j]=q.item(i,j)*gamma_dict[num-1].item(i,j)*gamma_dict[num-L-1].item(i,j)/(1.0e-9)
                # tmp_q[i,j]=q.item(i,j)*gamma_dict[num-L-1].item(i,j)/(1.0e-9)
    return np.matrix(tmp_gamma),np.matrix(tmp_q)     

def assess(bin,f,g,C,V,output):
    bbm1=np.matrix(np.ones(bin)).T
    print('sum of violation of f:',sum(abs(output*bbm1-f)))
    print('sum of violation of g:',sum(abs(output.T*bbm1-g)))
    # print('sum of violation of f:',sum(abs(np.sum(output,1)-np.array(f))))
    # print('sum of violation of g:',sum(abs(np.sum(output,0)-np.array(g))))
    output=output.A1.reshape((bin,bin))
    print('total cost:',sum(sum(output*C)))
    print('entropy:',sum(sum(-output*np.log(output+0.1**3))))
    print('tr violation:',sum(abs(output.T@V)))
    print('============================================')

def plots(x_range,g,f,output):
    fig = plt.figure(figsize=(3,3))
    gs = fig.add_gridspec(2, 2, width_ratios=(4,1), height_ratios=(1,4),left=0.1,right=0.9,bottom=0.1, top=0.9,wspace=0,hspace=0)
    # Create the Axes.
    ax = fig.add_subplot(gs[1, 0])
    ax.pcolormesh(x_range, x_range, output, cmap='Blues')
    ax.set_xlabel(r'supp($X$)',fontsize=10)
    ax.set_ylabel(r'supp($\tilde{X}$)',fontsize=10)#
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax_histx = fig.add_subplot(gs[0, 0], sharex=ax) 
    ax_histy = fig.add_subplot(gs[1, 1], sharey=ax)
    #ax_histx.set_title(r'$Pr[x]$',rotation='horizontal')
    #ax_histy.set_title(r'$Pr[\tilde{x}]$')
    ax_histx.tick_params(axis="x", labelbottom=False)
    ax_histx.tick_params(axis="y", labelleft=False)
    ax_histy.tick_params(axis="x", labelbottom=False)
    ax_histy.tick_params(axis="y", labelleft=False)
    ax_histx.plot(x_range,g,color='tab:blue')
    ax_histy.plot(f,x_range,color='tab:green') 
    return fig
    
def newton(fun,dfun,a, stepmax, tol):
    if abs(fun(a))<=tol: return a
    for step in range(1, stepmax+1):
        b=a-fun(a)/dfun(a)
        if abs(fun(b))<=tol:
            return b
        else:
            a = b
    return b 

# simplist
def baseline(C,e,px,ptx,V,K):
    # V is only used for assessment
    bin=len(px)
    bbm1=np.matrix(np.ones(bin)).T
    #I=np.where(~(V==0))[0].tolist()
    xi=np.exp(-C/e)
    gamma_classic=dict()
    gamma_classic[0]=np.matrix(xi+1.0e-9)
    for repeat in range(K):
        gamma_classic[1+2*repeat]=np.matrix(np.diag((px/(gamma_classic[2*repeat] @ bbm1)).A1))@gamma_classic[2*repeat] #np.diag(dist['x']/sum(gamma_classic.T))@gamma_classic
        gamma_classic[2+2*repeat]=gamma_classic[1+2*repeat]@np.matrix(np.diag((ptx/(gamma_classic[1+2*repeat].T @ bbm1)).A1))

    assess(bin,px,ptx,C,V,gamma_classic[2*K])
    return gamma_classic

# our method | total repair
def total_repair(C,e,px,ptx,V,K):
    bin=len(px)
    bbm1=np.matrix(np.ones(bin)).T
    I=np.where(~(V==0))[0].tolist()
    xi=np.exp(-C/e)
    gamma_dict=dict()
    gamma_dict[0]=np.matrix(xi+1.0e-9)
    gamma_dict[1]=np.matrix(np.diag((px/(gamma_dict[0] @ bbm1)).A1))@gamma_dict[0]
    gamma_dict[2]=gamma_dict[1]@np.matrix(np.diag((ptx/(gamma_dict[1].T @ bbm1)).A1))
    # step 3
    J=np.where(~((gamma_dict[2].T @ V).A1 ==0))[0].tolist()
    nu=np.zeros(bin)
    gamma_dict[3]=np.copy(gamma_dict[2])
    for j in J:
        fun = lambda z: sum(gamma_dict[2].item(i,j)*V.item(i)*np.exp(z*V.item(i)) for i in I)
        dfun = lambda z: sum(gamma_dict[2].item(i,j)*(V.item(i))**2*np.exp(z*V.item(i)) for i in I)
        nu = newton(fun,dfun,0.1,stepmax = 50,tol = 1.0e-5) #bisection(fun, -50,50, stepmax = 25, tol = 1.0e-3)
        for i in I:
            gamma_dict[3][i,j]=np.exp(nu*V.item(i))*gamma_dict[2].item(i,j)
    gamma_dict[3]=np.matrix(gamma_dict[3])

    #=========================
    L=3
    q_dict=dict()
    for loop in range(1,K):
        tmp,q_dict[(loop-1)*L+1]=tmp_generator(gamma_dict,loop*L+1,q_dict,(loop-2)*L+1,L) #np.matrix(gamma_dict[3].A1*gamma_dict[0].A1/gamma_dict[1].A1)
        gamma_dict[loop*L+1]=np.matrix(np.diag((px/(tmp @ bbm1)).A1))@tmp

        tmp,q_dict[(loop-1)*L+2]=tmp_generator(gamma_dict,loop*L+2,q_dict,(loop-2)*L+2,L)  #np.matrix(gamma_dict[4].A1*gamma_dict[1].A1/gamma_dict[2].A1)
        gamma_dict[loop*L+2]=tmp@np.matrix(np.diag((ptx/(tmp.T @ bbm1)).A1))

        # step 3
        tmp,q_dict[(loop-1)*L+3]=tmp_generator(gamma_dict,loop*L+3,q_dict,(loop-2)*L+3,L)  #np.matrix(gamma_dict[5].A1*gamma_dict[2].A1/gamma_dict[3].A1)
        J=np.where(~((abs(np.matrix(tmp).T @ V).A1)<=1.0e-5))[0].tolist()
        gamma_dict[loop*L+3]=np.copy(tmp)
        for j in J:
            fun = lambda z: sum(tmp.item(i,j)*V.item(i)*np.exp(z*V.item(i)) for i in I)
            dfun = lambda z: sum(tmp.item(i,j)*(V.item(i))**2*np.exp(z*V.item(i)) for i in I)
            nu = newton(fun,dfun,0.1,stepmax = 50,tol = 1.0e-5) 
            for i in I:
                gamma_dict[loop*L+3][i,j]=np.exp(nu*V.item(i))*tmp.item(i,j)
        gamma_dict[loop*L+3]=np.matrix(gamma_dict[loop*L+3])

    assess(bin,px,ptx,C,V,gamma_dict[K*L])
    return gamma_dict

# our method | partial repair
def partial_repair(C,e,px,ptx,V,theta_scale,K):
    bin=len(px)
    bbm1=np.matrix(np.ones(bin)).T
    I=np.where(~(V==0))[0].tolist()
    xi=np.exp(-C/e)
    theta=bbm1*theta_scale
    gamma_dict=dict()
    gamma_dict[0]=np.matrix(xi+1.0e-9)
    gamma_dict[1]=np.matrix(np.diag((px/(gamma_dict[0] @ bbm1)).A1))@gamma_dict[0]
    gamma_dict[2]=gamma_dict[1]@np.matrix(np.diag((ptx/(gamma_dict[1].T @ bbm1)).A1))
    # step 3
    Jplus=np.where(~((gamma_dict[2].T @ V).A1 <=theta.A1))[0].tolist()
    Jminus=np.where(~((gamma_dict[2].T @ V).A1>=-theta.A1))[0].tolist()
    gamma_dict[3]=np.copy(gamma_dict[2])
    for j in Jplus:
        fun = lambda z: sum(gamma_dict[2].item(i,j)*V.item(i)*np.exp(-z*V.item(i)) for i in I)-theta.item(j)
        dfun = lambda z: -sum(gamma_dict[2].item(i,j)*(V.item(i))**2*np.exp(-z*V.item(i)) for i in I)
        nu = newton(fun,dfun,0.1,stepmax = 50,tol = 1.0e-5) #bisection(fun, -50,50, stepmax = 25, tol = 1.0e-3)
        for i in I:
            gamma_dict[3][i,j]=np.exp(-nu*V.item(i))*gamma_dict[2].item(i,j)
    for j in Jminus:
        fun = lambda z: sum(gamma_dict[2].item(i,j)*V.item(i)*np.exp(-z*V.item(i)) for i in I)+theta.item(j)
        dfun = lambda z: -sum(gamma_dict[2].item(i,j)*(V.item(i))**2*np.exp(-z*V.item(i)) for i in I)
        nu = newton(fun,dfun,0.1,stepmax = 50,tol = 1.0e-5) #bisection(fun, -50,50, stepmax = 25, tol = 1.0e-3)
        for i in I:
            gamma_dict[3][i,j]=np.exp(-nu*V.item(i))*gamma_dict[2].item(i,j)
    gamma_dict[3]=np.matrix(gamma_dict[3])

    #=========================
    L=3
    q_dict=dict()
    for loop in range(1,K):
        print(loop)
        tmp,q_dict[(loop-1)*L+1]=tmp_generator(gamma_dict,loop*L+1,q_dict,(loop-2)*L+1,L) #np.matrix(gamma_dict[3].A1*gamma_dict[0].A1/gamma_dict[1].A1)
        gamma_dict[loop*L+1]=np.matrix(np.diag((px/(tmp @ bbm1)).A1))@tmp

        tmp,q_dict[(loop-1)*L+2]=tmp_generator(gamma_dict,loop*L+2,q_dict,(loop-2)*L+2,L)  #np.matrix(gamma_dict[4].A1*gamma_dict[1].A1/gamma_dict[2].A1)
        gamma_dict[loop*L+2]=tmp@np.matrix(np.diag((ptx/(tmp.T @ bbm1)).A1))

        # step 3
        tmp,q_dict[(loop-1)*L+3]=tmp_generator(gamma_dict,loop*L+3,q_dict,(loop-2)*L+3,L)  #np.matrix(gamma_dict[5].A1*gamma_dict[2].A1/gamma_dict[3].A1)
        Jplus=np.where(~((np.matrix(tmp).T @ V).A1 <=theta.A1))[0].tolist()
        Jminus=np.where(~((np.matrix(tmp).T @ V).A1 >=-theta.A1))[0].tolist()
        gamma_dict[loop*L+3]=np.copy(tmp)
        for j in Jplus:
            fun = lambda z: sum(tmp.item(i,j)*V.item(i)*np.exp(-z*V.item(i)) for i in I)-theta.item(j)
            dfun = lambda z: -sum(tmp.item(i,j)*(V.item(i))**2*np.exp(-z*V.item(i)) for i in I)
            nu = newton(fun,dfun,0.5,stepmax = 50,tol = 1.0e-5) 
            for i in I:
                gamma_dict[loop*L+3][i,j]=np.exp(-nu*V.item(i))*tmp.item(i,j)
        for j in Jminus:
            fun = lambda z: sum(tmp.item(i,j)*V.item(i)*np.exp(-z*V.item(i)) for i in I)+theta.item(j)
            dfun = lambda z: -sum(tmp.item(i,j)*(V.item(i))**2*np.exp(-z*V.item(i)) for i in I)
            nu = newton(fun,dfun,0.5,stepmax = 50,tol = 1.0e-5) 
            for i in I:
                gamma_dict[loop*L+3][i,j]=np.exp(-nu*V.item(i))*tmp.item(i,j)
        gamma_dict[loop*L+3]=np.matrix(gamma_dict[loop*L+3])

    assess(bin,px,ptx,C,V,gamma_dict[L*K])
    return gamma_dict

def empirical_distribution(sub,x_range):
    bin=len(x_range)
    distrition=np.zeros(bin)
    for i in range(bin):
        subset=sub[sub['X']==x_range[i]] #bin_value=x_range[i] #sub[(sub['X']>=bin_value)&(sub['X']<bin_value+width)]
        if subset.shape[0]>0:
            distrition[i]=sum(subset['W'])
    if sum(distrition)>0:
        return distrition/sum(distrition)
    else:
        return distrition

def projection(df,coupling_matrix,x_range):
    bin=len(x_range)
    coupling=coupling_matrix.A1.reshape((bin,bin))
    df_t=pd.DataFrame(columns=['X','S','W','Y'])
    for i in range(df.shape[0]):
        orig=df.iloc[i]
        loc=np.where(x_range==orig[0])[0][0]
        rows=np.nonzero(coupling[loc,:])[0]
        sub=pd.DataFrame(columns=['X','W'],index=rows)
        sub['X']=x_range[rows]
        sub['W']=coupling[loc,rows]/(sum(coupling[loc,rows]))*orig[2]
        sub['S']=orig[1]
        sub['Y']=orig[3]
        df_t=pd.concat([df_t, samples_groupby(sub)], ignore_index=True)
    return df_t

def plot_rdist(rdist,x_range):
    plt.plot(x_range,rdist['x'],label=r'$Pr[x]$',color='tab:blue')
    plt.plot(x_range,rdist['x_0'],label=r'$Pr[x|s_0]$',alpha=0.3,color='tab:orange')
    plt.plot(x_range,rdist['x_1'],label=r'$Pr[x|s_1]$',alpha=0.3,color='#9f86c0')
    plt.ylabel('PMF',fontsize=14)
    plt.xlabel(r'$supp(X)=supp(\tilde{X})$',fontsize=20)
    plt.legend()
    return plt

# def DisparateImpact(X_test,y_test,y_pred):
#     df_test=pd.DataFrame(np.concatenate((X_test,y_test.reshape(-1,1),y_pred.reshape(-1,1)), axis=1),columns=['X','S','W','Y','f'])
#     numerator=sum(df_test[(df_test['S']==0)&(df_test['f']==1)&(df_test['Y']==1)]['W'])/sum(df_test[(df_test['S']==0)&(df_test['Y']==1)]['W'])
#     denominator=sum(df_test[(df_test['S']==1)&(df_test['f']==1)&(df_test['Y']==1)]['W'])/sum(df_test[(df_test['S']==1)&(df_test['Y']==1)]['W'])
#     return numerator/denominator

# def DisparateImpact(X_test,y_pred):
#     df_test=pd.DataFrame(np.concatenate((X_test,y_pred.reshape(-1,1)), axis=1),columns=['X','S','W','f'])
#     numerator=sum(df_test[(df_test['S']==0)&(df_test['f']==1)]['W'])/sum(df_test[df_test['S']==0]['W'])
#     denominator=sum(df_test[(df_test['S']==1)&(df_test['f']==1)]['W'])/sum(df_test[df_test['S']==1]['W'])
#     if numerator==denominator:
#         return 1
#     return numerator/denominator

def DisparateImpact_postprocess(df_test,y_pred_tmp):
    df_test_tmp=df_test[:]
    df_test_tmp.insert(loc=0, column='f', value=y_pred_tmp)
    numerator=sum(df_test_tmp[(df_test_tmp['S']==0)&(df_test_tmp['f']==1)]['W'])/sum(df_test_tmp[df_test_tmp['S']==0]['W'])
    denominator=sum(df_test_tmp[(df_test_tmp['S']==1)&(df_test_tmp['f']==1)]['W'])/sum(df_test_tmp[df_test_tmp['S']==1]['W'])
    # if numerator==denominator:
    #     return 1
    return numerator/denominator

def rdata_analysis(rdata,x_range,x_name):
    rdist=dict()
    pivot=pd.pivot_table(rdata,index=x_name,values=['W'],aggfunc=[np.sum])[('sum','W')]
    rdist['x']= np.array([pivot[i] for i in x_range])/sum([pivot[i] for i in x_range]) #empirical_distribution(rdata,x_range)
    if rdata[rdata['S']==0].shape[0]>0:
        pivot0=pd.pivot_table(rdata[rdata['S']==0],index=x_name,values=['W'],aggfunc=[np.sum])[('sum','W')]
        rdist['x_0']=np.array([pivot0[i] if i in list(pivot0.index) else 0 for i in x_range])/sum([pivot0[i] if i in list(pivot0.index) else 0 for i in x_range]) #empirical_distribution(rdata[rdata['S']==0],x_range)
    if rdata[rdata['S']==1].shape[0]>0:
        pivot1=pd.pivot_table(rdata[rdata['S']==1],index=x_name,values=['W'],aggfunc=[np.sum])[('sum','W')]
        rdist['x_1']=np.array([pivot1[i] if i in list(pivot1.index) else 0 for i in x_range])/sum([pivot1[i] if i in list(pivot1.index) else 0 for i in x_range]) #empirical_distribution(rdata[rdata['S']==1],x_range)
    return rdist

def c_generate_higher(x_range,weight):
    bin=len(x_range)
    dim=len(x_range[0])
    C=np.random.random((bin,bin))
    for i in range(bin):
        for j in range(bin):
            C[i,j]=sum(weight[d]*abs(x_range[i][d]-x_range[j][d]) for d in range(dim))
    return C

def c_generate(x_range):
    bin=len(x_range)
    C=np.random.random((bin,bin))
    for i in range(bin):
        for j in range(bin):
            C[i,j]=abs(x_range[i]-x_range[j])
    return C

def projection_higher(df,coupling_matrix,x_range,x_list,var_list):
    df=df.drop(columns=x_list)
    dim=len(x_list)
    bin=len(x_range)
    arg_list=[elem for elem in var_list if elem not in x_list]
    df=df[arg_list+['X','S','W','Y']]
    coupling=coupling_matrix.A1.reshape((bin,bin))
    df_t=pd.DataFrame(columns=arg_list+['X','S','W','Y'])
    for i in range(df.shape[0]):
        orig=df.iloc[i]
        loc=np.where([x_range[b]==orig['X'] for b in range(bin)])[0][0]
        #rows=np.nonzero(coupling[loc,:])[0]
        sub_dict={'X':x_range,'W':list(coupling[loc,:]/(sum(coupling[loc,:]))*orig['W'])}
        sub_dict.update({var:[orig[var]]*bin for var in arg_list+['S','Y']})
        sub=pd.DataFrame(data=sub_dict, index=[*range(bin)])
        df_t=pd.concat([df_t,sub],ignore_index=True)#pd.concat([df_t,samples_groupby(sub,x_list)], ignore_index=True)
    return df_t
    df_t=df_t.groupby(by=arg_list+['X','S','Y'],as_index=False).sum()
    for d in range(dim):
        df_t[x_list[d]]=[df_t['X'][r][d] for r in range(df_t.shape[0])]
    return df_t[var_list+['S','W','Y']]

# def postprocess(df_test,coupling_matrix,x_range,x_range_pred):
#     bin=len(x_range)
#     coupling=coupling_matrix.A1.reshape((bin,bin))
#     x_pred_repaired=[]
#     for loc in range(bin):
#         ##loc=np.where([x_range[i]==orig['X'] for i in range(bin)])[0][0]
#         rows=np.nonzero(coupling[loc,:])[0]
#         p0=sum(x_range_pred[:,0]*coupling[loc,rows]/(sum(coupling[loc,rows])))
#         p1=sum(x_range_pred[:,1]*coupling[loc,rows]/(sum(coupling[loc,rows])))
#         if p0>p1:
#             x_pred_repaired+=[0]
#         else:
#             x_pred_repaired+=[1]
#     pred_dict=dict(zip(x_range,x_pred_repaired))
#     f_repaired=[pred_dict[df_test['X'][i]] for i in range(df_test.shape[0])]
#     return np.array(f_repaired)

def postprocess(df,coupling_matrix,x_list,x_range,var_list,var_range,clf):
    dim=len(x_list)
    var_dim=len(var_list)
    bin=len(x_range)
    x_loc_dict=dict(zip(x_range,[*range(bin)]))
    arg_list=[elem for elem in var_list if elem not in x_list]
    coupling=coupling_matrix.A1.reshape((bin,bin))
    pred_repaired=dict()
    for i in range(len(var_range)):
        if var_dim>1:
            var_tmp=pd.Series({var_list[d]:var_range[i][d] for d in range(var_dim)})
            if dim>1:
                loc=x_loc_dict[tuple(var_tmp[x_list])]
            else:
                loc=x_loc_dict[var_tmp[x_list[0]]]
        else:
            var_tmp=pd.Series({var_list[0]:var_range[i]})
            loc=x_loc_dict[var_tmp[x_list[0]]]
        sub=pd.DataFrame(x_range,columns=x_list)
        for arg in arg_list:
            sub[arg]=var_tmp[arg] 
        sub=sub[var_list]
        totalweight=sum(coupling[loc,:])
        pred=int(sum(coupling[loc,:]/totalweight*clf.predict(np.array(sub).reshape(-1,var_dim)))>0.1)
        pred_repaired.update({var_range[i]:pred})
        # prob=clf.predict_log_proba(np.array(sub).reshape(-1,var_dim)) #log is better
        # prob0=sum(prob[:,0]*coupling[loc,:]/totalweight)
        # prob1=sum(prob[:,1]*coupling[loc,:]/totalweight)
        # pred_repaired.update({var_range[i]:int(prob0<prob1)})
    if var_dim>1:
        return np.array([pred_repaired[tuple(df[var_list].iloc[i])] for i in range(df.shape[0])])
    else:
        return np.array([pred_repaired[df[var_list[0]].iloc[i]] for i in range(df.shape[0])])

def postprocess_bary(df,coupling_bary_matrix,x_list,x_range,var_list,var_range,clf):
    bin=len(x_range)
    coupling_bary=coupling_bary_matrix.A1.reshape((bin,bin))
    s0=df[df['S']==0]
    s1=df[df['S']==1]
    pi0=sum(s0['W'])/sum(df['W']) #s0.shape[0]/df.shape[0]
    pi1=sum(s1['W'])/sum(df['W']) #s1.shape[0]/df.shape[0]
    coupling0=np.zeros((bin,bin))
    coupling1=np.zeros((bin,bin))
    for i in range(bin):
        for j in range(bin):
            # assume the distance between every two adjacent x indices is the same
            ind0=int(pi0*i+pi1*j)
            ind1=int(pi0*j+pi1*i)
            coupling0[i,ind0]+=coupling_bary[i,j]
            coupling1[i,ind1]+=coupling_bary[j,i]

    # assess if dist['td{x}_0']==dist['td{x}_1']
    projectedDist_s0=rdata_analysis(projection_higher(s0,np.matrix(coupling0),x_range,x_list,var_list),x_range,'X')['x_0']
    projectedDist_s1=rdata_analysis(projection_higher(s1,np.matrix(coupling1),x_range,x_list,var_list),x_range,'X')['x_1']
    #print('tv distance between projected S-wise distributions',sum(abs(projectedDist_s0-projectedDist_s1))/2)

    s0.insert(loc=0, column='f', value=postprocess(s0,np.matrix(coupling0),x_list,x_range,var_list,var_range,clf))
    s1.insert(loc=0, column='f', value=postprocess(s1,np.matrix(coupling1),x_list,x_range,var_list,var_range,clf))
    s_concate=pd.concat([s0,s1], ignore_index=False)
    s_concate.sort_index()
    return np.array(s_concate['f']),sum(abs(projectedDist_s0-projectedDist_s1))/2

In [687]:
def assess_tv(df,coupling_matrix,x_range,x_list,var_list):
    df_project=projection_higher(df,coupling_matrix,x_range,x_list,var_list)
    rdist=rdata_analysis(df_project,x_range,'X')
    return sum(abs(rdist['x_0']-rdist['x_1']))/2

In [None]:
# assess if dist['td{x}_0']==dist['td{x}_1']
projectedDist_s0=rdata_analysis(projection_higher(s0,np.matrix(coupling0),x_range,x_list,var_list),x_range,'X')['x_0']
projectedDist_s1=rdata_analysis(projection_higher(s1,np.matrix(coupling1),x_range,x_list,var_list),x_range,'X')['x_1']
print('tv distance between projected S-wise distributions',sum(abs(projectedDist_s0-projectedDist_s1))/2)


In [560]:
#gender wise
print(pa)
tv_dist=dict()
for x_name in var_list:
    x_range_single=list(pd.pivot_table(messydata,index=x_name,values=['W'])[('W')].index) 
    dist=rdata_analysis(messydata,x_range_single,x_name)
    tv_dist[x_name]=round(sum(abs(dist['x_0']-dist['x_1']))/2,4)
tv_dist

sex


{'hoursperweek': 0.1819,
 'age': 0.101,
 'education-num': 0.071,
 'capitalgain': 0.0369,
 'capitalloss': 0.0201}

In [559]:
#racewise
print(pa)
tv_dist=dict()
for x_name in var_list:
    x_range_single=list(pd.pivot_table(messydata,index=x_name,values=['W'])[('W')].index) 
    dist=rdata_analysis(messydata,x_range_single,x_name)
    tv_dist[x_name]=round(sum(abs(dist['x_0']-dist['x_1']))/2,4)
tv_dist

sex


{'hoursperweek': 0.1819,
 'age': 0.101,
 'education-num': 0.071,
 'capitalgain': 0.0369,
 'capitalloss': 0.0201}

In [648]:
K=200
e=1.0e-3

var_list=['hoursperweek','age','education-num','capitalgain','capitalloss'] # ,'capitalgain','capitalloss','age','education-num','capitalgain','hoursperweek','capitalgain'
var_dim=len(var_list)
pa='sex'
pa_dict={'Male':1,'Female':0,'White':1,'Black':0}

messydata=pd.read_csv('C:/Users/zhouq/Documents/optimal_transport/adult_csv.csv',usecols=var_list+[pa,'class'])
messydata=messydata.rename(columns={pa:'S','class':'Y'})
messydata['S']=messydata['S'].replace(pa_dict)
messydata['Y']=messydata['Y'].replace({'>50K':1,'<=50K':0})
messydata=messydata[(messydata['S']==1)|(messydata['S']==0)]
for col in var_list+['S','Y']:
    messydata[col]=messydata[col].astype('category')
messydata['W']=1
X=messydata[var_list+['S','W']].to_numpy() # [X,S,W]
y=messydata['Y'].to_numpy() #[Y]

tv_dist=dict()
for x_name in var_list:
    x_range_single=list(pd.pivot_table(messydata,index=x_name,values=['W'])[('W')].index) 
    dist=rdata_analysis(messydata,x_range_single,x_name)
    tv_dist[x_name]=sum(abs(dist['x_0']-dist['x_1']))/2
x_list=[]
for key,val in tv_dist.items():
    if val>0.11:
        x_list+=[key]
x_list=['hoursperweek','age']
print(x_list)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4)
clf=RandomForestClassifier(max_depth=5, random_state=0).fit(X_train[:,0:var_dim],y_train)
df_test=pd.DataFrame(np.concatenate((X_test,y_test.reshape(-1,1)), axis=1),columns=var_list+['S','W','Y'])
df_test=df_test.groupby(by=var_list+['S','Y'],as_index=False).sum()
if len(x_list)>1:
    df_test['X']=[tuple(df_test[x_list].values[r]) for r in range(df_test.shape[0])]
    x_range=list(set(df_test['X']))
    weight=list(1/(df_test[x_list].max()-df_test[x_list].min())) # because 'education-num' range from 1 to 16 while others 1 to 4
    C=c_generate_higher(x_range,weight)
else:
    df_test['X']=df_test[x_list[0]]
    x_range=list(set(df_test['X']))
    C=c_generate(x_range)

bin=len(x_range)
var_range=list(pd.pivot_table(df_test,index=var_list,values=['S','W','Y']).index)
dist=rdata_analysis(df_test,x_range,'X')
dist['t_x']=dist['x'] 
dist['v']=[(dist['x_0'][i]-dist['x_1'][i])/dist['x'][i] for i in range(bin)]

px=np.matrix(dist['x']).T
ptx=np.matrix(dist['t_x']).T
if np.any(dist['x_0']==0):
    p0=np.matrix((dist['x_0']+1.0e-9)/sum(dist['x_0']+1.0e-9)).T
else:
    p0=np.matrix(dist['x_0']).T 
if np.any(dist['x_1']==0):
    p1=np.matrix((dist['x_1']+1.0e-9)/sum(dist['x_1']+1.0e-9)).T
else:
    p1=np.matrix(dist['x_1']).T 
V=np.matrix(dist['v']).T

['hoursperweek', 'age']


In [701]:
sum(abs(p0-p1))

matrix([[0.40801111]])

In [695]:
abs(sum(p0-p1))

matrix([[5.20417043e-18]])

In [680]:
y_pred_bary,=postprocess_bary(df_test,baseline(C,e,p0,p1,V,K)[K*2],x_list,x_range,var_list,var_range,clf)

sum of violation of f: [[0.00410803]]
sum of violation of g: [[1.14491749e-16]]
total cost: 0.20026769724734314
entropy: 3.3258740677245995
tr violation: [[0.21079305]]
tv distance between projected S-wise distributions 0.002031649233070215


In [674]:
y_pred_base=postprocess(df_test,baseline(C,e,px,ptx,V,K)[K*2],x_list,x_range,var_list,var_range,clf)

sum of violation of f: [[2.07105945e-08]]
sum of violation of g: [[8.41340886e-17]]
total cost: 1.814852514784848e-08
entropy: 2.7022608503339
tr violation: [[0.4080111]]


In [690]:
part_matrix=partial_repair(C,e,px,ptx,V,1.0e-4,K)[K*3]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
sum of violation of f: [[0.00044319]]
sum of violation of g: [[1.87781873e-06]]
total cost: 0.27021074087274827
entropy: 3.558424757252049
tr violation: [[0.0025601]]


In [691]:
assess_tv(df_test,part_matrix,x_range,x_list,var_list)

0.0012526462266495635

In [675]:
y_pred_part2=postprocess(df_test,part_matrix,x_list,x_range,var_list,var_range,clf)

In [668]:
y_pred=clf.predict(np.array(df_test[var_list]))

In [681]:
new_row=pd.Series({'DI':DisparateImpact_postprocess(df_test,y_pred),
                    'f1 macro':f1_score(df_test['Y'], y_pred, average='macro',sample_weight=df_test['W']),
                    'f1 micro':f1_score(df_test['Y'], y_pred, average='micro',sample_weight=df_test['W']),
                    'f1 weighted':f1_score(df_test['Y'], y_pred, average='weighted',sample_weight=df_test['W']),
                    'method':'origin'})
new_row_base=pd.Series({'DI':DisparateImpact_postprocess(df_test,y_pred_base),
                    'f1 macro':f1_score(df_test['Y'], y_pred_base, average='macro',sample_weight=df_test['W']),
                    'f1 micro':f1_score(df_test['Y'], y_pred_base, average='micro',sample_weight=df_test['W']),
                    'f1 weighted':f1_score(df_test['Y'], y_pred_base, average='weighted',sample_weight=df_test['W']),
                    'method':'baseline'})
new_row_bary=pd.Series({'DI':DisparateImpact_postprocess(df_test,y_pred_bary),
                    'f1 macro':f1_score(df_test['Y'], y_pred_bary, average='macro',sample_weight=df_test['W']),
                    'f1 micro':f1_score(df_test['Y'], y_pred_bary, average='micro',sample_weight=df_test['W']),
                    'f1 weighted':f1_score(df_test['Y'], y_pred_bary, average='weighted',sample_weight=df_test['W']),
                    'method':'barycentre'})
new_row_part2=pd.Series({'DI':DisparateImpact_postprocess(df_test,y_pred_part2),
                    'f1 macro':f1_score(df_test['Y'], y_pred_part2, average='macro',sample_weight=df_test['W']),
                    'f1 micro':f1_score(df_test['Y'], y_pred_part2, average='micro',sample_weight=df_test['W']),
                    'f1 weighted':f1_score(df_test['Y'], y_pred_part2, average='weighted',sample_weight=df_test['W']),
                    'method':'partial repair2'})

In [682]:
print(new_row)

DI             0.441977
f1 macro       0.689343
f1 micro       0.819522
f1 weighted    0.793665
method           origin
dtype: object


In [683]:
print(new_row_base)

DI             0.441977
f1 macro       0.689343
f1 micro       0.819522
f1 weighted    0.793665
method         baseline
dtype: object


In [684]:
print(new_row_bary)

DI               0.975972
f1 macro          0.49261
f1 micro         0.591544
f1 weighted      0.608838
method         barycentre
dtype: object


In [685]:
print(new_row_part2)

DI                    0.730197
f1 macro               0.66998
f1 micro              0.789067
f1 weighted           0.772821
method         partial repair2
dtype: object


In [345]:
s0=df_test[df_test['S']==0]
s1=df_test[df_test['S']==1]
projectedDist_s0=rdata_analysis(projection_higher(s0,coupling_matrix,x_range,x_list,var_list),x_range,'X')['x_0']
projectedDist_s1=rdata_analysis(projection_higher(s1,coupling_matrix,x_range,x_list,var_list),x_range,'X')['x_1']
print('tv distance between projected S-wise distributions',sum(abs(projectedDist_s0-projectedDist_s1))/2)

tv distance between projected S-wise distributions 0.001265520417256435


In [599]:
dim=len(x_list)
var_dim=len(var_list)
bin=len(x_range)
x_loc_dict=dict(zip(x_range,[*range(bin)]))
arg_list=[elem for elem in var_list if elem not in x_list]
coupling=coupling_matrix.A1.reshape((bin,bin))
pred_repaired=dict()

In [597]:
coupling_matrix=partial_repair(C,e,px,ptx,V,1.0e-3,K)[K*3]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
sum of violation of f: [[0.00350642]]
sum of violation of g: [[0.00021465]]
total cost: 0.24851607487542127
entropy: 3.4755895224011546
tr violation: [[0.0246825]]


In [600]:
for i in range(len(var_range)):
    if var_dim>1:
        var_tmp=pd.Series({var_list[d]:var_range[i][d] for d in range(var_dim)})
        if dim>1:
            loc=x_loc_dict[tuple(var_tmp[x_list])]
        else:
            loc=x_loc_dict[var_tmp[x_list[0]]]
    else:
        var_tmp=pd.Series({var_list[0]:var_range[i]})
        loc=x_loc_dict[var_tmp[x_list[0]]]
    sub=pd.DataFrame(x_range,columns=x_list)
    for arg in arg_list:
        sub[arg]=var_tmp[arg] 
    sub=sub[var_list]
    totalweight=sum(coupling[loc,:])
    pred=int(sum(coupling[loc,:]/totalweight*clf.predict(np.array(sub).reshape(-1,var_dim)))>0.5)
    # prob=clf.predict_proba(np.array(sub).reshape(-1,var_dim))
    # prob0=sum(prob[:,0]*coupling[loc,:]/totalweight)
    # prob1=sum(prob[:,1]*coupling[loc,:]/totalweight)
    pred_repaired.update({var_range[i]:pred}) #int(prob0<prob1)

In [606]:
clf.predict_proba(np.array(sub).reshape(-1,var_dim))

array([[0.74909542, 0.25090458],
       [0.36227798, 0.63772202],
       [0.35762101, 0.64237899],
       [0.53673357, 0.46326643],
       [0.58130788, 0.41869212],
       [0.41536291, 0.58463709],
       [0.83281872, 0.16718128],
       [0.56291521, 0.43708479],
       [0.35808551, 0.64191449],
       [0.74837815, 0.25162185],
       [0.35388608, 0.64611392],
       [0.69576468, 0.30423532],
       [0.42176291, 0.57823709],
       [0.56398979, 0.43601021],
       [0.58927582, 0.41072418],
       [0.64446848, 0.35553152],
       [0.35435058, 0.64564942],
       [0.53788534, 0.46211466],
       [0.36549165, 0.63450835],
       [0.8484607 , 0.1515393 ],
       [0.68071464, 0.31928536],
       [0.57961655, 0.42038345],
       [0.80399078, 0.19600922],
       [0.57561796, 0.42438204],
       [0.41428833, 0.58571167]])

In [602]:
int(sum(coupling[loc,:]/totalweight*clf.predict(np.array(sub).reshape(-1,var_dim)))>0.5)

1

In [601]:
pred_repaired

{(0, 0, 3, 0, 0): 0,
 (0, 0, 4, 0, 0): 0,
 (0, 0, 5, 0, 0): 0,
 (0, 0, 5, 1, 0): 0,
 (0, 0, 6, 0, 0): 0,
 (0, 0, 6, 0, 2): 0,
 (0, 0, 6, 1, 0): 0,
 (0, 0, 7, 0, 0): 0,
 (0, 0, 7, 0, 2): 0,
 (0, 0, 7, 1, 0): 0,
 (0, 0, 8, 0, 0): 0,
 (0, 0, 8, 0, 2): 0,
 (0, 0, 9, 0, 0): 0,
 (0, 0, 9, 0, 2): 0,
 (0, 0, 9, 4, 0): 1,
 (0, 0, 10, 0, 0): 0,
 (0, 0, 10, 0, 2): 0,
 (0, 0, 10, 1, 0): 0,
 (0, 0, 11, 0, 0): 0,
 (0, 0, 12, 0, 0): 0,
 (0, 0, 12, 0, 2): 0,
 (0, 0, 12, 0, 4): 0,
 (0, 0, 12, 1, 0): 0,
 (0, 0, 13, 0, 0): 0,
 (0, 0, 13, 0, 2): 1,
 (0, 0, 13, 1, 0): 0,
 (0, 0, 14, 0, 0): 0,
 (0, 1, 3, 0, 0): 0,
 (0, 1, 4, 0, 0): 0,
 (0, 1, 5, 0, 0): 0,
 (0, 1, 6, 0, 0): 0,
 (0, 1, 7, 0, 0): 0,
 (0, 1, 7, 1, 0): 0,
 (0, 1, 8, 0, 0): 0,
 (0, 1, 9, 0, 0): 0,
 (0, 1, 9, 0, 3): 0,
 (0, 1, 10, 0, 0): 0,
 (0, 1, 10, 0, 1): 0,
 (0, 1, 10, 0, 2): 0,
 (0, 1, 10, 0, 4): 0,
 (0, 1, 10, 2, 0): 1,
 (0, 1, 10, 3, 0): 1,
 (0, 1, 11, 0, 0): 0,
 (0, 1, 11, 1, 0): 0,
 (0, 1, 11, 2, 0): 1,
 (0, 1, 12, 0, 0): 0,
 (0, 1, 12, 

In [111]:
var_tmp

education-num    16
dtype: int64

In [114]:
sub=pd.DataFrame(x_range,columns=x_list)
sub

Unnamed: 0,education-num
0,1
1,2
2,3
3,4
4,5
5,6
6,7
7,8
8,9
9,10


In [115]:
for arg in arg_list:
    sub[arg]=var_tmp[arg] 
sub

Unnamed: 0,education-num
0,1
1,2
2,3
3,4
4,5
5,6
6,7
7,8
8,9
9,10


In [120]:
prob=np.array([clf.predict_proba(np.array(sub.loc[r]).reshape(-1,var_dim)) for r in range(bin)])

In [123]:
coupling[loc,:]

array([3.07730275e-14, 3.67173156e-14, 3.85322538e-14, 1.14772642e-04,
       1.94310306e-10, 6.44095012e-06, 5.03035187e-07, 1.12866799e-10,
       8.30214407e-03, 4.05109475e-14, 2.52966596e-32, 5.80352261e-05,
       3.20772003e-26, 1.16886024e-20, 1.60410776e-11, 2.31832146e-03])

In [None]:
report=pd.read_csv('C:/Users/zhouq/Documents/optimal_transport/report_postprocess_sex.csv')
performance=list(report.columns)[:-1]
methods=['origin','baseline','partial_repair2','partial_repair3'] #,'partial_repair4'
#list(set(report['method']))
colors=['#5f0f40','#9a031e','#FF8811','#F4D06F','#9DD9D2']
pivot=pd.pivot_table(report,index=['method'],values=performance,aggfunc=[np.mean,np.std])
pivot

In [None]:
ind=np.arange(len(performance))
width = 0.15

fig, ax1 = plt.subplots(figsize=(12,6))

i=0
for m in methods:
    sub=[pivot[('mean',p)][m] for p in performance]
    err=[pivot[('std',p)][m] for p in performance]
    ax1.bar(ind+width*i,sub,width,yerr=err,color=colors[i],label=m)
    i+=1
    
ax1.legend(['Origin','Baseline',r'$10^{-2}$-repair',r'$10^{-3}$-repair',r'$5\times 10^{-4}$-repair'],fontsize=14,
            framealpha=0.2,bbox_to_anchor=(1.03,1.1),ncol=5,frameon=False)
ax1.set_xlabel('Indices', fontsize=20)
ax1.set_ylabel('Values', fontsize=20)

ax1.set_xticks(ind+width*2)
ax1.set_xticklabels(['DisparateImpact']+[i for i in performance[1:]], fontsize=16)
ax1.tick_params(axis='y', which='major', labelsize=14)

plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42

#plt.savefig("C:/Users/zhouq/Documents/optimal_transport/adult_higher.pdf",bbox_inches='tight')