# Fairness in GCNs
In this python notebook, we have explored the use of Graph Convolutional Networks(GCNs) for the [Alibaba](https://tianchi.aliyun.com/dataset/56) dataset.

The dataset was pre-processed using the code from Erasmo Purificato's [CatGCN notebook](https://colab.research.google.com/drive/1zsx4an6BKYhJ_UT-mSl1_qPB-zyjTmrA#scrollTo=xxzSlLj3LDIu).  
This pre-processing provided us with various .csv files, which are used to form graph data.  

The nodes represent the user ids, with the node features being attributes such as buy, gender, student, etc. The edges between the nodes have been created through various relations between the users such as items bought, items clicked on, etc.  

In this notebook, we have only focused on GCNs and used fairness methods from the [AIF360](https://github.com/Trusted-AI/AIF360) framework.

# Imports

### Import necessary libraries

In [1]:
# Basic data processing libraries
import pandas as pd
import numpy as np
import os
import torch

# Graph data processing libraries
import networkx as nx
from torch_geometric.data import Data
from torch_geometric.utils import from_networkx

# Libraries for (G)NNs
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

### Initialise paths

In [2]:
base_path = os.getcwd()
input_ali_data_path = os.path.join(base_path, "input_ali_data")

base_path, input_ali_data_path

('c:\\Users\\Hanna\\iCloudDrive\\Desktop\\Uni\\Master Thesis\\Trials',
 'c:\\Users\\Hanna\\iCloudDrive\\Desktop\\Uni\\Master Thesis\\Trials\\input_ali_data')

---
# Data Stuff

### Load the data from .csv files

In [3]:
# Load the data files
user_labels_path = os.path.join(input_ali_data_path, "user_labels.csv")
user_edges_path = os.path.join(input_ali_data_path, "user_edge.csv")

In [4]:
# Create dataframes to store the information from the .csv files
user_labels = pd.read_csv(user_labels_path)
user_edges = pd.read_csv(user_edges_path)

### Pre-processing the data

In [5]:
# Prepare the data for GNNs
node_features = torch.tensor(user_labels.iloc[:, 1:].values, dtype=torch.float)
edge_index = torch.tensor(user_edges.values, dtype=torch.long).t().contiguous()

In [6]:
def show_df_info(df):
    print(df.info())
    print('####### Repeat ####### \n', df.duplicated().any())
    print('####### Count ####### \n', df.nunique())
    print('####### Example ####### \n',df.head())

In [6]:
node_features

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 1.,  ..., 0., 1., 1.],
        [0., 2., 1.,  ..., 0., 1., 1.],
        ...,
        [0., 0., 1.,  ..., 2., 0., 1.],
        [0., 4., 1.,  ..., 3., 0., 1.],
        [0., 0., 1.,  ..., 0., 0., 1.]])

In [79]:
show_df_info(user_labels)

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 166958 entries, 0 to 166957
Data columns (total 8 columns):
 #   Column   Non-Null Count   Dtype
---  ------   --------------   -----
 0   uid      166958 non-null  int64
 1   gender   166958 non-null  int64
 2   age      166958 non-null  int64
 3   buy      166958 non-null  int64
 4   student  166958 non-null  int64
 5   city     166958 non-null  int64
 6   bin_age  166958 non-null  int64
 7   bin_buy  166958 non-null  int64
dtypes: int64(8)
memory usage: 10.2 MB
None
####### Repeat ####### 
 False
####### Count ####### 
 uid        166958
gender          2
age             7
buy             3
student         2
city            4
bin_age         2
bin_buy         2
dtype: int64
####### Example ####### 
    uid  gender  age  buy  student  city  bin_age  bin_buy
0    0       0    0    0        0     0        0        0
1    1       0    1    1        1     0        1        1
2    2       0    2    1        1     0        1        1
3    3

In [7]:
# Create torch-geometric data
data = Data(x=node_features, edge_index=edge_index)
data

Data(x=[166958, 7], edge_index=[2, 29061406])

In [8]:
num_nodes = node_features.size(0)
num_classes = 2 # Binarised gender values from the data
num_node_features = data.num_node_features

# Create masks for training, and testing
train_mask = torch.zeros(num_nodes, dtype=torch.bool)
test_mask = torch.zeros(num_nodes, dtype=torch.bool)

# 80 - 20 Train and Test data split
num_train = int(num_nodes * 0.8)
train_mask[:num_train] = True
test_mask[num_train:] = True

data.train_mask = train_mask
data.test_mask = test_mask

In [9]:
num_nodes

166958

In [10]:
# Labels from the data (in this case: Gender Classification)
data.y = torch.tensor(user_labels['gender'].values, dtype=torch.long)

---
# Utils

### Function to clone the dataset

In [84]:
def clone(data):
    """
    Create a new cloned torch-geometric data object.

    Args:
    data: Actual data to be cloned.

    Returns:
    A torch-geometric data object.
    """
    clone_data = Data()

    # Copy the data's features and edges
    clone_data.x = data.x.clone()
    clone_data.edge_index = data.edge_index.clone()

    # Mask the data similar to the original train-test split
    clone_data.train_mask = data.train_mask.clone()
    clone_data.test_mask = data.test_mask.clone()

    # Copy the labels
    clone_data.y = data.y.clone()

    return clone_data

### Custom Loss Functions

In [11]:
def weighted_cross_entropy(output, data):
    """
    A custom loss function to calculate a weighted-cross entropy loss.

    Args:
    output: Outputs from the model.
    data: The torch-geometric data object used for the model.

    Returns:
    A weighted cross-entropy loss.
    """
    target = data.y[data.train_mask]
    weights = data.instance_weights[data.train_mask]

    loss = F.cross_entropy(output, target, reduction='none')
    weighted_loss = loss * weights

    return weighted_loss.mean()

In [12]:
import torch
import torch.nn.functional as F

def dual_objective_loss_with_lagrangian_relaxation(output, data, sensitive_attrs, lambda_param, lambda_entropy_coeff):
    """
    Compute the dual-objective loss with Lagrangian relaxation and entropy regularization.
    
    :param outputs: Model predictions.
    :param labels: Ground truth labels.
    :param sensitive_attrs: Sensitive attributes for each instance.
    :param lambda_param: Lagrange multiplier for balancing accuracy and fairness.
    :param lambda_entropy_coeff: Coefficient for the entropy regularization.
    :return: Total loss, prediction loss, fairness loss, entropy regularization term.
    """
    # Prediction loss (e.g., Cross-Entropy for classification)
    target = data.y[data.train_mask]
    pred_loss = F.cross_entropy(output, target)

    labels = data.y[train_mask]
    pos_prob = torch.sigmoid(output[:, 1])

    # Fairness loss (e.g., absolute difference in positive prediction rates between groups)
    # Statistical Parity Regularization
    sp_reg = torch.abs(pos_prob[sensitive_attrs == 1].mean() - pos_prob[sensitive_attrs == 0].mean())

    fairness_loss = torch.abs(sp_reg)
    
    # Entropy regularization for the Lagrange multiplier
    entropy_reg = -lambda_entropy_coeff * (lambda_param * torch.log(lambda_param))

    # Total loss
    total_loss = pred_loss + lambda_param * fairness_loss + entropy_reg

    return total_loss

In [12]:
def fairness_aware_loss(output, data, sensitive_attr, weighted=False, alpha=0.01, beta=0.01, gamma=0.01, delta=0.01):
    """
    Custom loss function to calculate a fairness-aware loss.
    This includes measures for statistical parity, treatment equality, equal opportunity difference, and overall accuracy equality difference.

    Args:
    output: Outputs from the model.
    data: The torch-geometric data object used for the model.
    sensitive_attr: The sensitive attribute in the data (e.g., bin_age).
    weighted: Boolean value indicating re-weighing done to the data or not.
    alpha: Parameter for statistical parity regularizer strength.
    beta: Parameter for treatment equality regularizer strength.
    gamma: Parameter for equal opportunity difference regularizer strength.
    delta: Parameter for overall accuracy equality difference regularizer strength.

    Returns:
    A fairness-aware combined loss.
    """
    if weighted:
        # Weighted cross-entropy loss
        standard_loss = weighted_cross_entropy(output, data)
    else:
        # Standard cross-entropy loss
        target = data.y[data.train_mask]
        standard_loss = F.cross_entropy(output, target)

    labels = data.y[train_mask]
    pos_prob = torch.sigmoid(output[:, 1])
    neg_prob = 1 - pos_prob
    predictions = output.argmax(dim=1)

    # Statistical Parity Regularization
    sp_reg = torch.abs(pos_prob[sensitive_attr == 1].mean() - pos_prob[sensitive_attr == 0].mean())

    # Treatment Equality Regularization
    fp_diff = (neg_prob * (labels == 0) * (sensitive_attr == 1)).float().mean() - \
              (neg_prob * (labels == 0) * (sensitive_attr == 0)).float().mean()
    fn_diff = (pos_prob * (labels == 1) * (sensitive_attr == 1)).float().mean() - \
              (pos_prob * (labels == 1) * (sensitive_attr == 0)).float().mean()
    treatment_reg = torch.abs(fp_diff) + torch.abs(fn_diff)

    # Equal Opportunity Difference Regularization
    eod_reg = torch.abs((pos_prob * (labels == 1) * (sensitive_attr == 1)).float().mean() - \
                        (pos_prob * (labels == 1) * (sensitive_attr == 0)).float().mean())

    # Overall Accuracy Equality Difference Regularization
    oaed_reg = torch.abs((pos_prob * (sensitive_attr == 1)).float().mean() - \
                         (pos_prob * (sensitive_attr == 0)).float().mean())

    # Combine losses
    combined_loss = standard_loss + alpha * sp_reg + beta * treatment_reg + gamma * eod_reg + delta * oaed_reg

    return combined_loss

In [50]:
def fairness_aware_loss(output, data, sensitive_attr, weighted=False, alpha=0.01, beta=0.02):
    """
    Custom loss function to calculate a fairness-aware loss.
    This includes measures for statistical parity, treatment equality, equal opportunity difference, and overall accuracy equality difference.

    Args:
    output: Outputs from the model.
    data: The torch-geometric data object used for the model.
    sensitive_attr: The sensitive attribute in the data (e.g., bin_age).
    weighted: Boolean value indicating re-weighing done to the data or not.
    alpha: Parameter for statistical parity regularizer strength.
    beta: Parameter for treatment equality regularizer strength.
    gamma: Parameter for equal opportunity difference regularizer strength.
    delta: Parameter for overall accuracy equality difference regularizer strength.

    Returns:
    A fairness-aware combined loss.
    """
    if weighted:
        # Weighted cross-entropy loss
        standard_loss = weighted_cross_entropy(output, data)
    else:
        # Standard cross-entropy loss
        target = data.y[data.train_mask]
        standard_loss = F.cross_entropy(output, target)

    labels = data.y[train_mask]
    pos_prob = torch.sigmoid(output[:, 1])
    neg_prob = 1 - pos_prob
    predictions = output.argmax(dim=1)

    # Statistical Parity Regularization
    sp_reg = torch.abs(pos_prob[sensitive_attr == 1].mean() - pos_prob[sensitive_attr == 0].mean())

    # Calculating FPR and TPR for each group
    fpr_group1 = ((predictions == 1) & (labels == 0) & (sensitive_attr == 1)).float().mean()
    fpr_group0 = ((predictions == 1) & (labels == 0) & (sensitive_attr == 0)).float().mean()
    tpr_group1 = ((predictions == 1) & (labels == 1) & (sensitive_attr == 1)).float().mean()
    tpr_group0 = ((predictions == 1) & (labels == 1) & (sensitive_attr == 0)).float().mean()

    # Difference in FPR and TPR between the two groups for Equalized Odds
    fpr_diff = torch.abs(fpr_group1 - fpr_group0)
    tpr_diff = torch.abs(tpr_group1 - tpr_group0)

    # Combine FPR and TPR differences for Equalized Odds Regularization
    equalized_odds_reg = fpr_diff + tpr_diff

    # Treatment Equality Regularization
    fp_diff = (neg_prob * (labels == 0) * (sensitive_attr == 1)).float().mean() - \
              (neg_prob * (labels == 0) * (sensitive_attr == 0)).float().mean()
    fn_diff = (pos_prob * (labels == 1) * (sensitive_attr == 1)).float().mean() - \
              (pos_prob * (labels == 1) * (sensitive_attr == 0)).float().mean()
    # treatment_reg = torch.abs(fp_diff) + torch.abs(fn_diff)
    treatment_reg = torch.abs(fn_diff)

    # Equal Opportunity Difference Regularization
    eod_reg = torch.abs((pos_prob * (labels == 1) * (sensitive_attr == 1)).float().mean() - \
                        (pos_prob * (labels == 1) * (sensitive_attr == 0)).float().mean())

    # Overall Accuracy Equality Difference Regularization
    oaed_reg = torch.abs((pos_prob * (sensitive_attr == 1)).float().mean() - \
                         (pos_prob * (sensitive_attr == 0)).float().mean())

    penalty = alpha + beta
    
    # Combine losses
    combined_loss = (1-penalty)*standard_loss + alpha * equalized_odds_reg + beta * treatment_reg 
    return combined_loss

### Fairness Metrics

In [44]:
def calculate_fairness(label, predictions, sens_attr='bin_age'):
    """
    Calculate various fairness metrics.

    Args:
    label: Actual labels (binary).
    predictions: Model predictions (binary).
    sens_attr: Binary sensitive attribute for fairness evaluation.

    Returns:
    A dictionary containing SPD, EOD, OAED, and TED values.
    """

    labels = torch.tensor(user_labels[label].values, dtype=torch.long)
    sensitive_attribute = torch.tensor(user_labels[sens_attr].values, dtype=torch.long)

    predictions = predictions.float()
    labels = labels.float()
    sensitive_attribute = sensitive_attribute.float()

    def statistical_parity_difference():
        prob_group_1 = predictions[sensitive_attribute == 1].mean()
        prob_group_0 = predictions[sensitive_attribute == 0].mean()
        return abs(prob_group_1 - prob_group_0), prob_group_0, prob_group_1

    def equal_opportunity_difference():
        tpr_group_1 = predictions[(labels == 1) & (sensitive_attribute == 1)].mean()
        tpr_group_0 = predictions[(labels == 1) & (sensitive_attribute == 0)].mean()
        return abs(tpr_group_1 - tpr_group_0)

    def overall_accuracy_equality_difference():
        acc_group_1 = (predictions[sensitive_attribute == 1] == labels[sensitive_attribute == 1]).float().mean()
        acc_group_0 = (predictions[sensitive_attribute == 0] == labels[sensitive_attribute == 0]).float().mean()
        return abs(acc_group_1 - acc_group_0)

    def treatment_equality_difference():
        fn_group_1 = ((predictions == 0) & (labels == 1) & (sensitive_attribute == 1)).sum()
        fp_group_1 = ((predictions == 1) & (labels == 0) & (sensitive_attribute == 1)).sum()

        fn_group_0 = ((predictions == 0) & (labels == 1) & (sensitive_attribute == 0)).sum()
        fp_group_0 = ((predictions == 1) & (labels == 0) & (sensitive_attribute == 0)).sum()

        ratio_group_1 = fn_group_1 / fp_group_1 if fp_group_1 != 0 else float('inf')
        ratio_group_0 = fn_group_0 / fp_group_0 if fp_group_0 != 0 else float('inf')

        return abs(ratio_group_1 - ratio_group_0)

    # Calculating each fairness metric
    spd, sp_g0, sp_g1 = statistical_parity_difference()
    eod = equal_opportunity_difference()
    oaed = overall_accuracy_equality_difference()
    ted = treatment_equality_difference()

    return {
        'Statistical Parity Difference': spd,
        'Statistical Parity Group with S=0': sp_g0,
        'Statistical Parity Group S=1': sp_g1,
        'Equal Opportunity Difference': eod,
        'Overall Accuracy Equality Difference': oaed,
        'Treatment Equality Difference': ted
    }

### Train and Test

In [28]:
# Train the model
def training(model, data, optimizer, epochs, weighted=False, fairness=False, alpha=0.01, beta=0.01, gamma=0.01, delta=0.01):
    """
    Helper function to train the GNN model.

    Args:
    model: Initialized GNN model.
    data: The torch_geometric data used to train the model.
    optimizer: Optimizer used to train the model.
    weighted: Boolean value indicating re-weighing done to the data or not.
    fairness: Boolean value indicating whether to use fairness-aware loss or not.
    alpha: Parameter for statistical parity regularizer strength.
    beta: Parameter for treatment equality regularizer strength.
    gamma: Parameter for equal opportunity difference regularizer strength.
    delta: Parameter for overall accuracy equality difference regularizer strength.

    Returns:
    -
    """
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        out = model(data)
        
        if fairness:
            # loss = fairness_aware_loss(out[data.train_mask], data, data.x[data.train_mask, -1],
                                    #    weighted=weighted, alpha=alpha, beta=beta, gamma=gamma, delta=delta)
            loss = fairness_aware_loss(out[data.train_mask], data, data.x[data.train_mask, -1],
                                       weighted=weighted, alpha=alpha, beta=beta)
        
        elif weighted:
            loss = weighted_cross_entropy(out[data.train_mask], data)
        else:
            criterion = torch.nn.CrossEntropyLoss()
            loss = criterion(out[data.train_mask], data.y[data.train_mask])

        loss.backward()
        optimizer.step()

        if epoch % 10 == 0:
            print(f'Epoch {epoch} | Loss: {loss.item()}')

In [19]:
# Test the model
def test(model, data):
    """
    Helper function to test the trained GNN model.
    Prints the Accuracy, as well as various fairness metrics values.
    For fairness metrics used: Check the calculate_fairness method

    Args:
    model: Trained GNN model.
    data: The torch_geometric data used to train the model.

    Returns:
    -
    """
    model.eval()
    with torch.inference_mode():
      out = model(data)

    _, pred = model(data).max(dim=1)
    correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
    accuracy = correct / int(data.test_mask.sum())
    print(f'Accuracy: {accuracy}')

    # Convert model outputs to binary predictions
    predictions = out.argmax(dim=1)

    # Fairness calculated for gender-classification task with bin_age as the sensitive attribute
    fairness_metrics = calculate_fairness(label='gender', predictions=predictions, sens_attr='bin_age')

    # Print the fairness metrics
    for metric, value in fairness_metrics.items():
        print(f"{metric}: {value}")

    return accuracy, fairness_metrics

---
# GCN Models

In [20]:
# GCN class that takes in the data as an input for dimensions of the convolutions
class GCN(torch.nn.Module):
    def __init__(self, data):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(data.num_node_features, 16)
        self.conv2 = GCNConv(16, 2) # 2 output classes for gender

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

### Base Model

In [44]:
from set_device import set_device

In [21]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [22]:
# Instantiate the model, define loss function and optimizer
gcn_model = GCN(data)
gcn_optimizer = torch.optim.Adam(gcn_model.parameters(), lr=0.01)

In [23]:
# Train the first model: GCN, standard data, cross-entropy loss
training(model=gcn_model, data=data, optimizer=gcn_optimizer, epochs=50)

Epoch 0 | Loss: 0.6290466785430908
Epoch 10 | Loss: 0.5474098324775696
Epoch 20 | Loss: 0.48671963810920715
Epoch 30 | Loss: 0.4376523792743683
Epoch 40 | Loss: 0.403105765581131


In [24]:
# Test the first model: GCN, standard data, cross-entropy loss
print("Here are the values for the GCN model with the standard dataset and cross-entropy loss: ")
print()
test(gcn_model, data)
print()

Here are the values for the GCN model with the standard dataset and cross-entropy loss: 

Accuracy: 0.8549053665548635
Statistical Parity Difference: 0.04585549235343933
Equal Opportunity Difference: 0.08249932527542114
Overall Accuracy Equality Difference: 0.0841907262802124
Treatment Equality Difference: 0.5477824211120605



In [114]:
# Instantiate the model, define loss function and optimizer
gcn_model = GCN(data)
gcn_optimizer = torch.optim.Adam(gcn_model.parameters(), lr=0.01)

In [140]:
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x15f692310>

### FAMP

In [21]:
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
import torch.nn as nn

In [23]:
class FairnessAwareMessagePassingLayer(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(FairnessAwareMessagePassingLayer, self).__init__(aggr='mean')  
        self.lin = nn.Linear(in_channels, out_channels)
        self.a_fair = nn.Parameter(torch.rand(out_channels)) 
        self.sensitive_attr = torch.tensor(user_labels['bin_age'].values, dtype=torch.float) 
        self.bias_correction = nn.Parameter(torch.rand(1))

    def forward(self, x, edge_index):
        # Add self-loops 
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        x = self.lin(x)

        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)
    
    def message(self, x_j, edge_index, size):
        row, col = edge_index
        deg = degree(row, size[0], dtype=x_j.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Compute statistical parity difference for each edge
        group_difference = self.sensitive_attr[row] - self.sensitive_attr[col]
        
        # Adjust messages based on statistical parity
        fairness_adjustment = (1 + self.bias_correction * group_difference.view(-1, 1))

        return fairness_adjustment * norm.view(-1, 1) * x_j

    def update(self, aggr_out):
        return aggr_out

In [24]:
# GCN class that takes in the data as an input for dimensions of the convolutions
class custom_GCN(torch.nn.Module):
    def __init__(self, data):
        super(custom_GCN, self).__init__()
        self.conv1 = FairnessAwareMessagePassingLayer(data.num_node_features, 16)
        self.conv2 = FairnessAwareMessagePassingLayer(16, 2) # 2 output classes for gender

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

In [33]:
# Instantiate the model, define loss function and optimizer
gcn_model = GCN(data)
gcn_optimizer = torch.optim.Adam(gcn_model.parameters(), lr=0.01)

training(model=gcn_model, data=data, optimizer=gcn_optimizer, epochs=50)

# Test the model
print("\nHere are the values for the GCN model with the standard dataset and cross-entropy loss: ")
test(gcn_model, data)

Epoch 0 | Loss: 0.6280730962753296
Epoch 10 | Loss: 0.5667222142219543
Epoch 20 | Loss: 0.5123195052146912
Epoch 30 | Loss: 0.4567728638648987
Epoch 40 | Loss: 0.4058452844619751

Here are the values for the GCN model with the standard dataset and cross-entropy loss: 
Accuracy: 0.8528988979396263
Statistical Parity Difference: 0.04322892427444458
Statistical Parity Group with S=0: 0.131281778216362
Statistical Parity Group S=1: 0.17451070249080658
Equal Opportunity Difference: 0.08256551623344421
Overall Accuracy Equality Difference: 0.08447939157485962
Treatment Equality Difference: 0.6746697425842285
Treatment Equality Group with S=0: 6.108710289001465
Treatment Equality Group S=1: 5.434040546417236


(0.8528988979396263,
 {'Statistical Parity Difference': tensor(0.0432),
  'Statistical Parity Group with S=0': tensor(0.1313),
  'Statistical Parity Group S=1': tensor(0.1745),
  'Equal Opportunity Difference': tensor(0.0826),
  'Overall Accuracy Equality Difference': tensor(0.0845),
  'Treatment Equality Difference': tensor(0.6747),
  'Treatment Equality Group with S=0': tensor(6.1087),
  'Treatment Equality Group S=1': tensor(5.4340)})

In [31]:
model2 = custom_GCN(data)
optimizer2 = torch.optim.Adam(model2.parameters(), lr=0.01)

# Train the model: Custom MP GNN, cross-entropy loss
training(model=model2, data=data, optimizer=optimizer2, epochs=50)

# Test the model
print("\nHere are the values for the GCN model with the standard dataset and cross-entropy loss: ")
test(model2, data)

Epoch 0 | Loss: 0.6909812092781067
Epoch 10 | Loss: 0.688873291015625
Epoch 20 | Loss: 0.685428261756897
Epoch 30 | Loss: 0.6810439825057983
Epoch 40 | Loss: 0.6754062175750732

Here are the values for the GCN model with the standard dataset and cross-entropy loss: 
Accuracy: 0.7761140392908481
Statistical Parity Difference: 0.011807771399617195
Statistical Parity Group with S=0: 0.02442988194525242
Statistical Parity Group S=1: 0.036237653344869614
Equal Opportunity Difference: 0.00039946287870407104
Overall Accuracy Equality Difference: 0.0871589183807373
Treatment Equality Difference: 470.03131103515625
Treatment Equality Group with S=0: 935.5789184570312
Treatment Equality Group S=1: 465.547607421875


(0.7761140392908481,
 {'Statistical Parity Difference': tensor(0.0118),
  'Statistical Parity Group with S=0': tensor(0.0244),
  'Statistical Parity Group S=1': tensor(0.0362),
  'Equal Opportunity Difference': tensor(0.0004),
  'Overall Accuracy Equality Difference': tensor(0.0872),
  'Treatment Equality Difference': tensor(470.0313),
  'Treatment Equality Group with S=0': tensor(935.5789),
  'Treatment Equality Group S=1': tensor(465.5476)})

In [32]:
model2_FairLoss = custom_GCN(data)
optimizer2_FairLoss = torch.optim.Adam(model2_FairLoss.parameters(), lr=0.01)

# Train the second model: Fair loss
training(model=model2_FairLoss, data=data, optimizer=optimizer2_FairLoss, fairness=True, alpha=0.1, epochs=50)

# Test the second model
print("\nHere are the values for the GCN model with the standard dataset and Fairness loss: ")
test(model2_FairLoss, data)

Epoch 0 | Loss: 0.643775999546051
Epoch 10 | Loss: 0.6410472393035889
Epoch 20 | Loss: 0.6383004784584045
Epoch 30 | Loss: 0.634784996509552
Epoch 40 | Loss: 0.6302443742752075

Here are the values for the GCN model with the standard dataset and cross-entropy loss: 
Accuracy: 0.7771322472448491
Statistical Parity Difference: 0.00629039853811264
Statistical Parity Group with S=0: 0.028298821300268173
Statistical Parity Group S=1: 0.034589219838380814
Equal Opportunity Difference: 0.021982602775096893
Overall Accuracy Equality Difference: 0.09217983484268188
Treatment Equality Difference: 66.20721435546875
Treatment Equality Group with S=0: 335.3846130371094
Treatment Equality Group S=1: 401.5918273925781


(0.7771322472448491,
 {'Statistical Parity Difference': tensor(0.0063),
  'Statistical Parity Group with S=0': tensor(0.0283),
  'Statistical Parity Group S=1': tensor(0.0346),
  'Equal Opportunity Difference': tensor(0.0220),
  'Overall Accuracy Equality Difference': tensor(0.0922),
  'Treatment Equality Difference': tensor(66.2072),
  'Treatment Equality Group with S=0': tensor(335.3846),
  'Treatment Equality Group S=1': tensor(401.5918)})

In [36]:
model2_FairLoss = custom_GCN(data)
optimizer2_FairLoss = torch.optim.Adam(model2_FairLoss.parameters(), lr=0.01)

# Train the second model: Fair loss
training(model=model2_FairLoss, data=data, optimizer=optimizer2_FairLoss, fairness=True, alpha=0.2, epochs=100)

# Test the second model
print("\nHere are the values for the GCN model with the standard dataset and Fairness loss: ")
test(model2_FairLoss, data)

Epoch 0 | Loss: 0.5933056473731995
Epoch 10 | Loss: 0.5912958979606628
Epoch 20 | Loss: 0.588941216468811
Epoch 30 | Loss: 0.5858114957809448
Epoch 40 | Loss: 0.5818194150924683
Epoch 50 | Loss: 0.5769309401512146
Epoch 60 | Loss: 0.571225643157959
Epoch 70 | Loss: 0.5649163126945496
Epoch 80 | Loss: 0.5582668781280518
Epoch 90 | Loss: 0.5515241622924805

Here are the values for the GCN model with the standard dataset and Fairness loss: 
Accuracy: 0.8258864398658361
Statistical Parity Difference: 0.015432022511959076
Statistical Parity Group with S=0: 0.07883617281913757
Statistical Parity Group S=1: 0.09426819533109665
Equal Opportunity Difference: 0.054860204458236694
Overall Accuracy Equality Difference: 0.07815015316009521
Treatment Equality Difference: 70.14901733398438
Treatment Equality Group with S=0: 30.57176399230957
Treatment Equality Group S=1: 100.72077941894531


(0.8258864398658361,
 {'Statistical Parity Difference': tensor(0.0154),
  'Statistical Parity Group with S=0': tensor(0.0788),
  'Statistical Parity Group S=1': tensor(0.0943),
  'Equal Opportunity Difference': tensor(0.0549),
  'Overall Accuracy Equality Difference': tensor(0.0782),
  'Treatment Equality Difference': tensor(70.1490),
  'Treatment Equality Group with S=0': tensor(30.5718),
  'Treatment Equality Group S=1': tensor(100.7208)})

'Statistical Parity Difference': tensor(0.0432),


  'Statistical Parity Group with S=0': tensor(0.1313),


  'Statistical Parity Group S=1': tensor(0.1745),

  'Equal Opportunity Difference': tensor(0.0826),

  'Overall Accuracy Equality Difference': tensor(0.0845),

  'Treatment Equality Difference': tensor(0.6747),

  'Treatment Equality Group with S=0': tensor(6.1087),

  'Treatment Equality Group S=1': tensor(5.4340)

In [48]:
model2_FairLoss = custom_GCN(data)
optimizer2_FairLoss = torch.optim.Adam(model2_FairLoss.parameters(), lr=0.01)

# Train the second model: Fair loss
training(model=model2_FairLoss, data=data, optimizer=optimizer2_FairLoss, fairness=True, alpha=0, beta=0.2, epochs=50)

# Test the second model
print("\nHere are the values for the GCN model with the standard dataset and Fairness loss: ")
test(model2_FairLoss, data)

Epoch 0 | Loss: 0.6824272871017456
Epoch 10 | Loss: 0.6796886324882507
Epoch 20 | Loss: 0.6766286492347717
Epoch 30 | Loss: 0.6725429892539978
Epoch 40 | Loss: 0.6672966480255127

Here are the values for the GCN model with the standard dataset and Fairness loss: 
Accuracy: 0.7783600862482032
Statistical Parity Difference: 0.009171690791845322
Statistical Parity Group with S=0: 0.02826736494898796
Statistical Parity Group S=1: 0.03743905574083328
Equal Opportunity Difference: 0.013812802731990814
Overall Accuracy Equality Difference: 0.08987855911254883
Treatment Equality Difference: 156.09625244140625


(0.7783600862482032,
 {'Statistical Parity Difference': tensor(0.0092),
  'Statistical Parity Group with S=0': tensor(0.0283),
  'Statistical Parity Group S=1': tensor(0.0374),
  'Equal Opportunity Difference': tensor(0.0138),
  'Overall Accuracy Equality Difference': tensor(0.0899),
  'Treatment Equality Difference': tensor(156.0963)})

In [51]:
model2_FairLoss = custom_GCN(data)
optimizer2_FairLoss = torch.optim.Adam(model2_FairLoss.parameters(), lr=0.01)

# Train the second model: Fair loss
training(model=model2_FairLoss, data=data, optimizer=optimizer2_FairLoss, fairness=True, alpha=0.1, beta=0.1, epochs=50)

# Test the second model
print("\nHere are the values for the GCN model with the standard dataset and Fairness loss: ")
test(model2_FairLoss, data)

Epoch 0 | Loss: 0.6480738520622253
Epoch 10 | Loss: 0.6397125124931335
Epoch 20 | Loss: 0.6073761582374573
Epoch 30 | Loss: 0.6045373678207397
Epoch 40 | Loss: 0.6016223430633545

Here are the values for the GCN model with the standard dataset and Fairness loss: 
Accuracy: 0.7732690464781984
Statistical Parity Difference: 0.003445802256464958
Statistical Parity Group with S=0: 0.02587680146098137
Statistical Parity Group S=1: 0.029322603717446327
Equal Opportunity Difference: 0.026854470372200012
Overall Accuracy Equality Difference: 0.09461230039596558
Treatment Equality Difference: 239.166015625


(0.7732690464781984,
 {'Statistical Parity Difference': tensor(0.0034),
  'Statistical Parity Group with S=0': tensor(0.0259),
  'Statistical Parity Group S=1': tensor(0.0293),
  'Equal Opportunity Difference': tensor(0.0269),
  'Overall Accuracy Equality Difference': tensor(0.0946),
  'Treatment Equality Difference': tensor(239.1660)})

In [52]:
model2_FairLoss = custom_GCN(data)
optimizer2_FairLoss = torch.optim.Adam(model2_FairLoss.parameters(), lr=0.01)

# Train the second model: Fair loss
training(model=model2_FairLoss, data=data, optimizer=optimizer2_FairLoss, fairness=True, alpha=0.1, beta=0.1, epochs=50)

# Test the second model
print("\nHere are the values for the GCN model with the standard dataset and Fairness loss: ")
test(model2_FairLoss, data)

Epoch 0 | Loss: 0.6303269863128662
Epoch 10 | Loss: 0.6086326241493225
Epoch 20 | Loss: 0.606320321559906
Epoch 30 | Loss: 0.6030330061912537
Epoch 40 | Loss: 0.5989081263542175

Here are the values for the GCN model with the standard dataset and Fairness loss: 
Accuracy: 0.7688667944417824
Statistical Parity Difference: 0.010660281404852867
Statistical Parity Group with S=0: 0.017488859593868256
Statistical Parity Group S=1: 0.028149140998721123
Equal Opportunity Difference: 0.008008614182472229
Overall Accuracy Equality Difference: 0.08765697479248047
Treatment Equality Difference: 837.611083984375


(0.7688667944417824,
 {'Statistical Parity Difference': tensor(0.0107),
  'Statistical Parity Group with S=0': tensor(0.0175),
  'Statistical Parity Group S=1': tensor(0.0281),
  'Equal Opportunity Difference': tensor(0.0080),
  'Overall Accuracy Equality Difference': tensor(0.0877),
  'Treatment Equality Difference': tensor(837.6111)})

In [53]:
model2_FairLoss = custom_GCN(data)
optimizer2_FairLoss = torch.optim.Adam(model2_FairLoss.parameters(), lr=0.01)

# Train the second model: Fair loss
training(model=model2_FairLoss, data=data, optimizer=optimizer2_FairLoss, fairness=True, alpha=0.1, beta=0.1, epochs=50)

# Test the second model
print("\nHere are the values for the GCN model with the standard dataset and Fairness loss: ")
test(model2_FairLoss, data)

Epoch 0 | Loss: 0.6472604274749756
Epoch 10 | Loss: 0.6342494487762451
Epoch 20 | Loss: 0.6073105931282043
Epoch 30 | Loss: 0.6047623753547668
Epoch 40 | Loss: 0.6016395688056946

Here are the values for the GCN model with the standard dataset and Fairness loss: 
Accuracy: 0.7702443699089603
Statistical Parity Difference: 0.002040307968854904
Statistical Parity Group with S=0: 0.023161206394433975
Statistical Parity Group S=1: 0.02520151436328888
Equal Opportunity Difference: 0.026998378336429596
Overall Accuracy Equality Difference: 0.09580826759338379
Treatment Equality Difference: 430.2586669921875


(0.7702443699089603,
 {'Statistical Parity Difference': tensor(0.0020),
  'Statistical Parity Group with S=0': tensor(0.0232),
  'Statistical Parity Group S=1': tensor(0.0252),
  'Equal Opportunity Difference': tensor(0.0270),
  'Overall Accuracy Equality Difference': tensor(0.0958),
  'Treatment Equality Difference': tensor(430.2587)})

In [54]:
model2_FairLoss = custom_GCN(data)
optimizer2_FairLoss = torch.optim.Adam(model2_FairLoss.parameters(), lr=0.01)

# Train the second model: Fair loss
training(model=model2_FairLoss, data=data, optimizer=optimizer2_FairLoss, fairness=True, alpha=0.1, beta=0.1, epochs=50)

# Test the second model
print("\nHere are the values for the GCN model with the standard dataset and Fairness loss: ")
test(model2_FairLoss, data)

Epoch 0 | Loss: 0.6112976670265198
Epoch 10 | Loss: 0.6072863340377808
Epoch 20 | Loss: 0.6045666933059692
Epoch 30 | Loss: 0.6013597846031189
Epoch 40 | Loss: 0.5975881814956665

Here are the values for the GCN model with the standard dataset and Fairness loss: 
Accuracy: 0.7739578342117872
Statistical Parity Difference: 0.005039209499955177
Statistical Parity Group with S=0: 0.024702489376068115
Statistical Parity Group S=1: 0.029741698876023293
Equal Opportunity Difference: 0.02041982114315033
Overall Accuracy Equality Difference: 0.0930333137512207
Treatment Equality Difference: 765.7023315429688


(0.7739578342117872,
 {'Statistical Parity Difference': tensor(0.0050),
  'Statistical Parity Group with S=0': tensor(0.0247),
  'Statistical Parity Group S=1': tensor(0.0297),
  'Equal Opportunity Difference': tensor(0.0204),
  'Overall Accuracy Equality Difference': tensor(0.0930),
  'Treatment Equality Difference': tensor(765.7023)})

In [94]:
# GCN class that takes in the data as an input for dimensions of the convolutions
class custom_GCN2(torch.nn.Module):
    def __init__(self, data):
        super(custom_GCN2, self).__init__()
        self.conv1 = FairnessAwareMessagePassingLayer(data.num_node_features, 16)
        self.conv2 = GCNConv(16, 2) 

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

In [21]:
model2 = custom_GCN(data)
optimizer2 = torch.optim.Adam(model2.parameters(), lr=0.01)

# Train the second model: GCN, standard data, cross-entropy loss
training(model=model2, data=data, optimizer=optimizer2, epochs=100)

# Test the second model: GCN, standard data, cross-entropy loss
print("\nHere are the values for the GCN model with the standard dataset and cross-entropy loss: ")
test(model2, data)

Epoch 0 | Loss: 0.6949846744537354
Epoch 10 | Loss: 0.6916375756263733
Epoch 20 | Loss: 0.6884897947311401
Epoch 30 | Loss: 0.6849395036697388
Epoch 40 | Loss: 0.6803786158561707
Epoch 50 | Loss: 0.6746808886528015
Epoch 60 | Loss: 0.6680154204368591
Epoch 70 | Loss: 0.6604992151260376
Epoch 80 | Loss: 0.6523556709289551
Epoch 90 | Loss: 0.6438772082328796

Here are the values for the GCN model with the standard dataset and cross-entropy loss: 
Accuracy: 0.8221130809774796
Statistical Parity Difference: 0.03101031482219696
Equal Opportunity Difference: 0.013422399759292603
Overall Accuracy Equality Difference: 0.07621175050735474
Treatment Equality Difference: 158.93136596679688


(0.8221130809774796,
 {'Statistical Parity Difference': tensor(0.0310),
  'Equal Opportunity Difference': tensor(0.0134),
  'Overall Accuracy Equality Difference': tensor(0.0762),
  'Treatment Equality Difference': tensor(158.9314)})

In [23]:
model2 = custom_GCN(data)
optimizer2 = torch.optim.Adam(model2.parameters(), lr=0.01)

# Train the second model: GCN, standard data, cross-entropy loss
training(model=model2, data=data, optimizer=optimizer2, epochs=50, fairness=True)

# Test the second model: GCN, standard data, cross-entropy loss
print("\nHere are the values for the GCN model with the standard dataset and cross-entropy loss: ")
test(model2, data)

Epoch 0 | Loss: 0.6959884762763977
Epoch 10 | Loss: 0.6938111782073975
Epoch 20 | Loss: 0.6908217668533325
Epoch 30 | Loss: 0.687012791633606
Epoch 40 | Loss: 0.6820933818817139

Here are the values for the GCN model with the standard dataset and cross-entropy loss: 
Accuracy: 0.7715620507906086
Statistical Parity Difference: 0.006483552977442741
Equal Opportunity Difference: 0.011388957500457764
Overall Accuracy Equality Difference: 0.09158873558044434
Treatment Equality Difference: 341.39794921875


(0.7715620507906086,
 {'Statistical Parity Difference': tensor(0.0065),
  'Equal Opportunity Difference': tensor(0.0114),
  'Overall Accuracy Equality Difference': tensor(0.0916),
  'Treatment Equality Difference': tensor(341.3979)})

In [55]:
model2 = custom_GCN(data)
optimizer2 = torch.optim.Adam(model2.parameters(), lr=0.01)

# Train the second model: GCN, standard data, cross-entropy loss
training(model=model2, data=data, optimizer=optimizer2, epochs=50)

# Test the second model: GCN, standard data, cross-entropy loss
print("\nHere are the values for the GCN model with the standard dataset and cross-entropy loss: ")
test(model2, data)

Epoch 0 | Loss: 0.6934686899185181
Epoch 10 | Loss: 0.6908966302871704
Epoch 20 | Loss: 0.6874423027038574
Epoch 30 | Loss: 0.6832926273345947
Epoch 40 | Loss: 0.6780106425285339
Here are the values for the GCN model with the standard dataset and cross-entropy loss: 

Accuracy: 0.7773418782942022
Statistical Parity Difference: 0.004843294620513916
Equal Opportunity Difference: 0.027115054428577423
Overall Accuracy Equality Difference: 0.09278804063796997
Treatment Equality Difference: 184.3036346435547


(0.7773418782942022,
 {'Statistical Parity Difference': tensor(0.0048),
  'Equal Opportunity Difference': tensor(0.0271),
  'Overall Accuracy Equality Difference': tensor(0.0928),
  'Treatment Equality Difference': tensor(184.3036)})

Values from the original model:

Accuracy: 0.8549053665548635


Statistical Parity Difference: 0.04585549235343933


Equal Opportunity Difference: 0.08249932527542114

Overall Accuracy Equality Difference: 0.0841907262802124

Treatment Equality Difference: 0.5477824211120605

In [95]:
model3 = custom_GCN2(data)
optimizer3 = torch.optim.Adam(model2.parameters(), lr=0.01)

# Train the second model: GCN, standard data, cross-entropy loss
training(model=model3, data=data, optimizer=optimizer3, epochs=50)

# Test the second model: GCN, standard data, cross-entropy loss
print("Here are the values for the GCN model with the standard dataset and cross-entropy loss: ")
test(model3, data)

Epoch 0 | Loss: 0.690571665763855
Epoch 10 | Loss: 0.690571665763855
Epoch 20 | Loss: 0.690571665763855


KeyboardInterrupt: 

In [None]:
def compute_statistical_parity_difference(label, predictions, sens_attr='bin_age'):
    """
    Calculate various fairness metrics.

    Args:
    label: Actual labels (binary).
    predictions: Model predictions (binary).
    sens_attr: Binary sensitive attribute for fairness evaluation.

    Returns:
    A dictionary containing SPD, EOD, OAED, and TED values.
    """
    labels = label.detach().clone()
    predictions = predictions.detach().clone()
    labels = torch.tensor(user_labels[labels].values, dtype=torch.long)
    sensitive_attribute = torch.tensor(user_labels[sens_attr].values, dtype=torch.long)

    predictions = predictions.float()
    labels = labels.float()
    sensitive_attribute = sensitive_attribute.float()

    prob_group_1 = predictions[sensitive_attribute == 1].mean()
    prob_group_0 = predictions[sensitive_attribute == 0].mean()
    return abs(prob_group_1 - prob_group_0)


In [48]:
import torch
import torch.nn.functional as F

def dual_objective_loss_with_lagrangian_relaxation(outputs, labels, sensitive_attrs, lambda_param, lambda_entropy_coeff):
    """
    Compute the dual-objective loss with Lagrangian relaxation and entropy regularization.
    
    :param outputs: Model predictions.
    :param labels: Ground truth labels.
    :param sensitive_attrs: Sensitive attributes for each instance.
    :param lambda_param: Lagrange multiplier for balancing accuracy and fairness.
    :param lambda_entropy_coeff: Coefficient for the entropy regularization.
    :return: Total loss, prediction loss, fairness loss, entropy regularization term.
    """
    sensitive_attribute = torch.tensor(user_labels[sensitive_attrs].values, dtype=torch.long)
    sensitive_attribute = sensitive_attribute.float()

    # Prediction loss (e.g., Cross-Entropy for classification)
    pred_loss = F.cross_entropy(outputs, labels)

    # Fairness loss (e.g., absolute difference in positive prediction rates between groups)
    positive_pred_rates_diff = compute_statistical_parity_difference(outputs, sensitive_attribute)
    fairness_loss = torch.abs(positive_pred_rates_diff)
    
    # Entropy regularization for the Lagrange multiplier
    entropy_reg = -lambda_entropy_coeff * (lambda_param * torch.log(lambda_param))

    # Total loss
    total_loss = pred_loss + lambda_param * fairness_loss + entropy_reg

    return total_loss, pred_loss, fairness_loss, entropy_reg


In [51]:
from torch.optim import Adam

sens_attr = 'bin_age'  # Sensitive attribute for fairness-aware training

# Model, optimizer, and data setup
model2 = FairnessAwareMessagePassingLayer(in_channels, out_channels)
optimizer = Adam(model2.parameters(), lr=0.01)
lambda_param = torch.tensor(0.01, requires_grad=True)  # Initial Lagrange multiplier
lambda_optimizer = Adam([lambda_param], lr=0.01)  # Optimizer for the Lagrange multiplier
lambda_entropy_coeff = 0.1  # Coefficient for entropy regularization

for epoch in range(50):  # Number of epochs
    model2.train()
    optimizer.zero_grad()
    lambda_optimizer.zero_grad()
    
    outputs = model2(data.x, data.edge_index)
    total_loss, pred_loss, fairness_loss, entropy_reg = dual_objective_loss_with_lagrangian_relaxation(
        outputs, data.y, sens_attr, lambda_param, lambda_entropy_coeff
    )
    
    # Update model parameters
    total_loss.backward()
    optimizer.step()
    
    # Update the Lagrange multiplier (separate optimizer to allow for ascent)
    (-total_loss).backward()  # Negate for gradient ascent on lambda
    lambda_optimizer.step()
    
    print(f'Epoch: {epoch+1}, Total Loss: {total_loss.item()}, Pred Loss: {pred_loss.item()}, Fairness Loss: {fairness_loss.item()}, Lambda: {lambda_param.item()}')


KeyError: "None of [Index([                       (tensor(-0.0012), tensor(0.0047), tensor(0.0035), tensor(0.0003), tensor(0.0003), tensor(-0.0034), tensor(-0.0015), tensor(-0.0014), tensor(-0.0017), tensor(-0.0006), tensor(0.0004), tensor(-0.0033), tensor(0.0016), tensor(0.0003), tensor(-0.0012), tensor(0.0022), tensor(-5.6961e-05), tensor(0.0014), tensor(-0.0006), tensor(-0.0015), tensor(0.0025), tensor(-0.0049), tensor(0.0018), tensor(0.0015), tensor(-0.0035), tensor(0.0004), tensor(-0.0003), tensor(-0.0007), tensor(0.0012), tensor(0.0008), tensor(0.0033), tensor(-0.0014), tensor(0.0045), tensor(-0.0013), tensor(0.0023), tensor(-0.0036), tensor(0.0007), tensor(-0.0018), tensor(-0.0010), tensor(0.0040), tensor(-0.0004), tensor(0.0018), tensor(-0.0028), tensor(-0.0024), tensor(-0.0028), tensor(-0.0011), tensor(-0.0002), tensor(0.0017), tensor(-0.0003), tensor(0.0022), tensor(0.0017), tensor(0.0007), tensor(0.0024), tensor(-0.0026), tensor(0.0032), tensor(-0.0029), tensor(0.0027), tensor(0.0060), tensor(-0.0012), tensor(-0.0013), tensor(0.0004), tensor(0.0032), tensor(-0.0038), tensor(-0.0023)),\n                                  (tensor(-0.0774), tensor(0.1434), tensor(0.1096), tensor(0.0091), tensor(0.0061), tensor(-0.0759), tensor(-0.0347), tensor(-0.0209), tensor(-0.0423), tensor(-0.0175), tensor(-0.0089), tensor(-0.1165), tensor(0.0724), tensor(0.0022), tensor(-0.0300), tensor(0.0860), tensor(0.0149), tensor(0.0808), tensor(-0.0163), tensor(-0.0586), tensor(0.0450), tensor(-0.1565), tensor(0.0182), tensor(0.0348), tensor(-0.1062), tensor(0.0011), tensor(-0.0135), tensor(-0.0369), tensor(0.0144), tensor(0.0418), tensor(0.1259), tensor(-0.0706), tensor(0.1613), tensor(-0.0508), tensor(0.0676), tensor(-0.1075), tensor(0.0128), tensor(-0.0652), tensor(-0.0319), tensor(0.1262), tensor(-0.0093), tensor(0.0323), tensor(-0.1078), tensor(-0.0937), tensor(-0.0734), tensor(-0.0502), tensor(0.0126), tensor(0.0415), tensor(-0.0115), tensor(0.0917), tensor(0.0490), tensor(-0.0029), tensor(0.0865), tensor(-0.0932), tensor(0.0773), tensor(-0.0719), tensor(0.1001), tensor(0.1790), tensor(-0.0049), tensor(-0.0280), tensor(0.0205), tensor(0.1011), tensor(-0.0935), tensor(-0.0778)),\n                              (tensor(-0.0027), tensor(0.0058), tensor(0.0044), tensor(0.0005), tensor(0.0004), tensor(-0.0029), tensor(-0.0014), tensor(-0.0014), tensor(-0.0018), tensor(-0.0006), tensor(-0.0005), tensor(-0.0041), tensor(0.0024), tensor(0.0002), tensor(-0.0009), tensor(0.0034), tensor(0.0004), tensor(0.0024), tensor(-0.0007), tensor(-0.0023), tensor(0.0022), tensor(-0.0061), tensor(0.0008), tensor(0.0017), tensor(-0.0044), tensor(-0.0001), tensor(-0.0003), tensor(-0.0015), tensor(0.0013), tensor(0.0013), tensor(0.0046), tensor(-0.0023), tensor(0.0061), tensor(-0.0020), tensor(0.0027), tensor(-0.0044), tensor(0.0004), tensor(-0.0022), tensor(-0.0006), tensor(0.0047), tensor(-0.0002), tensor(0.0018), tensor(-0.0040), tensor(-0.0036), tensor(-0.0030), tensor(-0.0016), tensor(1.4068e-05), tensor(0.0017), tensor(-0.0006), tensor(0.0032), tensor(0.0022), tensor(0.0005), tensor(0.0037), tensor(-0.0039), tensor(0.0032), tensor(-0.0030), tensor(0.0040), tensor(0.0067), tensor(-0.0004), tensor(-0.0015), tensor(0.0012), tensor(0.0039), tensor(-0.0042), tensor(-0.0029)),\n                                 (tensor(-0.1088), tensor(0.2267), tensor(0.1610), tensor(0.0227), tensor(0.0182), tensor(-0.0920), tensor(-0.0446), tensor(-0.0385), tensor(-0.0585), tensor(-0.0174), tensor(-0.0290), tensor(-0.1551), tensor(0.0981), tensor(0.0053), tensor(-0.0316), tensor(0.1345), tensor(0.0046), tensor(0.0917), tensor(-0.0352), tensor(-0.0896), tensor(0.0752), tensor(-0.2362), tensor(0.0185), tensor(0.0649), tensor(-0.1672), tensor(-0.0229), tensor(-0.0036), tensor(-0.0498), tensor(0.0380), tensor(0.0440), tensor(0.1737), tensor(-0.0956), tensor(0.2241), tensor(-0.0710), tensor(0.0849), tensor(-0.1709), tensor(0.0098), tensor(-0.0894), tensor(-0.0175), tensor(0.1629), tensor(-0.0265), tensor(0.0554), tensor(-0.1545), tensor(-0.1304), tensor(-0.1148), tensor(-0.0569), tensor(0.0051), tensor(0.0765), tensor(-0.0172), tensor(0.1281), tensor(0.0803), tensor(-0.0070), tensor(0.1361), tensor(-0.1537), tensor(0.1283), tensor(-0.1016), tensor(0.1610), tensor(0.2450), tensor(-0.0081), tensor(-0.0529), tensor(0.0618), tensor(0.1482), tensor(-0.1656), tensor(-0.0880)),\n                                  (tensor(-0.0062), tensor(0.0172), tensor(0.0128), tensor(0.0030), tensor(-0.0019), tensor(-0.0101), tensor(-0.0040), tensor(-0.0031), tensor(-0.0073), tensor(-0.0025), tensor(0.0022), tensor(-0.0116), tensor(0.0046), tensor(-0.0012), tensor(-0.0030), tensor(0.0086), tensor(-0.0013), tensor(0.0062), tensor(-0.0019), tensor(-0.0041), tensor(0.0076), tensor(-0.0164), tensor(0.0032), tensor(0.0024), tensor(-0.0103), tensor(0.0007), tensor(0.0023), tensor(-0.0049), tensor(0.0030), tensor(0.0012), tensor(0.0128), tensor(-0.0049), tensor(0.0172), tensor(-0.0024), tensor(0.0073), tensor(-0.0113), tensor(0.0030), tensor(-0.0079), tensor(-0.0026), tensor(0.0111), tensor(-0.0030), tensor(0.0057), tensor(-0.0102), tensor(-0.0073), tensor(-0.0095), tensor(-0.0017), tensor(0.0028), tensor(0.0064), tensor(0.0008), tensor(0.0093), tensor(0.0048), tensor(-0.0007), tensor(0.0067), tensor(-0.0073), tensor(0.0115), tensor(-0.0094), tensor(0.0091), tensor(0.0217), tensor(-0.0004), tensor(-0.0059), tensor(0.0019), tensor(0.0104), tensor(-0.0145), tensor(-0.0070)),\n                                   (tensor(-0.0224), tensor(0.0599), tensor(0.0458), tensor(0.0095), tensor(0.0026), tensor(-0.0286), tensor(-0.0216), tensor(-0.0220), tensor(-0.0242), tensor(-0.0054), tensor(-0.0049), tensor(-0.0378), tensor(0.0165), tensor(0.0001), tensor(-0.0054), tensor(0.0374), tensor(0.0051), tensor(0.0142), tensor(-0.0064), tensor(-0.0204), tensor(0.0270), tensor(-0.0577), tensor(0.0098), tensor(0.0203), tensor(-0.0447), tensor(-0.0018), tensor(-0.0013), tensor(-0.0183), tensor(0.0240), tensor(0.0098), tensor(0.0399), tensor(-0.0183), tensor(0.0598), tensor(-0.0202), tensor(0.0318), tensor(-0.0424), tensor(0.0021), tensor(-0.0170), tensor(0.0018), tensor(0.0487), tensor(0.0035), tensor(0.0303), tensor(-0.0344), tensor(-0.0336), tensor(-0.0267), tensor(-0.0112), tensor(-0.0051), tensor(0.0126), tensor(-0.0093), tensor(0.0263), tensor(0.0275), tensor(0.0140), tensor(0.0392), tensor(-0.0381), tensor(0.0314), tensor(-0.0367), tensor(0.0413), tensor(0.0648), tensor(-0.0042), tensor(-0.0246), tensor(0.0146), tensor(0.0381), tensor(-0.0493), tensor(-0.0320)),\n                         (tensor(-0.0022), tensor(0.0069), tensor(0.0051), tensor(0.0008), tensor(0.0007), tensor(-0.0034), tensor(-0.0021), tensor(-0.0025), tensor(-0.0025), tensor(-0.0006), tensor(-0.0006), tensor(-0.0043), tensor(0.0020), tensor(0.0005), tensor(-0.0008), tensor(0.0039), tensor(0.0002), tensor(0.0014), tensor(-0.0010), tensor(-0.0024), tensor(0.0032), tensor(-0.0068), tensor(0.0014), tensor(0.0026), tensor(-0.0052), tensor(-0.0003), tensor(-8.6641e-05), tensor(-0.0016), tensor(0.0025), tensor(0.0008), tensor(0.0043), tensor(-0.0020), tensor(0.0064), tensor(-0.0021), tensor(0.0033), tensor(-0.0051), tensor(0.0004), tensor(-0.0019), tensor(-2.7812e-05), tensor(0.0053), tensor(-0.0002), tensor(0.0031), tensor(-0.0040), tensor(-0.0037), tensor(-0.0035), tensor(-0.0013), tensor(-0.0008), tensor(0.0020), tensor(-0.0010), tensor(0.0030), tensor(0.0032), tensor(0.0013), tensor(0.0043), tensor(-0.0045), tensor(0.0041), tensor(-0.0040), tensor(0.0046), tensor(0.0074), tensor(-0.0010), tensor(-0.0025), tensor(0.0017), tensor(0.0045), tensor(-0.0057), tensor(-0.0032)),\n                                  (tensor(-0.0055), tensor(0.0192), tensor(0.0141), tensor(0.0017), tensor(0.0021), tensor(-0.0116), tensor(-0.0066), tensor(-0.0071), tensor(-0.0070), tensor(-0.0020), tensor(-0.0008), tensor(-0.0128), tensor(0.0058), tensor(0.0019), tensor(-0.0032), tensor(0.0102), tensor(0.0009), tensor(0.0042), tensor(-0.0028), tensor(-0.0063), tensor(0.0095), tensor(-0.0193), tensor(0.0054), tensor(0.0076), tensor(-0.0145), tensor(0.0001), tensor(-0.0011), tensor(-0.0037), tensor(0.0067), tensor(0.0026), tensor(0.0121), tensor(-0.0058), tensor(0.0179), tensor(-0.0061), tensor(0.0099), tensor(-0.0142), tensor(0.0018), tensor(-0.0055), tensor(-0.0016), tensor(0.0161), tensor(-0.0004), tensor(0.0090), tensor(-0.0113), tensor(-0.0104), tensor(-0.0103), tensor(-0.0043), tensor(-0.0025), tensor(0.0056), tensor(-0.0030), tensor(0.0083), tensor(0.0088), tensor(0.0040), tensor(0.0116), tensor(-0.0122), tensor(0.0118), tensor(-0.0120), tensor(0.0122), tensor(0.0222), tensor(-0.0042), tensor(-0.0064), tensor(0.0034), tensor(0.0130), tensor(-0.0157), tensor(-0.0098)),\n                                  (tensor(-0.0111), tensor(0.0438), tensor(0.0327), tensor(0.0055), tensor(0.0029), tensor(-0.0272), tensor(-0.0142), tensor(-0.0158), tensor(-0.0171), tensor(-0.0052), tensor(0.0009), tensor(-0.0278), tensor(0.0118), tensor(0.0024), tensor(-0.0075), tensor(0.0225), tensor(-0.0002), tensor(0.0087), tensor(-0.0057), tensor(-0.0139), tensor(0.0231), tensor(-0.0436), tensor(0.0131), tensor(0.0147), tensor(-0.0326), tensor(0.0017), tensor(-0.0005), tensor(-0.0091), tensor(0.0161), tensor(0.0046), tensor(0.0282), tensor(-0.0115), tensor(0.0411), tensor(-0.0123), tensor(0.0217), tensor(-0.0326), tensor(0.0049), tensor(-0.0143), tensor(-0.0033), tensor(0.0344), tensor(-0.0019), tensor(0.0208), tensor(-0.0244), tensor(-0.0222), tensor(-0.0238), tensor(-0.0076), tensor(-0.0039), tensor(0.0143), tensor(-0.0049), tensor(0.0188), tensor(0.0187), tensor(0.0092), tensor(0.0244), tensor(-0.0258), tensor(0.0283), tensor(-0.0272), tensor(0.0267), tensor(0.0517), tensor(-0.0089), tensor(-0.0161), tensor(0.0078), tensor(0.0287), tensor(-0.0375), tensor(-0.0213)),\n                                 (tensor(-0.0060), tensor(0.0213), tensor(0.0152), tensor(0.0040), tensor(0.0002), tensor(-0.0100), tensor(-0.0056), tensor(-0.0067), tensor(-0.0089), tensor(-0.0023), tensor(-0.0002), tensor(-0.0120), tensor(0.0041), tensor(0.0003), tensor(-0.0018), tensor(0.0114), tensor(-0.0012), tensor(0.0029), tensor(-0.0034), tensor(-0.0058), tensor(0.0100), tensor(-0.0195), tensor(0.0034), tensor(0.0059), tensor(-0.0139), tensor(-0.0010), tensor(0.0026), tensor(-0.0057), tensor(0.0079), tensor(-0.0004), tensor(0.0124), tensor(-0.0050), tensor(0.0188), tensor(-0.0040), tensor(0.0093), tensor(-0.0145), tensor(0.0022), tensor(-0.0066), tensor(0.0006), tensor(0.0128), tensor(-0.0024), tensor(0.0100), tensor(-0.0111), tensor(-0.0093), tensor(-0.0107), tensor(-0.0013), tensor(-0.0009), tensor(0.0073), tensor(-0.0018), tensor(0.0094), tensor(0.0089), tensor(0.0026), tensor(0.0107), tensor(-0.0117), tensor(0.0139), tensor(-0.0120), tensor(0.0124), tensor(0.0228), tensor(-0.0021), tensor(-0.0089), tensor(0.0053), tensor(0.0131), tensor(-0.0191), tensor(-0.0081)),\n       ...\n                          (tensor(-0.0040), tensor(0.0119), tensor(0.0089), tensor(0.0013), tensor(0.0010), tensor(-0.0063), tensor(-0.0031), tensor(-0.0039), tensor(-0.0040), tensor(-0.0011), tensor(-0.0007), tensor(-0.0075), tensor(0.0036), tensor(0.0008), tensor(-0.0017), tensor(0.0064), tensor(5.0072e-05), tensor(0.0029), tensor(-0.0017), tensor(-0.0042), tensor(0.0055), tensor(-0.0120), tensor(0.0026), tensor(0.0040), tensor(-0.0091), tensor(-0.0005), tensor(6.5253e-05), tensor(-0.0025), tensor(0.0037), tensor(0.0016), tensor(0.0080), tensor(-0.0035), tensor(0.0112), tensor(-0.0034), tensor(0.0053), tensor(-0.0089), tensor(0.0009), tensor(-0.0038), tensor(-0.0006), tensor(0.0088), tensor(-0.0008), tensor(0.0047), tensor(-0.0072), tensor(-0.0063), tensor(-0.0065), tensor(-0.0022), tensor(-0.0008), tensor(0.0039), tensor(-0.0013), tensor(0.0054), tensor(0.0049), tensor(0.0016), tensor(0.0071), tensor(-0.0077), tensor(0.0074), tensor(-0.0065), tensor(0.0077), tensor(0.0133), tensor(-0.0017), tensor(-0.0038), tensor(0.0026), tensor(0.0077), tensor(-0.0098), tensor(-0.0052)),\n                                  (tensor(-0.0581), tensor(0.2934), tensor(0.2196), tensor(0.0327), tensor(0.0399), tensor(-0.1607), tensor(-0.0516), tensor(-0.1195), tensor(-0.0881), tensor(-0.0207), tensor(-0.0152), tensor(-0.1494), tensor(0.0630), tensor(0.0413), tensor(-0.0373), tensor(0.1300), tensor(-0.0383), tensor(0.0231), tensor(-0.0538), tensor(-0.1040), tensor(0.1675), tensor(-0.2942), tensor(0.0877), tensor(0.1094), tensor(-0.2382), tensor(-0.0233), tensor(0.0303), tensor(-0.0338), tensor(0.1050), tensor(0.0065), tensor(0.1688), tensor(-0.0508), tensor(0.2372), tensor(-0.0627), tensor(0.1026), tensor(-0.2350), tensor(0.0300), tensor(-0.0843), tensor(0.0054), tensor(0.1699), tensor(-0.0494), tensor(0.1159), tensor(-0.1633), tensor(-0.1302), tensor(-0.1935), tensor(-0.0205), tensor(-0.0382), tensor(0.1306), tensor(-0.0211), tensor(0.1049), tensor(0.1220), tensor(0.0434), tensor(0.1668), tensor(-0.1943), tensor(0.2209), tensor(-0.1475), tensor(0.1763), tensor(0.3134), tensor(-0.0725), tensor(-0.0953), tensor(0.0754), tensor(0.1823), tensor(-0.2747), tensor(-0.0830)),\n                                    (tensor(-0.1099), tensor(0.2329), tensor(0.2028), tensor(0.0450), tensor(0.0037), tensor(-0.0823), tensor(-0.0056), tensor(-0.0656), tensor(-0.0490), tensor(-0.0088), tensor(-0.0257), tensor(-0.1134), tensor(0.0800), tensor(-0.0152), tensor(-0.0192), tensor(0.1334), tensor(-0.0337), tensor(0.0821), tensor(-0.0157), tensor(-0.1191), tensor(0.1127), tensor(-0.2548), tensor(0.0259), tensor(0.0374), tensor(-0.2172), tensor(-0.0314), tensor(0.0369), tensor(-0.0569), tensor(0.0543), tensor(0.0555), tensor(0.2086), tensor(-0.0580), tensor(0.2441), tensor(-0.0655), tensor(0.0478), tensor(-0.2050), tensor(0.0015), tensor(-0.1097), tensor(0.0269), tensor(0.1153), tensor(-0.0319), tensor(0.0415), tensor(-0.1548), tensor(-0.1271), tensor(-0.1494), tensor(-0.0141), tensor(0.0273), tensor(0.1029), tensor(0.0169), tensor(0.1060), tensor(0.0570), tensor(0.0011), tensor(0.1588), tensor(-0.1678), tensor(0.1515), tensor(-0.0725), tensor(0.1794), tensor(0.2451), tensor(0.0175), tensor(-0.0685), tensor(0.0884), tensor(0.1234), tensor(-0.2091), tensor(-0.0518)),\n                   (tensor(-0.0030), tensor(0.0088), tensor(0.0062), tensor(0.0017), tensor(-0.0002), tensor(-0.0037), tensor(-0.0017), tensor(-0.0019), tensor(-0.0034), tensor(-0.0010), tensor(-1.4659e-05), tensor(-0.0051), tensor(0.0022), tensor(-0.0002), tensor(-0.0008), tensor(0.0047), tensor(-0.0008), tensor(0.0019), tensor(-0.0014), tensor(-0.0025), tensor(0.0037), tensor(-0.0082), tensor(0.0009), tensor(0.0018), tensor(-0.0055), tensor(-0.0005), tensor(0.0013), tensor(-0.0024), tensor(0.0025), tensor(-0.0001), tensor(0.0057), tensor(-0.0024), tensor(0.0080), tensor(-0.0014), tensor(0.0033), tensor(-0.0060), tensor(0.0009), tensor(-0.0034), tensor(-3.4379e-05), tensor(0.0047), tensor(-0.0016), tensor(0.0032), tensor(-0.0049), tensor(-0.0038), tensor(-0.0044), tensor(-0.0006), tensor(0.0004), tensor(0.0035), tensor(-0.0003), tensor(0.0045), tensor(0.0031), tensor(5.5291e-05), tensor(0.0040), tensor(-0.0047), tensor(0.0058), tensor(-0.0043), tensor(0.0051), tensor(0.0095), tensor(-0.0004), tensor(-0.0033), tensor(0.0023), tensor(0.0054), tensor(-0.0076), tensor(-0.0029)),\n                                  (tensor(-0.0056), tensor(0.0177), tensor(0.0128), tensor(0.0028), tensor(-0.0002), tensor(-0.0092), tensor(-0.0042), tensor(-0.0049), tensor(-0.0068), tensor(-0.0017), tensor(0.0003), tensor(-0.0106), tensor(0.0040), tensor(0.0002), tensor(-0.0021), tensor(0.0089), tensor(-0.0012), tensor(0.0041), tensor(-0.0026), tensor(-0.0046), tensor(0.0081), tensor(-0.0165), tensor(0.0033), tensor(0.0046), tensor(-0.0115), tensor(-0.0008), tensor(0.0022), tensor(-0.0042), tensor(0.0045), tensor(0.0007), tensor(0.0112), tensor(-0.0044), tensor(0.0159), tensor(-0.0030), tensor(0.0073), tensor(-0.0120), tensor(0.0022), tensor(-0.0062), tensor(-0.0008), tensor(0.0110), tensor(-0.0027), tensor(0.0068), tensor(-0.0101), tensor(-0.0075), tensor(-0.0099), tensor(-0.0015), tensor(0.0005), tensor(0.0063), tensor(-0.0005), tensor(0.0081), tensor(0.0063), tensor(0.0006), tensor(0.0083), tensor(-0.0092), tensor(0.0117), tensor(-0.0095), tensor(0.0100), tensor(0.0200), tensor(-0.0015), tensor(-0.0063), tensor(0.0032), tensor(0.0107), tensor(-0.0154), tensor(-0.0063)),\n                                  (tensor(-0.0250), tensor(0.0599), tensor(0.0472), tensor(0.0088), tensor(0.0038), tensor(-0.0245), tensor(-0.0104), tensor(-0.0179), tensor(-0.0176), tensor(-0.0043), tensor(-0.0071), tensor(-0.0349), tensor(0.0195), tensor(0.0013), tensor(-0.0055), tensor(0.0348), tensor(-0.0016), tensor(0.0172), tensor(-0.0076), tensor(-0.0252), tensor(0.0260), tensor(-0.0616), tensor(0.0073), tensor(0.0163), tensor(-0.0485), tensor(-0.0058), tensor(0.0038), tensor(-0.0150), tensor(0.0174), tensor(0.0094), tensor(0.0446), tensor(-0.0183), tensor(0.0592), tensor(-0.0176), tensor(0.0221), tensor(-0.0471), tensor(0.0023), tensor(-0.0217), tensor(0.0022), tensor(0.0383), tensor(-0.0048), tensor(0.0193), tensor(-0.0381), tensor(-0.0333), tensor(-0.0328), tensor(-0.0087), tensor(-0.0005), tensor(0.0208), tensor(-0.0041), tensor(0.0285), tensor(0.0226), tensor(0.0051), tensor(0.0388), tensor(-0.0415), tensor(0.0365), tensor(-0.0271), tensor(0.0423), tensor(0.0635), tensor(-0.0023), tensor(-0.0189), tensor(0.0180), tensor(0.0370), tensor(-0.0499), tensor(-0.0221)),\n                          (tensor(-0.0048), tensor(0.0128), tensor(0.0089), tensor(0.0025), tensor(-0.0013), tensor(-0.0065), tensor(-0.0038), tensor(-0.0022), tensor(-0.0058), tensor(-0.0017), tensor(0.0011), tensor(-0.0086), tensor(0.0032), tensor(-0.0011), tensor(-0.0017), tensor(0.0071), tensor(-0.0003), tensor(0.0041), tensor(-0.0016), tensor(-0.0027), tensor(0.0050), tensor(-0.0115), tensor(0.0015), tensor(0.0022), tensor(-0.0069), tensor(0.0002), tensor(0.0014), tensor(-0.0041), tensor(0.0029), tensor(0.0005), tensor(0.0087), tensor(-0.0040), tensor(0.0124), tensor(-0.0020), tensor(0.0060), tensor(-0.0078), tensor(0.0017), tensor(-0.0052), tensor(-0.0012), tensor(0.0085), tensor(-0.0016), tensor(0.0051), tensor(-0.0071), tensor(-0.0054), tensor(-0.0058), tensor(-0.0014), tensor(0.0015), tensor(0.0039), tensor(-0.0002), tensor(0.0069), tensor(0.0042), tensor(-9.2372e-05), tensor(0.0051), tensor(-0.0054), tensor(0.0077), tensor(-0.0074), tensor(0.0070), tensor(0.0151), tensor(2.9559e-05), tensor(-0.0050), tensor(0.0020), tensor(0.0078), tensor(-0.0105), tensor(-0.0056)),\n                                 (tensor(-0.1317), tensor(0.2251), tensor(0.1706), tensor(0.0437), tensor(-0.0533), tensor(-0.0942), tensor(0.0081), tensor(0.0163), tensor(-0.0598), tensor(-0.0174), tensor(0.0263), tensor(-0.1442), tensor(0.0816), tensor(-0.0449), tensor(-0.0363), tensor(0.1085), tensor(-0.0526), tensor(0.1454), tensor(-0.0176), tensor(-0.0589), tensor(0.0683), tensor(-0.2209), tensor(0.0040), tensor(-0.0127), tensor(-0.1307), tensor(-0.0239), tensor(0.0712), tensor(-0.0587), tensor(-0.0452), tensor(0.0362), tensor(0.2160), tensor(-0.0743), tensor(0.2345), tensor(-0.0034), tensor(0.0303), tensor(-0.1492), tensor(0.0365), tensor(-0.1474), tensor(-0.0392), tensor(0.0861), tensor(-0.0937), tensor(-0.0158), tensor(-0.1640), tensor(-0.0784), tensor(-0.1516), tensor(-0.0052), tensor(0.1130), tensor(0.1136), tensor(0.0705), tensor(0.1462), tensor(0.0014), tensor(-0.1097), tensor(0.0650), tensor(-0.0828), tensor(0.1562), tensor(-0.0590), tensor(0.1242), tensor(0.2799), tensor(0.0574), tensor(-0.0398), tensor(0.0317), tensor(0.1081), tensor(-0.1825), tensor(-0.0229)),\n       (tensor(-0.0004), tensor(0.0014), tensor(0.0010), tensor(0.0001), tensor(7.7193e-05), tensor(-0.0008), tensor(-0.0004), tensor(-0.0004), tensor(-0.0005), tensor(-0.0002), tensor(3.6554e-05), tensor(-0.0009), tensor(0.0004), tensor(8.2265e-05), tensor(-0.0003), tensor(0.0007), tensor(-3.4714e-05), tensor(0.0003), tensor(-0.0002), tensor(-0.0004), tensor(0.0007), tensor(-0.0014), tensor(0.0004), tensor(0.0004), tensor(-0.0010), tensor(3.1380e-05), tensor(1.5369e-05), tensor(-0.0003), tensor(0.0004), tensor(0.0001), tensor(0.0009), tensor(-0.0004), tensor(0.0013), tensor(-0.0003), tensor(0.0006), tensor(-0.0010), tensor(0.0002), tensor(-0.0005), tensor(-0.0002), tensor(0.0010), tensor(-0.0001), tensor(0.0005), tensor(-0.0008), tensor(-0.0007), tensor(-0.0008), tensor(-0.0002), tensor(-5.1753e-05), tensor(0.0005), tensor(-0.0001), tensor(0.0006), tensor(0.0005), tensor(0.0002), tensor(0.0007), tensor(-0.0008), tensor(0.0009), tensor(-0.0008), tensor(0.0008), tensor(0.0016), tensor(-0.0003), tensor(-0.0004), tensor(0.0002), tensor(0.0009), tensor(-0.0011), tensor(-0.0006)),\n                                (tensor(-0.0027), tensor(0.0092), tensor(0.0068), tensor(0.0011), tensor(6.0358e-05), tensor(-0.0057), tensor(-0.0022), tensor(-0.0025), tensor(-0.0034), tensor(-0.0011), tensor(0.0006), tensor(-0.0060), tensor(0.0025), tensor(0.0003), tensor(-0.0017), tensor(0.0044), tensor(-0.0006), tensor(0.0026), tensor(-0.0013), tensor(-0.0026), tensor(0.0045), tensor(-0.0091), tensor(0.0024), tensor(0.0024), tensor(-0.0063), tensor(0.0002), tensor(0.0006), tensor(-0.0018), tensor(0.0021), tensor(0.0008), tensor(0.0063), tensor(-0.0025), tensor(0.0086), tensor(-0.0018), tensor(0.0040), tensor(-0.0066), tensor(0.0014), tensor(-0.0036), tensor(-0.0012), tensor(0.0064), tensor(-0.0014), tensor(0.0033), tensor(-0.0055), tensor(-0.0042), tensor(-0.0055), tensor(-0.0013), tensor(0.0003), tensor(0.0035), tensor(-0.0002), tensor(0.0044), tensor(0.0032), tensor(0.0004), tensor(0.0043), tensor(-0.0048), tensor(0.0063), tensor(-0.0051), tensor(0.0052), tensor(0.0113), tensor(-0.0013), tensor(-0.0029), tensor(0.0012), tensor(0.0058), tensor(-0.0078), tensor(-0.0037))],\n      dtype='object', length=166958)] are in the [columns]"

### Second model: GCN, standard data, fairness-aware loss (alpha=0.01)

In [27]:
# Instantiate the second model, define loss function and optimizer
gcn_model2 = GCN(data)
gcn_optimizer2 = torch.optim.Adam(gcn_model2.parameters(), lr=0.01)

In [25]:
# Train the second model: GCN, standard data, fairness-aware loss (alpha=0.01)
training(model=gcn_model2, data=data, optimizer=gcn_optimizer2, epochs=50, fairness=True)

Epoch 0 | Loss: 0.6220347881317139
Epoch 10 | Loss: 0.5531500577926636
Epoch 20 | Loss: 0.5012115240097046
Epoch 30 | Loss: 0.4466976225376129
Epoch 40 | Loss: 0.4097009301185608


In [30]:
# Test the second model: GCN, standard data, fairness-aware loss (alpha=0.01)
print("Here are the values for the GCN model with the standard dataset and fairness-entropy loss(alpha=0.01): ")
print()
test(gcn_model2, data)

Here are the values for the GCN model with the standard dataset and fairness-entropy loss(alpha=0.01): 

Accuracy: 0.8378653569717297
Statistical Parity Difference: 0.030072450637817383
Equal Opportunity Difference: 0.06199532747268677
Overall Accuracy Equality Difference: 0.08194565773010254
Treatment Equality Difference: 2.69785213470459


(0.8378653569717297,
 {'Statistical Parity Difference': tensor(0.0301),
  'Equal Opportunity Difference': tensor(0.0620),
  'Overall Accuracy Equality Difference': tensor(0.0819),
  'Treatment Equality Difference': tensor(2.6979)})

### Ignore

Test to check if stronger fairness-constraint produces a better model:

In [None]:
gcn_model3 = GCN(data)
gcn_optimizer3 = torch.optim.Adam(gcn_model3.parameters(), lr=0.01)

training(model=gcn_model3, data=data, optimizer=gcn_optimizer3, epochs=30, fairness=True, alpha=0.05)

print("Here are the values for the GCN model with the standard dataset and fairness-entropy loss(alpha=0.05): ")
print()
test(gcn_model3, data)

In [None]:
gcn_model4 = GCN(data)
gcn_optimizer4 = torch.optim.Adam(gcn_model4.parameters(), lr=0.01)

training(model=gcn_model4, data=data, optimizer=gcn_optimizer4, epochs=30, fairness=True, beta=0.005)
# print("\nHere are the values for the GCN model with the standard dataset and fairness-entropy loss(alpha=0.005): ")
print()
test(gcn_model4, data)
print()

In [None]:
print("\nHere are the values for the GCN model with the standard dataset and fairness-entropy loss(alpha=0.005): ")
print()
test(gcn_model4, data)


Here are the values for the GCN model with the standard dataset and fairness-entropy loss(alpha=0.005): 

Accuracy: 0.847747963584092
Statistical Parity Difference: 0.047549135982990265
Equal Opportunity Difference: 0.053668081760406494
Overall Accuracy Equality Difference: 0.08105164766311646
Treatment Equality Difference: 2.612381935119629


### Third model: GCN, re-weighed data, weighted-cross-entropy loss (alpha=0.01)

In [24]:
# Instantiate the third model, define loss function and optimizer
rw_data_gcn_model = GCN(rw_data)
rw_data_gcn_model_optimizer = torch.optim.Adam(rw_data_gcn_model.parameters(), lr=0.01)

In [25]:
# Train the third model: GCN, re-weighed data, weighted-cross entropy loss
training(model=rw_data_gcn_model, data=rw_data, optimizer=rw_data_gcn_model_optimizer, epochs=50, weighted=True)

Epoch 0 | Loss: 0.6312071681022644
Epoch 10 | Loss: 0.5407678484916687
Epoch 20 | Loss: 0.4680361747741699
Epoch 30 | Loss: 0.4228784739971161
Epoch 40 | Loss: 0.3963108956813812


In [26]:
# Test the third model: GCN, re-weighed data, weighted-cross entropy loss
print("Here are the values for the GCN model with the re-weighed dataset and weighted-cross-entropy loss: ")
print()
test(rw_data_gcn_model, data)
print()

Here are the values for the GCN model with the re-weighed dataset and weighted-cross-entropy loss: 

Accuracy: 0.8576305701964543
Statistical Parity Difference: 0.034996867179870605
Equal Opportunity Difference: 0.1110086739063263
Overall Accuracy Equality Difference: 0.08673566579818726
Treatment Equality Difference: 0.6095795631408691



In [None]:
# Test the third model: GCN, re-weighed data, weighted-cross entropy loss
print("Here are the values for the GCN model with the re-weighed dataset and weighted-cross-entropy loss: ")
print()
test(rw_data_gcn_model, data)
print()

Here are the values for the GCN model with the re-weighed dataset and weighted-cross-entropy loss: 

Accuracy: 0.8468495448011499
Statistical Parity Difference: 0.030316658318042755
Equal Opportunity Difference: 0.10094630718231201
Overall Accuracy Equality Difference: 0.08911752700805664
Treatment Equality Difference: 0.06582164764404297


### Fourth model: GCN, re-weighed data, weighted- and fairness-aware cross entropy loss

In [40]:
# Instantiate the fourth model, define loss function and optimizer
rw_data_gcn_model2 = GCN(rw_data)
rw_data_gcn_model_optimizer2 = torch.optim.Adam(rw_data_gcn_model2.parameters(), lr=0.01)

In [28]:
# Train the fourth model: GCN, re-weighed data, weighted- and fairness-aware cross entropy loss
training(model=rw_data_gcn_model2, data=rw_data, optimizer=rw_data_gcn_model_optimizer2, epochs=50, weighted=True, fairness=True)

Epoch 0 | Loss: 0.5843248963356018
Epoch 10 | Loss: 0.5047746896743774
Epoch 20 | Loss: 0.44967859983444214
Epoch 30 | Loss: 0.4109334349632263
Epoch 40 | Loss: 0.3880102038383484


In [29]:
# Test the fourth model: GCN, re-weighed data, weighted- and fairness-aware cross entropy loss
print("Here are the values for the GCN model with the re-weighed dataset and weighted- and fairness-aware cross-entropy loss: ")
print()
test(rw_data_gcn_model2, data)
print()

Here are the values for the GCN model with the re-weighed dataset and weighted- and fairness-aware cross-entropy loss: 

Accuracy: 0.8555342597029229
Statistical Parity Difference: 0.034720584750175476
Equal Opportunity Difference: 0.1112838089466095
Overall Accuracy Equality Difference: 0.08701056241989136
Treatment Equality Difference: 0.6815862655639648



In [31]:
# Train the fourth model: GCN, re-weighed data, weighted- and fairness-aware cross entropy loss
training(model=rw_data_gcn_model2, data=rw_data, optimizer=rw_data_gcn_model_optimizer2, epochs=50, weighted=True,
         fairness=True, alpha=0.01, beta=0.005, gamma=0.015, delta=0.012)

Epoch 0 | Loss: 0.745595395565033
Epoch 10 | Loss: 0.5941913723945618
Epoch 20 | Loss: 0.5534152984619141
Epoch 30 | Loss: 0.5050016641616821
Epoch 40 | Loss: 0.4620010554790497


In [32]:
# Test the fourth model: GCN, re-weighed data, weighted- and fairness-aware cross entropy loss
print("Here are the values for the GCN model with the re-weighed dataset and weighted- and fairness-aware cross-entropy loss: ")
print()
test(rw_data_gcn_model2, data)
print()

Here are the values for the GCN model with the re-weighed dataset and weighted- and fairness-aware cross-entropy loss: 

Accuracy: 0.8442441303306181
Statistical Parity Difference: 0.02938065677881241
Equal Opportunity Difference: 0.09115374088287354
Overall Accuracy Equality Difference: 0.08772581815719604
Treatment Equality Difference: 1.3014488220214844



In [34]:
# Train the fourth model: GCN, re-weighed data, weighted- and fairness-aware cross entropy loss
training(model=rw_data_gcn_model2, data=rw_data, optimizer=rw_data_gcn_model_optimizer2, epochs=50, weighted=True,
         fairness=True, alpha=0.01, beta=0.012, gamma=0.015, delta=0.015)

Epoch 0 | Loss: 0.7213656306266785
Epoch 10 | Loss: 0.5258520841598511
Epoch 20 | Loss: 0.4605921804904938
Epoch 30 | Loss: 0.424991250038147
Epoch 40 | Loss: 0.40024101734161377


In [35]:
# Test the fourth model: GCN, re-weighed data, weighted- and fairness-aware cross entropy loss
print("Here are the values for the GCN model with the re-weighed dataset and weighted- and fairness-aware cross-entropy loss: ")
print()
test(rw_data_gcn_model2, data)
print()

Here are the values for the GCN model with the re-weighed dataset and weighted- and fairness-aware cross-entropy loss: 

Accuracy: 0.8580797795879253
Statistical Parity Difference: 0.042138680815696716
Equal Opportunity Difference: 0.0997222363948822
Overall Accuracy Equality Difference: 0.0845213532447815
Treatment Equality Difference: 0.3458895683288574



In [None]:
# Train the fourth model: GCN, re-weighed data, weighted- and fairness-aware cross entropy loss
training(model=rw_data_gcn_model2, data=rw_data, optimizer=rw_data_gcn_model_optimizer2, epochs=50, weighted=True,
         fairness=True, alpha=0.01, beta=0.012, gamma=0.02, delta=0.015)

In [39]:
# Test the fourth model: GCN, re-weighed data, weighted- and fairness-aware cross entropy loss
print("Here are the values for the GCN model with the re-weighed dataset and weighted- and fairness-aware cross-entropy loss: ")
print()
test(rw_data_gcn_model2, data)
print()

Here are the values for the GCN model with the re-weighed dataset and weighted- and fairness-aware cross-entropy loss: 

Accuracy: 0.8528689506468615
Statistical Parity Difference: 0.04595881700515747
Equal Opportunity Difference: 0.0724097192287445
Overall Accuracy Equality Difference: 0.08293139934539795
Treatment Equality Difference: 1.1784725189208984



In [41]:
# Train the fourth model: GCN, re-weighed data, weighted- and fairness-aware cross entropy loss
training(model=rw_data_gcn_model2, data=rw_data, optimizer=rw_data_gcn_model_optimizer2, epochs=50, weighted=True,
         fairness=True, alpha=0.01, beta=0.012, gamma=0.018, delta=0.018)

Epoch 0 | Loss: 0.6623578667640686
Epoch 10 | Loss: 0.5442702770233154
Epoch 20 | Loss: 0.47677215933799744
Epoch 30 | Loss: 0.42851555347442627
Epoch 40 | Loss: 0.39934295415878296


In [42]:
# Test the fourth model: GCN, re-weighed data, weighted- and fairness-aware cross entropy loss
print("Here are the values for the GCN model with the re-weighed dataset and weighted- and fairness-aware cross-entropy loss: ")
print()
test(rw_data_gcn_model2, data)
print()

Here are the values for the GCN model with the re-weighed dataset and weighted- and fairness-aware cross-entropy loss: 

Accuracy: 0.8569717297556301
Statistical Parity Difference: 0.03602193295955658
Equal Opportunity Difference: 0.11044105887413025
Overall Accuracy Equality Difference: 0.08629560470581055
Treatment Equality Difference: 0.6496486663818359



---
---
---
---
---
# Extras


### DI Remover

In [None]:
# Convert your data to a format suitable for AIF360
dataset = BinaryLabelDataset(df=user_labels, label_names=['gender'], protected_attribute_names=['bin_age'])

# Apply the Disparate Impact Remover
DIR = DisparateImpactRemover(repair_level=1.0)
dataset_transf = DIR.fit_transform(dataset)

# Extract the transformed features
transformed_features = dataset_transf.features

In [None]:
rw_data.x = torch.tensor(transformed_features, dtype=torch.float)

### DI GCN

In [None]:
class GCN2(torch.nn.Module):
    def __init__(self):
        super(GCN2, self).__init__()
        self.conv1 = GCNConv(rw_data.num_node_features, 16)
        self.conv2 = GCNConv(16, 2)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

In [None]:
# Instantiate the model, define loss function and optimizer
gcn_model = GCN(rw_data)
gcn_optimizer = torch.optim.Adam(gcn_model.parameters(), lr=0.01)

In [None]:
training(model=gcn_model, data=rw_data, optimizer=gcn_optimizer, epochs=100)

Epoch 0 | Loss: 2247.96337890625
Epoch 10 | Loss: 1941.6514892578125
Epoch 20 | Loss: 1676.3948974609375


KeyboardInterrupt: 

In [None]:
test(gcn_model, rw_data)

### Loss

In [None]:
def fairness_aware_loss(output, data, sensitive_attr, weighted=False, alpha=0.01):
    """
    A custom loss function to calculate a fairness-aware loss.
    The fairness-factor measures the disparity in predictions between +ve and -ve sensitive attribute group.

    Args:
    output: Outputs from the model.
    data: The torch-geometric data object used for the model.
    sensitive_attr: The sensitive attribute in the data (in our case: bin_age)
    weighted: Boolean value indicating re-weighing done to the data or not.
    alpha: Parameter to control the strength of the fairness regularizer.

    Returns:
    A fairness-aware combined loss.
    """
    if weighted:
        # Call the weighted-cross entropy loss
        standard_loss = weighted_cross_entropy(output, data)
    else:
        # Call standard cross-entropy loss
        target = data.y[data.train_mask]
        standard_loss = F.cross_entropy(output, target)

    pos_prob = torch.sigmoid(output[:, 1])

    fairness_reg = torch.abs(pos_prob[sensitive_attr == 1].mean() - pos_prob[sensitive_attr == 0].mean())
    combined_loss = standard_loss + alpha * fairness_reg

    return combined_loss