# Necessities

In [1]:
# ------------------------------------------------------------------
# Imports
# ------------------------------------------------------------------
# Basic data processing libraries
import pandas as pd
import numpy as np
import os
import torch

# Graph data processing libraries
import networkx as nx
from torch_geometric.data import Data
from torch_geometric.utils import from_networkx

# Libraries for (G)NNs
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
import torch.nn as nn
from sklearn.metrics import roc_auc_score, f1_score

# ------------------------------------------------------------------
# Helper functions
# ------------------------------------------------------------------
def show_df_info(df):
    print(df.info())
    print('####### Repeat ####### \n', df.duplicated().any())
    print('####### Count ####### \n', df.nunique())
    print('####### Example ####### \n',df.head())

def label_statics(label_df, label_list):
    print("####### nCount #######")
    for label in label_list:
        print(label_df[label].value_counts())
    print("####### nPercent #######")
    for label in label_list:
        print(label_df[label].value_counts()/label_df.shape[0])

# ------------------------------------------------------------------
# Data stuff
# ------------------------------------------------------------------
base_path = os.getcwd()
input_ali_data_path = os.path.join(base_path, "input_ali_data")

# Load the data files
user_labels_path = os.path.join(input_ali_data_path, "user_labels.csv")
user_edges_path = os.path.join(input_ali_data_path, "user_edge.csv")

# Create dataframes to store the information from the .csv files
user_labels = pd.read_csv(user_labels_path)
user_edges = pd.read_csv(user_edges_path)

# Prepare the data for GNNs
node_features = torch.tensor(user_labels.iloc[:, 1:].values, dtype=torch.float)
edge_index = torch.tensor(user_edges.values, dtype=torch.long).t().contiguous()

# Create torch-geometric data
data = Data(x=node_features, edge_index=edge_index)

num_nodes = node_features.size(0)
num_classes = 2 
num_node_features = data.num_node_features

# Create masks for training, and testing
train_mask = torch.zeros(num_nodes, dtype=torch.bool)
test_mask = torch.zeros(num_nodes, dtype=torch.bool)
val_mask = torch.zeros(num_nodes, dtype=torch.bool)

# 60-20-20 Train and Test data split
num_train = int(num_nodes * 0.6)
num_val = int(num_nodes * 0.8)
train_mask[:num_train] = True
val_mask[num_train:num_val] = True
test_mask[num_val:] = True

data.train_mask = train_mask
data.test_mask = test_mask
data.val_mask = val_mask

# Labels from the data (in this case: Job Classification)
data.y = torch.tensor(user_labels['gender'].values, dtype=torch.long)


# ------------------------------------------------------------------
# Set Device
# ------------------------------------------------------------------

def set_device():
    return torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

# ------------------------------------------------------------------
# Loss
# ------------------------------------------------------------------

def fairness_aware_loss(output, data, sensitive_attr, alpha=0, beta=0, gamma=0, delta=0):
    target = data.y[data.train_mask]
    # standard_loss = F.cross_entropy(output, target)
    standard_loss = F.nll_loss(output, target)

    labels = data.y[train_mask]
    pos_prob = torch.exp(output[:, 1])
    neg_prob = torch.exp(output[:, 0])
    # pos_prob = torch.sigmoid(output[:, 1])
    # neg_prob = 1 - pos_prob
    predictions = output.argmax(dim=1)

    # Statistical Parity Regularization
    sp_reg = torch.abs(pos_prob[sensitive_attr == 1].mean() - pos_prob[sensitive_attr == 0].mean())

    # # Calculating FPR and TPR for each group
    # fpr_group1 = ((predictions == 1) & (labels == 0) & (sensitive_attr == 1)).float().mean()
    # fpr_group0 = ((predictions == 1) & (labels == 0) & (sensitive_attr == 0)).float().mean()
    # tpr_group1 = ((predictions == 1) & (labels == 1) & (sensitive_attr == 1)).float().mean()
    # tpr_group0 = ((predictions == 1) & (labels == 1) & (sensitive_attr == 0)).float().mean()

    # Treatment Equality Regularization
    fp_diff = (neg_prob * (labels == 0) * (sensitive_attr == 1)).float().mean() - \
              (neg_prob * (labels == 0) * (sensitive_attr == 0)).float().mean()
    fn_diff = (pos_prob * (labels == 1) * (sensitive_attr == 1)).float().mean() - \
              (pos_prob * (labels == 1) * (sensitive_attr == 0)).float().mean()
    treatment_reg = torch.abs(fp_diff) + torch.abs(fn_diff)
    # treatment_reg = torch.abs(fn_diff)

    # fn_group_1 = ((predictions == 0) & (labels == 1) & (sensitive_attr == 1)).sum()
    # fp_group_1 = ((predictions == 1) & (labels == 0) & (sensitive_attr == 1)).sum()

    # fn_group_0 = ((predictions == 0) & (labels == 1) & (sensitive_attr == 0)).sum()
    # fp_group_0 = ((predictions == 1) & (labels == 0) & (sensitive_attr == 0)).sum()
    
    # ratio_group_1 = fn_group_1 / fp_group_1 if fp_group_1 != 0 else torch.tensor(float('inf'))
    # ratio_group_0 = fn_group_0 / fp_group_0 if fp_group_0 != 0 else torch.tensor(float('inf'))
    # treatment_reg = torch.abs(ratio_group_1 - ratio_group_0)

    # Equal Opportunity Difference Regularization
    eod_reg = torch.abs((pos_prob * (labels == 1) * (sensitive_attr == 1)).float().mean() - \
                        (pos_prob * (labels == 1) * (sensitive_attr == 0)).float().mean())

    # Overall Accuracy Equality Difference Regularization
    oaed_reg = torch.abs((pos_prob * (sensitive_attr == 1)).float().mean() - \
                         (pos_prob * (sensitive_attr == 0)).float().mean())

    penalty = alpha + beta + gamma + delta
    
    # Combine losses
    combined_loss = (1-penalty)*standard_loss
    + alpha * sp_reg
    + beta * treatment_reg
    + gamma * eod_reg
    + delta * oaed_reg
    
    return combined_loss

# ------------------------------------------------------------------
# Fairness Metrics
# ------------------------------------------------------------------

def calculate_fairness(label, predictions, sens_attr='Gender', balanced=False):
    """
    Calculate various fairness metrics.

    Args:
    label: Actual labels (binary).
    predictions: Model predictions (binary).
    sens_attr: Binary sensitive attribute for fairness evaluation.

    Returns:
    A dictionary containing SPD, EOD, OAED, and TED values.
    """
    if balanced is False:
        labels = torch.tensor(user_labels[label].values, dtype=torch.long)
        sensitive_attribute = torch.tensor(user_labels[sens_attr].values, dtype=torch.long)
    else:
        labels = torch.tensor(filtered_user_labels[label].values, dtype=torch.long)
        sensitive_attribute = torch.tensor(filtered_user_labels[sens_attr].values, dtype=torch.long)
    
    labels = labels.to(set_device())
    sensitive_attribute = sensitive_attribute.to(set_device())

    predictions = predictions.float()
    labels = labels.float()
    sensitive_attribute = sensitive_attribute.float()

    def statistical_parity_difference():
        prob_group_1 = predictions[sensitive_attribute == 1].mean()
        prob_group_0 = predictions[sensitive_attribute == 0].mean()
        return abs(prob_group_1 - prob_group_0), prob_group_0, prob_group_1

    def equal_opportunity_difference():
        tpr_group_1 = predictions[(labels == 1) & (sensitive_attribute == 1)].mean()
        tpr_group_0 = predictions[(labels == 1) & (sensitive_attribute == 0)].mean()
        return abs(tpr_group_1 - tpr_group_0), tpr_group_0, tpr_group_1

    def overall_accuracy_equality_difference():
        acc_group_1 = (predictions[sensitive_attribute == 1] == labels[sensitive_attribute == 1]).float().mean()
        acc_group_0 = (predictions[sensitive_attribute == 0] == labels[sensitive_attribute == 0]).float().mean()
        return abs(acc_group_1 - acc_group_0), acc_group_0, acc_group_1

    def treatment_equality_difference():
        fn_group_1 = ((predictions == 0) & (labels == 1) & (sensitive_attribute == 1)).sum()
        fp_group_1 = ((predictions == 1) & (labels == 0) & (sensitive_attribute == 1)).sum()

        fn_group_0 = ((predictions == 0) & (labels == 1) & (sensitive_attribute == 0)).sum()
        fp_group_0 = ((predictions == 1) & (labels == 0) & (sensitive_attribute == 0)).sum()

        ratio_group_1 = fn_group_1 / fp_group_1 if fp_group_1 != 0 else float('inf')
        ratio_group_0 = fn_group_0 / fp_group_0 if fp_group_0 != 0 else float('inf')

        return abs(ratio_group_1 - ratio_group_0), ratio_group_0, ratio_group_1, fn_group_1, fp_group_1, fn_group_0, fp_group_0

    # Calculating each fairness metric
    spd, sp_g0, sp_g1 = statistical_parity_difference()
    eod, eod_g0, eod_g1 = equal_opportunity_difference()
    oaed, oaed_g0, oaed_g1 = overall_accuracy_equality_difference()
    ted, ted_g0, ted_g1, fn_group_1, fp_group_1, fn_group_0, fp_group_0 = treatment_equality_difference()

    return {
        'Statistical Parity Difference': spd,
        'Statistical Parity Group with S=0': sp_g0,
        'Statistical Parity Group S=1': sp_g1,
        'Equal Opportunity Difference': eod,
        'Equal Opportunity Group with S=0': eod_g0,
        'Equal Opportunity Group S=1': eod_g1,
        'Overall Accuracy Equality Difference': oaed,
        'Overall Accuracy Group with S=0': oaed_g0,
        'Overall Accuracy Group S=1': oaed_g1,
        'Treatment Equality Difference': ted,
        'Treatment Equality Group with S=0': ted_g0,
        'Treatment Equality Group S=1': ted_g1
        # 'False Negatives Group 1': fn_group_1,
        # 'False Positives Group 1': fp_group_1,
        # 'False Negatives Group 0': fn_group_0,
        # 'False Positives Group 0': fp_group_0
    }

# ------------------------------------------------------------------
# Model Training
# ------------------------------------------------------------------

# Train the model
def training(model, data, optimizer, epochs=2000, fairness=False, alpha=0, beta=0, gamma=0, delta=0):
    model.to(set_device())
    data.to(set_device())
    
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        
        if fairness:
            loss = fairness_aware_loss(out[data.train_mask], data, data.x[data.train_mask, -1],
                                       alpha=alpha, beta=beta, gamma=gamma, delta=delta)
            
        else:
            # criterion = torch.nn.CrossEntropyLoss()
            # criterion = torch.nn.BCELoss()
            criterion = torch.nn.NLLLoss()
            loss = criterion(out[data.train_mask], data.y[data.train_mask])

        loss.backward()
        optimizer.step()

        metrics = test(model, data)

        if epoch % 10 == 0:
            print(f'Epoch {epoch} | Loss: {loss.item()} | \n AUC_ROC: {metrics["AUC_ROC"]} | F1 Score: {metrics["F1_Score"]} | SPD: {metrics["parity"]} | EOD: {metrics["equality"]}')

# ------------------------------------------------------------------
# Model Testing
# ------------------------------------------------------------------

# Test the model
def test(model, data, balanced=False):
    # model.to('cpu')
    # data.to('cpu')
    model.to(set_device())
    data.to(set_device())
    
    model.eval()
    with torch.inference_mode():
      out = model(data.x, data.edge_index)

    _, pred = model(data.x, data.edge_index).max(dim=1)
    correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
    accuracy = correct / int(data.test_mask.sum())
    # print(f'Accuracy: {accuracy}')

    # Convert model outputs to binary predictions
    predictions = out.argmax(dim=1)

    fairness_metrics = calculate_fairness(label='GoodCustomer', predictions=predictions, sens_attr='Gender', balanced=balanced)
    fairness_metrics['Accuracy'] = accuracy

    return fairness_metrics

# ------------------------------------------------------------------
# Print Metrics
# ------------------------------------------------------------------

# def print_metrics(metrics):
#     for key, value in metrics.items():
#         print(f"\n{key} : {value:.5f}")

def print_metrics(metrics):
    count = -1

    for key, value in metrics.items():
        count += 1
        if count == 3:
            print(f"\n\n{key} : {value:.5f}")
            count = 0
        else:
            print(f"{key} : {value:.5f}")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def test(model, data, val=True, balanced=False):
    model.to(set_device())
    data.to(set_device())
    
    if val==True:
      mask = data.val_mask
    else:
      mask = data.test_mask

    model.eval()
    with torch.no_grad():
        out = model(data.x, data.edge_index)
        predictions = out.argmax(dim=1)

    # Compute accuracy
    correct = int(predictions[mask].eq(data.y[mask]).sum().item())
    accuracy = correct / int(mask.sum())
    
    # Extract the predictions and the true labels
    y_true = data.y[mask].cpu().numpy()
    y_pred = predictions[mask].cpu().numpy()
    
    # Compute F1 score
    f1 = f1_score(y_true, y_pred, average='binary')

    # Compute AUC-ROC score
    y_probs = out[mask][:, 1].cpu().numpy() 
    auc_roc = roc_auc_score(y_true, y_probs)
    
    fairness_metrics = fair_metric('gender', predictions, 'bin_age')
    fairness_metrics['Accuracy'] = accuracy
    fairness_metrics['F1_Score'] = f1
    fairness_metrics['AUC_ROC'] = auc_roc

    return fairness_metrics

def fair_metric(labels, pred, sens):
	
	labels = user_labels[labels].values
	sens = user_labels[sens].values
	
	idx_s0 = sens==0
	idx_s1 = sens==1

	idx_s0_y1 = np.bitwise_and(idx_s0, labels==1)
	idx_s1_y1 = np.bitwise_and(idx_s1, labels==1)

	parity = abs(sum(pred[idx_s0])/sum(idx_s0)-sum(pred[idx_s1])/sum(idx_s1))
	equality = abs(sum(pred[idx_s0_y1])/sum(idx_s0_y1)-sum(pred[idx_s1_y1])/sum(idx_s1_y1))
    
	return {"parity": parity.item(), "equality": equality.item()}

In [3]:
sens_attribute_tensor = torch.tensor(user_labels['bin_age'].values, dtype=torch.long)
sens_attribute_tensor = sens_attribute_tensor.to(set_device())

# GCN Model

In [28]:
class GCN(nn.Module):
	def __init__(self, nfeat, nhid=128, nclass=2, dropout=0):
		super(GCN, self).__init__()
		self.body = GCN_Body(nfeat,nhid,dropout)
		self.fc = nn.Linear(nhid, nclass)

		for m in self.modules():
			self.weights_init(m)

	def weights_init(self, m):
		if isinstance(m, nn.Linear):
			torch.nn.init.xavier_uniform_(m.weight.data)
			if m.bias is not None:
				m.bias.data.fill_(0.0)

	def forward(self, x, edge_index):
		x = self.body(x, edge_index)
		x = self.fc(x)
		return F.log_softmax(x, dim=1)
		# return x

In [29]:
class GCN_Body(nn.Module):
	def __init__(self, nfeat, nhid, dropout):
		super(GCN_Body, self).__init__()
		self.gc1 = GCNConv(nfeat, nhid)

	def forward(self, x, edge_index):
		x = self.gc1(x, edge_index)
		return x

In [48]:
gcn_model = GCN(data.num_node_features, nhid=16, nclass=2)
optimizer_gcn_model = torch.optim.Adam(gcn_model.parameters(), lr=0.01, weight_decay=1e-5)

In [None]:
training(model=gcn_model, 
         data=data, 
         optimizer=optimizer_gcn_model, 
         fairness=False,  
         epochs=1000)

Epoch 0 | Loss: 0.6151986718177795 | 
 AUC_ROC: 0.4582007971760998 | F1 Score: 0.0014182720718591184 | SPD: 0.0007968544960021973 | EOD: 0.0012943694600835443
Epoch 10 | Loss: 0.5964589715003967 | 
 AUC_ROC: 0.46234220692331923 | F1 Score: 0.0007100591715976331 | SPD: 0.00020969855540897697 | EOD: 0.00034848408540710807
Epoch 20 | Loss: 0.5860309600830078 | 
 AUC_ROC: 0.4672246215855648 | F1 Score: 0.0 | SPD: 2.0969855540897697e-05 | EOD: 4.978344077244401e-05
Epoch 30 | Loss: 0.5811080932617188 | 
 AUC_ROC: 0.4744588521745437 | F1 Score: 0.0 | SPD: 0.0 | EOD: 0.0
Epoch 40 | Loss: 0.5781946182250977 | 
 AUC_ROC: 0.48471953199159956 | F1 Score: 0.0 | SPD: 0.0 | EOD: 0.0
Epoch 50 | Loss: 0.5754379630088806 | 
 AUC_ROC: 0.4979018657505362 | F1 Score: 0.0 | SPD: 0.0 | EOD: 0.0
Epoch 60 | Loss: 0.5725720524787903 | 
 AUC_ROC: 0.5131979524240887 | F1 Score: 0.00023679848448969926 | SPD: 3.1454783311346546e-05 | EOD: 9.956688154488802e-05
Epoch 70 | Loss: 0.5697234272956848 | 
 AUC_ROC: 0.529

In [None]:
print("Here are the values for the GCN model")

metrics_gcn_model = test(gcn_model, data)

print_metrics(metrics_gcn_model)

Here are the values for the GCN model
parity : 0.04591
equality : 0.05184
Accuracy : 0.85353


F1_Score : 0.62449
AUC_ROC : 0.87258


In [49]:
training(model=gcn_model, 
         data=data, 
         optimizer=optimizer_gcn_model, 
         fairness=False,  
         epochs=50)

print("Here are the values for the GCN model")

metrics_gcn_model = test(gcn_model, data)

print_metrics(metrics_gcn_model)

Epoch 0 | Loss: 0.7154015898704529 | 
 AUC_ROC: 0.28981126667042123 | F1 Score: 0.0006988120195667364 | SPD: 0.005784710869193077 | EOD: 0.0011948025785386562
Epoch 10 | Loss: 0.592928946018219 | 
 AUC_ROC: 0.5173302297258584 | F1 Score: 0.00753473039792795 | SPD: 0.0023349570110440254 | EOD: 0.006431572139263153
Epoch 20 | Loss: 0.5235864520072937 | 
 AUC_ROC: 0.7881768579702766 | F1 Score: 0.04535987040037028 | SPD: 0.0013826140202581882 | EOD: 0.003806181252002716
Epoch 30 | Loss: 0.4776657819747925 | 
 AUC_ROC: 0.831259212099218 | F1 Score: 0.17817947062621048 | SPD: 0.013948183506727219 | EOD: 0.012088701128959656
Epoch 40 | Loss: 0.43756547570228577 | 
 AUC_ROC: 0.8557631723362468 | F1 Score: 0.4623223384285331 | SPD: 0.03427240252494812 | EOD: 0.00240248441696167
Here are the values for the GCN model
parity : 0.04365
equality : 0.02798
Accuracy : 0.84200


F1_Score : 0.57217
AUC_ROC : 0.86603


# FAME

In [7]:
class FairnessAwareMessagePassingLayer(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(FairnessAwareMessagePassingLayer, self).__init__(aggr='mean')  
        self.lin = nn.Linear(in_channels, out_channels)
        self.sensitive_attr = sens_attribute_tensor
        self.bias_correction = nn.Parameter(torch.rand(1))

    def forward(self, x, edge_index):        
        # Add self-loops 
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        x = self.lin(x)

        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)
    
    def message(self, x_j, edge_index, size):
        row, col = edge_index
        deg = degree(row, size[0], dtype=x_j.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        
        group_difference = self.sensitive_attr[row] - self.sensitive_attr[col]
        
        # Adjust messages based on statistical parity
        fairness_adjustment = (1 + self.bias_correction * group_difference.view(-1, 1))

        return fairness_adjustment * norm.view(-1, 1) * x_j

    def update(self, aggr_out):
        return aggr_out

In [44]:
class FairMP_GCN(torch.nn.Module):
    def __init__(self, data, layers=1, hidden=128, dropout=0):
        super(FairMP_GCN, self).__init__()
        self.conv1 = FairnessAwareMessagePassingLayer(data.num_node_features, hidden)
        self.convs = torch.nn.ModuleList()
        
        for i in range(layers - 1):
            self.convs.append(FairnessAwareMessagePassingLayer(hidden, hidden))
        
        # self.conv2 = FairnessAwareMessagePassingLayer(hidden, 2)
        self.fc = nn.Linear(hidden, 2)
        self.dropout = dropout

    def forward(self, x, edge_index, *args, **kwargs):
        x = F.relu(self.conv1(x, edge_index))
        # x = F.dropout(x, p=self.dropout, training=self.training)

        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
            x = F.dropout(x, p=self.dropout, training=self.training)

        # x = self.conv2(x, edge_index)
        x = self.fc(x)
        
        return F.log_softmax(x, dim=1)

In [45]:
Fair_gcn_model = FairMP_GCN(data, hidden=16)
optimizer_Fair_gcn_model = torch.optim.Adam(Fair_gcn_model.parameters(), lr=0.01, weight_decay=1e-5)

In [46]:
training(model=Fair_gcn_model, 
         data=data, 
         optimizer=optimizer_Fair_gcn_model, 
         fairness=False,  
         epochs=50)

Epoch 0 | Loss: 0.690586268901825 | 
 AUC_ROC: 0.35402714856739625 | F1 Score: 0.01308564084589321 | SPD: 0.005428011063486338 | EOD: 0.011536752805113792
Epoch 10 | Loss: 0.6365414261817932 | 
 AUC_ROC: 0.41586195891002364 | F1 Score: 0.0 | SPD: 0.0 | EOD: 0.0
Epoch 20 | Loss: 0.5943252444267273 | 
 AUC_ROC: 0.8310719827277167 | F1 Score: 0.13972934316206406 | SPD: 0.0005518160760402679 | EOD: 0.025728903710842133
Epoch 30 | Loss: 0.5547952651977539 | 
 AUC_ROC: 0.8901853575524458 | F1 Score: 0.22962497381102037 | SPD: 0.005978608503937721 | EOD: 0.024777323007583618
Epoch 40 | Loss: 0.5184928774833679 | 
 AUC_ROC: 0.903465677609534 | F1 Score: 0.3414442700156986 | SPD: 0.01719280704855919 | EOD: 0.01047012209892273


In [47]:
print("Here are the values for the GCN model")

metrics_Fair_gcn_model = test(Fair_gcn_model, data)

print_metrics(metrics_Fair_gcn_model)

Here are the values for the GCN model
parity : 0.02237
equality : 0.01013
Accuracy : 0.81073


F1_Score : 0.40389
AUC_ROC : 0.90681


In [37]:
print("Here are the values for the GCN model")

metrics_Fair_gcn_model = test(Fair_gcn_model, data)

print_metrics(metrics_Fair_gcn_model)

Here are the values for the GCN model
parity : 0.03111
equality : 0.02125
Accuracy : 0.82589


F1_Score : 0.48430
AUC_ROC : 0.68408


# A-FAME

In [21]:
import torch
from torch_geometric.nn import MessagePassing
from torch.nn import Linear, Parameter
from torch_geometric.utils import add_self_loops, softmax
import torch.nn.functional as F

class Attention_FairMessagePassing(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(Attention_FairMessagePassing, self).__init__(aggr='add') 
        self.lin = Linear(in_channels, out_channels) 
        self.att = Linear(2 * out_channels, 1) 
        
        self.sensitive_attr = sens_attribute_tensor 
        self.bias_correction = Parameter(torch.rand(1))  

    def forward(self, x, edge_index):
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        x = self.lin(x)

        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)

    def message(self, edge_index, x_i, x_j, size_i):
        x_cat = torch.cat([x_i, x_j], dim=-1)  
        alpha = self.att(x_cat)

        row, col = edge_index
        group_difference = self.sensitive_attr[row] - self.sensitive_attr[col]

        fairness_adjustment = self.bias_correction * group_difference.view(-1, 1)
        alpha = alpha + fairness_adjustment

        alpha = softmax(alpha, edge_index[0], num_nodes=size_i)

        return alpha * x_j

    def update(self, aggr_out):
        return aggr_out
    
# GCN class that takes in the data as an input for dimensions of the convolutions
class Fair_Attention_MP_GCN(torch.nn.Module):
    def __init__(self, data, layers=1, hidden=128, dropout=0):
        super(Fair_Attention_MP_GCN, self).__init__()
        self.conv1 = Attention_FairMessagePassing(data.num_node_features, hidden)
        self.convs = torch.nn.ModuleList()
        
        for i in range(layers - 1):
            self.convs.append(Attention_FairMessagePassing(hidden, hidden))
        
        # self.conv2 = Attention_FairMessagePassing(hidden, 2)
        self.fc = Linear(hidden, 2)
        self.dropout = dropout

    def forward(self, x, edge_index, *args, **kwargs):
        x = F.relu(self.conv1(x, edge_index))
        # x = F.dropout(x, p=self.dropout, training=self.training)

        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
            # x = F.dropout(x, p=self.dropout, training=self.training)

        # x = self.conv2(x, edge_index)
        x = self.fc(x)
        
        return F.log_softmax(x, dim=1)

In [22]:
Fair_gat_model = Fair_Attention_MP_GCN(data, hidden=16)
optimizer_Fair_gat_model = torch.optim.Adam(Fair_gat_model.parameters(), lr=0.01, weight_decay=1e-5)

training(model=Fair_gat_model, 
         data=data, 
         optimizer=optimizer_Fair_gat_model, 
         fairness=False,  
         epochs=50)

Epoch 0 | Loss: 0.943186342716217 | 
 AUC_ROC: 0.5154655590396341 | F1 Score: 0.40370963501207063 | SPD: 0.0 | EOD: 0.0
Epoch 10 | Loss: 0.6069545745849609 | 
 AUC_ROC: 0.49094220184921106 | F1 Score: 0.0 | SPD: 0.0 | EOD: 0.0
Epoch 20 | Loss: 0.5302107930183411 | 
 AUC_ROC: 0.780581656083069 | F1 Score: 0.00023679848448969926 | SPD: 9.436435357201844e-05 | EOD: 0.0004480509669519961
Epoch 30 | Loss: 0.4590913653373718 | 
 AUC_ROC: 0.8944414283799714 | F1 Score: 0.0011834319526627217 | SPD: 0.00029357796302065253 | EOD: 0.0013939363416284323
Epoch 40 | Loss: 0.3787717819213867 | 
 AUC_ROC: 0.9833063382707635 | F1 Score: 0.042390548992355795 | SPD: 0.0050068264827132225 | EOD: 0.028081420809030533


In [23]:
print("Here are the values for the FairGAT model")

metrics_Fair_gat_model = test(Fair_gat_model, data)

print_metrics(metrics_Fair_gat_model)

Here are the values for the FairGAT model
parity : 0.01047
equality : 0.06353
Accuracy : 0.76051


F1_Score : 0.10678
AUC_ROC : 0.99429


# GAT Model

In [24]:
class GAT(nn.Module):
	def __init__(self, nfeat, nhid=128, nclass=2, dropout=0):
		super(GAT, self).__init__()
		self.body = GAT_Body(nfeat,nhid,dropout)
		self.fc = nn.Linear(nhid, nclass)

		for m in self.modules():
			self.weights_init(m)

	def weights_init(self, m):
		if isinstance(m, nn.Linear):
			torch.nn.init.xavier_uniform_(m.weight.data)
			if m.bias is not None:
				m.bias.data.fill_(0.0)

	def forward(self, x, edge_index):
		x = self.body(x, edge_index)
		x = self.fc(x)
		return F.log_softmax(x, dim=1)
		# return x

In [25]:
from torch_geometric.nn import GATConv

class GAT_Body(nn.Module):
	def __init__(self, nfeat, nhid, dropout):
		super(GAT_Body, self).__init__()
		self.gc1 = GATConv(nfeat, nhid)

	def forward(self, x, edge_index):
		x = self.gc1(x, edge_index)
		return x

In [26]:
gat_model = GAT(data.num_node_features, nhid=16)
optimizer_gat_model = torch.optim.Adam(gat_model.parameters(), lr=0.01, weight_decay=1e-5)

training(model=gat_model, 
         data=data, 
         optimizer=optimizer_gat_model, 
         fairness=False,  
         epochs=50)

Epoch 0 | Loss: 0.5749228596687317 | 
 AUC_ROC: 0.5852403448181667 | F1 Score: 0.03706230104592997 | SPD: 0.017376001924276352 | EOD: 0.024556681513786316
Epoch 10 | Loss: 0.495073527097702 | 
 AUC_ROC: 0.7977187208225428 | F1 Score: 0.16602067183462532 | SPD: 0.017557242885231972 | EOD: 0.02045833319425583
Epoch 20 | Loss: 0.4108549654483795 | 
 AUC_ROC: 0.8856521094109684 | F1 Score: 0.5775786188685285 | SPD: 0.044105589389801025 | EOD: 0.037001967430114746
Epoch 30 | Loss: 0.3071121275424957 | 
 AUC_ROC: 0.9542107349285638 | F1 Score: 0.7325951557093425 | SPD: 0.07764032483100891 | EOD: 0.0003832578659057617
Epoch 40 | Loss: 0.20482370257377625 | 
 AUC_ROC: 0.9816319846149621 | F1 Score: 0.8563790101964684 | SPD: 0.1038682758808136 | EOD: 0.049774885177612305


In [27]:
print("Here are the values for the GAT model")

metrics_gat_model = test(gat_model, data)

print_metrics(metrics_gat_model)

Here are the values for the GAT model
parity : 0.10312
equality : 0.02911
Accuracy : 0.95993


F1_Score : 0.92187
AUC_ROC : 0.98939


# Compare results

In [None]:
Values for the FairGAT model - same model specifications as the one from Carlos,
but with one A-FAME instead of one GCNConv layer

parity : 0.00706
equality : 0.00207
Accuracy : 0.68000
F1_Score : 0.80952
AUC_ROC : 0.67496