In [1]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim 
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader, random_split
import gurobipy as gp
from gurobipy import GRB
from tqdm import tqdm 
from datasets import all_data 
from helper import *  
from models import LogisticRegressionModel

In [2]:
merit_attrs_dict = {'lsac': [2,3], 'crime': [54,61], 'compas': [395,399]}

In [3]:
dataset_name = 'lsac'
sens_attribute = "Race"
nominal = False
merit = True

In [4]:
train_x = pd.read_csv('train_X_'+dataset_name+'.csv')
train_x.columns

Index(['decile1b', 'decile3', 'lsat', 'ugpa', 'zfygpa', 'zgpa', 'fulltime',
       'fam_inc', 'male', 'tier', 'race'],
      dtype='object')

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
# set parameters 
num_epochs = 20 
lr_theta = 1e-3 
lr_z = 1e-1 
epsilon = 0.01 
delta = 0.2 

In [7]:
dataset = all_data(sens_attribute=sens_attribute, train=True, dataset_name=dataset_name)
merit_attrs = merit_attrs_dict[dataset_name] 
train_loader = DataLoader(dataset, batch_size=64, shuffle=True, drop_last=True)
train_X_1, train_sens_1, train_y_1 = dataset.get_values()
flag = dataset.get_flag()
train_x_bar = dataset.get_xbar()
dataset_features = dataset.get_columns()
merit_attrs_strs = [dataset_features[i] for i in merit_attrs]

In [8]:
X, s, y, idx = next(iter(train_loader))
num_feat = X.shape[1] 
model = LogisticRegressionModel(num_feat,1) 
model = model.to(device) 
model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 
opt = optim.Adam(model.parameters(), lr=lr_theta)

In [9]:
z_list = []
z0_train = torch.tensor([0 for i in range(len(train_y_1))])
z0_train = z0_train.to(device)
train_X_1 = torch.Tensor(train_X_1)
train_sens_1 = torch.Tensor(list(train_sens_1))
train_y_1 = torch.Tensor(list(train_y_1))
z0_train = proj_z_structured(z0_train, train_sens_1, train_y_1, train_X_1, train_x_bar, flag, epsilon, delta, merit_attrs)

Academic license - for non-commercial use only - expires 2023-10-25
Using license file /Users/thodoris/gurobi.lic


In [11]:
for _ in range(num_epochs):
      if nominal:
          train_err, train_loss = epoch_nominal(train_loader, model, opt=opt)
      else:
          train_err, train_loss, z = epoch_flipped_structured(train_loader, model, z0_train, train_sens_1, train_y_1, train_X_1, train_x_bar, flag, merit_attrs, epsilon=epsilon, delta=delta, lr_z=lr_z, opt=opt)
          z0_train = z

      print(*("{:.6f}".format(i) for i in (train_err, train_loss)), sep="\t")

state = {
          'epoch': num_epochs,
          'state_dict': model.state_dict(),
          'optimizer': opt.state_dict(),
       }

if nominal: 
    nom_str = 'nominal'
else: 
    nom_str = 'flipped'

if sens_attribute == 'Gender':
    sens_str = 'gender'
else:
    sens_str = 'race'

savepath = dataset_name+'_'+sens_str+'_checkpoint_'+nom_str+'.pth'
torch.save(state,savepath)

100%|██████████| 227/227 [00:00<00:00, 570.40it/s]

Academic license - for non-commercial use only - expires 2023-10-25





Using license file /Users/thodoris/gurobi.lic


 69%|██████▊   | 156/227 [00:00<00:00, 766.24it/s]

0.112911	0.326617


100%|██████████| 227/227 [00:00<00:00, 782.29it/s]

Academic license - for non-commercial use only - expires 2023-10-25
Using license file /Users/thodoris/gurobi.lic



 30%|██▉       | 67/227 [00:00<00:00, 660.00it/s]

0.111401	0.315585


100%|██████████| 227/227 [00:00<00:00, 600.64it/s]

Academic license - for non-commercial use only - expires 2023-10-25
Using license file /Users/thodoris/gurobi.lic



 65%|██████▌   | 148/227 [00:00<00:00, 728.47it/s]

0.110989	0.318277


100%|██████████| 227/227 [00:00<00:00, 750.50it/s]

Academic license - for non-commercial use only - expires 2023-10-25
Using license file /Users/thodoris/gurobi.lic



 28%|██▊       | 63/227 [00:00<00:00, 626.98it/s]

0.109822	0.301646


100%|██████████| 227/227 [00:00<00:00, 674.70it/s]

Academic license - for non-commercial use only - expires 2023-10-25
Using license file /Users/thodoris/gurobi.lic



 65%|██████▍   | 147/227 [00:00<00:00, 714.64it/s]

0.109548	0.306887


100%|██████████| 227/227 [00:00<00:00, 717.15it/s]

Academic license - for non-commercial use only - expires 2023-10-25
Using license file /Users/thodoris/gurobi.lic



 65%|██████▌   | 148/227 [00:00<00:00, 721.64it/s]

0.109822	0.308753


100%|██████████| 227/227 [00:00<00:00, 731.17it/s]

Academic license - for non-commercial use only - expires 2023-10-25
Using license file /Users/thodoris/gurobi.lic



 27%|██▋       | 62/227 [00:00<00:00, 614.48it/s]

0.109410	0.308870


100%|██████████| 227/227 [00:00<00:00, 612.77it/s]

Academic license - for non-commercial use only - expires 2023-10-25
Using license file /Users/thodoris/gurobi.lic



 27%|██▋       | 61/227 [00:00<00:00, 606.94it/s]

0.109479	0.307333


100%|██████████| 227/227 [00:00<00:00, 613.69it/s]

Academic license - for non-commercial use only - expires 2023-10-25
Using license file /Users/thodoris/gurobi.lic



 64%|██████▍   | 145/227 [00:00<00:00, 707.11it/s]

0.108793	0.308788


100%|██████████| 227/227 [00:00<00:00, 713.81it/s]

Academic license - for non-commercial use only - expires 2023-10-25
Using license file /Users/thodoris/gurobi.lic



 28%|██▊       | 64/227 [00:00<00:00, 638.42it/s]

0.108175	0.298657


100%|██████████| 227/227 [00:00<00:00, 602.05it/s]

Academic license - for non-commercial use only - expires 2023-10-25
Using license file /Users/thodoris/gurobi.lic



 65%|██████▌   | 148/227 [00:00<00:00, 719.60it/s]

0.107763	0.305692


100%|██████████| 227/227 [00:00<00:00, 741.59it/s]

Academic license - for non-commercial use only - expires 2023-10-25
Using license file /Users/thodoris/gurobi.lic



 32%|███▏      | 73/227 [00:00<00:00, 725.03it/s]

0.107763	0.302287


100%|██████████| 227/227 [00:00<00:00, 743.08it/s]

Academic license - for non-commercial use only - expires 2023-10-25
Using license file /Users/thodoris/gurobi.lic



 22%|██▏       | 50/227 [00:00<00:00, 492.33it/s]

0.105429	0.296216


100%|██████████| 227/227 [00:00<00:00, 549.18it/s]

Academic license - for non-commercial use only - expires 2023-10-25
Using license file /Users/thodoris/gurobi.lic



 28%|██▊       | 64/227 [00:00<00:00, 636.36it/s]

0.107489	0.302165


100%|██████████| 227/227 [00:00<00:00, 670.24it/s]

Academic license - for non-commercial use only - expires 2023-10-25
Using license file /Users/thodoris/gurobi.lic



 54%|█████▍    | 123/227 [00:00<00:00, 628.80it/s]

0.106939	0.302068


100%|██████████| 227/227 [00:00<00:00, 617.83it/s]

Academic license - for non-commercial use only - expires 2023-10-25
Using license file /Users/thodoris/gurobi.lic



 69%|██████▉   | 157/227 [00:00<00:00, 784.16it/s]

0.106459	0.305194


100%|██████████| 227/227 [00:00<00:00, 780.23it/s]

Academic license - for non-commercial use only - expires 2023-10-25
Using license file /Users/thodoris/gurobi.lic



 34%|███▍      | 78/227 [00:00<00:00, 769.97it/s]

0.107420	0.308136


100%|██████████| 227/227 [00:00<00:00, 779.28it/s]

Academic license - for non-commercial use only - expires 2023-10-25
Using license file /Users/thodoris/gurobi.lic



 35%|███▌      | 80/227 [00:00<00:00, 798.09it/s]

0.107626	0.308057


100%|██████████| 227/227 [00:00<00:00, 786.88it/s]

Academic license - for non-commercial use only - expires 2023-10-25
Using license file /Users/thodoris/gurobi.lic



 67%|██████▋   | 151/227 [00:00<00:00, 729.72it/s]

0.108381	0.303356


100%|██████████| 227/227 [00:00<00:00, 730.36it/s]

Academic license - for non-commercial use only - expires 2023-10-25
Using license file /Users/thodoris/gurobi.lic





0.107489	0.304616
