In [1]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim 
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader, random_split
import gurobipy as gp
from gurobipy import GRB
from tqdm import tqdm 
from helper import *  

In [2]:
# set parameters  
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
dataset = 'lsac'
sens_attribute = 'race'
nominal = False
merit = True
num_epochs = 30 
lr_theta = 1e-3 
lr_z = 1e-1 
epsilon = 1e-2 
delta = 0.5 

In [3]:
train_dataset = all_data(sens_attr=sens_attribute, train=True, dataset=dataset)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last=True)
X_train, y_train, sens_train = train_dataset.get_values()
flag = train_dataset.get_flag()
X_train_moments = train_dataset.get_moments()

X, y, s, idx = next(iter(train_loader))
num_feat = X.shape[1]

test_dataset = all_data(sens_attr=sens_attribute, train=False, dataset=dataset)
df = test_dataset.get_df(train=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, drop_last=False)
X_test, y_test, sens_test = test_dataset.get_values()
X_test, y_test, sens_test = X_test.numpy(), y_test.numpy(), sens_test.numpy()
columns = list(df.columns)

In [4]:
# Define merit attributes 
dataset_features = train_dataset.get_columns()
dataset_features

['decile1b',
 'decile3',
 'lsat',
 'ugpa',
 'zfygpa',
 'zgpa',
 'fulltime',
 'fam_inc',
 'male',
 'tier',
 'race']

In [5]:
merit_attrs = [2,3]

In [6]:
# Define model 
class LogisticRegressionModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LogisticRegressionModel, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        out = self.linear(x)
        return out

In [7]:
X_train = torch.Tensor(X_train)
sens_train = torch.Tensor(list(sens_train))
y_train = torch.Tensor(list(y_train)) 
z0_train = torch.tensor([0 for i in range(len(y_train))])
z0_train = z0_train.to(device)
z0_train = proj_z_str(X_train_moments, merit_attrs, z0_train, X_train, y_train, sens_train, flag, epsilon=epsilon, delta=delta, merit=merit) 

model = LogisticRegressionModel(num_feat,1) 
model = model.to(device)  
opt = optim.Adam(model.parameters(), lr=lr_theta)

for _ in range(num_epochs): 
    train_err, train_loss, z = epoch_manual_str(train_loader, model, z0_train, X_train_moments, X_train, y_train, sens_train, flag, device, merit_attrs, delta=delta, epsilon=epsilon, lr_z=lr_z, opt=opt, merit=merit)
    z0_train = z 
    print(*("{:.6f}".format(i) for i in (train_err, train_loss)), sep="\t")

Academic license - for non-commercial use only - expires 2024-10-29
Using license file /Users/thodoris/gurobi.lic


100%|██████████| 227/227 [00:00<00:00, 650.29it/s]

Academic license - for non-commercial use only - expires 2024-10-29
Using license file /Users/thodoris/gurobi.lic



 34%|███▍      | 78/227 [00:00<00:00, 777.34it/s]

0.111607	0.502150


100%|██████████| 227/227 [00:00<00:00, 726.66it/s]

Academic license - for non-commercial use only - expires 2024-10-29
Using license file /Users/thodoris/gurobi.lic



 36%|███▌      | 81/227 [00:00<00:00, 805.22it/s]

0.111470	0.338118


100%|██████████| 227/227 [00:00<00:00, 789.31it/s]

Academic license - for non-commercial use only - expires 2024-10-29
Using license file /Users/thodoris/gurobi.lic



 35%|███▌      | 80/227 [00:00<00:00, 792.16it/s]

0.107763	0.309012


100%|██████████| 227/227 [00:00<00:00, 797.45it/s]

Academic license - for non-commercial use only - expires 2024-10-29
Using license file /Users/thodoris/gurobi.lic



 35%|███▌      | 80/227 [00:00<00:00, 793.49it/s]

0.106047	0.300435


100%|██████████| 227/227 [00:00<00:00, 797.78it/s]

Academic license - for non-commercial use only - expires 2024-10-29
Using license file /Users/thodoris/gurobi.lic



 30%|██▉       | 68/227 [00:00<00:00, 679.38it/s]

0.103645	0.293615


100%|██████████| 227/227 [00:00<00:00, 676.92it/s]

Academic license - for non-commercial use only - expires 2024-10-29
Using license file /Users/thodoris/gurobi.lic



 36%|███▌      | 82/227 [00:00<00:00, 815.36it/s]

0.104331	0.293775


100%|██████████| 227/227 [00:00<00:00, 786.01it/s]

Academic license - for non-commercial use only - expires 2024-10-29
Using license file /Users/thodoris/gurobi.lic



 71%|███████▏  | 162/227 [00:00<00:00, 795.59it/s]

0.103576	0.293541


100%|██████████| 227/227 [00:00<00:00, 801.76it/s]

Academic license - for non-commercial use only - expires 2024-10-29
Using license file /Users/thodoris/gurobi.lic



 67%|██████▋   | 152/227 [00:00<00:00, 726.28it/s]

0.103096	0.284328


100%|██████████| 227/227 [00:00<00:00, 743.61it/s]

Academic license - for non-commercial use only - expires 2024-10-29
Using license file /Users/thodoris/gurobi.lic



 72%|███████▏  | 163/227 [00:00<00:00, 815.66it/s]

0.102341	0.289661


100%|██████████| 227/227 [00:00<00:00, 810.31it/s]

Academic license - for non-commercial use only - expires 2024-10-29
Using license file /Users/thodoris/gurobi.lic



 72%|███████▏  | 163/227 [00:00<00:00, 806.54it/s]

0.102135	0.286079


100%|██████████| 227/227 [00:00<00:00, 783.90it/s]

Academic license - for non-commercial use only - expires 2024-10-29
Using license file /Users/thodoris/gurobi.lic



 34%|███▍      | 77/227 [00:00<00:00, 763.04it/s]

0.101448	0.285887


100%|██████████| 227/227 [00:00<00:00, 742.72it/s]

Academic license - for non-commercial use only - expires 2024-10-29
Using license file /Users/thodoris/gurobi.lic



 65%|██████▍   | 147/227 [00:00<00:00, 731.95it/s]

0.102752	0.277529


100%|██████████| 227/227 [00:00<00:00, 720.07it/s]

Academic license - for non-commercial use only - expires 2024-10-29
Using license file /Users/thodoris/gurobi.lic



 32%|███▏      | 73/227 [00:00<00:00, 723.73it/s]

0.100831	0.274146


100%|██████████| 227/227 [00:00<00:00, 721.43it/s]

Academic license - for non-commercial use only - expires 2024-10-29
Using license file /Users/thodoris/gurobi.lic



 67%|██████▋   | 151/227 [00:00<00:00, 746.57it/s]

0.101929	0.281984


100%|██████████| 227/227 [00:00<00:00, 758.23it/s]


Academic license - for non-commercial use only - expires 2024-10-29
Using license file /Users/thodoris/gurobi.lic


 31%|███▏      | 71/227 [00:00<00:00, 704.24it/s]

0.101242	0.283096


100%|██████████| 227/227 [00:00<00:00, 708.06it/s]

Academic license - for non-commercial use only - expires 2024-10-29
Using license file /Users/thodoris/gurobi.lic



 35%|███▌      | 80/227 [00:00<00:00, 790.78it/s]

0.099870	0.283775


100%|██████████| 227/227 [00:00<00:00, 776.34it/s]

Academic license - for non-commercial use only - expires 2024-10-29
Using license file /Users/thodoris/gurobi.lic



 68%|██████▊   | 154/227 [00:00<00:00, 768.42it/s]

0.100831	0.277944


100%|██████████| 227/227 [00:00<00:00, 763.30it/s]

Academic license - for non-commercial use only - expires 2024-10-29





Using license file /Users/thodoris/gurobi.lic


 67%|██████▋   | 151/227 [00:00<00:00, 739.81it/s]

0.101997	0.271749


100%|██████████| 227/227 [00:00<00:00, 757.92it/s]


Academic license - for non-commercial use only - expires 2024-10-29
Using license file /Users/thodoris/gurobi.lic


 33%|███▎      | 75/227 [00:00<00:00, 744.53it/s]

0.099046	0.280307


100%|██████████| 227/227 [00:00<00:00, 747.87it/s]

Academic license - for non-commercial use only - expires 2024-10-29
Using license file /Users/thodoris/gurobi.lic



 69%|██████▊   | 156/227 [00:00<00:00, 771.89it/s]

0.097536	0.268059


100%|██████████| 227/227 [00:00<00:00, 779.92it/s]

Academic license - for non-commercial use only - expires 2024-10-29
Using license file /Users/thodoris/gurobi.lic



 65%|██████▍   | 147/227 [00:00<00:00, 737.65it/s]

0.099389	0.281218


100%|██████████| 227/227 [00:00<00:00, 725.14it/s]

Academic license - for non-commercial use only - expires 2024-10-29
Using license file /Users/thodoris/gurobi.lic



 33%|███▎      | 76/227 [00:00<00:00, 757.53it/s]

0.097879	0.269348


100%|██████████| 227/227 [00:00<00:00, 699.13it/s]

Academic license - for non-commercial use only - expires 2024-10-29
Using license file /Users/thodoris/gurobi.lic



 33%|███▎      | 76/227 [00:00<00:00, 758.01it/s]

0.098565	0.269745


100%|██████████| 227/227 [00:00<00:00, 760.22it/s]


Academic license - for non-commercial use only - expires 2024-10-29
Using license file /Users/thodoris/gurobi.lic


 29%|██▊       | 65/227 [00:00<00:00, 640.03it/s]

0.100076	0.280932


100%|██████████| 227/227 [00:00<00:00, 712.63it/s]

Academic license - for non-commercial use only - expires 2024-10-29
Using license file /Users/thodoris/gurobi.lic



 35%|███▍      | 79/227 [00:00<00:00, 784.08it/s]

0.097261	0.272227


100%|██████████| 227/227 [00:00<00:00, 774.61it/s]

Academic license - for non-commercial use only - expires 2024-10-29
Using license file /Users/thodoris/gurobi.lic



 68%|██████▊   | 154/227 [00:00<00:00, 762.14it/s]

0.098154	0.275918


100%|██████████| 227/227 [00:00<00:00, 773.30it/s]

Academic license - for non-commercial use only - expires 2024-10-29
Using license file /Users/thodoris/gurobi.lic



 68%|██████▊   | 154/227 [00:00<00:00, 763.54it/s]

0.098840	0.277364


100%|██████████| 227/227 [00:00<00:00, 765.62it/s]

Academic license - for non-commercial use only - expires 2024-10-29
Using license file /Users/thodoris/gurobi.lic



 33%|███▎      | 75/227 [00:00<00:00, 744.73it/s]

0.098565	0.265629


100%|██████████| 227/227 [00:00<00:00, 762.76it/s]

Academic license - for non-commercial use only - expires 2024-10-29





Using license file /Users/thodoris/gurobi.lic


 67%|██████▋   | 152/227 [00:00<00:00, 747.00it/s]

0.097742	0.262100


100%|██████████| 227/227 [00:00<00:00, 754.44it/s]

Academic license - for non-commercial use only - expires 2024-10-29
Using license file /Users/thodoris/gurobi.lic





0.098565	0.268232


In [8]:
 # Evaluate Model
model.eval()
yp = get_preds(test_loader, model, device)
results_fairness, results_merit = metrics_str(merit_attrs, columns, X_test, y_test, sens_test, yp)

100%|██████████| 98/98 [00:00<00:00, 1101.09it/s]


In [9]:
results_fairness

Unnamed: 0,Acc_Total,SPD,EOD,DEO
0,0.8989,0.1003,0.0338,0.0852


In [10]:
results_merit

Unnamed: 0,lsat,ugpa
0,0.1044,0.1262
