In [79]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from sklearn.model_selection import train_test_split
os.environ["CUDA_VISIBLE_DEVICES"] = "2,3,0" #list the gpu cores in the order you want to use e.g "0,1,2,3"
cpu = torch.device('cpu')#"cuda:0" if torch.cuda.is_available() else "cpu")
dv=torch.device("cuda")

import os
import torch.tensor as tensor
os.chdir("/yourpath/ACS_PUMS")


## Loading Train and Validation data

In [105]:
import folktables
from folktables import ACSDataSource

data_source = ACSDataSource(survey_year='2018', horizon='1-Year', survey='person')
acs_data = data_source.get_data( download=True)

In [4]:
features=[
        'AGEP',
        'SCHL',
        'MAR',
        #'RELSHIPP',#'RELP',
        'DIS',
        'ESP',
        'CIT',
        'MIG',
        'MIL',
        'ANC',
        'NATIVITY',
        'DEAR',
        'DEYE',
        'DREM',
        'SEX',
        'RAC1P',
        'PUMA',
        'ST',
        'OCCP',
        'JWTR',#use 'JWTRNS' for testing (2019) data for training (2018) data the feature is 'JWTR',#
        'POWPUMA',
    ]

In [5]:
Employment = folktables.BasicProblem(
     features=features,
    target='ESR',
    target_transform=lambda x: x == 1,
    group='SEX',
    preprocess=folktables.acs.adult_filter,
    postprocess=lambda x: np.nan_to_num(x, -1),
)

In [6]:
Income = folktables.BasicProblem(
     features=features,
    target='PINCP',
    target_transform=lambda x: x > 50000,
    group='SEX',
    preprocess=folktables.acs.adult_filter,
    postprocess=lambda x: np.nan_to_num(x, -1),
)

In [7]:
HealthInsurance = folktables.BasicProblem(
     features=features,
    target='HINS2',
    target_transform=lambda x: x == 1,
    group='SEX',
    preprocess=folktables.acs.adult_filter,
    postprocess=lambda x: np.nan_to_num(x, -1),
)

In [8]:
TravelTime = folktables.BasicProblem(
     features=features,
    target="JWMNP",
    target_transform=lambda x: x > 20,
    group='SEX',
    preprocess=folktables.acs.adult_filter,
    postprocess=lambda x: np.nan_to_num(x, -1),
)

In [9]:
IncomePovertyRatio = folktables.BasicProblem(
    features=features,
    target='POVPIP',
    target_transform=lambda x: x < 250,
    group='SEX',
    preprocess=folktables.acs.adult_filter,
    postprocess=lambda x: np.nan_to_num(x, -1),
)

In [10]:
f, l1, g = Employment.df_to_numpy(acs_data)
f, l2, g = Income.df_to_numpy(acs_data)

f, l3, g = HealthInsurance.df_to_numpy(acs_data)
f, l4, g = TravelTime.df_to_numpy(acs_data)
f, l5, g = IncomePovertyRatio.df_to_numpy(acs_data)

In [11]:
y=np.array([[0 if v==False else 1 for v in l1],[0 if v==False else 1 for v in l2],[0 if v==False else 1 for v in l3],\
           [0 if v==False else 1 for v in l4],[0 if v==False else 1 for v in l5]])

In [34]:
ids=np.arange(len(f))
X_train, X_val,in_tr,in_val  = train_test_split(f,ids, test_size=0.3,random_state=9)

In [37]:
y_train,y_v=[y[i][in_tr] for i in range(len(y))],[y[i][in_val] for i in range(len(y))]
g_train=g[in_tr]
g_val=g[in_val]
N_tasks=len(y)

5

In [38]:
y_train=[torch.tensor(y_train[i]) for i in range(N_tasks)]

## Build STL model

In [48]:
class STL(nn.Module):

    def __init__(self,d_in=50):
        super(STL, self).__init__()
        self.fc1 = nn.Linear(d_in, 1024)  
        self.bn1= nn.BatchNorm1d(1024)
        self.fc2 = nn.Linear(1024, 512)
        self.bn2 = nn.BatchNorm1d(512)
        self.fc3 = nn.Linear(512, 128)
        
        self.bn3 = nn.BatchNorm1d(1024)
        self.fc4 = nn.Linear(1024,1024)        
        self.task = nn.Linear(128,2)


    def forward(self, x):
        x = F.relu(self.bn1(self.fc1(x)))
        
        x = F.relu(self.bn3(self.fc4(x)))
        
        x = F.relu(self.bn2(self.fc2(x)))        
        x = F.relu(self.fc3(x))
        t = self.task(x)
       
        return t

In [49]:
def fair_loss(output,target,x_control):
    prot_att=x_control
    index_prot=torch.squeeze(torch.nonzero(prot_att[:] != 1.))
    target_prot=torch.index_select(target, 0, index=index_prot)
    index_prot_pos=torch.squeeze(torch.nonzero(target_prot[:] == 1. ))
    index_prot_neg=torch.squeeze(torch.nonzero(target_prot[:] == 0. ))

    index_non_prot=torch.squeeze(torch.nonzero(prot_att[:] == 1.))
    target_non_prot=torch.index_select(target, 0, index=index_non_prot)
    index_non_prot_pos=torch.squeeze(torch.nonzero(target_non_prot[:] == 1. ))
    index_non_prot_neg=torch.squeeze(torch.nonzero(target_non_prot[:] == 0. ))

    l_prot_pos=F.cross_entropy(torch.index_select(output, 0, index=index_prot_pos),torch.index_select(target, 0, index=index_prot_pos))    
    l_non_prot_pos=F.cross_entropy(torch.index_select(output, 0, index=index_non_prot_pos),torch.index_select(target, 0, index=index_non_prot_pos))    
    l_non_prot_neg=F.cross_entropy(torch.index_select(output, 0, index=index_non_prot_neg),torch.index_select(target, 0, index=index_non_prot_neg))
    l_prot_neg=F.cross_entropy(torch.index_select(output, 0, index=index_prot_neg),torch.index_select(target, 0, index=index_prot_neg))    

    dl_pos=torch.max(l_prot_pos,l_non_prot_pos)
    dl_neg=torch.max(l_prot_neg,l_non_prot_neg)
    L=dl_pos+dl_neg
    
    return L

In [50]:
import torchmetrics

In [51]:


acc = torchmetrics.Accuracy()
def DM_rate(output,target,x_control):
    prot_att=x_control
    index_prot=torch.squeeze(torch.nonzero(prot_att[:] != 1.))
    target_prot=torch.index_select(target, 0, index=index_prot)
    index_prot_pos=torch.squeeze(torch.nonzero(target_prot[:] == 1. ))
    index_prot_neg=torch.squeeze(torch.nonzero(target_prot[:] == 0. ))

    index_non_prot=torch.squeeze(torch.nonzero(prot_att[:] == 1.))
    target_non_prot=torch.index_select(target, 0, index=index_non_prot)
    index_non_prot_pos=torch.squeeze(torch.nonzero(target_non_prot[:] == 1. ))
    index_non_prot_neg=torch.squeeze(torch.nonzero(target_non_prot[:] == 0. ))

    l_prot_pos=acc(torch.index_select(output, 0, index=index_prot_pos),torch.index_select(target, 0, index=index_prot_pos))    
    l_non_prot_pos=acc(torch.index_select(output, 0, index=index_non_prot_pos),torch.index_select(target, 0, index=index_non_prot_pos))    
    l_non_prot_neg=acc(torch.index_select(output, 0, index=index_non_prot_neg),torch.index_select(target, 0, index=index_non_prot_neg))
    l_prot_neg=acc(torch.index_select(output, 0, index=index_prot_neg),torch.index_select(target, 0, index=index_prot_neg))    

    dl_pos=torch.abs(l_prot_pos-l_non_prot_pos)
    dl_neg=torch.abs(l_prot_neg-l_non_prot_neg)
    DM=dl_pos+dl_neg
    
    return DM

In [46]:
SLs=[nn.DataParallel(STL(d_in=X_train.shape[1])).to(dv) for t in range(N_tasks)]
SL_optis=[optim.AdamW(SLs[t].parameters()) for t in range(N_tasks)]
spaths={'path'+str(t):'/home/roy/ACS_PUMS/model/Model_stl'+str(t)+'.pt' for t in range(N_tasks)}

## Train Model

In [None]:
criteria = nn.CrossEntropyLoss()
best_S=[[0,1] for t in range(N_tasks)]
All_S=[[] for t in range(N_tasks)]
SL_E=[[] for t in range(N_tasks)]

for epoch in range(50):  # loop over each NN multiple times

    i,batch=0,8192
    j=0
    while(i<len(X_train)):
        # get the inputs; data is a list of [inputs, labels]
        if (i+batch)<len(X_train):
            inputs, in_t = torch.tensor(X_train[i:i+batch]),in_tr[i:i+batch]
            labels=[y_train[t][i:i+batch] for t in range(N_tasks)]#,y3_train[i:i+batch]]
            #if epoch<pretrn:
            xc=xg[i:i+batch]
            i=i+batch 
        else:
            inputs,in_t = torch.tensor(X_train[i:]),in_tr[i:]
            labels=[y_train[t][i:] for t in range(N_tasks)]#,y2_train[i:]]#,y3_train[i:]]
            #if epoch<pretrn:
            xc=xg[i:]
            i=len(X_train)
        
        # zero the parameter gradients
        for t in range(N_tasks):
            ##training STLs
            SL_optis[t].zero_grad()
            out=SLs[t](inputs.to(dv).float())
            loss_Sa=criteria(out, labels[t].to(dv))
            loss_Sf=fair_loss(out, labels[t].to(dv),xc.to(dv))
            SL_E[t].append([loss_Sa,loss_Sf])
            loss=loss_Sa+loss_Sf
            loss.backward()
            SL_optis[t].step()
            
            
            
            
    with torch.no_grad():
       
        for t in range(N_tasks): 
            pred_SL=SLs[t](torch.tensor(X_val).to(dv).float())
            SL_acc=acc(pred_SL.to(cpu),torch.tensor(y_v[t]).to(cpu))
            SL_eo=DM_rate(pred_SL.to(cpu),torch.tensor(y_v[t]).to(cpu),torch.tensor(g_val).to(cpu))
            All_S[t].append([SL_acc,SL_eo])
            if SL_acc>best_S[t][0]:
                best_S[t][0]=SL_acc
                best_S[t][1]=SL_eo
                torch.save(SLs[t].state_dict(),spaths['path'+str(t)])
        