In [1]:
from utils import load_credit, load_german, feature_norm
from torch_geometric.data import Data
from torch_geometric.utils import convert, to_networkx
import torch
import types

In [2]:
def load_dataset(dataset):
    # Load credit_scoring dataset
    if dataset == 'credit':
        sens_attr = "Age"  # column number after feature process is 1
        sens_idx = 1
        predict_attr = 'NoDefaultNextMonth'
        label_number = 30000
        path_credit = "./dataset/credit"
        adj, features, labels, idx_train, idx_val, idx_test, sens = load_credit(dataset, sens_attr,
                                                                                predict_attr, path=path_credit,
                                                                                label_number=label_number
                                                                                )
        norm_features = feature_norm(features)
        norm_features[:, sens_idx] = features[:, sens_idx]
        features = norm_features

    # Load german dataset
    elif dataset == 'german':
        sens_attr = "Gender"  # column number after feature process is 0
        sens_idx = 0
        predict_attr = "GoodCustomer"
        label_number = 1000
        path_german = "./dataset/german"
        adj, features, labels, idx_train, idx_val, idx_test, sens = load_german(dataset, sens_attr,
                                                                                predict_attr, path=path_german,
                                                                                label_number=label_number,
                                                                                )
    
    edge_index = convert.from_scipy_sparse_matrix(adj)[0]
    # don't include sensitive attributes
    features = torch.cat((features[:, :sens_idx], features[:, sens_idx+1:]), dim=1)
    sens = torch.tensor(sens).long()
    
    data = Data(
            x=features,
            edge_index=edge_index,
            y=labels.long(),
            edge_attr=torch.ones(edge_index.size(1)),
        )
    num_nodes = features.size(0)
    data.train_mask = idx_train
    data.val_mask = idx_val
    data.test_mask = idx_test

    dataset = types.SimpleNamespace()
    dataset.data = data
    dataset.num_classes = data.y.max().item() + 1

    return dataset, sens

In [3]:
from collections import Counter, defaultdict
from hashlib import blake2b

def _hash_label(label, digest_size):
    return blake2b(label.encode("ascii"), digest_size=digest_size).hexdigest()


def _init_node_labels(G, edge_attr, node_attr):
    if node_attr:
        return {u: str(dd[node_attr]) for u, dd in G.nodes(data=True)}
    elif edge_attr:
        return {u: "" for u in G}
    else:
        # use same hash for all nodes if no features
        return {u: "0" for u, deg in G.degree()}


def _neighborhood_aggregate(G, node, node_labels, edge_attr=None):
    """
    Compute new labels for given node by aggregating
    the labels of each node's neighbors.
    """
    label_list = []
    for nbr in G.neighbors(node):
        prefix = "" if edge_attr is None else str(G[node][nbr][edge_attr])
        label_list.append(prefix + node_labels[nbr])
    return node_labels[node] + "".join(sorted(label_list))


def weisfeiler_lehman_graph_hash(
    G, edge_attr=None, node_attr=None, iterations=3, digest_size=16
):
    def weisfeiler_lehman_step(G, labels, edge_attr=None):
        """
        Apply neighborhood aggregation to each node
        in the graph.
        Computes a dictionary with labels for each node.
        """
        new_labels = {}
        for node in G.nodes():
            label = _neighborhood_aggregate(G, node, labels, edge_attr=edge_attr)
            new_labels[node] = _hash_label(label, digest_size)
        return new_labels

    # set initial node labels
    node_labels = _init_node_labels(G, edge_attr, node_attr)

    subgraph_hash_counts = []
    for _ in range(iterations):
        node_labels = weisfeiler_lehman_step(G, node_labels, edge_attr=edge_attr)
        counter = Counter(node_labels.values())
        # sort the counter, extend total counts
        subgraph_hash_counts.extend(sorted(counter.items(), key=lambda x: x[0]))

    # hash the final counter
    return _hash_label(str(tuple(subgraph_hash_counts)), digest_size)

def weisfeiler_lehman_subgraph_hashes(
    G, edge_attr=None, node_attr=None, iterations=3, digest_size=16
):
    def weisfeiler_lehman_step(G, labels, node_subgraph_hashes, edge_attr=None):
        """
        Apply neighborhood aggregation to each node
        in the graph.
        Computes a dictionary with labels for each node.
        Appends the new hashed label to the dictionary of subgraph hashes
        originating from and indexed by each node in G
        """
        new_labels = {}
        for node in G.nodes():
            label = _neighborhood_aggregate(G, node, labels, edge_attr=edge_attr)
            hashed_label = _hash_label(label, digest_size)
            new_labels[node] = hashed_label
            node_subgraph_hashes[node].append(hashed_label)
        return new_labels

    node_labels = _init_node_labels(G, edge_attr, node_attr)

    node_subgraph_hashes = defaultdict(list)
    for _ in range(iterations):
        node_labels = weisfeiler_lehman_step(
            G, node_labels, node_subgraph_hashes, edge_attr
        )

    return dict(node_subgraph_hashes)

In [4]:
def partition_hashes(dataset):
    partitions = [] # Found partitions
    for G_idx, y_idx, _hash in dataset: # Loop over each element
        found = False # Note it is not yet part of a know partition
        for p in partitions:
            if _hash == p[0][2]: # Found a partition for it!
                p.append((G_idx, y_idx, _hash))
                found = True
                break
        if not found: # Make a new partition for it.
            partitions.append([(G_idx, y_idx, _hash)])
    return partitions

In [5]:
loaded_datasets = {}
for dataset_name in ["credit", "german"]:
    loaded_datasets[dataset_name] = load_dataset(dataset_name)

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  idx_features_labels['Gender'][idx_features_labels['Gender'] == 'Female'] = 1
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  idx_features_labels['Gender'][idx_features_labels['Gender'] == 'Male'] = 0
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  idx_features_labels['PurposeOfLoan'][idx_features_labels['PurposeOfLoan'] == val] = i


In [6]:
def eq_class_stats(dataset, name, partition):
    print('num', name, 'eq classes:', len(partition))
    print('num singleton', name, 'classes:', sum([int(len(c) == 1) for c in partition]))
    return [len(partition), sum([int(len(c) == 1) for c in partition])]

In [12]:
wl_stats = {}

for dataset_name in ["credit", "credit-0", "credit-1", "german"]:
    
    dn = dataset_name
    sens_type = None
    if "-" in dataset_name:
        dn, sens_type = dataset_name.split("-")
        sens_type = int(sens_type)
    
    dataset, sens = loaded_datasets[dn]
    
    print("======", dataset_name, "======")
    
    hashed_dataset = []
    k = 3
    G_idx = 0
    G_torch = dataset.data
    G_nx = to_networkx(G_torch, node_attrs=['x'])
    G_hashes = weisfeiler_lehman_subgraph_hashes(G_nx, node_attr='x', iterations=k)
    for node_id in G_hashes:
        if sens_type is not None and sens[node_id].item() != sens_type:
            continue
        hashed_dataset.append((G_idx, node_id, G_hashes[node_id][-1]))

    stats = [len(hashed_dataset)]
    wl_eq_classes = partition_hashes(hashed_dataset)
    stats.extend(eq_class_stats(dataset, "wl-{}".format(k), wl_eq_classes))

    cols = [r'\# nodes', r'$|{\cal E}_{\text{WL}}^' + str(k) + r'|$', r'$\#_1 ({\cal E}_{\text{WL}}^' + str(k) + r')$']
        
    wl_stats[dataset_name] = stats

num wl-3 eq classes: 29535
num singleton wl-3 classes: 29367
num wl-3 eq classes: 26874
num singleton wl-3 classes: 26720
num wl-3 eq classes: 2662
num singleton wl-3 classes: 2649
num wl-3 eq classes: 1000
num singleton wl-3 classes: 1000


In [13]:
import pandas as pd

pd_wl_stats = pd.DataFrame.from_dict(wl_stats, orient='index', columns=cols)
print(pd_wl_stats.to_latex(escape=False))

\begin{tabular}{lrrr}
\toprule
{} &  \# nodes &  $|{\cal E}_{\text{WL}}^3|$ &  $\#_1 ({\cal E}_{\text{WL}}^3)$ \\
\midrule
credit   &     30000 &                       29535 &                            29367 \\
credit-0 &     27315 &                       26874 &                            26720 \\
credit-1 &      2685 &                        2662 &                             2649 \\
german   &      1000 &                        1000 &                             1000 \\
\bottomrule
\end{tabular}

