In [2]:
import os
import os.path as osp

from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.utils import negative_sampling
from torch_geometric.nn import GCNConv
import numpy as np
import torch
from torch.nn import Sequential, Linear, ReLU
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, accuracy_score

from utils import (
    get_link_labels,
    prediction_fairness,
)

from torch_geometric.utils import train_test_split_edges

device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
class GCN(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, 128)
        self.conv2 = GCNConv(128, out_channels)

    def encode(self, x, pos_edge_index):
        x = F.relu(self.conv1(x, pos_edge_index))
        x = self.conv2(x, pos_edge_index)
        return x

    def decode(self, z, pos_edge_index, neg_edge_index):
        edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
        logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)
        return logits, edge_index

In [7]:
dataset = "citeseer" #"cora" "pubmed"
#dataset = "pubmed"
path = osp.join(osp.dirname(osp.realpath('__file__')), "..", "data", dataset)
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())

In [8]:
test_seeds = [0,1,2,3,4,5]
acc_auc = []
fairness = []

In [9]:
delta = 0.16

for random_seed in test_seeds:

    np.random.seed(random_seed)
    data = dataset[0]
    protected_attribute = data.y
    data.train_mask = data.val_mask = data.test_mask = data.y = None
    data = train_test_split_edges(data, val_ratio=0.1, test_ratio=0.2)
    data = data.to(device)

    num_classes = len(np.unique(protected_attribute))
    N = data.num_nodes
    
    
    epochs = 101
    model = GCN(data.num_features, 128).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
    

    Y = torch.LongTensor(protected_attribute).to(device)
    Y_aux = (
        Y[data.train_pos_edge_index[0, :]] != Y[data.train_pos_edge_index[1, :]]
    ).to(device)
    randomization = (
        torch.FloatTensor(epochs, Y_aux.size(0)).uniform_() < 0.5 + delta
    ).to(device)
    
    
    best_val_perf = test_perf = 0
    for epoch in range(1, epochs):
        # TRAINING    
        neg_edges_tr = negative_sampling(
            edge_index=data.train_pos_edge_index,
            num_nodes=N,
            num_neg_samples=data.train_pos_edge_index.size(1) // 2,
        ).to(device)

        if epoch == 1 or epoch % 10 == 0:
            keep = torch.where(randomization[epoch], Y_aux, ~Y_aux)

        model.train()
        optimizer.zero_grad()

        z = model.encode(data.x, data.train_pos_edge_index[:, keep])
        link_logits, _ = model.decode(
            z, data.train_pos_edge_index[:, keep], neg_edges_tr
        )
        tr_labels = get_link_labels(
            data.train_pos_edge_index[:, keep], neg_edges_tr
        ).to(device)
        
        loss = F.binary_cross_entropy_with_logits(link_logits, tr_labels)
        loss.backward()
        optimizer.step()

        # EVALUATION
        model.eval()
        perfs = []
        for prefix in ["val", "test"]:
            pos_edge_index = data[f"{prefix}_pos_edge_index"]
            neg_edge_index = data[f"{prefix}_neg_edge_index"]
            with torch.no_grad():
                z = model.encode(data.x, data.train_pos_edge_index)
                link_logits, edge_idx = model.decode(z, pos_edge_index, neg_edge_index)
            link_probs = link_logits.sigmoid()
            link_labels = get_link_labels(pos_edge_index, neg_edge_index)
            auc = roc_auc_score(link_labels.cpu(), link_probs.cpu())
            perfs.append(auc)

        val_perf, tmp_test_perf = perfs
        if val_perf > best_val_perf:
            best_val_perf = val_perf
            test_perf = tmp_test_perf
        if epoch%10==0:
            log = "Epoch: {:03d}, Loss: {:.4f}, Val: {:.4f}, Test: {:.4f}"
            print(log.format(epoch, loss, best_val_perf, test_perf))

    # FAIRNESS
    auc = test_perf
    cut = [0.45, 0.50, 0.55, 0.60, 0.65, 0.70, 0.75]
    best_acc = 0
    best_cut = 0.5
    for i in cut:
        acc = accuracy_score(link_labels.cpu(), link_probs.cpu() >= i)
        if acc > best_acc:
            best_acc = acc
            best_cut = i
    f = prediction_fairness(
        edge_idx.cpu(), link_labels.cpu(), link_probs.cpu() >= best_cut, Y.cpu()
    )
    acc_auc.append([best_acc * 100, auc * 100])
    fairness.append([x * 100 for x in f])



Epoch: 010, Loss: 0.6321, Val: 0.8013, Test: 0.8013
Epoch: 020, Loss: 0.5780, Val: 0.8013, Test: 0.8013
Epoch: 030, Loss: 0.5361, Val: 0.8313, Test: 0.8288
Epoch: 040, Loss: 0.5245, Val: 0.8635, Test: 0.8454
Epoch: 050, Loss: 0.5166, Val: 0.8635, Test: 0.8454
Epoch: 060, Loss: 0.5044, Val: 0.8692, Test: 0.8665
Epoch: 070, Loss: 0.4885, Val: 0.8722, Test: 0.8743
Epoch: 080, Loss: 0.4775, Val: 0.8830, Test: 0.8823
Epoch: 090, Loss: 0.4743, Val: 0.8961, Test: 0.8892
Epoch: 100, Loss: 0.4770, Val: 0.8977, Test: 0.8894




Epoch: 010, Loss: 0.6360, Val: 0.8405, Test: 0.8435
Epoch: 020, Loss: 0.5798, Val: 0.8405, Test: 0.8435
Epoch: 030, Loss: 0.5487, Val: 0.8405, Test: 0.8435
Epoch: 040, Loss: 0.5216, Val: 0.8767, Test: 0.8415
Epoch: 050, Loss: 0.5001, Val: 0.8920, Test: 0.8669
Epoch: 060, Loss: 0.4961, Val: 0.8938, Test: 0.8814
Epoch: 070, Loss: 0.4874, Val: 0.9042, Test: 0.8888
Epoch: 080, Loss: 0.4693, Val: 0.9071, Test: 0.8890
Epoch: 090, Loss: 0.4722, Val: 0.9071, Test: 0.8890
Epoch: 100, Loss: 0.4762, Val: 0.9071, Test: 0.8890




Epoch: 010, Loss: 0.6303, Val: 0.7701, Test: 0.8078
Epoch: 020, Loss: 0.5771, Val: 0.7701, Test: 0.8078
Epoch: 030, Loss: 0.5364, Val: 0.8251, Test: 0.8534
Epoch: 040, Loss: 0.5270, Val: 0.8572, Test: 0.8720
Epoch: 050, Loss: 0.4997, Val: 0.8629, Test: 0.8790
Epoch: 060, Loss: 0.4947, Val: 0.8641, Test: 0.8806
Epoch: 070, Loss: 0.4953, Val: 0.8648, Test: 0.8825
Epoch: 080, Loss: 0.4920, Val: 0.8648, Test: 0.8825
Epoch: 090, Loss: 0.4816, Val: 0.8717, Test: 0.8878
Epoch: 100, Loss: 0.4821, Val: 0.8803, Test: 0.8959




Epoch: 010, Loss: 0.6283, Val: 0.7644, Test: 0.7867
Epoch: 020, Loss: 0.5703, Val: 0.7751, Test: 0.8004
Epoch: 030, Loss: 0.5423, Val: 0.7925, Test: 0.8303
Epoch: 040, Loss: 0.5162, Val: 0.8275, Test: 0.8581
Epoch: 050, Loss: 0.5157, Val: 0.8281, Test: 0.8573
Epoch: 060, Loss: 0.4946, Val: 0.8417, Test: 0.8755
Epoch: 070, Loss: 0.4928, Val: 0.8537, Test: 0.8862
Epoch: 080, Loss: 0.4839, Val: 0.8556, Test: 0.8881
Epoch: 090, Loss: 0.4876, Val: 0.8563, Test: 0.8865
Epoch: 100, Loss: 0.4734, Val: 0.8563, Test: 0.8865




Epoch: 010, Loss: 0.6353, Val: 0.8114, Test: 0.8214
Epoch: 020, Loss: 0.5762, Val: 0.8114, Test: 0.8214
Epoch: 030, Loss: 0.5319, Val: 0.8222, Test: 0.8503
Epoch: 040, Loss: 0.5264, Val: 0.8242, Test: 0.8551
Epoch: 050, Loss: 0.5078, Val: 0.8305, Test: 0.8532
Epoch: 060, Loss: 0.5015, Val: 0.8321, Test: 0.8493
Epoch: 070, Loss: 0.4934, Val: 0.8451, Test: 0.8627
Epoch: 080, Loss: 0.4834, Val: 0.8683, Test: 0.8863
Epoch: 090, Loss: 0.4762, Val: 0.8756, Test: 0.8960
Epoch: 100, Loss: 0.4772, Val: 0.8771, Test: 0.8967




Epoch: 010, Loss: 0.6300, Val: 0.8302, Test: 0.8228
Epoch: 020, Loss: 0.5519, Val: 0.8302, Test: 0.8228
Epoch: 030, Loss: 0.5149, Val: 0.8564, Test: 0.8443
Epoch: 040, Loss: 0.5169, Val: 0.8601, Test: 0.8435
Epoch: 050, Loss: 0.5063, Val: 0.8601, Test: 0.8435
Epoch: 060, Loss: 0.4945, Val: 0.8694, Test: 0.8419
Epoch: 070, Loss: 0.5070, Val: 0.8858, Test: 0.8538
Epoch: 080, Loss: 0.4836, Val: 0.8981, Test: 0.8615
Epoch: 090, Loss: 0.4858, Val: 0.8990, Test: 0.8614
Epoch: 100, Loss: 0.4833, Val: 0.8998, Test: 0.8694


In [10]:
ma = np.mean(np.asarray(acc_auc), axis=0)
mf = np.mean(np.asarray(fairness), axis=0)

sa = np.std(np.asarray(acc_auc), axis=0)
sf = np.std(np.asarray(fairness), axis=0)

print(f"ACC: {ma[0]:2f} +- {sa[0]:2f}")
print(f"AUC: {ma[1]:2f} +- {sa[1]:2f}")

print(f"DP mix: {mf[0]:2f} +- {sf[0]:2f}")
print(f"EoP mix: {mf[1]:2f} +- {sf[1]:2f}")
print(f"DP group: {mf[2]:2f} +- {sf[2]:2f}")
print(f"EoP group: {mf[3]:2f} +- {sf[3]:2f}")
print(f"DP sub: {mf[4]:2f} +- {sf[4]:2f}")
print(f"EoP sub: {mf[5]:2f} +- {sf[5]:2f}")

ACC: 79.358974 +- 1.037025
AUC: 88.782202 +- 0.904982
DP mix: 45.829692 +- 2.456295
EoP mix: 28.844359 +- 6.723988
DP group: 22.011666 +- 2.428847
EoP group: 21.349084 +- 3.382816
DP sub: 70.934783 +- 2.438119
EoP sub: 69.881827 +- 16.049702
