In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

## Fairness Metrics

### 1. Equal Accuracy

In [None]:
"""
Input: 
    y_truth: ground truth
    y_pred: predictions
    label: name of the ground truth label (e.g., openness)
    sen_att: sensitive attribute (list of values)
    sen_type: name of the sensitive attribute of insterest (e.g., gender, ethnicity)
    verbose: 1: show plots and Intermediate steps; 0: don't print any Intermediate steps
Output: 
    dis_mae: dictionary of the MAE distance between group pairs. 
             key: names of demographic groups; value: MAE distance
"""
def acc_compare(y_truth, y_pred, label, sen_att, sen_type, verbose=True):
    genders={1:'Male',2:'Female'}
    races={1:'Asian',2:'Caucasian',3:'African-American'}
    
    print("=======%s======="%label)
    if(verbose == True):
        print ("R2 on Test: %f"%r2_score(y_truth,y_pred))
        print ("Corr. on Test: ",pearsonr(y_truth,y_pred))
        print ("MSE on Test: %f"%mean_squared_error(y_truth,y_pred))
        print ("MAE on Test: %f"%mean_absolute_error(y_truth,y_pred))
        plt.figure()
        plt.scatter(y_truth,y_pred, c="k",alpha=0.4)
        plt.xlabel('Ground Truth')
        plt.ylabel('Prediction')
        plt.title('Test: %s'%label)
        plt.show()
    
    #get the sensitive attribute info
    colors=['b','r','g']
    groups=list(set(sen_att))
    if(sen_type=='Gender'):
        g_names=genders
    elif(sen_type=='Ethnicity'):
        g_names=races
        
    errs=[]
    maes=[]
    for g in groups:
        y_t_g=[y_truth[i] for i in range(len(y_truth)) if sen_att[i]==g]
        y_p_g=[y_pred[i] for i in range(len(y_pred)) if sen_att[i]==g]
        if(verbose == True):
            print("----%s----"%g_names[g])        
            print ("R2 on Test: %f"%r2_score(y_t_g,y_p_g))
            print ("Corr. on Test: ",pearsonr(y_t_g,y_p_g))
            print ("MSE on Test: %f"%mean_squared_error(y_t_g,y_p_g))
            print ("MAE on Test: %f"%mean_absolute_error(y_t_g,y_p_g))
            plt.figure()
            plt.scatter(y_t_g,y_p_g, c=colors[g-1],alpha=0.4)
            plt.xlabel('Ground Truth')
            plt.ylabel('Prediction')
            plt.title('Test: %s'%label)
            plt.show()
        errs.append([y_p_g[i]-y_t_g[i] for i in range(len(y_p_g))])
        maes.append(mean_absolute_error(y_t_g,y_p_g))

    if(verbose == True):
        for i in range(len(errs)):
            sns.distplot(errs[i],color=colors[i])
        plt.title('Error dist: %s'%label)
        plt.show()

        print("T-test")
        for i in range(len(errs)):
            g1=groups[i]
            for j in range(i+1,len(errs)):
                g2=groups[j]
                print("%s-%s: "%(g_names[g1],g_names[g2]),ttest_ind(errs[i],errs[j]))

    print("MAE differences")
    dis_mae={}
    for i in range(len(maes)):
        g1=groups[i]
        for j in range(i+1,len(maes)):
            g2=groups[j]
            t,p=ttest_ind(errs[i],errs[j])
            print("%s-%s: "%(g_names[g1],g_names[g2]),maes[i]-maes[j])
            if(p<0.05):
                print("Significant! ",p)
            dis_mae.setdefault((g_names[g1],g_names[g2]), maes[i]-maes[j])
    
    return dis_mae

### 2. Mutual Information (MI) Gain
Compare the MI increase in predictions and ground truth labels

In [1]:
def sp_mi(y_truth, y_pred, label, sen_att, sen_type):
    y_truth=np.array(y_truth).reshape(-1, 1)
    y_pred=np.array(y_pred).reshape(-1, 1)
    return mutual_info_classif(y_pred, sen_att, n_neighbors=5, random_state=0)[0]- mutual_info_classif(y_truth, sen_att, n_neighbors=5, random_state=0)[0]

## Adversarial Model

In [None]:
class Net(nn.Module):
    def __init__(self, disc_n):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(feat_ed-feat_st,40, bias=True)
        # self.fc1 = nn.Linear(81061, 40, bias=True)
        self.fc2 = nn.Linear(40, 20)
        self.discr_linear = nn.Linear(20, disc_n, bias=True)
        self.pred_linear =  nn.Linear(20,1, bias = True)
        

    def forward(self, x):
        #filter
        h = self.fc1(x)
        h = F.relu(h)
        h = self.fc2(h)
        filtered = F.relu(h)    

        ## discriminator
        discr = self.discr_linear(filtered)

        ## pred
        pred = self.pred_linear(filtered).squeeze()

        return discr, pred

In [None]:
## statistics of FI dataset
feat_dim = 81061
if modality_id == 0:
    feat_st = 0
    feat_ed = 70592
elif modality_id == 1:
    feat_st = 70592
    feat_ed = feat_dim
    
features = torch.from_numpy(d[:,feat_st:feat_ed]).float()
if disc_type  == 'Ethnicity':
    disc_n = 3
    class_labels = torch.from_numpy(d[:, feat_dim]).long() - 1
elif disc_type == 'Gender':
    disc_n = 2
    class_labels = torch.from_numpy(d[:, feat_dim+1]).long() - 1
### predict label seperately
regress_labels = torch.from_numpy(d[:, pred_type_indx-6]).float()

In [None]:
train_idx, test_idx = np.arange(8000), np.arange(8000,10000)[held_out_idx]
np.random.shuffle(train_idx)

train_batches = np.array_split(train_idx, len(train_idx) / batch_size)
test_features, test_class_labels, test_regress_labels = features[test_idx], class_labels[test_idx], regress_labels[test_idx]
alpha_disc = 0.0001
batch_size = 500

In [None]:
def train(features, class_labels, regress_labels, train_batches, batch_size, alpha_disc):
    model = Net(disc_n)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0)
    cross_entropy = nn.CrossEntropyLoss()
    mse = nn.MSELoss()
    ## MAE is
    mae = nn.L1Loss()


    ## train the network
    for epoch in range(15):  # loop over the dataset multiple times

        model.train()

        for g in optimizer.param_groups:
            g['lr'] = g['lr']/10

        for i, b in enumerate(train_batches):
            batch_features, b_class_labels, b_regress_labels = features[b], class_labels[b], regress_labels[b]

            discr, pred = model(batch_features)
            discr_loss, pred_loss = cross_entropy(discr, b_class_labels), mse(pred, b_regress_labels)
            loss = -discr_loss*alpha_disc + pred_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    return model          