In [1]:
import os
import os.path as osp

from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.utils import negative_sampling
from torch_geometric.nn import GCNConv
import numpy as np
import torch
from torch.nn import Sequential, Linear, ReLU
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, accuracy_score

from utils import (
    get_link_labels,
    prediction_fairness,
)

from torch_geometric.utils import train_test_split_edges

device = "cuda" if torch.cuda.is_available() else "cpu"

  from .autonotebook import tqdm as notebook_tqdm


In [49]:
from torch_geometric import seed_everything

seed_everything(188)

In [50]:
from torch_geometric.utils import to_undirected, to_networkx, k_hop_subgraph, is_undirected
from torch_geometric.data import Data
from torch_geometric.loader import GraphSAINTRandomWalkSampler
from torch_geometric.seed import seed_everything

In [51]:
from model.gcn import GCN

In [73]:
import numpy as np
import torch
from sklearn.multioutput import MultiOutputClassifier
from torch_sparse import SparseTensor
from sklearn.metrics import (
    roc_auc_score,
    make_scorer,
    balanced_accuracy_score,
)
from sklearn.neural_network import MLPClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn import model_selection, pipeline, metrics

# Metrics
from fairlearn.metrics import (
    demographic_parity_difference,
    equalized_odds_difference,
)
from itertools import combinations_with_replacement

In [74]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [75]:
from training_args import parse_args
args=parse_args()

In [76]:
dataset = "citeseer" #"cora" "pubmed"
path = osp.join(osp.dirname(osp.realpath('__file__')), "..", "data", dataset)
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())

In [77]:
pwd

'/home/jupyter/FairDrop'

In [78]:
path

'/home/jupyter/FairDrop/../data/citeseer'

In [79]:
test_seeds = [0,1,2,3,4,5]
acc_auc = []
fairness = []

In [80]:
data = dataset[0]

In [81]:
data

Data(x=[3327, 3703], edge_index=[2, 9104], y=[3327], train_mask=[3327], val_mask=[3327], test_mask=[3327])

In [82]:
data

Data(x=[3327, 3703], edge_index=[2, 9104], y=[3327], train_mask=[3327], val_mask=[3327], test_mask=[3327])

In [83]:
args.in_dim=data.num_features

In [84]:
delta = 0.16

In [85]:
protected_attribute = data.y
Y = torch.LongTensor(protected_attribute).to(device)
data.train_mask = data.val_mask = data.test_mask = data.y = None
data = train_test_split_edges(data, val_ratio=0.1, test_ratio=0.2)
data = data.to(device)
num_classes = len(np.unique(protected_attribute))
N = data.num_nodes

In [86]:
data.dtrain_mask = torch.ones(data.train_pos_edge_index.shape[1], dtype=torch.bool)

In [87]:
train_pos_edge_index = to_undirected(data.train_pos_edge_index)
data.train_pos_edge_index = train_pos_edge_index
data.dtrain_mask = torch.ones(data.train_pos_edge_index.shape[1], dtype=torch.bool)
assert is_undirected(data.train_pos_edge_index)

In [88]:
model=GCN(args)
model=model.to(device)

In [89]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
best_val_perf = test_perf = 0

In [90]:

for epoch in range(1, 1000):
    # TRAINING    
    #print("epoch",epoch)
    neg_edges_tr = negative_sampling(
        edge_index=data.train_pos_edge_index,
        num_nodes=N,
        num_neg_samples=data.train_pos_edge_index.size(1) // 2,
    ).to(device)

    model.train()
    optimizer.zero_grad()

    z = model(data.x, data.train_pos_edge_index)
    link_logits= model.decode(
        z, data.train_pos_edge_index, neg_edges_tr
    )
    tr_labels = get_link_labels(
        data.train_pos_edge_index, neg_edges_tr
    ).to(device)

    loss = F.binary_cross_entropy_with_logits(link_logits, tr_labels)
    loss.backward()
    optimizer.step()

    # EVALUATION
    model.eval()
    perfs = []
    for prefix in ["val", "test"]:
        pos_edge_index = data[f"{prefix}_pos_edge_index"]
        neg_edge_index = data[f"{prefix}_neg_edge_index"]
        with torch.no_grad():
            z = model(data.x, data.train_pos_edge_index)
            link_logits = model.decode(z, pos_edge_index, neg_edge_index)
        link_probs = link_logits.sigmoid()
        link_labels = get_link_labels(pos_edge_index, neg_edge_index)
        auc = roc_auc_score(link_labels.cpu(), link_probs.cpu())
        perfs.append(auc)

    val_perf, tmp_test_perf = perfs
    if val_perf > best_val_perf:
        best_val_perf = val_perf
        test_perf = tmp_test_perf
    if epoch%100==0:
        log = "Epoch: {:03d}, Loss: {:.4f}, Val: {:.4f}, Test: {:.4f}"
        print(log.format(epoch, loss, best_val_perf, test_perf))

Epoch: 100, Loss: 0.3978, Val: 0.7563, Test: 0.7252
Epoch: 200, Loss: 0.3216, Val: 0.8392, Test: 0.8387
Epoch: 300, Loss: 0.3022, Val: 0.8459, Test: 0.8523
Epoch: 400, Loss: 0.2977, Val: 0.8543, Test: 0.8553
Epoch: 500, Loss: 0.2912, Val: 0.8552, Test: 0.8577
Epoch: 600, Loss: 0.2839, Val: 0.8612, Test: 0.8590
Epoch: 700, Loss: 0.2742, Val: 0.8668, Test: 0.8593


KeyboardInterrupt: 

In [None]:
ckpt = {
            'model_state': model.state_dict(),
            'optimizer_state': optimizer.state_dict(),
        }


In [None]:
torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_final.pt'))

In [None]:
epoch

In [None]:
data

In [None]:
prefix="test"
pos_edge_index = data[f"{prefix}_pos_edge_index"]
neg_edge_index = data[f"{prefix}_neg_edge_index"]
with torch.no_grad():
    z = model(data.x, data.train_pos_edge_index)
    link_logits= model.decode(z, pos_edge_index, neg_edge_index)
link_probs = link_logits.sigmoid()
link_labels = get_link_labels(pos_edge_index, neg_edge_index)

edge_idx=torch.cat([pos_edge_index, neg_edge_index], dim=-1)

auc = test_perf
cut = [0.45, 0.50, 0.55, 0.60, 0.65, 0.70, 0.75]
best_acc = 0
best_cut = 0.5
for i in cut:
    acc = accuracy_score(link_labels.cpu(), link_probs.cpu() >= i)
    if acc > best_acc:
        best_acc = acc
        best_cut = i
f = prediction_fairness(
    edge_idx.cpu(), link_labels.cpu(), link_probs.cpu() >= best_cut, Y.cpu()
)
acc_auc.append([best_acc * 100, auc * 100])
fairness.append([x * 100 for x in f])

In [None]:
acc_auc

In [None]:
fairness

In [None]:
np.unique(Y.cpu())

In [None]:
def fair_metrics(gt, y, group):
    metrics_dict = {
        "DPd": demographic_parity_difference(gt, y, sensitive_features=group),
        "EOd": equalized_odds_difference(gt, y, sensitive_features=group),
    }
    return metrics_dict


In [30]:
group=Y.cpu()
te_y=link_probs.cpu() >= best_cut
test_edge_idx=edge_idx.cpu()
test_edge_labels=link_labels.cpu()

In [31]:
te_dyadic_src = group[test_edge_idx[0]]
te_dyadic_dst = group[test_edge_idx[1]]

# SUBGROUP DYADIC
u = list(combinations_with_replacement(np.unique(group), r=2))

In [32]:
u

[(0, 0),
 (0, 1),
 (0, 2),
 (0, 3),
 (0, 4),
 (0, 5),
 (1, 1),
 (1, 2),
 (1, 3),
 (1, 4),
 (1, 5),
 (2, 2),
 (2, 3),
 (2, 4),
 (2, 5),
 (3, 3),
 (3, 4),
 (3, 5),
 (4, 4),
 (4, 5),
 (5, 5)]

In [33]:
te_sub_diatic = []
for i, j in zip(te_dyadic_src, te_dyadic_dst):
    for k, v in enumerate(u):
        if (i, j) == v or (j, i) == v:
            te_sub_diatic.append(k)
            break
te_sub_diatic = np.asarray(te_sub_diatic)

In [36]:
te_mixed_dyadic = te_dyadic_src != te_dyadic_dst

fairlearn.metrics.demographic_parity_difference(y_true, y_pred, *, sensitive_features, method='between_groups', sample_weight=None)
Calculate the demographic parity difference.

The demographic parity difference is defined as the difference between the largest and the smallest group-level selection rate, 
, across all values 
 of the sensitive feature(s). The demographic parity difference of 0 means that all groups have the same selection rate.

Read more in the User Guide.

Parameters
:
y_true (array-like) – Ground truth (correct) labels.

y_pred (array-like) – Predicted labels 
 returned by the classifier.

sensitive_features – The sensitive features over which demographic parity should be assessed

method (str) – How to compute the differences. See fairlearn.metrics.MetricFrame.difference() for details.

sample_weight (array-like) – The sample weights

Returns
:
The demographic parity difference


In [None]:
group=Y.cpu()
te_y=link_probs.cpu() >= best_cut
test_edge_idx=edge_idx.cpu()
test_edge_labels=link_labels.cpu()

In [46]:
te_dyadic_src = group[test_edge_idx[0]]
te_dyadic_dst = group[test_edge_idx[1]]

# SUBGROUP DYADIC
u = list(combinations_with_replacement(np.unique(group), r=2))

te_sub_diatic = []
for i, j in zip(te_dyadic_src, te_dyadic_dst):
    for k, v in enumerate(u):
        if (i, j) == v or (j, i) == v:
            te_sub_diatic.append(k)
            break
te_sub_diatic = np.asarray(te_sub_diatic)
# MIXED DYADIC 

te_mixed_dyadic = te_dyadic_src != te_dyadic_dst
# GROUP DYADIC
te_gd_dict = fair_metrics(
    np.concatenate([test_edge_labels, test_edge_labels], axis=0),
    np.concatenate([te_y, te_y], axis=0),
    np.concatenate([te_dyadic_src, te_dyadic_dst], axis=0),
)

te_md_dict = fair_metrics(test_edge_labels, te_y, te_mixed_dyadic)

te_sd_dict = fair_metrics(test_edge_labels, te_y, te_sub_diatic)

fair_list = [
    te_md_dict["DPd"],
    te_md_dict["EOd"],
    te_gd_dict["DPd"],
    te_gd_dict["EOd"],
    te_sd_dict["DPd"],
    te_sd_dict["EOd"],
]

In [47]:
fair_list

[0.42743364127490735,
 0.30759522199684475,
 0.18659453511828356,
 0.20884312910722247,
 0.6431478191093025,
 0.6010928961748634]

In [14]:
epochs = 101
model = GCN(data.num_features, 128).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)


Y = torch.LongTensor(protected_attribute).to(device)
Y_diff = (
    Y[data.train_pos_edge_index[0, :]] != Y[data.train_pos_edge_index[1, :]]
).to(device)

Y_same = (
    Y[data.train_pos_edge_index[0, :]] == Y[data.train_pos_edge_index[1, :]]
).to(device)


In [15]:
Y

tensor([3, 1, 5,  ..., 3, 1, 5], device='cuda:0')

In [16]:
torch.sum(Y_diff)

tensor(1706, device='cuda:0')

In [17]:
torch.sum(Y_same)

tensor(4668, device='cuda:0')

In [18]:
diff=Y_diff.nonzero().squeeze()

In [19]:
same=Y_same.nonzero().squeeze()
same

tensor([   0,    1,    2,  ..., 6369, 6370, 6373], device='cuda:0')

In [20]:
edge_to_delete=100
ratio=3
diff_size=int(edge_to_delete*(ratio)/(ratio+1))
same_size=int(edge_to_delete*(1)/(ratio+1))
idx_diff = torch.randperm(diff.shape[0])[:diff_size]
df_diff_idx = diff[idx_diff]
idx_same = torch.randperm(same.shape[0])[:same_size]
df_same_idx = same[idx_same]
df_global_idx=torch.cat((df_diff_idx,df_same_idx),0)

In [21]:
df_same_idx

tensor([ 335, 1389, 1387, 2277, 2125, 2856, 2965, 3766, 1033, 2632,  620, 4441,
         861, 1951, 3376, 5056, 1842,  375, 5372, 5927,  922, 4951, 2834, 1451,
        4695], device='cuda:0')

In [22]:
dr_mask = torch.ones(data.train_pos_edge_index.shape[1], dtype=torch.bool)
dr_mask[df_global_idx] = False
dr_mask=dr_mask.to(device)

df_mask = torch.zeros(data.train_pos_edge_index.shape[1], dtype=torch.bool)
df_mask[df_global_idx] = True
df_mask=df_mask.to(device)

In [23]:
# Edges in S_Df
_, two_hop_edge, _, two_hop_mask = k_hop_subgraph(
        data.train_pos_edge_index[:, df_mask].flatten().unique(), 
        2, 
        data.train_pos_edge_index,
        num_nodes=data.num_nodes)
data.sdf_mask = two_hop_mask

In [24]:
# Nodes in S_Df
_, one_hop_edge, _, one_hop_mask = k_hop_subgraph(
    data.train_pos_edge_index[:, df_mask].flatten().unique(), 
    1, 
    data.train_pos_edge_index,
    num_nodes=data.num_nodes)

In [25]:
sdf_node_1hop = torch.zeros(data.num_nodes, dtype=torch.bool)
sdf_node_2hop = torch.zeros(data.num_nodes, dtype=torch.bool)

In [26]:
sdf_node_1hop[one_hop_edge.flatten().unique()] = True
sdf_node_2hop[two_hop_edge.flatten().unique()] = True
assert sdf_node_1hop.sum() == len(one_hop_edge.flatten().unique())
assert sdf_node_2hop.sum() == len(two_hop_edge.flatten().unique())

In [27]:
data.sdf_node_1hop_mask = sdf_node_1hop
data.sdf_node_2hop_mask = sdf_node_2hop

In [28]:
data.train_pos_edge_index

tensor([[   0,    1,    1,  ..., 3324, 3325, 3326],
        [ 628,  158,  486,  ..., 2820, 1643,   33]], device='cuda:0')

In [29]:
df_mask

tensor([False, False, False,  ..., False, False, False], device='cuda:0')

In [30]:
two_hop_mask.int()

tensor([0, 0, 1,  ..., 1, 0, 0], device='cuda:0', dtype=torch.int32)

In [31]:
train_pos_edge_index, [df_mask, two_hop_mask] = to_undirected(data.train_pos_edge_index, [df_mask.int(), two_hop_mask.int()])
two_hop_mask = two_hop_mask.bool()
df_mask = df_mask.bool()
dr_mask = ~df_mask

data.train_pos_edge_index = train_pos_edge_index
data.edge_index = train_pos_edge_index
assert is_undirected(data.train_pos_edge_index)

In [32]:
data.sdf_mask = two_hop_mask
data.df_mask = df_mask
data.dr_mask = dr_mask

In [8]:
delta = 0.16

for random_seed in test_seeds:

    np.random.seed(random_seed)
    data = dataset[0]
    protected_attribute = data.y
    data.train_mask = data.val_mask = data.test_mask = data.y = None
    data = train_test_split_edges(data, val_ratio=0.1, test_ratio=0.2)
    data = data.to(device)

    num_classes = len(np.unique(protected_attribute))
    N = data.num_nodes
    
    
    epochs = 101
    model = GCN(data.num_features, 128).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
    

    Y = torch.LongTensor(protected_attribute).to(device)
    Y_aux = (
        Y[data.train_pos_edge_index[0, :]] != Y[data.train_pos_edge_index[1, :]]
    ).to(device)
    randomization = (
        torch.FloatTensor(epochs, Y_aux.size(0)).uniform_() < 0.5 + delta
    ).to(device)
    
    
    best_val_perf = test_perf = 0
    for epoch in range(1, epochs):
        # TRAINING    
        neg_edges_tr = negative_sampling(
            edge_index=data.train_pos_edge_index,
            num_nodes=N,
            num_neg_samples=data.train_pos_edge_index.size(1) // 2,
        ).to(device)

        if epoch == 1 or epoch % 10 == 0:
            keep = torch.where(randomization[epoch], Y_aux, ~Y_aux)

        model.train()
        optimizer.zero_grad()

        z = model.encode(data.x, data.train_pos_edge_index[:, keep])
        link_logits, _ = model.decode(
            z, data.train_pos_edge_index[:, keep], neg_edges_tr
        )
        tr_labels = get_link_labels(
            data.train_pos_edge_index[:, keep], neg_edges_tr
        ).to(device)
        
        loss = F.binary_cross_entropy_with_logits(link_logits, tr_labels)
        loss.backward()
        optimizer.step()

        # EVALUATION
        model.eval()
        perfs = []
        for prefix in ["val", "test"]:
            pos_edge_index = data[f"{prefix}_pos_edge_index"]
            neg_edge_index = data[f"{prefix}_neg_edge_index"]
            with torch.no_grad():
                z = model.encode(data.x, data.train_pos_edge_index)
                link_logits, edge_idx = model.decode(z, pos_edge_index, neg_edge_index)
            link_probs = link_logits.sigmoid()
            link_labels = get_link_labels(pos_edge_index, neg_edge_index)
            auc = roc_auc_score(link_labels.cpu(), link_probs.cpu())
            perfs.append(auc)

        val_perf, tmp_test_perf = perfs
        if val_perf > best_val_perf:
            best_val_perf = val_perf
            test_perf = tmp_test_perf
        if epoch%10==0:
            log = "Epoch: {:03d}, Loss: {:.4f}, Val: {:.4f}, Test: {:.4f}"
            print(log.format(epoch, loss, best_val_perf, test_perf))

    # FAIRNESS
    auc = test_perf
    cut = [0.45, 0.50, 0.55, 0.60, 0.65, 0.70, 0.75]
    best_acc = 0
    best_cut = 0.5
    for i in cut:
        acc = accuracy_score(link_labels.cpu(), link_probs.cpu() >= i)
        if acc > best_acc:
            best_acc = acc
            best_cut = i
    f = prediction_fairness(
        edge_idx.cpu(), link_labels.cpu(), link_probs.cpu() >= best_cut, Y.cpu()
    )
    acc_auc.append([best_acc * 100, auc * 100])
    fairness.append([x * 100 for x in f])



Epoch: 010, Loss: 0.6316, Val: 0.8095, Test: 0.8118


KeyboardInterrupt: 

In [None]:
ma = np.mean(np.asarray(acc_auc), axis=0)
mf = np.mean(np.asarray(fairness), axis=0)

sa = np.std(np.asarray(acc_auc), axis=0)
sf = np.std(np.asarray(fairness), axis=0)

print(f"ACC: {ma[0]:2f} +- {sa[0]:2f}")
print(f"AUC: {ma[1]:2f} +- {sa[1]:2f}")

print(f"DP mix: {mf[0]:2f} +- {sf[0]:2f}")
print(f"EoP mix: {mf[1]:2f} +- {sf[1]:2f}")
print(f"DP group: {mf[2]:2f} +- {sf[2]:2f}")
print(f"EoP group: {mf[3]:2f} +- {sf[3]:2f}")
print(f"DP sub: {mf[4]:2f} +- {sf[4]:2f}")
print(f"EoP sub: {mf[5]:2f} +- {sf[5]:2f}")